import logging
import typing
from dataclasses import dataclass
from datetime import datetime
from dateutil.relativedelta import relativedelta
from decimal import Decimal
from gino.loader import ColumnLoader
from sqlalchemy import and_
from ylog.context import log_context

from crm.agency_cabinet.common.consts import Services, get_fin_year, get_start_of_current_fin_year_with_dt
from crm.agency_cabinet.rewards.common import structs
from crm.agency_cabinet.rewards.common.utils import get_reward_index
from crm.agency_cabinet.rewards.server.src.db import Contract, Reward, ServiceReward, db
from crm.agency_cabinet.rewards.server.src.db.queries import build_contract_to_services_map
from .exceptions import UnknownYear

LOGGER = logging.getLogger('rewards.procedures')


@dataclass
class GetDashboard:

    async def __call__(self, request: structs.GetDashboard) -> structs.GetDashboardResponse:
        contracts_query = Contract.query.where(
            and_(
                Contract.agency_id == request.agency_id,
                Contract.finish_date > get_start_of_current_fin_year_with_dt()
            )
        )

        if request.filter_contract:
            contracts_query = contracts_query.where(
                Contract.id == request.filter_contract
            )
        contracts = await contracts_query.gino.all()

        sum_revenue = db.func.sum(ServiceReward.revenue).label('service_revenue')
        sum_payment = db.func.sum(ServiceReward.payment).label('service_payment')

        service_rewards_query = db.select(
            [
                Reward,
                ServiceReward.service,
                sum_revenue,
                sum_payment,
            ]
        ).select_from(
            ServiceReward.join(Reward)
        ).where(
            Reward.contract_id.in_([c.id for c in contracts]),
        )

        if request.year > datetime.now().year:
            raise UnknownYear()

        start_fin_year, end_fin_year = get_fin_year(request.year)

        service_rewards_query = service_rewards_query.where(
            and_(
                Reward.period_from >= start_fin_year,
                Reward.period_from < end_fin_year
            )
        )

        if request.filter_service:
            service_rewards_query = service_rewards_query.where(
                ServiceReward.service == request.filter_service
            )

        service_rewards: typing.List[typing.Tuple[Reward, Decimal, Decimal, str]] = await service_rewards_query.group_by(
            Reward.id,
            ServiceReward.service,
        ).gino.load(
            (
                Reward,
                ColumnLoader(sum_revenue),
                ColumnLoader(sum_payment),
                ColumnLoader(ServiceReward.service)
            )
        ).all()
        services_map = await build_contract_to_services_map((contract.id for contract in contracts))
        dashboard_map = {}
        for contract in contracts:
            contract_actual_services = services_map[contract.id] | set(contract.services or [])
            for service in Services:
                if service in Services.get_exception_services_list():
                    continue
                if request.filter_service and request.filter_service != service.value:
                    continue
                key = (contract.id, service.value)
                dashboard_map[key] = structs.DashboardItem(
                    contract_id=contract.id,
                    service=service,
                    active=service.value in contract_actual_services,
                    rewards=structs.DashboardRewardsMap(
                        self.create_empty_rewards(start_fin_year, end_fin_year, 1),
                        self.create_empty_rewards(start_fin_year, end_fin_year, 3),
                        self.create_empty_rewards(start_fin_year, end_fin_year, 6),
                    ) if service.value in contract_actual_services else None,
                    updated_at=None
                )

        for service_reward in service_rewards:
            key = (service_reward[0].contract_id, service_reward[3])
            if key not in dashboard_map:
                continue
            dashboard_item = dashboard_map[key]
            if not dashboard_item.active:
                with log_context(agency_id=request.agency_id):
                    LOGGER.warning('Reward for service %s is not included in the contracts', service_reward[3])
                continue

            reward_index = get_reward_index(
                start_fin_year,
                service_reward[0].period_from,
                service_reward[0].type
            )
            dashboard_reward = getattr(dashboard_item.rewards, service_reward[0].type)[reward_index]
            dashboard_reward.reward = service_reward[2]
            dashboard_reward.reward_percent = 100 * service_reward[2] / service_reward[1] \
                if service_reward[1] \
                else Decimal(0)
            dashboard_reward.predict = service_reward[0].predict
            dashboard_reward.period_from = service_reward[0].period_from

            dashboard_item.updated_at = max(service_reward[0].updated_at, dashboard_item.updated_at) \
                if dashboard_item.updated_at is not None \
                else service_reward[0].updated_at

        return structs.GetDashboardResponse(dashboard=list(dashboard_map.values()))

    @staticmethod
    def create_empty_rewards(begin: datetime, end: datetime, step: int) -> typing.List[structs.DashboardReward]:
        rewards = []
        while begin < end:
            rewards.append(structs.DashboardReward(Decimal(0), Decimal(0), True, begin))
            begin += relativedelta(months=step)
        return rewards
