from operator import itemgetter
from typing import Dict, List

from marshmallow import ValidationError, fields, post_load, pre_dump, pre_load, validate

from maps_adv.common.protomallow import (
    PbDateTimeField,
    PbEnumField,
    ProtobufSchema,
    with_schemas,
)
from maps_adv.geosmb.promoter.proto import (
    leads_pb2,
    list_lead_segments_pb2,
    segments_pb2,
)
from maps_adv.geosmb.promoter.server.lib.enums import (
    EventType,
    OrderByField,
    OrderDirection,
    SegmentType,
    Source,
)
from maps_adv.geosmb.proto import common_pb2

from ..data_manager import BaseDataManager

ENUMS_MAP = {
    "order_by_field": [
        (leads_pb2.OrderBy.OrderField.NAME, OrderByField.NAME),
        (leads_pb2.OrderBy.OrderField.MAKE_ROUTES, OrderByField.MAKE_ROUTES),
        (leads_pb2.OrderBy.OrderField.REVIEW_RATING, OrderByField.REVIEW_RATING),
        (leads_pb2.OrderBy.OrderField.CLICKS_ON_PHONE, OrderByField.CLICKS_ON_PHONE),
        (leads_pb2.OrderBy.OrderField.SITE_OPENS, OrderByField.SITE_OPENS),
        (
            leads_pb2.OrderBy.OrderField.LAST_ACTIVITY_TIMESTAMP,
            OrderByField.LAST_ACTIVITY_TIMESTAMP,
        ),
        (
            leads_pb2.OrderBy.OrderField.VIEW_WORKING_HOURS,
            OrderByField.VIEW_WORKING_HOURS,
        ),
        (leads_pb2.OrderBy.OrderField.VIEW_ENTRANCES, OrderByField.VIEW_ENTRANCES),
        (leads_pb2.OrderBy.OrderField.CTA_BUTTON_CLICK, OrderByField.CTA_BUTTON_CLICK),
        (leads_pb2.OrderBy.OrderField.FAVOURITE_CLICK, OrderByField.FAVOURITE_CLICK),
        (leads_pb2.OrderBy.OrderField.LOCATION_SHARING, OrderByField.LOCATION_SHARING),
        (
            leads_pb2.OrderBy.OrderField.BOOKING_SECTION_INTERACTION,
            OrderByField.BOOKING_SECTION_INTERACTION,
        ),
        (
            leads_pb2.OrderBy.OrderField.SHOWCASE_PRODUCT_CLICK,
            OrderByField.SHOWCASE_PRODUCT_CLICK,
        ),
        (leads_pb2.OrderBy.OrderField.PROMO_TO_SITE, OrderByField.PROMO_TO_SITE),
        (
            leads_pb2.OrderBy.OrderField.GEOPRODUCT_BUTTON_CLICK,
            OrderByField.GEOPRODUCT_BUTTON_CLICK,
        ),
    ],
    "order_direction": [
        (leads_pb2.OrderBy.OrderDirection.ASC, OrderDirection.ASC),
        (leads_pb2.OrderBy.OrderDirection.DESC, OrderDirection.DESC),
    ],
    "segment_type": [
        (segments_pb2.SegmentType.PROSPECTIVE, SegmentType.PROSPECTIVE),
        (segments_pb2.SegmentType.ACTIVE, SegmentType.ACTIVE),
        (segments_pb2.SegmentType.LOST, SegmentType.LOST),
        (segments_pb2.SegmentType.LOYAL, SegmentType.LOYAL),
        (segments_pb2.SegmentType.DISLOYAL, SegmentType.DISLOYAL),
    ],
    "event_type": [
        (leads_pb2.Event.MAKE_ROUTE, EventType.MAKE_ROUTE),
        (leads_pb2.Event.REVIEW, EventType.REVIEW),
        (leads_pb2.Event.CLICK_ON_PHONE, EventType.CLICK_ON_PHONE),
        (leads_pb2.Event.OPEN_SITE, EventType.OPEN_SITE),
        (leads_pb2.Event.VIEW_WORKING_HOURS, EventType.VIEW_WORKING_HOURS),
        (leads_pb2.Event.VIEW_ENTRANCES, EventType.VIEW_ENTRANCES),
        (leads_pb2.Event.CTA_BUTTON_CLICK, EventType.CTA_BUTTON_CLICK),
        (leads_pb2.Event.FAVOURITE_CLICK, EventType.FAVOURITE_CLICK),
        (leads_pb2.Event.LOCATION_SHARING, EventType.LOCATION_SHARING),
        (
            leads_pb2.Event.BOOKING_SECTION_INTERACTION,
            EventType.BOOKING_SECTION_INTERACTION,
        ),
        (leads_pb2.Event.SHOWCASE_PRODUCT_CLICK, EventType.SHOWCASE_PRODUCT_CLICK),
        (leads_pb2.Event.PROMO_TO_SITE, EventType.PROMO_TO_SITE),
        (leads_pb2.Event.GEOPRODUCT_BUTTON_CLICK, EventType.GEOPRODUCT_BUTTON_CLICK),
    ],
    "source": [
        (leads_pb2.Source.EXTERNAL_ADVERT, Source.EXTERNAL_ADVERT),
        (leads_pb2.Source.STRAIGHT, Source.STRAIGHT),
        (leads_pb2.Source.DISCOVERY_ADVERT, Source.DISCOVERY_ADVERT),
        (leads_pb2.Source.DISCOVERY_NO_ADVERT, Source.DISCOVERY_NO_ADVERT),
    ],
}


