import typing
import datetime

from sqlalchemy import and_
from asgiref.sync import async_to_sync
from crm.agency_cabinet.common.yt.base import YtModelLoader, MethodExtractor, BaseRowLoadException
from crm.agency_cabinet.agencies.server.src.celery.base import celery_app as celery
from crm.agency_cabinet.agencies.server.src.db.models import Client, Agency, AgencyAnalytics
from crm.agency_cabinet.agencies.server.src.db import db
from crm.agency_cabinet.agencies.server.src import config


class UnknownAgency(BaseRowLoadException):
    pass


class EmptyAgencyName(BaseRowLoadException):
    pass


class UnknownClient(BaseRowLoadException):
    pass


class AnalyticsDataLoader(YtModelLoader):
    expected_columns = ('agency_name', 'all_money', 'check_ctg', 'clientid',
                        'cohort_date', 'domain', 'epoch', 'month', 'role', 'tier')

    def _init(self, **kwargs):
        @async_to_sync
        async def _get_all_clients() -> typing.List[Client]:
            async with self.db_bind:
                return await db.select([Client.id]).select_from(Client).gino.all()

        @async_to_sync
        async def _get_all_agencies() -> typing.List[Agency]:
            async with self.db_bind:
                return await db.select([Agency.id, Agency.name]).select_from(Agency).gino.all()

        self.map_name_to_agency_id = {a.name: a.id for a in _get_all_agencies()}
        self.client_ids = {c.id for c in _get_all_clients()}

    def _get_agency_id(self, yt_row):
        if yt_row['agency_name'] is None:
            raise EmptyAgencyName('NULL agency_name in yt_row')
        name = str(yt_row['agency_name'])
        agency_id = self.map_name_to_agency_id.get(name)
        if agency_id is None:
            raise UnknownAgency(f'Can\'t find agency_id for {name}')
        return agency_id

    def _get_client_id(self, yt_row):
        client_id = yt_row['clientid']
        if client_id not in self.client_ids:
            raise UnknownClient(f'Can\'t find client with id {client_id}')
        return client_id

    def _get_cohort_date(self, yt_row):
        cohort_date = yt_row['cohort_date']
        return datetime.datetime.strptime(cohort_date, '%Y-%m-%d')

    def _get_month(self, yt_row):
        month = yt_row['month']
        return datetime.datetime.strptime(month, '%Y-%m-%d')

    def _find_duplicate(self, yt_row) -> typing.Optional[AgencyAnalytics]:
        @async_to_sync
        async def _get_agency_analytics(agency_id, client_id, epoch, cohort_date):
            async with self.db_bind:
                return await AgencyAnalytics.query.where(and_(AgencyAnalytics.agency_id == agency_id,
                                                              AgencyAnalytics.client_id == client_id,
                                                              AgencyAnalytics.epoch == epoch,
                                                              AgencyAnalytics.cohort_date == cohort_date)).gino.first()
        # return None
        return _get_agency_analytics(
            self._get_agency_id(yt_row),
            yt_row['clientid'],
            yt_row['epoch'],
            self._get_cohort_date(yt_row)
        )

    def _process_duplicate(self, yt_row, db_row: AgencyAnalytics):
        @async_to_sync
        async def _update_agency_analytics(agency_analytics: AgencyAnalytics, check_ctg, all_money, domain):
            async with self.db_bind:
                return await agency_analytics.update(
                    check_ctg=check_ctg,
                    all_money=all_money,
                    domain=domain
                ).apply()
        # Don't update because values should be the same
        # _update_agency_analytics(
        #    db_row,
        #    yt_row['check_ctg'],
        #    yt_row['all_money'],
        #    yt_row['domain']
        # )


@celery.task(bind=True)
def load_client_analytics_data_task():
    loader = AnalyticsDataLoader(
        table_path='//home/geoadv/geosmb/analytics/hackathon_10_2021/database',
        model=AgencyAnalytics,
        columns_mapper={
            'agency_id': MethodExtractor('_get_agency_id'),
            'client_id': MethodExtractor('_get_client_id'),
            'all_money': 'all_money',
            'check_ctg': 'check_ctg',
            'cohort_date': MethodExtractor('_get_cohort_date'),
            'month': MethodExtractor('_get_month'),
            'domain': 'domain',
            'epoch': 'epoch',
            'role': 'role',
            'tier': 'tier'
        },
        default_columns={},
        client_config={
            'cluster': 'hahn',
            'token': config.YT_CONFIG['TOKEN'],
            'config': {}
        },
        use_bulk_insert=True
    )
    loader.load()
