from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple

import dateutil
from marshmallow import ValidationError, fields, post_load, pre_dump

from maps_adv.common.helpers.enums import CampaignTypeEnum
from maps_adv.common.proto import campaign_pb2
from maps_adv.common.protomallow import (
    PbDateTimeField,
    PbDecimalField,
    PbEnumField,
    ProtobufSchema,
)
from maps_adv.statistics.dashboard.proto import campaign_stat_pb2, campaigns_stat_pb2
from maps_adv.statistics.dashboard.server.lib.ch_query_log import ClickHouseQueryLog
from maps_adv.statistics.dashboard.server.lib.data_manager import (
    AbstractDataManager,
    NoCampaignsPassed,
)
from maps_adv.statistics.dashboard.server.lib.domain import Domain


ENUMS_MAP = {
    "campaign_type": [
        (campaign_pb2.CampaignType.PIN_ON_ROUTE, CampaignTypeEnum.PIN_ON_ROUTE),
        (campaign_pb2.CampaignType.BILLBOARD, CampaignTypeEnum.BILLBOARD),
        (
            campaign_pb2.CampaignType.ZERO_SPEED_BANNER,
            CampaignTypeEnum.ZERO_SPEED_BANNER,
        ),
        (campaign_pb2.CampaignType.OVERVIEW_BANNER, CampaignTypeEnum.OVERVIEW_BANNER),
        (campaign_pb2.CampaignType.CATEGORY_SEARCH, CampaignTypeEnum.CATEGORY_SEARCH),
        (campaign_pb2.CampaignType.ROUTE_BANNER, CampaignTypeEnum.ROUTE_BANNER),
        (campaign_pb2.CampaignType.VIA_POINTS, CampaignTypeEnum.VIA_POINTS),
        (campaign_pb2.CampaignType.PROMOCODE, CampaignTypeEnum.PROMOCODE),
    ],
}


class with_proto_schemas:
    __slots__ = "input_schema", "output_schema"

    input_schema: Optional[ProtobufSchema]
    output_schema: Optional[ProtobufSchema]

    def __init__(
        self,
        input_schema: Optional[ProtobufSchema] = None,
        output_schema: Optional[ProtobufSchema] = None,
    ):
        self.input_schema = input_schema
        self.output_schema = output_schema

    def __call__(self, func):
        async def wrapped(s, pb_data=None, **kwargs):
            if self.input_schema is not None:
                dumped = self.input_schema().from_bytes(pb_data)
                kwargs.update(dumped)

            got = await func(s, **kwargs)

            if self.output_schema is not None:
                return self.output_schema().to_bytes(got)
            return got

        return wrapped


def _validate_campaign_list(campaigns):
    if len(campaigns) < 1:
        raise NoCampaignsPassed


class PbFixedFloatField(fields.Field):
    def __init__(self, places: int, **kwargs):
        self._places = places
        self._coefficient = 10**places
        super().__init__(**kwargs)

    def _serialize(self, value, attr, obj):
        return int(value * self._coefficient)


class CampaignsStatInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = campaigns_stat_pb2.CampaignsStatInput

    campaign_ids = fields.List(fields.Integer(), validate=_validate_campaign_list)
    period_from = PbDateTimeField()
    period_to = PbDateTimeField()


class CampaignsStatDetailsSchema(ProtobufSchema):
    class Meta:
        pb_message_class = campaigns_stat_pb2.CampaignsStatDetails

    call = fields.Integer()
    makeRoute = fields.Integer()
    openSite = fields.Integer()
    saveOffer = fields.Integer()
    search = fields.Integer()
    show = fields.Integer()
    tap = fields.Integer()
    ctr = fields.Float()
    clicks_to_routes = fields.Float()
    charged_sum = PbDecimalField(places=2)
    show_unique = fields.Integer()


class CampaignsStatOnDateSchema(ProtobufSchema):
    class Meta:
        pb_message_class = campaigns_stat_pb2.CampaignsStatOnDate

    date = fields.String()
    details = fields.Nested(CampaignsStatDetailsSchema)

    @pre_dump
    def _extract_submessage(self, data: dict) -> dict:
        _date = data.pop("date")
        return {"date": _date, "details": data}


class CampaignsStatOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = campaigns_stat_pb2.CampaignsStatOutput

    by_dates = fields.Nested(CampaignsStatOnDateSchema, many=True)
    total = fields.Nested(CampaignsStatDetailsSchema)


class CampaignChargedSumInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = campaign_stat_pb2.CampaignChargedSumInput

    campaign_ids = fields.List(fields.Integer(), validate=_validate_campaign_list)
    on_timestamp = fields.Integer()


class CampaignChargedSumSchema(ProtobufSchema):
    class Meta:
        pb_message_class = campaign_stat_pb2.CampaignChargedSum

    campaign_id = fields.Integer()
    charged_sum = PbDecimalField(places=2)


class CampaignChargedSumOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = campaign_stat_pb2.CampaignChargedSumOutput

    campaigns_charged_sums = fields.Nested(CampaignChargedSumSchema, many=True)


class IconCampaignsStatDetailsSchema(ProtobufSchema):
    class Meta:
        pb_message_class = campaigns_stat_pb2.IconCampaignsStatDetails

    icon_shows = fields.Integer()
    icon_clicks = fields.Integer()
    pin_shows = fields.Integer()
    pin_clicks = fields.Integer()
    routes = fields.Integer()
    unique_icon_shows = fields.Integer(required=False)


