import datetime
import logging
import re
import typing

from dateutil.relativedelta import relativedelta
from sqlalchemy import join

from crm.agency_cabinet.rewards.server.src.celery.base import celery_app as celery
from crm.agency_cabinet.rewards.server.src.db.models import Document, Reward
from crm.agency_cabinet.rewards.server.src.db import db
from crm.agency_cabinet.common.consts import YaDocDocumentType, compute_reward_type
from crm.agency_cabinet.common.yadoc import YaDocClient
from crm.agency_cabinet.common.server.common.tvm import get_tvm_client
from crm.agency_cabinet.rewards.server.config.clients import YADOC_CONFIG
from crm.agency_cabinet.rewards.server.config.tvm import RewardsTvm2Config
from crm.agency_cabinet.common.celery.base import async_to_sync


LOGGER = logging.getLogger('celery.load_yadoc_ids')

DOC_NUMBER_REGEXP = re.compile(r'(\d*/\d*)_(\w*-\d*)-(\w*-\d*)')


@celery.task(bind=True)
def load_yadoc_ids_task(
    self,
    doc_date_from: typing.Union[datetime.date, str] = None,
    doc_date_to: typing.Union[datetime.date, str] = None,
    reward_period_from_start: typing.Union[datetime.date, str] = None,
    reward_period_from_end: typing.Union[datetime.date, str] = None,
    exclude_reversed: bool = True,
    ignore_null_yadoc_id: bool = False
):
    if isinstance(doc_date_from, str):
        doc_date_from = datetime.datetime.strptime(doc_date_from, '%Y-%m-%d')
    elif doc_date_from is None:
        doc_date_from = datetime.datetime.now(datetime.timezone.utc)
        doc_date_from = (
            doc_date_from.replace(day=1, hour=0, minute=0, second=0, microsecond=0) - relativedelta(months=1)
        ).date()

    if isinstance(doc_date_to, str):
        doc_date_to = datetime.datetime.strptime(doc_date_to, '%Y-%m-%d')
    elif doc_date_to is None:
        doc_date_to = doc_date_from + relativedelta(months=1)

    if isinstance(reward_period_from_start, str):
        reward_period_from_start = datetime.datetime.strptime(reward_period_from_start, '%Y-%m-%d')

    if isinstance(reward_period_from_end, str):
        reward_period_from_end = datetime.datetime.strptime(reward_period_from_end, '%Y-%m-%d')

    @async_to_sync
    async def _do_stuff():
        query = db.select(
            [
                Document,
                Reward
            ]
        ).select_from(
            join(Document, Reward, Reward.id == Document.reward_id)
        )

        if not ignore_null_yadoc_id:
            query = query.where(Document.yadoc_id.is_(None))

        if reward_period_from_start is not None:
            query = query.where(Reward.period_from >= reward_period_from_start)
        if reward_period_from_end is not None:
            query = query.where(Reward.period_from < reward_period_from_end)

        reward_doc_map: typing.Dict[typing.Tuple, Document] = {}
        contract_ids = set()
        for document_model in await query.gino.load(Document.load(reward=Reward.on(Reward.id == Document.reward_id))).all():
            document_model: Document
            key = (
                document_model.reward.contract_id,
                document_model.reward.type,
                document_model.reward.period_from.year,
                document_model.reward.period_from.month
            )
            reward_doc_map[key] = document_model
            contract_ids.add(document_model.reward.contract_id)

        doc_types = [YaDocDocumentType.act.value]

        tvm2_config = RewardsTvm2Config.from_environ()
        tvm_client = get_tvm_client(tvm2_config)

        yadoc_client = YaDocClient(
            YADOC_CONFIG['endpoint_url'],
            tvm_client=tvm_client,
            yadoc_tvm_id=YADOC_CONFIG['tvm_id'],
            raise_for_status=True,
        )

        result = yadoc_client.get_docs_info(
            contract_ids=list(contract_ids),
            doc_types=doc_types,
            date_from=doc_date_from,
            date_to=doc_date_to,
            exclude_reversed=exclude_reversed
        )
        if not reward_doc_map:
            LOGGER.info('All documents have yadoc_id')
            return

        async for doc_info in result:
            contracts_field = doc_info.get('contracts') or []
            for contract_doc_info in contracts_field:
                documents = contract_doc_info.get('documents') or []
                contract_id = contract_doc_info.get('contract_id')
                for document in documents:
                    try:
                        doc_number = document.get('doc_number')  # e.g. 110462/18_Aug-21-Aug-21
                        parsed = DOC_NUMBER_REGEXP.match(doc_number)
                        contract_eid = parsed.group(1)
                        date_from = datetime.datetime.strptime(parsed.group(2), '%b-%y')
                        date_to = datetime.datetime.strptime(parsed.group(3), '%b-%y')
                        doc_date = document.get('doc_date')
                        doc_id = document.get('doc_id')
                        key = (
                            contract_id,
                            compute_reward_type(date_from, date_to),
                            date_from.year,
                            date_from.month
                        )
                        LOGGER.debug('Process doc %s (date: %s, contract: %s)', doc_id, doc_date, contract_eid)
                        doc_model = reward_doc_map.get(key)
                        if doc_model is not None:
                            await doc_model.update(yadoc_id=doc_id).apply()
                        else:
                            LOGGER.debug('Not found doc for key: %s', key)
                    except Exception as ex:
                        LOGGER.exception('Error during processing doc: %s', ex)

    _do_stuff()
