import logging
from abc import ABC, abstractmethod
from collections import namedtuple
from typing import AsyncGenerator, Dict, List, Optional

from smb.common.pgswim import PoolType

from maps_adv.geosmb.promoter.server.lib.db import DB
from maps_adv.geosmb.promoter.server.lib.enums import (
    OrderByField,
    OrderDirection,
    SegmentType,
)
from maps_adv.geosmb.promoter.server.lib.exceptions import (
    BadSortingParams,
    NoLeadIdFieldsPassed,
    UnknownLead,
)

from . import sqls

DbParam = namedtuple("DbParam", "db_field, db_value")


class BaseDataManager(ABC):
    @abstractmethod
    async def list_leads(
        self,
        *,
        biz_id: int,
        order_by_field: Optional[OrderByField] = None,
        order_by_direction: Optional[OrderDirection] = None,
        filter_by_segment: Optional[SegmentType] = None,
        limit: int,
        offset: int,
    ) -> Dict[int, List[dict]]:
        raise NotImplementedError()

    @abstractmethod
    async def retrieve_lead(self, *, biz_id: int, lead_id: int) -> dict:
        raise NotImplementedError()

    @abstractmethod
    async def iter_leads_for_export(self, chunk_size: int):
        raise NotImplementedError()

    @abstractmethod
    async def list_segments(self, biz_id: int) -> Dict[int, dict]:
        raise NotImplementedError()

    @abstractmethod
    async def list_lead_events(
        self, *, biz_id: int, lead_id: int, limit: int, offset: int
    ) -> Dict[int, List[dict]]:
        raise NotImplementedError()

    @abstractmethod
    async def import_events_from_generator(
        self, generator: AsyncGenerator[List[list], None]
    ):
        raise NotImplementedError()

    @abstractmethod
    async def list_lead_segments(
        self,
        *,
        passport_uid: Optional[str] = None,
        yandex_uid: Optional[str] = None,
        device_id: Optional[str] = None,
        biz_id: Optional[int] = None,
    ) -> List[dict]:
        raise NotImplementedError()

    @abstractmethod
    async def check_leads_existence_by_passport(self, passport_uid: str) -> bool:
        raise NotImplementedError()

    @abstractmethod
    async def delete_leads_data_by_passport(self, passport_uid: str) -> List[dict]:
        raise NotImplementedError()


