import asyncio
import concurrent.futures
from collections import defaultdict
from datetime import date, datetime, time, timedelta, timezone
from itertools import chain
from operator import itemgetter
from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple, Union

import aioch
import copy
import itertools
from dateutil import tz
from yql.api.v1.client import YqlClient

from smb.common.pgswim import PoolType
from maps_adv.common.helpers.enums import CampaignTypeEnum
from maps_adv.common.helpers.mappers import EVENT_TYPES
from maps_adv.statistics.dashboard.server.lib import sqls
from maps_adv.statistics.dashboard.server.lib.db.engine import DB

__all__ = ["AbstractDataManager", "DataManager", "NothingFound", "NoCampaignsPassed"]


class NothingFound(Exception):
    pass


class NoCampaignsPassed(Exception):
    pass


class CampaignsFromDifferentVersionOfStatistics(Exception):
    pass


class AbstractDataManager:
    async def calculate_by_campaigns_and_period(
        self, *, campaign_ids: Iterable[int], period_from: date, period_to: date
    ) -> List[dict]:
        raise NotImplementedError()

    async def calculate_campaigns_charged_sum(
        self, campaign_ids: Iterable[int], on_timestamp: Optional[Union[int, float]]
    ) -> List[dict]:
        raise NotImplementedError()

    async def sync_category_search_reports(self):
        raise NotImplementedError()

    async def fetch_search_icons_statistics(
        self, campaign_ids: List[int], period_from: str, period_to: str
    ) -> List[dict]:
        raise NotImplementedError()

    async def calculate_campaigns_events_for_period(
        self,
        events_query: Collection[Tuple[int, CampaignTypeEnum]],
        period_from: Optional[datetime] = None,
        period_to: Optional[datetime] = None,
    ) -> Dict[int, int]:
        raise NotImplementedError()

    async def calculate_metrics(
        self,
        start_time: datetime,
        end_time: datetime,
        campaign_ids: Optional[List[int]] = None,
    ) -> List[dict]:
        raise NotImplementedError()

    async def retrieve_tables_metrics(self, end_time: datetime) -> List[dict]:
        raise NotImplementedError()

    async def get_campaign_ids_for_period(
        self, start_time: datetime, end_time: datetime
    ) -> List[int]:
        raise NotImplementedError()

    async def get_aggregated_normalized_events_by_campaign(
        self, start_time: datetime, end_time: datetime
    ) -> Dict[int, dict]:
        raise NotImplementedError()

    async def get_aggregated_processed_events_by_campaign(
        self, start_time: datetime, end_time: datetime
    ) -> Dict[int, dict]:
        raise NotImplementedError()

    async def get_aggregated_mapkit_events_by_campaign(
        self, start_time: datetime, end_time: datetime
    ) -> Dict[int, dict]:
        raise NotImplementedError()


