import datetime
import logging
import typing

from asgiref.sync import async_to_sync
from collections import defaultdict
from dateutil.relativedelta import relativedelta
from crm.agency_cabinet.common.consts import RewardsTypes, DocumentType, compute_reward_type
from crm.agency_cabinet.common.yt.base import YtModelLoader, MethodExtractor, BaseRowLoadException
from crm.agency_cabinet.rewards.server.config.clients import YT_CONFIG
from crm.agency_cabinet.rewards.server.src.celery.base import celery_app as celery
from crm.agency_cabinet.rewards.server.src.db import models, db

LOGGER = logging.getLogger('celery.load_documents')


class RewardNotFoundException(BaseRowLoadException):
    default_message = 'Reward not found'


class UnknownPeriodException(BaseRowLoadException):
    default_message = 'Unknown period'


class DocumentInfoLoader(YtModelLoader):
    expected_columns = ('PAID_STATUS', 'PERIOD_START_DATE',
                        'PAID_FLAG', 'PERIOD_END_DATE',
                        'CONTRACT_NAME', 'SEEN_STATUS',
                        'RECEIVE_DOC_STATUS', 'SEND_ORIGINAL_DATE',)

    def __init__(self, period_from, period_to, **kwargs):
        self.period_from = period_from
        self.period_to = period_to
        super().__init__(**kwargs)

    def _init(self):
        @async_to_sync
        async def get_all_rewards():
            return await db.select([
                models.Reward.id,
                models.Reward.period_from,
                models.Reward.type,
                models.Contract.eid,
                models.Reward.is_paid,
                models.Reward.payment_date
            ]).select_from(
                models.Reward.join(models.Contract)
            ).gino.all()

        self.rewards_map = {}
        for row in get_all_rewards():
            reward_id, period_from, reward_type, eid, is_paid, payment_date = row
            key = (eid, reward_type, period_from)
            self.rewards_map[key] = {
                'id': reward_id,
                'is_paid': is_paid,
                'payment_date': payment_date
            }

        self.update_rewards = defaultdict(dict)

    def _read_table(self, **kwargs):
        def filter_documents(input_row):
            contract_eids = {key[0] for key in self.rewards_map.keys()}
            if self.period_from <= self._get_period_from(input_row) <= self.period_to and input_row['CONTRACT_NAME'] in contract_eids:
                yield input_row

        table_path_without_last_slash = self.table_path[:-1]

        with self.client.TempTable() as output_table:
            self.client.run_map(
                filter_documents,
                table_path_without_last_slash,
                output_table,
                **kwargs
            )

            self._table = self.client.read_table(output_table, **kwargs)

    def _process_new(self, yt_row) -> typing.List[dict]:
        res = super()._process_new(yt_row)

        reward_id = res[0]['reward_id']
        self.update_rewards[reward_id]['is_paid'] = self._get_is_paid(yt_row)
        self.update_rewards[reward_id]['payment_date'] = self._get_payment_date(yt_row)

        return res

    def _find_duplicate(self, yt_row) -> typing.Optional[models.Document]:
        @async_to_sync
        async def _get_document(reward_id):
            return await models.Document.query.where(
                models.Document.reward_id == reward_id
            ).gino.first()

        try:
            reward_id = self._get_reward_id(yt_row)
        except RewardNotFoundException:
            return None

        return _get_document(reward_id)

    def _process_duplicate(self, yt_row, db_row: models.Document):
        @async_to_sync
        async def _update_document(doc, yt_row):
            async with self.db_bind:
                await doc.update(
                    got_scan=self._get_scan_status(yt_row),
                    got_original=self._get_original_status(yt_row),
                    sending_date=self._get_sending_date(yt_row)
                ).apply()

            eid = yt_row['CONTRACT_NAME']
            reward_type = self._get_reward_type(yt_row)
            period_from = self._get_period_from(yt_row)

            key = (eid, reward_type, period_from)
            if key in self.rewards_map:
                is_paid = self._get_is_paid(yt_row)
                payment_date = self._get_payment_date(yt_row)
                reward_payment_info = self.rewards_map[key]
                if reward_payment_info['is_paid'] != is_paid or reward_payment_info['payment_date'] != payment_date:
                    self.update_rewards[doc.reward_id]['is_paid'] = is_paid
                    self.update_rewards[doc.reward_id]['payment_date'] = payment_date

        _update_document(db_row, yt_row)

    def _after_finish(self):
        @async_to_sync
        async def _update_reward(reward_id, is_paid, payment_date):
            return await models.Reward.update.values(
                is_paid=is_paid,
                payment_date=payment_date
            ).where(models.Reward.id == reward_id).gino.status(read_only=False, reuse=False)

        for reward_id, update_data in self.update_rewards.items():
            _update_reward(reward_id, update_data['is_paid'], update_data['payment_date'])

    def _get_reward_type(self, yt_row):
        try:
            period_from = self._get_period_from(yt_row)
            period_to = self._get_period_to(yt_row)
        except ValueError:
            raise UnknownPeriodException('Can\'t parse period')

        try:
            return compute_reward_type(period_from, period_to)
        except ValueError as ex:
            raise UnknownPeriodException(f'Unknown period: {period_from} - {period_to}') from ex

    def _get_reward_id(self, yt_row) -> int:
        eid = yt_row['CONTRACT_NAME']
        reward_type = self._get_reward_type(yt_row)
        period_from = self._get_period_from(yt_row)
        key = (eid, reward_type, period_from)
        if key in self.rewards_map:
            return self.rewards_map[key]['id']

        raise RewardNotFoundException(f'Can\'t find reward: {eid}, {reward_type}, {period_from}')

    def _get_scan_status(self, yt_row):
        return self._get_status(yt_row, 'SEEN_STATUS')

    def _get_original_status(self, yt_row):
        return self._get_status(yt_row, 'RECEIVE_DOC_STATUS')

    def _get_is_paid(self, yt_row) -> bool:
        return self._get_status(yt_row, 'PAID_FLAG')

    def _get_sending_date(self, yt_row) -> typing.Optional[datetime.datetime]:
        sending_date = None
        if yt_row['SEND_ORIGINAL_DATE']:
            sending_date = self._str_to_date(yt_row['SEND_ORIGINAL_DATE'])
        return sending_date

    def _get_payment_date(self, yt_row) -> typing.Optional[datetime.datetime]:
        payment_date = None
        if yt_row['PAID_STATUS']:
            payment_date = self._str_to_date(yt_row['PAID_STATUS'])
        return payment_date

    def _get_name(self, yt_row) -> str:
        period_from = self._get_period_from(yt_row)
        reward_type = self._get_reward_type(yt_row)
        try:
            if reward_type == RewardsTypes.month.value:
                ru_month_names = {
                    1: 'Январь',
                    2: 'Февраль',
                    3: 'Март',
                    4: 'Апрель',
                    5: 'Май',
                    6: 'Июнь',
                    7: 'Июль',
                    8: 'Август',
                    9: 'Сентябрь',
                    10: 'Октябрь',
                    11: 'Ноябрь',
                    12: 'Декабрь'
                }
                return 'Акт премии за {month} {year}'.format(month=ru_month_names[period_from.month], year=period_from.year)

            if reward_type == RewardsTypes.quarter.value:
                ru_quarter_names = {
                    3: '1 квартал',
                    6: '2 квартал',
                    9: '3 квартал',
                    12: '4 квартал',
                }
                return 'Акт премии за {quarter} {year}'.format(quarter=ru_quarter_names[period_from.month], year=period_from.year)

            if reward_type == RewardsTypes.semiyear.value:
                ru_semiyear_names = {
                    3: '1 полугодие',
                    9: '2 полугодие'
                }
                return 'Акт премии за {semiyear} {year}'.format(semiyear=ru_semiyear_names[period_from.month], year=period_from.year)
        except KeyError as e:
            raise UnknownPeriodException(f'Unknown period type {reward_type}: {period_from}') from e
        raise UnknownPeriodException(f'Unknown period for type {reward_type}: {period_from}')

    def _get_period_from(self, yt_row) -> datetime.datetime:
        return self._str_to_date(yt_row['PERIOD_START_DATE'])

    def _get_period_to(self, yt_row) -> datetime.datetime:
        return self._str_to_date(yt_row['PERIOD_END_DATE'])

    def _str_to_date(self, date: str) -> datetime.datetime:
        return datetime.datetime.strptime(date, '%Y-%m-%d %H:%M:%S').replace(tzinfo=datetime.timezone.utc)

    def _get_status(self, yt_row, status_name) -> bool:
        return yt_row[status_name] == 'Y'


