# -*- coding: utf-8 -*-

from collections import OrderedDict
from copy import deepcopy
from typing import Callable, Type

from google.protobuf.descriptor import FieldDescriptor
from google.protobuf.message import Message

from travel.cpa.data_processing.lib.protobuf_utils import camel_to_snake, protobuf_to_dict, proto_from_base64_str
from travel.library.python.schematized import Schematized
from travel.library.python.schematized.fields import Any, Field, Float, String, UInt64, Int64, Boolean
# noinspection PyUnresolvedReferences
import mapreduce.yt.interface.protos.extension_pb2 as extension_pb2
# noinspection PyUnresolvedReferences
import travel.hotels.proto2.label_pb2 as label_hotels_pb2
# noinspection PyUnresolvedReferences
import travel.proto.avia.cpa.label_pb2 as label_avia_pb2
# noinspection PyUnresolvedReferences
import travel.proto.cpa.generic_label_pb2 as label_generic_pb2
# noinspection PyUnresolvedReferences
import travel.proto.cpa.proto_options.options_pb2 as options_pb2
# noinspection PyUnresolvedReferences
import travel.proto.suburban.cpa_label_pb2 as label_suburban_pb2
# noinspection PyUnresolvedReferences
import travel.proto.trains.label_params_pb2 as label_trains_pb2
# noinspection PyUnresolvedReferences
import travel.proto.cpa.tours.tours_label_pb2 as label_tours_pb2
# noinspection PyUnresolvedReferences
import travel.proto.buses.label_params_pb2 as label_buses_pb2


TYPE_MAP = {
    FieldDescriptor.TYPE_STRING: String(default=''),
    FieldDescriptor.TYPE_ENUM: String(default=''),
    FieldDescriptor.TYPE_UINT32: UInt64(default=0),
    FieldDescriptor.TYPE_UINT64: UInt64(default=0),
    FieldDescriptor.TYPE_INT32: Int64(default=0),
    FieldDescriptor.TYPE_INT64: Int64(default=0),
    FieldDescriptor.TYPE_DOUBLE: Float(default=0.0),
    FieldDescriptor.TYPE_BOOL: Boolean(default=False),
}


class LabelConverter(object):

    def __init__(self, proto_cls: Type[Message], fields_cls_name: str):
        self.proto_cls = proto_cls
        self.modified_key_func = self._get_modified_key_func(proto_cls)
        label_fields = self._get_label_fields(proto_cls, self.modified_key_func)
        self.fields_cls = self._get_label_class(fields_cls_name, label_fields)

    def str_to_proto(self, proto_str: str) -> Message:
        return proto_from_base64_str(self.proto_cls, proto_str)

    def proto_to_dict(self, proto):
        label_dict = protobuf_to_dict(proto)
        return {self.modified_key_func(key): value for key, value in label_dict.items()}

    @staticmethod
    def _get_modified_key_func(label_proto_class: Type[Message]) -> Callable[[str], str]:
        fields = OrderedDict()
        for f in label_proto_class.DESCRIPTOR.fields:
            options = f.GetOptions()
            replacement = options.Extensions[extension_pb2.column_name]
            if replacement:
                fields[f.name] = replacement

        def modified_key_func(key: str) -> str:
            value = fields.get(key)
            if value is None:
                value = 'label_' + camel_to_snake(key)
            return value

        return modified_key_func

    @staticmethod
    def _get_label_fields(
        proto_class: Type[Message],
        modified_key_func: Callable[[str], str],
        respect_reserved: bool = False,
    ):
        fields = OrderedDict()
        for f in proto_class.DESCRIPTOR.fields:
            field_index = f.number - 1
            next_index = len(fields)
            if field_index > next_index and respect_reserved:
                for i in range(field_index - next_index):
                    fields['_reserved_{}'.format(next_index + i + 1)] = Any(default=0)
            field_name = f.name
            if f.label == f.LABEL_REPEATED:
                field_type = Any([])
            else:
                field_type = TYPE_MAP.get(f.type)
                if field_type is None:
                    raise TypeError('Unmatched type value for field {}: {}'.format(field_name, f.type))
                if f.type == FieldDescriptor.TYPE_ENUM and f.GetOptions().Extensions[options_pb2.UseEnumNameInYt]:
                    field_type = String(optional=True)
            field = deepcopy(field_type)
            field.hidden = f.GetOptions().Extensions[options_pb2.HiddenField]
            fields[modified_key_func(field_name)] = field
        return fields

    @staticmethod
    def _get_label_class(name: str, fields: OrderedDict[str, Field]) -> Type[Schematized]:
        label_class = type(str(name), (Schematized,), {str('__fields__'): fields})
        # noinspection PyTypeChecker
        return label_class


LABEL_CONVERTERS = {
    'generic': LabelConverter(label_generic_pb2.TGenericLabel, 'LabelFieldsGeneric'),
    'avia': LabelConverter(label_avia_pb2.TLabel, 'LabelFieldsAvia'),
    'hotels': LabelConverter(label_hotels_pb2.TLabel, 'LabelFieldsHotels'),
    'suburban': LabelConverter(label_suburban_pb2.TSuburbanCpaLabel, 'LabelFieldsSuburban'),
    'train': LabelConverter(label_trains_pb2.TLabelParams, 'LabelFieldsTrains'),
    'tours': LabelConverter(label_tours_pb2.TToursLabel, 'LabelFieldsTours'),
    'buses': LabelConverter(label_buses_pb2.TLabelParams, 'LabelFieldsBuses'),
}
