package ru.yandex.crypta.common.ws.json.protobuf;

import java.io.IOException;
import java.io.StringReader;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonToken;
import com.fasterxml.jackson.core.TreeNode;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.PropertyNamingStrategy.PropertyNamingStrategyBase;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import com.fasterxml.jackson.databind.type.TypeFactory;
import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.EnumDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.MessageOrBuilder;
import com.google.protobuf.util.JsonFormat;

import ru.yandex.crypta.common.exception.Exceptions;
import ru.yandex.crypta.lib.proto.TWsIoConfig;

public class ProtobufDeserializer<T extends Message> extends StdDeserializer<MessageOrBuilder> {

    private final T defaultInstance;
    private final boolean build;
    private final TWsIoConfig config;
    private final Map<FieldDescriptor, JsonDeserializer<Object>> deserializerCache;
    private final JsonFormat.Parser jsonFormat;

    @SuppressWarnings("unchecked")
    public ProtobufDeserializer(Class<T> messageType, boolean build, TWsIoConfig config) {
        super(messageType);

        try {
            this.defaultInstance = (T) messageType.getMethod("getDefaultInstance").invoke(null);
        } catch (Exception e) {
            throw new RuntimeException("Unable to get default instance for type " + messageType, e);
        }

        this.build = build;
        this.config = config;
        this.deserializerCache = new ConcurrentHashMap<>();
        this.jsonFormat = JsonFormat.parser();
    }

    private static boolean ignorableEnum(String value, DeserializationContext context) {
        return (acceptEmptyStringAsNull(context) && value.length() == 0) || ignoreUnknownEnums(context);
    }

    private static boolean acceptEmptyStringAsNull(DeserializationContext context) {
        return context.isEnabled(DeserializationFeature.ACCEPT_EMPTY_STRING_AS_NULL_OBJECT);
    }

    private static boolean allowNumbersForEnums(DeserializationContext context) {
        return !context.isEnabled(DeserializationFeature.FAIL_ON_NUMBERS_FOR_ENUMS);
    }

    private static boolean ignoreUnknownEnums(DeserializationContext context) {
        return context.isEnabled(DeserializationFeature.READ_UNKNOWN_ENUM_VALUES_AS_NULL);
    }

    private static String indexRange(EnumDescriptor field) {
        List<Integer> indices = Lists.transform(field.getValues(), value -> {
            assert value != null;
            return value.getIndex();
        });

        // Guava returns non-modifiable list
        indices = Lists.newArrayList(indices);

        Collections.sort(indices);

        return "[" + Joiner.on(',').join(indices) + "]";
    }

    private static JsonMappingException mappingException(FieldDescriptor field, DeserializationContext context)
            throws IOException
    {
        JsonParser parser = context.getParser();
        JavaType javaType = context.constructType(field.getJavaType().getDeclaringClass());
        JsonToken jsonToken = parser.getCurrentToken();
        throw context.wrongTokenException(parser, javaType, jsonToken, "Wrong token");
    }

    @Override
    public MessageOrBuilder deserialize(JsonParser parser, DeserializationContext context) throws IOException {
        Message.Builder builder = defaultInstance.newBuilderForType();

        if (config.getEnableJsonFormatDeserialization()) {
            TreeNode treeNode = parser.readValueAsTree();
            try {
                jsonFormat.merge(new StringReader(treeNode.toString()), builder);
            } catch (InvalidProtocolBufferException ex) {
                throw Exceptions.wrongRequestException(ex.getMessage(), "BAD_FORMAT");
            }
        } else {
            populate(builder, parser, context);
        }

        if (build) {
            return builder.build();
        } else {
            return builder;
        }
    }

    private void populate(Message.Builder builder, JsonParser parser, DeserializationContext context)
            throws IOException
    {
        JsonToken token = parser.getCurrentToken();
        if (token == JsonToken.START_ARRAY) {
            token = parser.nextToken();
        }

        switch (token) {
            case END_OBJECT:
                return;
            case START_OBJECT:
                token = parser.nextToken();
                if (token == JsonToken.END_OBJECT) {
                    return;
                }
                break;
            default:
                break; // make findbugs happy
        }

        Descriptor descriptor = builder.getDescriptorForType();
        Map<String, FieldDescriptor> fieldLookup = buildFieldLookup(descriptor, context);

        do {
            if (!token.equals(JsonToken.FIELD_NAME)) {
                throw context.wrongTokenException(parser, (JavaType) null, JsonToken.FIELD_NAME, "Missed field name");
            }

            FieldDescriptor field = fieldLookup.get(parser.getCurrentName());
            if (field == null) {
                context.handleUnknownProperty(parser, this, builder, parser.getCurrentName());
                continue;
            }

            parser.nextToken();
            setField(builder, field, parser, context);
        } while ((token = parser.nextToken()) != JsonToken.END_OBJECT);
    }

