import logging
import typing
from datetime import datetime, timezone
from dateutil.relativedelta import relativedelta

from crm.agency_cabinet.rewards.server.src.db import models
from crm.agency_cabinet.common.consts import compute_reward_type
from crm.agency_cabinet.common.yt.base import async_to_sync, BaseRowLoadException
from crm.agency_cabinet.common.yql.base import YqlModelLoader

LOGGER = logging.getLogger('celery.load_rewards.reward_payment')


class UnknownPeriodException(BaseRowLoadException):
    pass


class RewardPaymentRewardLoader(YqlModelLoader):
    YQL_QUERY = '''
    USE hahn;

    DECLARE $payment_table_name AS String;
    DECLARE $period_from AS String;

    SELECT
        contract_id AS contract_id,
        SUM(reward_to_pay) AS payment,
        from_dt, till_dt

    FROM $payment_table_name
    WHERE from_dt >= $period_from
    GROUP BY (contract_id, from_dt, till_dt);
    '''

    def __init__(self, from_dt: datetime, *args, **kwargs):
        self.from_dt = from_dt
        super().__init__(*args, **kwargs)

    def _init(self, **kwargs):
        super()._init(**kwargs)

        @async_to_sync
        async def get_all_rewards(from_dt: datetime):
            return await models.Reward.query.where(
                models.Reward.period_from >= from_dt
            ).gino.all()

        self.rewards_map = {}

        for reward in get_all_rewards(self.from_dt):
            key = (reward.contract_id, reward.type, reward.period_from)
            self.rewards_map[key] = reward

    def _find_duplicate(self, yt_row) -> typing.Optional[models.Reward]:
        contract_id = self._get_column_value(yt_row, 'contract_id')

        try:
            period_from = datetime.strptime(
                self._get_column_value(yt_row, 'from_dt'), '%Y-%m-%dT%H:%M:%SZ'
            ).replace(tzinfo=timezone.utc)
            period_from = period_from + relativedelta(hours=3)

            period_to = datetime.strptime(
                self._get_column_value(yt_row, 'till_dt'), '%Y-%m-%dT%H:%M:%SZ'
            ).replace(tzinfo=timezone.utc)
            period_to = period_to + relativedelta(hours=3)

            reward_type = compute_reward_type(period_from, period_to)

        except ValueError:
            raise UnknownPeriodException

        key = (contract_id, reward_type, period_from)
        if key in self.rewards_map:
            return self.rewards_map[key]
        else:
            LOGGER.warning(f"Couldn't find reward: contract_id={contract_id}, period_from={period_from}, "
                           f"type={reward_type}")
            return None

    def _process_duplicate(self, yt_row, db_row: models.Reward):
        @async_to_sync
        async def update_reward(reward: models.Reward, payment: float):
            await reward.update(payment=payment).apply()

        update_reward(db_row, self._get_column_value(yt_row, 'payment'))

    def _create_db_rows(self, db_rows) -> typing.List[models.Reward]:
        return []
