from typing import List

from marshmallow import fields, pre_dump

from maps_adv.common.protomallow import (
    PbDateTimeField,
    PbEnumField,
    ProtobufSchema,
    with_schemas,
)
from maps_adv.geosmb.scenarist.proto import scenarios_pb2

from ..domain import Domain
from ..enums import ScenarioName, SegmentType, SubscriptionStatus

PB_TO_ENUMS_MAP = {
    "scenario_name": [
        (scenarios_pb2.ScenarioName.DISCOUNT_FOR_LOST, ScenarioName.DISCOUNT_FOR_LOST),
        (
            scenarios_pb2.ScenarioName.ENGAGE_PROSPECTIVE,
            ScenarioName.ENGAGE_PROSPECTIVE,
        ),
        (scenarios_pb2.ScenarioName.THANK_THE_LOYAL, ScenarioName.THANK_THE_LOYAL),
        (
            scenarios_pb2.ScenarioName.DISCOUNT_FOR_DISLOYAL,
            ScenarioName.DISCOUNT_FOR_DISLOYAL,
        ),
    ],
    "segment_type": [
        (scenarios_pb2.Scenario.PROSPECTIVE, SegmentType.PROSPECTIVE),
        (scenarios_pb2.Scenario.ACTIVE, SegmentType.ACTIVE),
        (scenarios_pb2.Scenario.LOST, SegmentType.LOST),
        (scenarios_pb2.Scenario.LOYAL, SegmentType.LOYAL),
        (scenarios_pb2.Scenario.DISLOYAL, SegmentType.DISLOYAL),
        (scenarios_pb2.Scenario.REGULAR, SegmentType.REGULAR),
        (scenarios_pb2.Scenario.UNPROCESSED_ORDERS, SegmentType.UNPROCESSED_ORDERS),
        (scenarios_pb2.Scenario.NO_ORDERS, SegmentType.NO_ORDERS),
    ],
    "subscription_status": [
        (scenarios_pb2.SubscriptionStatus.ACTIVE, SubscriptionStatus.ACTIVE),
        (scenarios_pb2.SubscriptionStatus.PAUSED, SubscriptionStatus.PAUSED),
        (scenarios_pb2.SubscriptionStatus.COMPLETED, SubscriptionStatus.COMPLETED),
    ],
}


class ListScenariosInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = scenarios_pb2.ListScenariosInput

    biz_id = fields.Integer(required=True)


class CouponStatisticsLineSchema(ProtobufSchema):
    class Meta:
        pb_message_class = scenarios_pb2.CouponStatisticsLine

    sent = fields.Integer(required=True)
    opened = fields.Integer(required=True)
    clicked = fields.Integer(required=True)


class CouponStatisticsSchema(ProtobufSchema):
    class Meta:
        pb_message_class = scenarios_pb2.CouponStatistics

    total = fields.Nested(CouponStatisticsLineSchema)

    @pre_dump
    def reformat(self, data: dict) -> dict:
        return dict(total=data)


class SubscriptionSchema(ProtobufSchema):
    class Meta:
        pb_message_class = scenarios_pb2.Subscription

    subscription_id = fields.Integer(required=True)
    biz_id = fields.Integer(required=True)
    scenario_name = PbEnumField(
        enum=ScenarioName,
        pb_enum=scenarios_pb2.ScenarioName,
        values_map=PB_TO_ENUMS_MAP["scenario_name"],
        required=True,
    )
    status = PbEnumField(
        required=True,
        enum=SubscriptionStatus,
        pb_enum=scenarios_pb2.SubscriptionStatus,
        values_map=PB_TO_ENUMS_MAP["subscription_status"],
    )
    coupon_id = fields.Integer()


class CouponSchema(ProtobufSchema):
    class Meta:
        pb_message_class = scenarios_pb2.Coupon

    created_at = PbDateTimeField(required=True)
    coupon_id = fields.Integer()
    statistics = fields.Nested(CouponStatisticsSchema, required=True)