    private Map<String, FieldDescriptor> buildFieldLookup(Descriptor descriptor, DeserializationContext context) {
        PropertyNamingStrategyBase namingStrategy =
                new PropertyNamingStrategyWrapper(context.getConfig().getPropertyNamingStrategy());

        Map<String, FieldDescriptor> fieldLookup = Maps.newHashMap();

        for (FieldDescriptor field : descriptor.getFields()) {
            fieldLookup.put(namingStrategy.translate(field.getName()), field);
        }

        return fieldLookup;
    }

    private void setField(Message.Builder builder, FieldDescriptor field, JsonParser parser,
            DeserializationContext context) throws IOException
    {
        Object value = readValue(builder, field, parser, context);

        if (value != null) {
            if (field.isRepeated()) {
                if (value instanceof Iterable) {
                    for (Object subValue : (Iterable<?>) value) {
                        builder.addRepeatedField(field, subValue);
                    }
                } else if (context.isEnabled(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY)) {
                    builder.addRepeatedField(field, value);
                } else {
                    throw mappingException(field, context);
                }
            } else {
                builder.setField(field, value);
            }
        }
    }

    private Object readValue(Message.Builder builder, FieldDescriptor field, JsonParser parser,
            DeserializationContext context) throws IOException
    {
        final Object value;

        if (parser.getCurrentToken() == JsonToken.START_ARRAY) {
            if (field.isRepeated()) {
                return readArray(builder, field, parser, context);
            } else {
                throw mappingException(field, context);
            }
        }

        switch (field.getJavaType()) {
            case INT:
                value = _parseIntPrimitive(parser, context);
                break;
            case LONG:
                value = _parseLongPrimitive(parser, context);
                break;
            case FLOAT:
                value = _parseFloatPrimitive(parser, context);
                break;
            case DOUBLE:
                value = _parseDoublePrimitive(parser, context);
                break;
            case BOOLEAN:
                value = _parseBooleanPrimitive(parser, context);
                break;
            case STRING:
                switch (parser.getCurrentToken()) {
                    case VALUE_STRING:
                        value = parser.getText();
                        break;
                    case VALUE_NULL:
                        value = null;
                        break;
                    default:
                        value = _parseString(parser, context);
                }
                break;
            case BYTE_STRING:
                switch (parser.getCurrentToken()) {
                    case VALUE_STRING:
                        value = ByteString.copyFrom(context.getBase64Variant().decode(parser.getText()));
                        break;
                    case VALUE_NULL:
                        value = null;
                        break;
                    default:
                        throw mappingException(field, context);
                }
                break;
            case ENUM:
                switch (parser.getCurrentToken()) {
                    case VALUE_STRING:
                        value = field.getEnumType().findValueByName(parser.getText());

                        if (value == null && !ignorableEnum(parser.getText().trim(), context)) {
                            throw context.weirdStringException(parser.getText(), field.getEnumType().getClass(),
                                    "value not one of declared Enum instance names");
                        }
                        break;
                    case VALUE_NUMBER_INT:
                        int intValue = parser.getIntValue();
                        if (allowNumbersForEnums(context)) {
                            value = field.getEnumType().findValueByNumber(intValue);

                            if (value == null && !ignoreUnknownEnums(context)) {
                                throw context.weirdNumberException(intValue, field.getEnumType().getClass(),
                                        "index value outside legal index range " + indexRange(field.getEnumType()));
                            }
                        } else {
                            throw context.weirdNumberException(intValue,
                                    field.getJavaType().getDeclaringClass(),
                                    "Not allowed to deserialize Enum value out of JSON number " +
                                            "(disable DeserializationFeature.FAIL_ON_NUMBERS_FOR_ENUMS to allow)");
                        }
                        break;
                    case VALUE_NULL:
                        value = null;
                        break;
                    default:
                        throw mappingException(field, context);
                }
                break;
            case MESSAGE:
                switch (parser.getCurrentToken()) {
                    case START_OBJECT:
                        JsonDeserializer<Object> deserializer = deserializerCache.get(field);
                        if (deserializer == null) {
                            Message.Builder subBuilder = builder.newBuilderForField(field);
                            Class<?> subType = subBuilder.getDefaultInstanceForType().getClass();

                            JavaType type = TypeFactory.defaultInstance().constructType(subType);
                            deserializer = context.findContextualValueDeserializer(type, null);
                            deserializerCache.put(field, deserializer);
                        }

                        value = deserializer.deserialize(parser, context);
                        break;
                    case VALUE_NULL:
                        value = null;
                        break;
                    default:
                        throw mappingException(field, context);
                }
                break;
            default:
                throw mappingException(field, context);
        }

        return value;
    }

    private List<Object> readArray(Message.Builder builder, FieldDescriptor field, JsonParser parser,
            DeserializationContext context) throws IOException
    {
        List<Object> values = Lists.newArrayList();
        while (parser.nextToken() != JsonToken.END_ARRAY) {
            Object value = readValue(builder, field, parser, context);

            if (value != null) {
                values.add(value);
            }
        }
        return values;
    }
}
