import typing
from datetime import datetime, timezone
from decimal import Decimal

from crm.agency_cabinet.common.consts.reward import RewardsTypes
from crm.agency_cabinet.common.consts.service import Services
from crm.agency_cabinet.common.yt.base import async_to_sync
from crm.agency_cabinet.rewards.server.src.db.models import Contract, Reward, ServiceReward
from crm.agency_cabinet.rewards.server.src.celery.tasks.deprecated.rewards_2021.load_rewards.base import BadYtRow, BaseRewardLoader
from crm.agency_cabinet.common.consts import START_FIN_YEAR_2021


class EarlyPaymentRewardLoader(BaseRewardLoader):
    REWARD_PERIOD = RewardsTypes.month
    IS_PROF = None
    MIN_TILL_DT = START_FIN_YEAR_2021.strftime('%Y-%m-%dT%H:%M:%SZ')
    expected_columns = ('contract_eid', 'reward_to_charge')

    def _init(self, **kwargs):
        self.rewards_map: typing.Dict[typing.Tuple[str, datetime], Reward] = {}
        self.contracts_map: typing.Dict[str, Contract] = {}
        self.raise_on_unknown_contract = kwargs.get('raise_on_unknown_contract', True)

    def _get_clear_query(self):
        return None

    def _find_duplicate(self, yt_row) -> typing.Optional[ServiceReward]:
        eid = yt_row['contract_eid']
        d_type = yt_row.get('discount_type') or yt_row.get('service_id')
        period_from = self._get_period_from(yt_row)
        key = (eid, period_from)
        reward = self.rewards_map.get(key)

        @async_to_sync
        async def _get_service_reward(reward_id, discount_type):
            async with self.db_bind:
                return await ServiceReward.query.where(
                    ServiceReward.reward_id == reward_id
                ).where(
                    ServiceReward.discount_type == discount_type
                ).where(
                    ServiceReward.service == Services.early_payment.value
                ).gino.first()
        return _get_service_reward(reward.id, d_type)

    def _process_duplicate(self, yt_row, db_row: ServiceReward):
        @async_to_sync
        async def _update_service_reward(service_reward: ServiceReward, payment: float):
            async with self.db_bind:
                return await service_reward.update(
                    payment=payment,
                ).apply()

        _update_service_reward(db_row, yt_row['reward_to_charge'])

    @staticmethod
    def _get_service(yt_row):
        return Services.early_payment.value

    @staticmethod
    def _get_reward_percent(yt_row):
        return Decimal(2.0)

    def _get_reward_id(self, yt_row):
        eid = yt_row['contract_eid']
        period_from = self._get_period_from(yt_row)
        key = (eid, period_from)
        if key in self.rewards_map:
            return self.rewards_map[key].id
        return None

    def _process_rewards_map(self, eid, contract, yt_row):
        @async_to_sync
        async def _get_or_create_reward(columns):
            async with self.db_bind:
                filter_columns = {
                    'contract_id': columns['contract_id'],
                    'period_from': columns['period_from'],
                    'type': columns['type'],
                }
                model = await Reward.get_or_create(filter_columns, columns)
                model.payment = Decimal(0)
                return model
        period_from = self._get_period_from(yt_row)
        key = (eid, period_from)
        if key in self.rewards_map:
            pass
        else:
            r_columns = self._build_reward_columns(yt_row, contract)
            reward = _get_or_create_reward(r_columns)
            self.rewards_map[key] = reward

    def _get_period_from(self, yt_row) -> datetime:
        period_from = yt_row.get('parsed_till_dt')
        if period_from is None:
            yt_row_period_from = yt_row.get('till_dt')
            try:
                raw_period_from = datetime.strptime(yt_row_period_from, '%Y-%m-%dT%H:%M:%SZ').replace(tzinfo=timezone.utc)
                period_from = raw_period_from.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
                yt_row['parsed_till_dt'] = period_from
            except ValueError as ex:
                raise BadYtRow('Can\' parse till_dt') from ex
        return period_from

    def _build_reward_columns(self, yt_row, contract):
        period_from = self._get_period_from(yt_row)
        is_prof = yt_row.get('tp') == 'prof'
        r_columns = {
            'contract_id': contract.id,
            'period_from': period_from,
            'type': self.REWARD_PERIOD.value,
            'is_prof': is_prof,
            'payment': Decimal(0),
            'is_accrued': True
        }
        return r_columns

    @classmethod
    def filter_reward(cls, input_row):
        till_dt = input_row.get('till_dt')
        if input_row.get('reward_type') == 311 and till_dt and till_dt >= cls.MIN_TILL_DT:
            yield input_row

    def _read_table(self, **kwargs):
        table_path_without_last_slash = self.table_path[:-1]
        with self.client.TempTable() as output_table:
            self.client.run_map(
                self.filter_reward,
                table_path_without_last_slash,
                output_table,
                **kwargs
            )

            self._table = self.client.read_table(output_table, **kwargs)


class EarlyPaymentRewardLocalFilterLoader(EarlyPaymentRewardLoader):

    @classmethod
    def filter_reward(cls, input_row):
        try:
            next(EarlyPaymentRewardLoader.filter_reward(input_row))
            return True
        except StopIteration:
            return False

    def _read_table(self, **kwargs):
        self._table = filter(self.filter_reward, self.client.read_table(self.table_path[:-1], **kwargs))
