import logging
import typing
from datetime import datetime, timezone
from asgiref.sync import async_to_sync
from functools import partial

from crm.agency_cabinet.rewards.server.config.clients import YT_CONFIG
from crm.agency_cabinet.rewards.server.src.celery.base import celery_app as celery
from crm.agency_cabinet.common.consts.contract import PaymentType, ContractType
from crm.agency_cabinet.common.consts.service import service_id_to_name
from crm.agency_cabinet.rewards.common.consts import CONTRACT_CODES_FOR_BASE, CONTRACT_CODES_FOR_PROF, \
    CONTRACT_CODES_FOR_ALL
from crm.agency_cabinet.rewards.server.src.db import models
from crm.agency_cabinet.common.yt.base import YtModelLoader, BaseRowLoadException, MethodExtractor


LOGGER = logging.getLogger('celery.load_contracts_info')


class UnknownPaymentType(BaseRowLoadException):
    pass


class ContractInfoLoader(YtModelLoader):

    @staticmethod
    def get_services(yt_row):
        services = []
        for s in yt_row['ids_services'].split(', '):
            service_name = service_id_to_name(int(s))
            if service_name:
                services.append(service_name)

        return services

    @staticmethod
    def get_finish_date(yt_row):
        return datetime.strptime(yt_row['finish_dt'], '%Y-%m-%d').replace(tzinfo=timezone.utc)

    @staticmethod
    def get_payment_type(yt_row):
        if yt_row['payment_type'] == 'постоплата':
            return PaymentType.postpayment.value

        if yt_row['payment_type'] == 'предоплата':
            return PaymentType.prepayment.value

        raise UnknownPaymentType(f'UNKNOWN_PAYMENT_TYPE: {yt_row["payment_type"]}')

    @staticmethod
    def get_contract_type(yt_row):
        type = yt_row['wsale_ag_prm_awrd_sc_tp']  # TODO лучше использовать wsale_ag_prm_awrd_sc_code
        if type in CONTRACT_CODES_FOR_BASE:
            return ContractType.base.value

        if type in CONTRACT_CODES_FOR_PROF:
            return ContractType.prof.value

        raise UnknownPaymentType(f'UNKNOWN_CONTRACT_TYPE: {type}')

    def _read_table(self, **kwargs):
        def filter_agencies(input_row):
            # TODO лучше использовать wsale_ag_prm_awrd_sc_code
            if input_row['wsale_ag_prm_awrd_sc_tp'] in CONTRACT_CODES_FOR_ALL and input_row['finish_dt'] >= '2022-03-01':
                yield input_row

        table_path_without_last_slash = self.table_path[:-1]
        with self.client.TempTable() as output_table:
            self.client.run_map(
                filter_agencies,
                table_path_without_last_slash,
                output_table,
                **kwargs
            )

            self._table = self.client.read_table(output_table, **kwargs)

    def _find_duplicate(self, yt_row) -> typing.Optional[models.Contract]:
        @async_to_sync
        async def _get_contract(contract_id):
            async with self.db_bind:
                return await models.Contract.query.where(models.Contract.id == contract_id).gino.first()

        return _get_contract(yt_row['contract_id'])

    def _process_duplicate(self, yt_row, db_row: models.Contract):
        @async_to_sync
        async def _update_contract(contract, eid, finish_date, services, payment_type, contract_type):
            async with self.db_bind:
                await contract.update(
                    eid=eid,
                    finish_date=finish_date,
                    services=services,
                    payment_type=payment_type,
                    type=contract_type
                ).apply()

        _update_contract(
            db_row,
            yt_row['contract_eid'],
            self.get_finish_date(yt_row),
            self.get_services(yt_row),
            self.get_payment_type(yt_row),
            self.get_contract_type(yt_row)
        )


class ContractInnLoader(YtModelLoader):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        @async_to_sync
        async def _get_all_contracts():
            async with self.db_bind:
                return await models.Contract.query.gino.all()

        contracts = _get_all_contracts()
        self.person_ids = [contract.person_id for contract in contracts]

    def _read_table(self, **kwargs):
        def filter_persons(person_ids, input_row):
            if input_row['id'] in person_ids:
                yield input_row

        table_path_without_last_slash = self.table_path[:-1]
        with self.client.TempTable() as output_table:
            self.client.run_map(
                partial(filter_persons, self.person_ids),
                table_path_without_last_slash,
                output_table,
                **kwargs
            )

            self._table = self.client.read_table(output_table, **kwargs)

    def _find_duplicate(self, yt_row) -> typing.Optional[models.Contract]:
        @async_to_sync
        async def _get_contract(person_id):
            async with self.db_bind:
                return await models.Contract.query.where(models.Contract.person_id == person_id).gino.first()

        return _get_contract(yt_row['id'])

    def _process_duplicate(self, yt_row, db_row: models.Contract):
        @async_to_sync
        async def _update_contract(contract, inn, person_id):
            async with self.db_bind:
                await contract.update(inn=inn, person_id=person_id).apply()

        _update_contract(db_row, yt_row['inn'], yt_row['id'])


@celery.task(bind=True)
def load_contracts_info_task(self):
    client_config = {
        'cluster': 'hahn',
        'token': YT_CONFIG['TOKEN'],
        'config': {}
    }

    contract_loader = ContractInfoLoader(
        table_path='//home/balance/prod/bo/v_contract_apex_full',
        model=models.Contract,
        columns_mapper={
            'id': 'contract_id',
            'eid': 'contract_eid',
            'agency_id': 'agency_id',
            'payment_type': MethodExtractor('get_payment_type'),
            'type': MethodExtractor('get_contract_type'),
            'services': MethodExtractor('get_services'),
            'finish_date': MethodExtractor('get_finish_date'),
            'person_id': 'person_id'
        },
        default_columns={},
        client_config=client_config
    )

    contract_loader.load()

    inn_loader = ContractInnLoader(
        table_path='//home/balance/prod/bo/t_person',
        model=models.Contract,
        columns_mapper={
            'inn': 'inn',
            'person_id': 'id'
        },
        default_columns={},
        client_config=client_config
    )

    inn_loader.load()
