from datetime import datetime
from typing import List

import pytz

from maps_adv.stat_tasks_starter.lib.base.exceptions import UnexpectedNaiveDateTime


class UnexpectedEmptyCampaignsIds(Exception):
    pass


__select_events_stat_template = """SELECT CampaignID, charged_daily, charged_total, events_count
FROM
(
    SELECT CampaignID, events_count
    FROM (
        SELECT DISTINCT CampaignID
        FROM {database}.{normalized_table}
        WHERE
            CampaignID IN ({campaigns_ids})
            AND EventName = 'geoadv.bb.pin.show'
    ) LEFT JOIN (
        SELECT CampaignID, Count(*) AS events_count
        FROM {database}.{normalized_table}
        WHERE
            CampaignID IN ({campaigns_ids})
            AND EventName = 'geoadv.bb.pin.show'
            AND ReceiveTimestamp BETWEEN {timestamp_start} AND {timestamp_end}
        GROUP BY CampaignID
    ) USING (CampaignID)
) LEFT JOIN (
    SELECT CampaignID, charged_total, charged_daily
    FROM (
        SELECT CampaignID, SUM(Cost) AS charged_total
        FROM {database}.{charged_table}
        WHERE
            CampaignID IN ({campaigns_ids})
            AND EventName = 'geoadv.bb.pin.show'
        GROUP BY CampaignID
    ) LEFT JOIN (
        SELECT CampaignID, Sum(Cost) AS charged_daily
        FROM {database}.{charged_table}
        WHERE
            CampaignID IN ({campaigns_ids})
            AND EventName = 'geoadv.bb.pin.show'
            AND ReceiveTimestamp BETWEEN {local_day_start} AND {local_day_end}
        GROUP BY CampaignID
    ) USING (CampaignID)
) USING (CampaignID)"""


def build_select_events_stat(
    database: str,
    normalized_table: str,
    charged_table: str,
    timing_from: datetime,
    timing_to: datetime,
    campaigns_ids: List[int],
    tz_name: str,
) -> str:
    if not all([timing_from.tzinfo, timing_to.tzinfo]):
        raise UnexpectedNaiveDateTime()

    if not campaigns_ids:
        raise UnexpectedEmptyCampaignsIds()

    timestamp_start = int(timing_from.timestamp())
    timestamp_end = int(timing_to.timestamp())

    timezone = pytz.timezone(tz_name)
    local_timing_from = timing_from.astimezone(timezone)
    local_day_start = int(
        local_timing_from.replace(hour=0, minute=0, second=0).timestamp()
    )
    local_day_end = int(
        local_timing_from.replace(hour=23, minute=59, second=59).timestamp()
    )

    return __select_events_stat_template.format(
        database=database,
        normalized_table=normalized_table,
        charged_table=charged_table,
        timestamp_start=timestamp_start,
        timestamp_end=timestamp_end,
        local_day_start=local_day_start,
        local_day_end=local_day_end,
        campaigns_ids=", ".join(map(str, campaigns_ids)),
    )


def build_union_select_events(events_selects: List[str]) -> str:
    return " UNION ALL ".join(events_selects)
