import logging
import typing
from collections import defaultdict
from datetime import datetime
from dateutil.relativedelta import relativedelta
from decimal import Decimal

from crm.agency_cabinet.common.consts import START_FIN_YEAR_2021, SERVICE_DISCOUNT_TYPE_MAP
from crm.agency_cabinet.rewards.common import structs
from .base import UpdateDataLoader, UnknownServiceException

LOGGER = logging.getLogger('celery.calculator.base')


class EarlyPaymentDataLoader(UpdateDataLoader):
    REWARD_TYPE = 311
    INDEX_ID = 'early_payment'
    REWARD_PERCENT = 2

    def __init__(self, period_from, *args, **kwargs):
        self.period_from = period_from
        self.discount_types = SERVICE_DISCOUNT_TYPE_MAP.get(self.service, None)
        if self.discount_types is None:
            raise UnknownServiceException(f'Unknown service: {self.service}')
        super().__init__(*args, **kwargs)

    @staticmethod
    def _calculate_revenue(reward: Decimal) -> Decimal:
        return reward * 100 / EarlyPaymentDataLoader.REWARD_PERCENT

    @staticmethod
    def _get_till_dt(yt_row):
        return datetime.strptime(yt_row['till_dt'], '%Y-%m-%dT%H:%M:%SZ')

    def _filter_row(self, row):
        if not self._check_contract_id(row['contract_id']):
            return False

        if row['reward_type'] != self.REWARD_TYPE or row['discount_type'] not in self.discount_types:
            return False

        till_dt = EarlyPaymentDataLoader._get_till_dt(row)
        return self.period_from <= till_dt

    def _read_table(self, **kwargs):
        table_path_without_last_slash = self.table_path[:-1]

        contracts = defaultdict(list)
        for row in self.client.read_table(table_path_without_last_slash, **kwargs):
            if self._filter_row(row):
                contracts[row['contract_id']].append(row)

        self._table = []
        for contract_id, rows in contracts.items():
            try:
                self._table.append(self._make_table_row(contract_id, rows))
            except Exception as ex:
                LOGGER.exception('Something went wrong: %s', ex)

    def _make_table_row(self, contract_id, rows):
        period_reward_map: typing.Dict[datetime, Decimal] = defaultdict(Decimal)
        for i in range(0, 12):  # TODO надо формировать данные не за весь финансовый год, а только за прошедшие месяцы
            period_from = START_FIN_YEAR_2021 + relativedelta(months=i)
            period_reward_map[period_from] = Decimal(0)

        for row in rows:
            till_dt = EarlyPaymentDataLoader._get_till_dt(row)
            period_from = datetime(till_dt.year, till_dt.month, 1)
            period_reward_map[period_from] += Decimal(row['reward_to_charge'])

        months = []
        for period_from, reward in period_reward_map.items():
            index = structs.CalculatorIndexData(
                index_id=EarlyPaymentDataLoader.INDEX_ID,
                revenue=self._calculate_revenue(reward)  # либо приводить к Decimal здесь, тк структуры ждут Decimal?
            )

            months.append(self.make_month_data(False, period_from, index))

        return {
            'contract_id': contract_id,
            'data': self.make_data(months=months)
        }

    def make_month_data(self, predict, period_from, index):
        raise NotImplementedError

    def make_data(self, months):
        raise NotImplementedError
