# -*- coding: utf-8 -*-

from typing import Any, Type

from google.protobuf.message import Message

# noinspection PyUnresolvedReferences
import travel.hotels.proto2.label_pb2 as label_hotels_pb2
# noinspection PyUnresolvedReferences
import travel.proto.cpa.generic_label_pb2 as label_generic_pb2
# noinspection PyUnresolvedReferences
import travel.proto.trains.label_params_pb2 as label_trains_pb2
# noinspection PyUnresolvedReferences
import travel.proto.buses.label_params_pb2 as label_buses_pb2


class LabelMapper(object):

    __map__ = dict()
    __dst_proto_cls__ = None

    def __init__(self):
        all_dst_fields = set(self._get_proto_field_names(self.__dst_proto_cls__))
        map_fields = set(self.__map__.keys())
        unknown_fields = map_fields - all_dst_fields
        if unknown_fields:
            raise Exception('{} has no such such fields: {}'.format(self.__dst_proto_cls__, unknown_fields))
        self.all_dst_fields = all_dst_fields
        self.auto_map_fields = all_dst_fields - map_fields

    def get_mapped_proto(self, src: Message) -> Message:
        dst = self.__dst_proto_cls__()

        for key, mapper in self.__map__.items():
            if isinstance(mapper, str):
                value = getattr(src, mapper)
                self._apply_field_value(key, value, dst)
            elif callable(mapper):
                mapper(src, dst)
            elif isinstance(mapper, staticmethod):
                mapper.__func__(src, dst)
            else:
                raise Exception(
                    'Wrong mapper type for {}:{}. Use str, callable or staticmethod'.format(key, type(mapper))
                )

        for key in self.auto_map_fields:
            if hasattr(src, key):
                value = getattr(src, key)
                self._apply_field_value(key, value, dst)

        return dst

    @staticmethod
    def _apply_field_value(field_name: str, value: Any, dst: Message):
        field_descriptor = dst.DESCRIPTOR.fields_by_name[field_name]
        if field_descriptor.label == field_descriptor.LABEL_REPEATED:
            field = getattr(dst, field_name)
            field.extend(value)
        else:
            setattr(dst, field_name, value)

    @staticmethod
    def _get_proto_field_names(proto_cls: Type[Message]) -> list[str]:
        return [f.name for f in proto_cls.DESCRIPTOR.fields]


class HotelsToGenericLabelMapper(LabelMapper):

    @staticmethod
    def _get_boxes(test_ids: list[str], test_buckets: list[str]) -> str:
        return ';'.join([
            '{},0,{}'.format(test_id, test_bucket) for test_id, test_bucket in zip(test_ids, test_buckets)
        ])

    @staticmethod
    def _map_exp_boxes(src: label_hotels_pb2.TLabel, dst: label_generic_pb2.TGenericLabel) -> None:
        if not src.IntTestIds or not src.IntTestBuckets:
            return
        dst.ExpBoxes = HotelsToGenericLabelMapper._get_boxes(src.IntTestIds, src.IntTestBuckets)

    @staticmethod
    def _map_portal_exp_boxes(src: label_hotels_pb2.TLabel, dst: label_generic_pb2.TGenericLabel) -> None:
        if not src.IntPortalTestIds or not src.IntPortalTestBuckets:
            return
        dst.PortalExpBoxes = HotelsToGenericLabelMapper._get_boxes(src.IntPortalTestIds, src.IntPortalTestBuckets)

    __map__ = {
        'UtmSource': 'Source',
        'UtmMedium': 'Medium',
        'UtmCampaign': 'Campaign',
        'UtmContent': 'Content',
        'UtmTerm': 'Term',
        'TestIds': 'IntTestIds',
        'TestBuckets': 'IntTestBuckets',
        'PortalTestIds': 'IntPortalTestIds',
        'PortalTestBuckets': 'IntPortalTestBuckets',
        'ExpBoxes': _map_exp_boxes,
        'PortalExpBoxes': _map_portal_exp_boxes,
    }

    __dst_proto_cls__ = label_generic_pb2.TGenericLabel


class TrainToGenericLabelMapper(LabelMapper):

    @staticmethod
    def _map_test_ids(src: label_trains_pb2.TLabelParams, dst: label_generic_pb2.TGenericLabel) -> None:
        if not src.TestBuckets:
            return
        try:
            dst.TestIds.extend(int(box.split(',')[0]) for box in src.TestBuckets.split(';'))
        except Exception as e:
            raise Exception('{}\nsrc = {}'.format(e, src))

    @staticmethod
    def _map_test_buckets(src: label_trains_pb2.TLabelParams, dst: label_generic_pb2.TGenericLabel) -> None:
        if not src.TestBuckets:
            return
        dst.TestBuckets.extend(int(box.split(',')[2]) for box in src.TestBuckets.split(';'))

    __map__ = {
        'ExpBoxes': 'TestBuckets',
        'TestIds': _map_test_ids,
        'TestBuckets': _map_test_buckets,
        'UserDevice': 'Device',
        'UserRegion': 'RegionId',
        'Uuid': 'SerpUuid',
    }

    __dst_proto_cls__ = label_generic_pb2.TGenericLabel


class BusesToGenericLabelMapper(LabelMapper):

    @staticmethod
    def _map_test_ids(src: label_buses_pb2.TLabelParams, dst: label_generic_pb2.TGenericLabel) -> None:
        if not src.TestBuckets:
            return
        try:
            dst.TestIds.extend(int(box.split(',')[0]) for box in src.TestBuckets.split(';'))
        except Exception as e:
            raise Exception('{}\nsrc = {}'.format(e, src))

    @staticmethod
    def _map_test_buckets(src: label_buses_pb2.TLabelParams, dst: label_generic_pb2.TGenericLabel) -> None:
        if not src.TestBuckets:
            return
        dst.TestBuckets.extend(int(box.split(',')[2]) for box in src.TestBuckets.split(';'))

    __map__ = {
        'ExpBoxes': 'TestBuckets',
        'TestIds': _map_test_ids,
        'TestBuckets': _map_test_buckets,
        'UserDevice': 'Device',
        'UserRegion': 'RegionId',
        'Uuid': 'SerpUuid',
    }

    __dst_proto_cls__ = label_generic_pb2.TGenericLabel


LABEL_MAPPERS = {
    ('hotels', 'generic'): HotelsToGenericLabelMapper(),
    ('train', 'generic'): TrainToGenericLabelMapper(),
    ('buses', 'generic'): BusesToGenericLabelMapper(),
}
