import logging
import typing
from datetime import datetime
from decimal import Decimal
from sqlalchemy import and_
from collections import defaultdict
from crm.agency_cabinet.common.yt.synchronizers import BaseSynchronizer
from crm.agency_cabinet.rewards.common import structs
from crm.agency_cabinet.common.consts import CalculatorServiceType
from crm.agency_cabinet.rewards.common.schemas import calculator as calculator_schemas
from crm.agency_cabinet.rewards.server.src.db import db, models
from crm.agency_cabinet.rewards.server.src.celery.tasks.calculator.base import VERSION


LOGGER = logging.getLogger('celery.tasks.calculator.synchronizer')


class BaseCalculatorSynchronizer(BaseSynchronizer):
    INDEX_ID_TO_YT_COLUMN_MAP: typing.Dict[str, str] = {}
    GRADE_TO_YT_COLUMN_MAP: typing.Dict[str, typing.Dict[str, str]] = {}
    GRADE_META_INFO: typing.Dict[str, typing.Dict[str, int]] = {}  # TODO: extract grade meta info from bunker
    service: typing.Union[str, CalculatorServiceType] = None  # CalculatorType->value
    contract_type = None
    predict = None
    use_average_grade=False

    def _process_indexes(self, row) -> list[structs.CalculatorIndexData]:
        indexes = []
        for index_id, column_name in self.INDEX_ID_TO_YT_COLUMN_MAP.items():
            indexes.append(structs.CalculatorIndexData(
                index_id=index_id,
                revenue=Decimal(getattr(row, column_name, 0) or 0)
            ))
        return indexes

    def _process_grades(self, row) -> list[structs.CalculatorGradeData]:
        grades = []
        if self.use_average_grade:
            for grade_id, column_map in self.GRADE_TO_YT_COLUMN_MAP.items():
                domains_count_raw = getattr(row, column_map['domains_count'], 0)
                domains_count = round(domains_count_raw)

                if domains_count:
                    coef = domains_count_raw / domains_count

                    revenue_average = coef * (
                        getattr(row, column_map['revenue_average'], 0) / domains_count_raw +
                        self.GRADE_META_INFO[grade_id]['threshold_start']
                    )

                    grades.append(structs.CalculatorGradeData(
                        grade_id=grade_id,
                        domains_count=domains_count,
                        revenue_average=revenue_average
                    ))
        else:
            for grade_id, value in self.GRADE_TO_YT_COLUMN_MAP.items():
                grade_data_dict = {
                    'grade_id': grade_id
                }
                for grade_key, row_col_name in value.items():
                    grade_data_dict[grade_key] = getattr(row, row_col_name, None)

                grades.append(structs.CalculatorGradeData(
                    **grade_data_dict
                ))
        return grades

    def process_row(self, row):
        indexes = self._process_indexes(row)
        grades = self._process_grades(row)
        return indexes, grades

    async def get_contracts(self):
        return await models.Contract.query.where(models.Contract.type == self.contract_type).gino.all()

    async def process_data(self, rows: list[tuple], prefer_actual=False, validate_actual=True, *args, **kwargs) -> bool:
        if self.contract_type is None or self.service is None or self.predict is None:
            raise NotImplementedError()

        contract_data_map = defaultdict(list)

        all_contracts = await self.get_contracts()
        all_contracts_ids = [contract.id for contract in all_contracts]
        now = datetime.now()
        for row in rows:
            contract_id = row[0]

            if contract_id not in all_contracts_ids:
                continue

            sorted_rows = sorted(row[1], key=lambda x: datetime.strptime(x.dt, '%Y-%m-%d'))
            for i in sorted_rows:
                period_from = datetime.strptime(i.dt, '%Y-%m-%d')

                indexes, grades = self.process_row(i)
                if validate_actual is True and self.predict is False:
                    predict_status = self.predict if period_from < now else True  # no actual data for future
                else:
                    predict_status = self.predict
                contract_data_map[contract_id].append(structs.CalculatorMonthData(
                    period_from=period_from,
                    predict=predict_status,
                    indexes=indexes,
                    grades=grades
                ))

        for contract_id, months in contract_data_map.items():
            async with db.transaction(read_only=False, reuse=False):
                duplicate = await models.CalculatorData.query.where(
                    and_(
                        models.CalculatorData.contract_id == contract_id,
                        models.CalculatorData.service == self.service,
                        models.CalculatorData.version == VERSION
                    )
                ).gino.first()

                d = calculator_schemas.CalculatorDataSchema().dump(structs.CalculatorData(months=months))

                if duplicate:
                    await models.CalculatorDataUpdater(duplicate).update_data(d, prefer_actual=prefer_actual)
                else:
                    await models.CalculatorData.create(
                        contract_id=contract_id,
                        data=d,
                        version=VERSION,
                        service=self.service
                    )

        return True
