from asgiref.sync import async_to_sync
from collections import defaultdict
from dateutil.relativedelta import relativedelta
from sqlalchemy import and_

from crm.agency_cabinet.common.consts.reward import RewardsTypes
from crm.agency_cabinet.common.consts import START_FIN_YEAR_2021
from crm.agency_cabinet.rewards.server.src.db import models
from crm.agency_cabinet.rewards.server.src.db.engine import db
from .base import get_predict_start_from


class RevenueLoader:

    service = None

    def __init__(self, period_from, period_to, index_id, prev=False):
        self.period_from = period_from
        self.period_to = period_to
        self.index_id = index_id
        self.prev = prev

    async def get_service_rewards(self):
        return await db.select(
            [
                models.Reward.contract_id,
                models.Reward.period_from,
                db.func.sum(models.ServiceReward.revenue),
            ]
        ).select_from(
            models.Reward.join(models.ServiceReward)
        ).where(
            and_(
                models.ServiceReward.service == self.service,
                models.Reward.period_from >= self.period_from,
                models.Reward.period_from < self.period_to,
                models.Reward.type == RewardsTypes.month.value
            )
        ).group_by(
            models.Reward.contract_id, models.Reward.period_from
        ).gino.all()

    async def create_or_update_data(self, contract_id, json_data):
        db_row = await models.CalculatorData.query.where(
            and_(
                models.CalculatorData.contract_id == contract_id,
                models.CalculatorData.service == self.service
            ),
        ).gino.first()

        if db_row is None:
            await models.CalculatorData.create(
                contract_id=contract_id,
                service=self.service,
                data=json_data
            )
        else:
            await models.CalculatorDataUpdater(db_row).update_data(json_data)

    def load(self):
        @async_to_sync
        async def load_from_db():
            service_rewards = await self.get_service_rewards()

            contract_period_revenue_map = defaultdict(dict)
            for contract_id, period_from, revenue in service_rewards:
                period_from = period_from if not self.prev else period_from + relativedelta(years=1)
                contract_period_revenue_map[contract_id][period_from.replace(tzinfo=None)] = revenue

            contract_data_map = defaultdict(list)
            for contract_id, period_revenue_map in contract_period_revenue_map.items():
                for i in range(0, 12):
                    period_from = (START_FIN_YEAR_2021 + relativedelta(months=i)).replace(tzinfo=None)
                    revenue = period_revenue_map.get(period_from, 0)
                    predict = get_predict_start_from() <= period_from
                    contract_data_map[contract_id].append(self.make_month_data(predict, period_from, revenue))

            for contract_id, months in contract_data_map.items():
                json_data = self.make_data(months)

                db_row = await models.CalculatorData.query.where(
                    and_(
                        models.CalculatorData.contract_id == contract_id,
                        models.CalculatorData.service == self.service
                    ),
                ).gino.first()

                if db_row is None:
                    await models.CalculatorData.create(
                        contract_id=contract_id,
                        service=self.service,
                        data=json_data
                    )
                else:
                    await models.CalculatorDataUpdater(db_row).update_data(json_data)

        load_from_db()

    def make_month_data(self, predict, period_from, revenue):
        raise NotImplementedError

    def make_data(self, months):
        raise NotImplementedError
