from dataclasses import dataclass
from decimal import Decimal
from sqlalchemy import and_
from gino.loader import ModelLoader, ColumnLoader

from crm.agency_cabinet.common.consts import DocumentType, PaymentsStatuses
from crm.agency_cabinet.rewards.common import structs
from crm.agency_cabinet.rewards.server.src.db import Contract, db, Reward, ServiceReward, Document

from .exceptions import NoSuchReward, UnsuitableAgency


@dataclass
class GetRewardsInfo:
    async def __call__(self, request: structs.GetRewardsInfoRequest) -> structs.GetRewardsInfoResponse:
        query = Reward.load(
            contract=Contract.on(Reward.contract_id == Contract.id)
        ).where(
            and_(
                Contract.agency_id == request.agency_id,
                Reward.predict.is_(False)
            )
        )
        if request.filter_from:
            query = query.where(Reward.period_from >= request.filter_from)
        if request.filter_to:
            query = query.where(Reward.period_from <= request.filter_to)
        if request.filter_is_paid is not None:
            query = query.where(Reward.is_paid == request.filter_is_paid)
        if request.filter_type:
            query = query.where(Reward.type == request.filter_type)
        if request.filter_contract:
            query = query.where(Contract.id == request.filter_contract)

        rewards = await query.gino.all()

        reward_ids = [reward.id for reward in rewards]

        services_map = {
            reward_id: services for reward_id, services in await db.select(
                [
                    ServiceReward.reward_id,
                    db.func.array_agg(
                        db.func.distinct(ServiceReward.service), type_=db.ARRAY(db.VARCHAR)
                    ),
                ]
            ).select_from(
                ServiceReward
            ).where(
                ServiceReward.reward_id.in_(reward_ids)
            ).group_by(
                ServiceReward.reward_id
            ).gino.all()
        }

        docs_info_map = {
            reward_id: (got_scan, got_original) for reward_id, got_scan, got_original in await db.select(
                [
                    Document.reward_id,
                    Document.got_scan,
                    Document.got_original
                ]
            ).select_from(
                Document
            ).where(
                and_(
                    Document.reward_id.in_(reward_ids),
                    Document.type == DocumentType.act.value
                )
            ).gino.all()
        }

        return structs.GetRewardsInfoResponse(
            rewards=[structs.RewardInfo(
                id=reward.id,
                contract_id=reward.contract.id,
                type=reward.type,
                services=services_map.get(reward.id, []),
                got_scan=docs_info_map[reward.id][0] if reward.id in docs_info_map else False,
                got_original=docs_info_map[reward.id][1] if reward.id in docs_info_map else False,
                is_accrued=reward.is_accrued,
                is_paid=reward.is_paid,
                payment_date=reward.payment_date,
                payment=reward.payment,
                period_from=reward.period_from,
            ) for reward in rewards]
        )


@dataclass
class GetDetailedRewardInfo:
    async def __call__(self, request: structs.GetDetailedRewardInfoRequest) -> structs.GetDetailedRewardInfoResponse:
        reward = await Reward.query.where(
            Reward.id == request.reward_id
        ).gino.first()

        if not reward:
            raise NoSuchReward()

        contract = await Contract.query.where(
            Contract.id == reward.contract_id
        ).gino.first()

        if contract.agency_id != request.agency_id:
            raise UnsuitableAgency()

        sum_revenue = db.func.sum(ServiceReward.revenue).label('revenue')
        sum_payment = db.func.sum(ServiceReward.payment).label('payment')
        percents = db.func.array_agg(db.func.distinct(ServiceReward.reward_percent)).label('percents')
        services = await db.select(
            [
                ServiceReward.service,
                # currency,
                percents,
                sum_revenue,
                sum_payment,
            ]
        ).select_from(
            ServiceReward
        ).where(
            ServiceReward.reward_id == request.reward_id
        ).group_by(
            ServiceReward.service,
            # currency
        ).gino.load(
            (
                ModelLoader(
                    ServiceReward,
                    ServiceReward.service,
                    revenue=sum_revenue,
                    payment=sum_payment
                ),
                ColumnLoader(percents)
            )
        ).all()

        docs = await Document.query.where(
            Document.reward_id == reward.id
        ).gino.all()

        return structs.GetDetailedRewardInfoResponse(
            reward=structs.DetailedRewardInfo(
                id=reward.id,
                contract_id=reward.contract_id,
                type=reward.type,
                services=[
                    structs.DetailedServiceInfo(
                        service=service.service,
                        revenue=service.revenue,
                        currency='RUB',  # TODO: пока только рубли
                        reward_percent=percents[0] if len(percents) == 1 and percents[0] is not None else service.raw_percent,
                        accrual=service.payment,
                        error_message=None,  # TODO: когда будут условия
                    ) for service, percents in services
                ],
                documents=[
                    structs.DocumentInfo(
                        id=doc.id,
                        name=doc.name,
                        sending_date=doc.sending_date,
                        got_scan=doc.got_scan,
                        got_original=doc.got_original
                    ) for doc in docs
                ],
                status=PaymentsStatuses.paid.value if reward.is_paid else PaymentsStatuses.accrued.value,
                accrual=sum(service.payment for service, percents in services) if services else Decimal(0),
                # Если нет разбивки по сервисам, то accrual - 0
                payment=reward.payment,
                accrual_date=reward.created_at,
                payment_date=reward.payment_date,
                period_from=reward.period_from,
                predict=reward.predict,
            )
        )