class IconCampaignsStatOnDateSchema(ProtobufSchema):
    class Meta:
        pb_message_class = campaigns_stat_pb2.IconCampaignsStatOnDate

    date = fields.String()
    details = fields.Nested(IconCampaignsStatDetailsSchema)


class IconCampaignsStatOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = campaigns_stat_pb2.IconCampaignsStatOutput

    total = fields.Nested(IconCampaignsStatDetailsSchema)
    by_dates = fields.List(fields.Nested(IconCampaignsStatOnDateSchema))


class CampaignEventsInputPartSchema(ProtobufSchema):
    class Meta:
        pb_message_class = campaign_stat_pb2.CampaignEventsInputPart

    campaign_id = fields.Integer()
    campaign_type = PbEnumField(
        enum=CampaignTypeEnum,
        pb_enum=campaign_pb2.CampaignType.Enum,
        values_map=ENUMS_MAP["campaign_type"],
    )

    @post_load
    def campaigns_load(self, data):
        return data["campaign_id"], data["campaign_type"]


class CampaignEventsOnDateInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = campaign_stat_pb2.CampaignEventsForPeriodInput

    campaigns = fields.Nested(CampaignEventsInputPartSchema, many=True)
    period_from = PbDateTimeField(required=False)
    period_to = PbDateTimeField(required=False)

    @post_load
    def campaigns_load(self, data):
        return {"events_query": data.pop("campaigns"), **data}


class CampaignEventsOnDateSchema(ProtobufSchema):
    class Meta:
        pb_message_class = campaign_stat_pb2.CampaignEvents

    campaign_id = fields.Integer()
    events = fields.Integer()

    @pre_dump
    def predump(self, data):
        (id, num) = data
        return {"campaign_id": id, "events": num}


class CampaignEventsOnDateOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = campaign_stat_pb2.CampaignEventsForPeriodOutput

    campaigns_events = fields.Nested(CampaignEventsOnDateSchema, many=True)

    @pre_dump
    def predump(self, data):
        return {"campaigns_events": list(data.items())}


class ApiProvider:
    __slots__ = "_dm", "_domain", "_ch_query_log"

    _dm: AbstractDataManager
    _domain: Domain
    _ch_query_log: ClickHouseQueryLog

    def __init__(
        self, dm: AbstractDataManager, domain: Domain, ch_query_log: ClickHouseQueryLog
    ):
        self._dm = dm
        self._domain = domain
        self._ch_query_log = ch_query_log

    @with_proto_schemas(CampaignsStatInputSchema, CampaignsStatOutputSchema)
    async def calculate_by_campaigns_and_period(self, **kwargs) -> dict:
        kwargs["period_from"] = kwargs["period_from"].date()
        kwargs["period_to"] = kwargs["period_to"].date()

        got = await self._dm.calculate_by_campaigns_and_period(**kwargs)

        return {"by_dates": got[:-1], "total": got[-1]}

    @with_proto_schemas(CampaignChargedSumInputSchema, CampaignChargedSumOutputSchema)
    async def calculate_campaigns_charged_sum(self, **kwargs) -> Dict[str, List[dict]]:
        got = await self._dm.calculate_campaigns_charged_sum(**kwargs)
        return {"campaigns_charged_sums": got}

    @with_proto_schemas(CampaignsStatInputSchema, IconCampaignsStatOutputSchema)
    async def fetch_search_icons_statistics(self, **kwargs) -> dict:
        got = await self._dm.fetch_search_icons_statistics(**kwargs)
        total = {
            "icon_shows": got[-1]["icon_shows"],
            "icon_clicks": got[-1]["icon_clicks"],
            "pin_shows": got[-1]["pin_shows"],
            "pin_clicks": got[-1]["pin_clicks"],
            "routes": got[-1]["routes"],
        }

        by_dates = []
        for row in got[:-1]:
            if row["unique_icon_shows"] is None:
                del row["unique_icon_shows"]
            by_dates.append(dict(date=row.pop("date"), details=row))

        return {"total": total, "by_dates": by_dates}

    @with_proto_schemas(
        CampaignEventsOnDateInputSchema, CampaignEventsOnDateOutputSchema
    )
    async def calculate_campaigns_events_for_period(
        self, **kwargs
    ) -> Dict[str, List[dict]]:
        return await self._dm.calculate_campaigns_events_for_period(**kwargs)

    async def calculate_monitoring_data(self, now: str, period: str) -> dict:
        to_datetime, period = self._parse_solomon_params(now, period)

        metrics = await self._domain.calculate_monitoring_data(
            to_datetime, int(period.total_seconds())
        )

        return {"metrics": metrics}

    async def calculate_monitoring_data_for_queries(
        self, now: str, period: str
    ) -> dict:
        to_datetime, period = self._parse_solomon_params(now, period)

        metrics = await self._ch_query_log.retrieve_metrics_for_queries(
            from_datetime=to_datetime - period, to_datetime=to_datetime
        )

        return {"metrics": metrics}

    @classmethod
    def convert_period_to_seconds(cls, period: str) -> int:
        """
        period looks like 15s, 30m, etc
        """
        seconds_per_unit = {"s": 1, "m": 60, "h": 3600}

        return int(period[:-1]) * seconds_per_unit[period[-1]]

    @classmethod
    def _parse_solomon_params(cls, now: str, period: str) -> Tuple[datetime, timedelta]:
        try:
            now_dt = dateutil.parser.isoparse(now)
            period_seconds = cls.convert_period_to_seconds(period)
        except Exception as e:
            raise ValidationError("Invalid parameters: " + str(e))

        return now_dt, timedelta(seconds=period_seconds)
