import logging
import typing
from datetime import datetime, timezone
from decimal import Decimal

from crm.agency_cabinet.common.consts.reward import RewardsTypes
from crm.agency_cabinet.common.consts.service import Services
from crm.agency_cabinet.common.yt.base import async_to_sync, BaseRowLoadException, YtModelLoader
from crm.agency_cabinet.rewards.server.src.db import db as current_db
from crm.agency_cabinet.rewards.server.src.db.models import Contract, Reward, ServiceReward
from crm.agency_cabinet.common.consts.service import service_id_to_name

LOGGER = logging.getLogger('celery.load_rewards')


class UnknownDiscountType(BaseRowLoadException):
    pass


class ContractNotFound(BaseRowLoadException):
    pass


class BadYtRow(BaseRowLoadException):
    pass


class InconsistentPeriodFrom(BaseRowLoadException):
    default_message = 'INCONSISTENT_PERIOD_FROM'


class NoEidInYtRow(BaseRowLoadException):
    default_message = 'CAN_NOT_GET_EID_FROM_ROW'


class BaseRewardLoader(YtModelLoader):

    REWARD_PERIOD: RewardsTypes = None
    IS_PROF: bool = None
    expected_columns = ('contract_eid', 'amt', 'reward')

    def _init(self, **kwargs):
        self.rewards_map: typing.Dict[str, Reward] = {}
        self.contracts_map: typing.Dict[str, Contract] = {}
        self.period_from = datetime.strptime(self.table_path.strip('/')[-6:], '%Y%m').replace(tzinfo=timezone.utc)
        self.raise_on_unknown_contract = kwargs.get('raise_on_unknown_contract', True)
        # if 'payment' not in self.columns_mapper:
        #    raise ValueError('Expected payment column')

    def _get_clear_query(self):
        return Reward.delete.where(
            Reward.period_from == self.period_from
        ).where(
            Reward.type == self.REWARD_PERIOD.value
        ).where(
            Reward.is_prof == self.IS_PROF
        ).where(
            Reward.predict.is_(True)
        ).gino.status()

    def _find_duplicate(self, yt_row) -> typing.Optional[ServiceReward]:
        eid = yt_row['contract_eid']
        d_type = yt_row.get('discount_type') or yt_row.get('service_id')
        reward = self.rewards_map.get(eid)

        @async_to_sync
        async def _get_service_reward(reward_id, discount_type):
            async with self.db_bind:
                return await ServiceReward.query.where(
                    ServiceReward.reward_id == reward_id
                ).where(
                    ServiceReward.discount_type == discount_type
                ).where(
                    ServiceReward.service != Services.early_payment.value
                ).gino.first()
        if reward.predict:
            return None
        return _get_service_reward(reward.id, d_type)

    def _process_duplicate(self, yt_row, db_row: ServiceReward):
        @async_to_sync
        async def _update_service_reward(service_reward: ServiceReward, payment: float, revenue: float):
            async with self.db_bind:
                return await service_reward.update(
                    payment=payment,
                    revenue=revenue
                ).apply()

        _update_service_reward(db_row, yt_row['reward'], yt_row['amt'])

    def _before_start(self):
        # clear all predict rows with same period from
        @async_to_sync
        async def _clear_rewards():
            async with self.db_bind:
                async with current_db.transaction() as tx:  # noqa
                    query = self._get_clear_query()
                    if query:
                        return await query
        _clear_rewards()

    @staticmethod
    def _get_service(yt_row):
        discount_type = yt_row.get('discount_type') or yt_row.get('service_id')
        discount_type = service_id_to_name(discount_type)
        if not discount_type:
            raise UnknownDiscountType(f'UNKNOWN_DISCOUNT_TYPE: {discount_type}')
        return discount_type

    @staticmethod
    def _get_reward_percent(yt_row) -> typing.Optional[Decimal]:
        payment = Decimal(yt_row['reward'])
        revenue = Decimal(yt_row['amt'])
        return payment / revenue * 100 if payment and revenue else None

    def _get_reward_id(self, yt_row):
        eid = yt_row['contract_eid']
        if eid in self.rewards_map:
            return self.rewards_map[eid].id
        return None

    def _on_contract_not_found(self, eid):
        if self.raise_on_unknown_contract:
            raise ContractNotFound(f'CONTRACT_NOT_FOUND: {eid}')

        @async_to_sync
        async def _create_stub_contract():
            async with self.db_bind:
                return await Contract.create(eid=eid, agency_id=0, payment_type='STUB', type='STUB')
        return _create_stub_contract()

    def _build_reward_columns(self, yt_row, contract):
        yt_row_period_from = yt_row.get('yyyymm')
        if yt_row_period_from:
            try:
                period_from = datetime.strptime(yt_row_period_from, '%Y%m').replace(tzinfo=timezone.utc)
            except ValueError as ex:
                raise BadYtRow('Can\' parse yyyymm') from ex
        else:
            period_from = self.period_from
        if period_from != self.period_from:
            raise InconsistentPeriodFrom()
        r_columns = {
            'contract_id': contract.id,
            'period_from': period_from,
            'type': self.REWARD_PERIOD.value,
            'is_prof': self.IS_PROF,
            'payment': Decimal(0),
            'is_accrued': True
        }
        return r_columns

    def _process_rewards_map(self, eid, contract, yt_row):
        @async_to_sync
        async def _get_or_create_reward(columns):
            async with self.db_bind:
                filter_columns = {
                    'contract_id': columns['contract_id'],
                    'period_from': columns['period_from'],
                    'type': columns['type'],
                }
                model = await Reward.get_or_create(filter_columns, columns)
                model.payment = Decimal(0)
                return model

        # @async_to_sync
        # async def _clear_predict(reward_id):
        #     async with current_db.transaction():
        #         await Reward.update.values(predict=False).gino.status()
        #         await ServiceReward.delete.where(ServiceReward.reward_id == reward_id).gino.status()

        if eid in self.rewards_map:
            pass
        else:
            r_columns = self._build_reward_columns(yt_row, contract)
            reward: Reward = _get_or_create_reward(r_columns)
            self.rewards_map[eid] = reward
            # if reward.predict:
            #     _clear_predict(reward.id)

    def _preprocess_yt_row(self, yt_row):
        @async_to_sync
        async def _get_contract(_eid):
            async with self.db_bind:
                return await Contract.query.where(Contract.eid == _eid).gino.first()

        @async_to_sync
        async def _create_contract(columns):  # noqa
            async with self.db_bind:
                return await Contract.create(**columns)

        super()._preprocess_yt_row(yt_row)

        eid = yt_row.get('contract_eid')

        if eid not in self.contracts_map:
            contract = _get_contract(eid)
            if contract is not None:
                self.contracts_map[eid] = contract
            LOGGER.debug(f'Contract {eid} was not found in map, so try to find: {contract is not None}')
        else:
            contract = self.contracts_map[eid]
            LOGGER.debug(f'Contract {eid} was found in map')

        if contract is None:
            contract = self._on_contract_not_found(eid)
            if contract is not None:
                LOGGER.warning(f'Created stub contract {eid}')
                self.contracts_map[eid] = contract

        self._process_rewards_map(eid, contract, yt_row)

    def _after_creation(self):
        @async_to_sync
        async def _recalculate_reward(reward: Reward):
            reward_sum = current_db.select(
                [current_db.func.sum(ServiceReward.payment).label('payment')]
            ).where(
                ServiceReward.reward_id == reward.id
            ).group_by(
                ServiceReward.reward_id
            ).cte("reward_sum")
            await reward.update(payment=current_db.select([reward_sum.c.payment]).select_from(reward_sum)).apply()

        for _, r in self.rewards_map.items():
            try:
                _recalculate_reward(r)
            except Exception as ex:
                LOGGER.exception('Couldn\'t recalculate reward %s due to %s', r.id, ex)

    # def _process_new(self, yt_row) -> typing.List[dict]:
    #     @async_to_sync
    #     async def _update_reward(reward, payment):
    #         async with self.db_bind:
    #             await reward.update(payment=Reward.payment + payment).apply()
    #     if not yt_row.get('contract_eid'):
    #         raise BadYtRow('Bad value for contract_eid')
    #     elif not yt_row.get('payment'):
    #         raise BadYtRow('Bad value for payment')
    #     eid = yt_row['contract_eid']
    #     rows = super()._process_new(yt_row)
    #     # N.B.: can't use yt_row['reward'] because with obfuscation it's hard to adjust values
    #     # TODO: update reward after service_reward was added
    #     # _update_reward(self.rewards_map[eid], sum((row.get('payment') or 0) for row in rows))
    #     return rows