class PaginationSchema(ProtobufSchema):
    class Meta:
        pb_message_class = common_pb2.Pagination

    limit = fields.Integer(required=True)
    offset = fields.Integer(required=True)


class StatisticsSchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.Statistics

    make_routes = fields.Integer(required=True)
    review_rating = fields.String()
    clicks_on_phone = fields.Integer(required=True)
    site_opens = fields.Integer(required=True)
    view_working_hours = fields.Integer(required=True)
    view_entrances = fields.Integer(required=True)
    cta_button_click = fields.Integer(required=True)
    favourite_click = fields.Integer(required=True)
    location_sharing = fields.Integer(required=True)
    booking_section_interaction = fields.Integer(required=True)
    showcase_product_click = fields.Integer(required=True)
    promo_to_site = fields.Integer(required=True)
    geoproduct_button_click = fields.Integer(required=True)
    last_activity_timestamp = PbDateTimeField(required=True)


class LeadSchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.Lead

    lead_id = fields.Integer(required=True)
    biz_id = fields.Integer(required=True)
    name = fields.String(required=True)
    statistics = fields.Nested(StatisticsSchema, required=True)
    segments = fields.List(
        PbEnumField(
            required=True,
            enum=SegmentType,
            pb_enum=segments_pb2.SegmentType,
            values_map=ENUMS_MAP["segment_type"],
        )
    )
    source = PbEnumField(
        enum=Source,
        pb_enum=leads_pb2.Source,
        values_map=ENUMS_MAP["source"],
    )


class OrderBySchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.OrderBy

    field = PbEnumField(
        required=True,
        enum=OrderByField,
        pb_enum=leads_pb2.OrderBy.OrderField,
        values_map=ENUMS_MAP["order_by_field"],
    )
    direction = PbEnumField(
        required=True,
        enum=OrderDirection,
        pb_enum=leads_pb2.OrderBy.OrderDirection,
        values_map=ENUMS_MAP["order_direction"],
    )


class ListLeadsInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.ListLeadsInput

    biz_id = fields.Integer(required=True, validate=validate.Range(min=1))
    order_by = fields.Nested(OrderBySchema)
    filter_by_segment = PbEnumField(
        enum=SegmentType,
        pb_enum=segments_pb2.SegmentType,
        values_map=ENUMS_MAP["segment_type"],
    )
    pagination = fields.Nested(PaginationSchema, required=True)

    @post_load
    def _unpack_pagination(self, data: dict) -> dict:
        limit, offset = itemgetter("limit", "offset")(data.pop("pagination"))
        data.update(limit=limit, offset=offset)
        return data

    @post_load
    def _unpack_ordering(self, data: dict) -> dict:
        if data.get("order_by"):
            field, direction = itemgetter("field", "direction")(data.pop("order_by"))
            data.update(order_by_field=field, order_direction=direction)
        return data


class RetrieveLeadInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.RetrieveLeadInput

    biz_id = fields.Integer(required=True, validate=validate.Range(min=1))
    lead_id = fields.Integer(required=True, validate=validate.Range(min=1))


class ListLeadsOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.ListLeadsOutput

    total_count = fields.Integer(required=True)
    leads = fields.Nested(LeadSchema, many=True)


class ListSegmentsInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = segments_pb2.ListSegmentsInput

    biz_id = fields.Integer(required=True, validate=validate.Range(min=1))


class SegmentSchema(ProtobufSchema):
    class Meta:
        pb_message_class = segments_pb2.Segment

    segment_type = PbEnumField(
        required=True,
        enum=SegmentType,
        pb_enum=segments_pb2.SegmentType,
        values_map=ENUMS_MAP["segment_type"],
        dump_to="type",
    )
    size = fields.Integer(required=True)


class ListSegmentsOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = segments_pb2.ListSegmentsOutput

    total_leads = fields.Integer(required=True)
    segments = fields.Nested(SegmentSchema, many=True)

    @pre_dump
    def _compose_segments(self, data):
        segments = [dict(segment_type=k, size=v) for k, v in data["segments"].items()]
        data["segments"] = sorted(segments, key=lambda i: i["segment_type"].name)
        return data


class EventSchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.Event

    event_type = PbEnumField(
        required=True,
        enum=EventType,
        pb_enum=leads_pb2.Event.EventType,
        values_map=ENUMS_MAP["event_type"],
        dump_to="type",
    )
    event_type_str = fields.String(required=True, dump_to="type_str")
    event_timestamp = PbDateTimeField(required=True, dump_to="timestamp")
    event_value = fields.String(dump_to="value")
    source = PbEnumField(
        enum=Source,
        pb_enum=leads_pb2.Source,
        values_map=ENUMS_MAP["source"],
    )

    @pre_dump
    def set_type_str(self, data: dict) -> dict:
        data["event_type_str"] = data["event_type"].name
        return data


class ListLeadEventsOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.ListLeadEventsOutput

    total_events = fields.Integer(required=True)
    events = fields.Nested(EventSchema, many=True)


class ListLeadEventsInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.ListLeadEventsInput

    biz_id = fields.Integer(required=True, validate=validate.Range(min=1))
    lead_id = fields.Integer(required=True, validate=validate.Range(min=1))
    pagination = fields.Nested(PaginationSchema, required=True)

    @post_load
    def _unpack_pagination(self, data: dict) -> dict:
        limit, offset = itemgetter("limit", "offset")(data.pop("pagination"))
        data.update(limit=limit, offset=offset)
        return data


class BusinessSegmentsSchema(ProtobufSchema):
    class Meta:
        pb_message_class = list_lead_segments_pb2.BusinessSegments

    lead_id = fields.Integer(required=True)
    biz_id = fields.Integer(required=True)
    segments = fields.List(
        PbEnumField(
            required=True,
            enum=SegmentType,
            pb_enum=segments_pb2.SegmentType,
            values_map=ENUMS_MAP["segment_type"],
        )
    )


class ListLeadSegmentsOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = list_lead_segments_pb2.ListLeadSegmentsOutput

    biz_segments = fields.Nested(BusinessSegmentsSchema, many=True)

    @pre_dump
    def _to_dict(self, data):
        return {"biz_segments": data}


class ListLeadSegmentsInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = list_lead_segments_pb2.ListLeadSegmentsInput

    passport_uid = fields.String(required=False, validate=validate.Length(min=1))
    yandex_uid = fields.String(required=False, validate=validate.Length(min=1))
    device_id = fields.String(required=False, validate=validate.Length(min=1))
    biz_id = fields.Integer(required=False, validate=validate.Range(min=1))

    @pre_load
    def _validate_id_fields(self, data):
        if not any(data.values()):
            raise ValidationError(
                "At least one id field should be listed: "
                "passport_uid, yandex_uid or device_id"
            )
        return data


class SearchLeadsForGdprInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.SearchLeadsForGdprInput

    passport_uid = fields.Integer(required=True)

    @post_load
    def _to_str(self, data: dict) -> dict:
        data["passport_uid"] = str(data["passport_uid"])
        return data


class SearchLeadsForGdprOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.SearchLeadsForGdprOutput

    leads_exist = fields.Boolean(required=True)

    @pre_dump
    def to_dict(self, data: bool) -> dict:
        return {"leads_exist": data}


class RemoveLeadsForGdprInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.RemoveLeadsForGdprInput

    passport_uid = fields.Integer(required=True)

    @post_load
    def _to_str(self, data: dict) -> dict:
        data["passport_uid"] = str(data["passport_uid"])
        return data


class RemovedLeadSchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.RemoveLeadsForGdprOutput.RemovedLead

    biz_id = fields.Integer(required=True)
    lead_id = fields.Integer(required=True)


class RemoveLeadsForGdprOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = leads_pb2.RemoveLeadsForGdprOutput

    removed_leads = fields.Nested(RemovedLeadSchema, many=True, required=True)

    @pre_dump
    def to_dict(self, data):
        return {"removed_leads": data}


class ApiProvider:
    __slots__ = ["_dm"]

    _dm: BaseDataManager

    def __init__(self, dm: BaseDataManager):
        self._dm = dm

    @with_schemas(ListLeadsInputSchema, ListLeadsOutputSchema)
    async def list_leads(self, **kwargs) -> Dict[int, List[dict]]:
        return await self._dm.list_leads(**kwargs)

    @with_schemas(RetrieveLeadInputSchema, LeadSchema)
    async def retrieve_lead(self, **kwargs) -> Dict[int, List[dict]]:
        return await self._dm.retrieve_lead(**kwargs)

    @with_schemas(ListSegmentsInputSchema, ListSegmentsOutputSchema)
    async def list_segments(self, **kwargs) -> dict:
        return await self._dm.list_segments(**kwargs)

    @with_schemas(ListLeadEventsInputSchema, ListLeadEventsOutputSchema)
    async def list_lead_events(self, **kwargs) -> Dict[int, List[dict]]:
        return await self._dm.list_lead_events(**kwargs)

    @with_schemas(ListLeadSegmentsInputSchema, ListLeadSegmentsOutputSchema)
    async def list_lead_segments(self, **kwargs) -> List[dict]:
        return await self._dm.list_lead_segments(**kwargs)

    @with_schemas(SearchLeadsForGdprInputSchema, SearchLeadsForGdprOutputSchema)
    async def search_leads_for_gdpr(self, **kwargs) -> bool:
        return await self._dm.check_leads_existence_by_passport(**kwargs)

    @with_schemas(RemoveLeadsForGdprInputSchema, RemoveLeadsForGdprOutputSchema)
    async def remove_leads_for_gdpr(self, **kwargs) -> List[dict]:
        return await self._dm.delete_leads_data_by_passport(**kwargs)