class DataManager(AbstractDataManager):
    __slots__ = (
        "_ch_config",
        "_database",
        "_table",
        "_aggregated_table",
        "_pg",
        "_yql_config",
        "_campaigns_only_for_v2",
    )

    _ch_config: dict
    _yql_config: dict
    _database: str
    _table: str
    _aggregated_table: str
    _pg: DB
    _campaigns_only_for_v2: Set[int]

    def __init__(
        self,
        *,
        ch_config: dict,
        table: str,
        aggregated_table: str,
        postgres_db: DB = None,
        yql_config: Optional[dict] = None,
        campaigns_only_for_v2: Optional[Iterable[int]] = None,
        use_only_v2: bool = False,
    ):
        self._pg = postgres_db
        self._yql_config = yql_config
        self._ch_config = copy.copy(ch_config)
        self._clients = itertools.cycle(
            map(
                lambda replica: aioch.Client(**self._ch_config, **replica),
                self._ch_config.pop("hosts"),
            ),
        )
        self._database = self._ch_config["database"]
        self._table = table
        self._aggregated_table = aggregated_table
        self._campaigns_only_for_v2 = set()
        if campaigns_only_for_v2:
            self._campaigns_only_for_v2 = set(campaigns_only_for_v2)
        self._use_only_v2 = use_only_v2

    def _split_campaign_ids(self, campaign_ids) -> Tuple[List[int], List[int]]:
        campaigns_for_v1_statistics = []
        campaigns_for_v2_statistics = []

        if self._use_only_v2:
            campaigns_for_v2_statistics.extend(campaign_ids)

        else:
            for campaign_id in campaign_ids:
                if campaign_id in self._campaigns_only_for_v2:
                    campaigns_for_v2_statistics.append(campaign_id)
                else:
                    campaigns_for_v1_statistics.append(campaign_id)

        return campaigns_for_v1_statistics, campaigns_for_v2_statistics

    async def calculate_by_campaigns_and_period(
        self, *, campaign_ids: Iterable[int], period_from: date, period_to: date
    ) -> List[dict]:
        if not campaign_ids:
            raise NothingFound()

        campaign_ids_for_v1, campaign_ids_for_v2 = self._split_campaign_ids(
            campaign_ids
        )

        if campaign_ids_for_v1 and campaign_ids_for_v2:
            raise CampaignsFromDifferentVersionOfStatistics(campaign_ids)

        sql = None
        if campaign_ids_for_v1:
            sql = f"""
                SELECT
                    date,
                    countMerge(call) AS call,
                    countMerge(makeRoute) AS makeRoute,
                    countMerge(openSite) AS openSite,
                    countMerge(saveOffer) AS saveOffer,
                    countMerge(search) AS search,
                    countMerge(show) AS show,
                    countMerge(tap) AS tap,
                    (show = 0 ? 0 : floor(tap / show, 4)) as ctr,
                    (tap = 0 ? 0 : floor(makeRoute / tap, 4)) as clicks_to_routes,
                    round(sumMerge(charged_sum), 2) AS charged_sum,
                    least(uniqMerge(show_unique), show) as show_unique
                FROM {self._database}.{self._aggregated_table}
                WHERE CampaignID IN ({', '.join(map(str, campaign_ids_for_v1))})
                    AND date BETWEEN %(period_from)s AND %(period_to)s
                GROUP BY date WITH TOTALS
                ORDER BY date DESC
            """

        if campaign_ids_for_v2:
            sql = f"""
                SELECT
                    date,
                    countMerge(call) AS call,
                    countMerge(make_route) AS make_route,
                    countMerge(open_site) AS open_site,
                    countMerge(save_offer) AS save_offer,
                    countMerge(search) AS search,
                    countMerge(show) AS show,
                    countMerge(tap) AS tap,
                    (show = 0 ? 0 : floor(tap / show, 4)) as ctr,
                    (tap = 0 ? 0 : floor(make_route / tap, 4)) as clicks_to_routes,
                    round(sumMerge(charged_sum), 2) AS charged_sum,
                    least(uniqMerge(show_unique), show) as show_unique
                FROM {self._database}.aggregated_processed_events_by_campaigns_and_days_distributed
                WHERE campaign_id IN ({', '.join(map(str, campaign_ids_for_v2))})
                    AND date BETWEEN %(period_from)s AND %(period_to)s
                GROUP BY date WITH TOTALS
                ORDER BY date DESC
            """

        got = await next(self._clients).execute(
            sql, {"period_from": period_from, "period_to": period_to}
        )

        # result should include empty total if nothing found
        if len(got) <= 1:
            raise NothingFound()

        keys = (
            "date",
            "call",
            "makeRoute",
            "openSite",
            "saveOffer",
            "search",
            "show",
            "tap",
            "ctr",
            "clicks_to_routes",
            "charged_sum",
            "show_unique",
        )

        return [
            *[dict(zip(keys, el)) for el in got[:-1]],
            dict(zip(keys[1:], got[-1][1:])),
        ]

    async def calculate_campaigns_charged_sum(
        self,
        campaign_ids: Iterable[int],
        on_timestamp: Optional[Union[int, float]] = None,
    ) -> List[dict]:
        if not campaign_ids:
            raise NothingFound()

        if not on_timestamp:
            on_timestamp = datetime.now(tz=timezone.utc).timestamp()

        campaign_ids_for_v1, campaign_ids_for_v2 = self._split_campaign_ids(
            campaign_ids
        )

        sqls = []
        if campaign_ids_for_v1:
            sqls.append(
                f"""
                SELECT
                    CampaignID,
                    floor(sum(Cost), 2) AS charged_sum
                FROM {self._database}.{self._table}
                WHERE CampaignID IN ({', '.join(map(str, campaign_ids_for_v1))})
                  AND ReceiveTimestamp <= {int(on_timestamp)}
                GROUP BY CampaignID
                """
            )

        if campaign_ids_for_v2:
            sqls.append(
                f"""
                SELECT
                    campaign_id,
                    floor(sum(cost), 2) AS charged_sum
                FROM {self._database}.processed_events_distributed
                WHERE campaign_id IN ({', '.join(map(str, campaign_ids_for_v2))})
                  AND receive_timestamp <= {int(on_timestamp)}
                GROUP BY campaign_id
                """
            )

        sql = "\nUNION ALL\n".join(sqls)

        got = await next(self._clients).execute(sql)

        if len(got) == 0:
            raise NothingFound()

        return [dict(zip(("campaign_id", "charged_sum"), el)) for el in got]

    async def fetch_search_icons_statistics(
        self, campaign_ids: List[int], period_from: str, period_to: str
    ) -> List[dict]:
        if not campaign_ids:
            raise NoCampaignsPassed()

        sql = f"""
            SELECT
                date::text as date,
                SUM(icon_shows) as icon_shows,
                SUM(icon_clicks) as icon_clicks,
                {"NULL" if len(campaign_ids) > 1 else "SUM(devices)"}
                    as unique_icon_shows,
                SUM(pin_shows) as pin_shows,
                SUM(pin_clicks) as pin_clicks,
                SUM(routes) as routes
            FROM category_search_report
            WHERE (date BETWEEN $1 AND $2) AND campaign_id = ANY($3::int[])
            GROUP BY date
        UNION ALL
            SELECT
                NULL as date,
                SUM(icon_shows) as icon_shows,
                SUM(icon_clicks) as icon_clicks,
                NULL as unique_icon_shows,
                SUM(pin_shows) as pin_shows,
                SUM(pin_clicks) as pin_clicks,
                SUM(routes) as routes
            FROM category_search_report
            WHERE (date BETWEEN $1 AND $2) AND campaign_id = ANY($3::int[])
        ORDER BY date DESC NULLS LAST
        """

        async with self._pg.acquire(PoolType.replica) as con:
            got = await con.fetch(sql, period_from, period_to, campaign_ids)

        if len(got) <= 1:
            raise NothingFound

        return [dict(row) for row in got]

    async def sync_category_search_reports(self):
        cluster_name = self._yql_config["cluster"]
        yt_table_name = self._yql_config["category_search_report_table"]

        async with self._pg.acquire() as conn:
            async with conn.transaction():
                max_sync_timestamp = await conn.fetchval(
                    """
                        SELECT coalesce(max(created_at), 0)
                        FROM category_search_report"""
                )

        def _yql_query() -> Dict[date, List[tuple]]:
            pool_pragma = (
                f"""
                PRAGMA yt.PoolTrees = "physical";
                PRAGMA yt.Pool = "{self._yql_config["yt_pool"]}";
                """
                if self._yql_config["yt_pool"]
                else ""
            )
            with YqlClient(token=self._yql_config["token"]) as client:
                request = client.query(
                    f"""
                    {pool_pragma}
                    SELECT
                        `campaign_id`,
                        `fielddate` as `date`,
                        `created_at`,
                        `icon_clicks`,
                        `icon_shows`,
                        `clicks` as `pin_clicks`,
                        `shows` as `pin_shows`,
                        `routes`,
                        `devices`
                    FROM {cluster_name}.`{yt_table_name}`
                    WHERE fielddate in (
                        SELECT `fielddate`
                        FROM {cluster_name}.`{yt_table_name}`
                        WHERE created_at>{max_sync_timestamp}
                        GROUP BY fielddate
                    )""",
                    syntax_version=1,
                )

                request.run()
                table = request.get_results().table

                data = defaultdict(list)
                for row in table.get_iterator():
                    campaign_id, date_report, created_at = row[:3]
                    icon_clicks, icon_shows, pin_clicks, pin_shows, routes, devices = [
                        value if value is not None else 0 for value in row[3:]
                    ]

                    date_report = datetime.strptime(date_report, "%Y-%m-%d").date()

                    data[date_report].append(
                        (
                            int(campaign_id),
                            date_report,
                            created_at,
                            icon_clicks,
                            icon_shows,
                            pin_clicks,
                            pin_shows,
                            routes,
                            devices,
                        )
                    )

                return data

        with concurrent.futures.ThreadPoolExecutor() as pool:
            result = await asyncio.get_event_loop().run_in_executor(pool, _yql_query)

        async with self._pg.acquire() as conn:
            async with conn.transaction():
                for sync_date, rows in result.items():
                    await conn.execute(
                        """
                        DELETE FROM category_search_report
                        WHERE date=$1""",
                        sync_date,
                    )
                    await conn.copy_records_to_table(
                        "category_search_report",
                        records=rows,
                        columns=[
                            "campaign_id",
                            "date",
                            "created_at",
                            "icon_clicks",
                            "icon_shows",
                            "pin_clicks",
                            "pin_shows",
                            "routes",
                            "devices",
                        ],
                    )

    async def calculate_campaigns_events_for_period(
        self,
        events_query: Collection[Tuple[int, CampaignTypeEnum]],
        period_from: Optional[datetime] = None,
        period_to: Optional[datetime] = None,
    ) -> Dict[int, int]:
        if not events_query:
            return {}

        if not period_from:
            period_from = datetime.combine(
                date.today(), time(0, 0, 0, tzinfo=tz.tzutc())
            )

        if not period_to:
            period_to = datetime.now(tz=timezone.utc)

        result = {id: 0 for id in map(itemgetter(0), events_query)}
        events_query = chain.from_iterable(
            [(id, event) for event in EVENT_TYPES[type]] for id, type in events_query
        )
        got = await next(self._clients).execute(
            sqls.calculate_campaign_events(events_query, period_from, period_to)
        )
        result.update(dict(got))

        return result

    async def calculate_metrics(
        self,
        start_time: datetime,
        end_time: datetime,
        campaign_ids: Optional[List[int]] = None,
    ) -> dict:

        # TODO: add support for new statisic sources - GEOPROD-4520
        sql = f"""
        SELECT
            countIf(EventName, EventName = 'geoadv.bb.pin.show') AS shows,
            countIf(EventName, EventName = 'geoadv.bb.pin.tap') AS clicks,
            uniqExact(DeviceID) as users
        FROM {self._database}.maps_adv_statistics_raw_metrika_log_distributed
        WHERE _timestamp BETWEEN  {int(start_time.timestamp())} AND {int(end_time.timestamp())}
        """  # noqa: E501

        if campaign_ids:
            sql += f" AND CampaignID IN ({', '.join(map(str, campaign_ids))})"

        got = await next(self._clients).execute(sql)

        return [dict(zip(("shows", "clicks", "users"), el)) for el in got][0]

    async def retrieve_tables_metrics(self, end_time: datetime) -> List[dict]:
        start_time = end_time - timedelta(days=7)

        start_timestamp = int(start_time.timestamp())
        end_timestamp = int(end_time.timestamp())

        sql = f"""
            SELECT *
            FROM (
                SELECT 'mapkit_events' AS name,
                       toUnixTimestamp(max(receive_time)) AS max_receive_timestamp
                FROM {self._database}.mapkit_events_distributed
                WHERE receive_time BETWEEN {start_timestamp} AND {end_timestamp}

                UNION ALL
                SELECT 'maps_adv_statistics_raw_metrika_log' AS name,
                       toUnixTimestamp(max(ReceiveTimestamp)) AS max_receive_timestamp
                FROM {self._database}.maps_adv_statistics_raw_metrika_log_distributed
                WHERE ReceiveTimestamp BETWEEN {start_timestamp} AND {end_timestamp}

                UNION ALL
                SELECT 'normalized_events' AS name,
                       toUnixTimestamp(max(receive_timestamp)) AS max_receive_timestamp
                FROM {self._database}.normalized_events_distributed
                WHERE receive_timestamp BETWEEN  {start_timestamp} AND {end_timestamp}

                UNION ALL
                SELECT 'processed_events' AS name,
                       toUnixTimestamp(max(receive_timestamp)) AS max_receive_timestamp
                FROM {self._database}.processed_events_distributed
                WHERE receive_timestamp BETWEEN {start_timestamp} AND {end_timestamp}
            ) as tables
            ORDER BY tables.name
        """

        got = await next(self._clients).execute(sql)

        return [dict(zip(("table", "max_receive_timestamp"), el)) for el in got]

    async def get_campaign_ids_for_period(
        self, start_time: datetime, end_time: datetime
    ) -> List[int]:

        sql = f"""SELECT DISTINCT CampaignID
                FROM {self._database}.maps_adv_statistics_raw_metrika_log_distributed
                WHERE _timestamp BETWEEN  {int(start_time.timestamp())}
                                 AND {int(end_time.timestamp())}
                """

        got = await next(self._clients).execute(sql)

        return [id[0] for id in got]

    async def get_aggregated_normalized_events_by_campaign(
        self, start_time: datetime, end_time: datetime
    ) -> Dict[int, dict]:
        return await self._get_aggregated_events_from_table_by_campaign(
            "normalized_events_distributed", start_time, end_time
        )

    async def get_aggregated_processed_events_by_campaign(
        self, start_time: datetime, end_time: datetime
    ) -> Dict[int, dict]:
        return await self._get_aggregated_events_from_table_by_campaign(
            "processed_events_distributed", start_time, end_time
        )

    async def _get_aggregated_events_from_table_by_campaign(
        self, table_name: str, start_time: datetime, end_time: datetime
    ) -> Dict[int, dict]:
        sql = f"""SELECT campaign_id, countIf(event_name='BILLBOARD_SHOW') as billboard_show, countIf(event_name='ACTION_MAKE_ROUTE') as action_make_route
                  FROM {self._database}.{table_name}
                  WHERE receive_timestamp BETWEEN {int(start_time.timestamp())} AND {int(end_time.timestamp())}
                  GROUP BY campaign_id;"""

        result = await next(self._clients).execute(sql)
        return {
            row[0]: {"billboard_show": row[1], "action_make_route": row[2]}
            for row in result
        }

    async def get_aggregated_mapkit_events_by_campaign(
        self, start_time: datetime, end_time: datetime
    ) -> Dict[int, dict]:
        sql = f"""SELECT toUInt64OrNull(JSONExtractString(log_id, 'campaignId')) AS campaign_id,
                        countIf(event='billboard.show') as billboard_show,
                        countIf(event='billboard.navigation.via') as action_make_route
                  FROM stat.mapkit_events_distributed
                  WHERE receive_time BETWEEN {int(start_time.timestamp())} AND {int(end_time.timestamp())}
                  GROUP BY campaign_id;"""

        result = await next(self._clients).execute(sql)
        return {
            row[0]: {"billboard_show": row[1], "action_make_route": row[2]}
            for row in result
        }