class DataManager(BaseDataManager):
    __slots__ = ("_db",)

    SEGMENTS_TRACK_PERIOD_DAYS = 90

    SEGMENT_TO_CTE_MAP = {
        SegmentType.PROSPECTIVE: "prospective_leads",
        SegmentType.ACTIVE: "active_leads",
        SegmentType.LOST: "lost_leads",
        SegmentType.LOYAL: "loyal_leads",
        SegmentType.DISLOYAL: "disloyal_leads",
    }

    _db: DB

    def __init__(self, db: DB):
        self._db = db

    async def list_leads(
        self,
        *,
        biz_id: int,
        order_by_field: Optional[OrderByField] = None,
        order_direction: Optional[OrderDirection] = None,
        filter_by_segment: Optional[SegmentType] = None,
        limit: int,
        offset: int,
    ) -> Dict[int, List[dict]]:
        if order_by_field:
            if not order_direction:
                raise BadSortingParams("Order direction must be set with order field")
            direction_expression = (
                "ASC NULLS LAST"
                if order_direction == OrderDirection.ASC
                else "DESC NULLS FIRST"
            )
            order_by_expression = f"{order_by_field.value} {direction_expression}"
        else:
            order_by_expression = "last_activity_timestamp DESC"

        segment_name = filter_by_segment.name if filter_by_segment else None
        async with self._db.acquire(PoolType.replica) as con:
            rows = await con.fetch(
                sqls.list_leads.format(order_by_expression=order_by_expression),
                biz_id,
                limit,
                offset,
                segment_name,
            )
            total_count, leads = self._normalize_leads(rows)

            return dict(total_count=total_count, leads=leads)

    async def retrieve_lead(self, *, biz_id: int, lead_id: int) -> dict:
        async with self._db.acquire(PoolType.replica) as con:
            row = await con.fetchrow(sqls.retrieve_lead, biz_id, lead_id)

            if not row:
                raise UnknownLead(
                    f"Unknown lead with biz_id={biz_id}, lead_id={lead_id}"
                )

            return self._normalize_lead(row)

    async def import_events_from_generator(
        self, generator: AsyncGenerator[List[list], None]
    ):
        logger = logging.getLogger("data_manager.import_events")

        async with self._db.acquire(PoolType.master) as con:
            async with con.transaction():
                await con.execute(sqls.create_table_import_events_tmp)

                async for records in generator:
                    await con.copy_records_to_table(
                        "import_events_tmp",
                        records=records,
                        columns=[
                            "biz_id",
                            "event_type_name",
                            "event_timestamp",
                            "data_source",
                            "lead_name",
                            "passport_uid",
                            "device_id",
                            "yandex_uid",
                            "event_value",
                            "events_amount",
                            "source",
                        ],
                    )

                await con.execute(sqls.analyze_import_events_tmp)

                # Расставляем passport_uids для некоторых новых событий
                # Это предварительный этап слияния лидов - в пределах новых данных
                logger.info("Starting set_events_passport_uids")
                await con.execute(sqls.set_events_passport_uids)

                # Привязываем к новым событиям всех подходящих им лидов
                logger.info("Starting set_events_lead_ids")
                await con.execute(sqls.set_events_lead_ids)

                # Сливаем вместе лидов, которые оказались одинаковыми
                # Критерий - сметчились с одним и тем же событием в новой пачке
                logger.info("Starting merge_leads")
                await con.execute(sqls.merge_leads)

                # Обновляем данные лидов
                # - для уже известных лидов в новой пачке событий могли появиться
                # новые данные;
                # - при слиянии "оставшийся" лид получает совокупность данных тех,
                # кто в него влился
                logger.info("Starting update_leads")
                await con.execute(sqls.update_leads)

                # Создаём новых лидов (для тех событий, которые не сматчилсь
                # ни с одним из существующих)
                logger.info("Starting create_leads")
                await con.execute(sqls.create_leads)

                # Добавляем в БД новые события
                logger.info("Starting insert_lead_events")
                await con.execute(sqls.insert_lead_events)

                # Обновляем предрассчитанную статистику
                logger.info("Starting refresh_events_stat_precalced")
                await con.execute(sqls.refresh_events_stat_precalced)

                logger.info("Import done")

    async def iter_leads_for_export(self, chunk_size: int):
        async with self._db.acquire(PoolType.replica) as con:
            async with con.transaction():
                cur = await con.cursor(sqls.list_leads_for_scenarist)

                while True:
                    rows = await cur.fetch(chunk_size)
                    if rows == []:
                        break
                    yield [
                        dict(
                            promoter_id=row["lead_id"],
                            biz_id=row["biz_id"],
                            passport_uid=row["passport_uid"],
                            yandex_uid=row["yandex_uid"],
                            device_id=row["device_id"],
                            name=row["name"],
                            segments=sorted(row["segments"]),
                        )
                        for row in rows
                    ]

    def _normalize_leads(self, rows: list) -> [int, List[dict]]:
        if not rows:
            return 0, []

        total_count = rows[0]["total_count"]
        if len(rows) == 1 and rows[0]["lead_id"] is None:
            return total_count, []

        return total_count, [self._normalize_lead(row) for row in rows]

    def _normalize_lead(self, row: dict) -> dict:
        return dict(
            lead_id=row["lead_id"],
            biz_id=row["biz_id"],
            name=row["name"],
            source=row["source"],
            segments=[SegmentType[s] for s in row["segments"]],
            statistics={
                el: row[el]
                for el in (
                    "clicks_on_phone",
                    "site_opens",
                    "make_routes",
                    "review_rating",
                    "view_working_hours",
                    "view_entrances",
                    "showcase_product_click",
                    "promo_to_site",
                    "location_sharing",
                    "geoproduct_button_click",
                    "favourite_click",
                    "cta_button_click",
                    "booking_section_interaction",
                    "last_activity_timestamp",
                )
            },
        )

    async def list_segments(self, biz_id: int) -> Dict[int, dict]:
        async with self._db.acquire() as con:
            result = await con.fetchrow(sqls.list_segments, biz_id)

        return dict(
            total_leads=result.get("total_leads"),
            segments={
                SegmentType.PROSPECTIVE: result.get("prospective_count"),
                SegmentType.ACTIVE: result.get("active_count"),
                SegmentType.LOST: result.get("lost_count"),
                SegmentType.LOYAL: result.get("loyal_count"),
                SegmentType.DISLOYAL: result.get("disloyal_count"),
            },
        )

    async def list_lead_events(
        self, *, biz_id: int, lead_id: int, limit: int, offset: int
    ) -> Dict[int, List[dict]]:
        async with self._db.acquire(PoolType.replica) as con:
            rows = await con.fetch(
                sqls.list_lead_events, biz_id, lead_id, limit, offset
            )

            total_events, events = self._normalize_events(rows, biz_id, lead_id)

            return dict(total_events=total_events, events=events)

    def _normalize_events(
        self, rows: list, biz_id: int, lead_id: int
    ) -> [int, List[dict]]:
        total_events = rows[0]["total_events"]

        if total_events == 0:
            raise UnknownLead(f"Unknown lead with biz_id={biz_id}, lead_id={lead_id}")

        if len(rows) == 1 and rows[0]["event_type"] is None:
            return total_events, []

        return total_events, [self._normalize_event(row) for row in rows]

    def _normalize_event(self, row: dict) -> dict:
        return dict(
            event_type=row["event_type"],
            event_value=row["event_value"],
            event_timestamp=row["event_timestamp"],
            source=row["source"],
        )

    async def list_lead_segments(
        self,
        *,
        passport_uid: Optional[str] = None,
        yandex_uid: Optional[str] = None,
        device_id: Optional[str] = None,
        biz_id: Optional[int] = None,
    ) -> List[dict]:
        if not any([passport_uid, yandex_uid, device_id]):
            raise NoLeadIdFieldsPassed()

        search_params = []
        if passport_uid:
            search_params.append(DbParam("passport_uid", passport_uid))
        if yandex_uid:
            search_params.append(DbParam("yandex_uid", yandex_uid))
        if device_id:
            search_params.append(DbParam("device_id", device_id))

        search_params_pairs = []
        for param_pos, param in enumerate(search_params, start=2):
            search_params_pairs.append(f"{param.db_field}=${param_pos}")

        async with self._db.acquire(PoolType.replica) as con:
            rows = await con.fetch(
                sqls.list_lead_segments.format(
                    leads_filter_str=f"({' OR '.join(search_params_pairs)})"
                ),
                biz_id,
                *[param.db_value for param in search_params],
            )

            return [
                dict(
                    lead_id=row["lead_id"],
                    biz_id=row["biz_id"],
                    segments=[SegmentType[s] for s in row["segments"]],
                )
                for row in rows
            ]

    async def check_leads_existence_by_passport(self, passport_uid: str) -> bool:
        async with self._db.acquire(PoolType.replica) as con:
            return await con.fetchval(
                sqls.check_leads_existence_by_passport, passport_uid
            )

    async def delete_leads_data_by_passport(self, passport_uid: str) -> List[dict]:
        async with self._db.acquire(PoolType.master) as con:
            rows = await con.fetch(sqls.delete_leads_data_by_passport, passport_uid)

            return [dict(row) for row in rows]