@celery.task(bind=True)
def load_documents_info_task(self, period_from: str = None, period_to: str = None):
    if period_to:
        period_to = datetime.datetime.strptime(period_to, '%Y-%m-%d').replace(tzinfo=datetime.timezone.utc)
    else:
        period_to = datetime.datetime.now().replace(tzinfo=datetime.timezone.utc)

    if period_from:
        period_from = datetime.datetime.strptime(period_from, '%Y-%m-%d').replace(tzinfo=datetime.timezone.utc)
    else:
        period_from = (period_to - relativedelta(months=3)).replace(tzinfo=datetime.timezone.utc)

    client_config = {
        'cluster': 'hahn',
        'token': YT_CONFIG['TOKEN'],
        'config': {}
    }

    loader = DocumentInfoLoader(
        period_from=period_from,
        period_to=period_to,
        table_path='//home/bi/stable/dwh/ods/oebs_oracle/XXYA/XXAR_HEADER_AGENCIES/XXAR_HEADER_AGENCIES',
        model=models.Document,
        columns_mapper={
            'reward_id': MethodExtractor('_get_reward_id'),
            'got_scan': MethodExtractor('_get_scan_status'),
            'got_original': MethodExtractor('_get_original_status'),
            'sending_date': MethodExtractor('_get_sending_date'),
            'name': MethodExtractor('_get_name')
        },
        default_columns={
            'type': DocumentType.act.value,
        },
        client_config=client_config
    )

    loader.load()