class ScenarioSchema(ProtobufSchema):
    class Meta:
        pb_message_class = scenarios_pb2.Scenario

    name = PbEnumField(
        enum=ScenarioName,
        pb_enum=scenarios_pb2.ScenarioName,
        values_map=PB_TO_ENUMS_MAP["scenario_name"],
        required=True,
    )
    segments = fields.List(
        PbEnumField(
            enum=SegmentType,
            pb_enum=scenarios_pb2.Scenario.SegmentType,
            values_map=PB_TO_ENUMS_MAP["segment_type"],
        )
    )
    subscription = fields.Nested(SubscriptionSchema)


class ListScenariosOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = scenarios_pb2.ListScenariosOutput

    scenarios = fields.Nested(ScenarioSchema, many=True)

    @pre_dump
    def reformat(self, data: dict) -> dict:
        return dict(scenarios=data)


class CreateSubscriptionInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = scenarios_pb2.CreateSubscriptionInput

    biz_id = fields.Integer(required=True)
    scenario_name = PbEnumField(
        enum=ScenarioName,
        pb_enum=scenarios_pb2.ScenarioName,
        values_map=PB_TO_ENUMS_MAP["scenario_name"],
        required=True,
    )
    coupon_id = fields.Integer()


class RetrieveSubscriptionInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = scenarios_pb2.RetrieveSubscriptionInput

    subscription_id = fields.Integer(required=True)
    biz_id = fields.Integer(required=True)


class RetrieveSubscriptionOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = scenarios_pb2.RetrieveSubscriptionOutput

    subscription = fields.Nested(SubscriptionSchema, required=True)
    coupons_history = fields.Nested(CouponSchema, many=True)


class UpdateSubscriptionStatusInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = scenarios_pb2.UpdateSubscriptionStatusInput

    subscription_id = fields.Integer(required=True)
    biz_id = fields.Integer(required=True)
    status = PbEnumField(
        enum=SubscriptionStatus,
        pb_enum=scenarios_pb2.UpdateSubscriptionStatusInput.StatusForUpdate,
        values_map=PB_TO_ENUMS_MAP["subscription_status"],
        required=True,
    )


class ReplaceCouponInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = scenarios_pb2.ReplaceCouponInput

    subscription_id = fields.Integer(required=True)
    biz_id = fields.Integer(required=True)
    coupon_id = fields.Integer()


class ReplaceCouponOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = scenarios_pb2.ReplaceCouponOutput

    status = PbEnumField(
        enum=SubscriptionStatus,
        pb_enum=scenarios_pb2.SubscriptionStatus,
        values_map=PB_TO_ENUMS_MAP["subscription_status"],
        required=True,
    )

    @pre_dump
    def reformat(self, data: dict) -> dict:
        return dict(status=data)


class ApiProvider:
    __slots__ = ["_domain"]

    _domain: Domain

    def __init__(self, domain: Domain):
        self._domain = domain

    @with_schemas(
        input_schema=ListScenariosInputSchema, output_schema=ListScenariosOutputSchema
    )
    async def list_scenarios(self, **kwargs) -> List[dict]:
        return await self._domain.list_scenarios(**kwargs)

    @with_schemas(
        input_schema=CreateSubscriptionInputSchema, output_schema=SubscriptionSchema
    )
    async def create_subscription(self, **kwargs) -> dict:
        return await self._domain.create_subscription(**kwargs)

    @with_schemas(
        input_schema=RetrieveSubscriptionInputSchema,
        output_schema=RetrieveSubscriptionOutputSchema,
    )
    async def retrieve_subscription(self, **kwargs) -> dict:
        return await self._domain.retrieve_subscription(**kwargs)

    @with_schemas(input_schema=UpdateSubscriptionStatusInputSchema)
    async def update_subscription_status(self, **kwargs) -> dict:
        return await self._domain.update_subscription_status(**kwargs)

    @with_schemas(
        input_schema=ReplaceCouponInputSchema, output_schema=ReplaceCouponOutputSchema
    )
    async def replace_subscription_coupon(self, **kwargs) -> dict:
        return await self._domain.replace_subscription_coupon(**kwargs)
