import logging
import typing
from asgiref.sync import async_to_sync
from datetime import datetime
from sqlalchemy import and_

from crm.agency_cabinet.common.yt.base import YtModelLoader, BaseRowLoadException
from crm.agency_cabinet.rewards.server.src.db import models


LOGGER = logging.getLogger('celery.calculator.base')


class CalculatorLoadDataException(BaseRowLoadException):
    pass


class CalculatorUpdateDataException(BaseRowLoadException):
    pass


class UnknownServiceException(CalculatorLoadDataException):
    pass


def get_predict_start_from():
    now = datetime.now()
    return datetime(now.year, now.month, 1)


class ServiceDataLoader(YtModelLoader):

    service = None

    def _find_duplicate(self, yt_row) -> typing.Optional[models.CalculatorData]:
        @async_to_sync
        async def _get_calculator_data(contract_id: int) -> typing.Optional[models.CalculatorData]:
            async with self.db_bind:
                return await models.CalculatorData.query.where(
                    and_(
                        models.CalculatorData.contract_id == contract_id,
                        models.CalculatorData.service == self.service
                    )
                ).gino.first()

        return _get_calculator_data(yt_row['contract_id'])

    def _process_duplicate(self, yt_row, db_row: models.CalculatorData):
        @async_to_sync
        async def _update_data(yt_row, db_row: models.CalculatorData):
            await models.CalculatorDataUpdater(db_row).update_data(yt_row['data'])

        try:
            _update_data(yt_row, db_row)
        except Exception:
            raise CalculatorUpdateDataException()


class UpdateDataLoader(ServiceDataLoader):
    def _init(self, **kwargs):
        @async_to_sync
        async def _get_all_data():
            async with self.db_bind:
                return await models.CalculatorData.query.where(
                    models.CalculatorData.service == self.service
                ).gino.all()

        data = _get_all_data()
        self.contract_ids = {c.contract_id for c in data}
        super()._init(**kwargs)

    def _check_contract_id(self, contract_id: int) -> bool:
        return contract_id in self.contract_ids
