import datetime
from importlib import import_module
import json
import logging
import re
from typing import Any, Dict

import dateutil.parser

import celery

from django.db import DatabaseError, models
from django.db.transaction import get_connection

from staff.celery_app import app
from staff.lib.db import atomic

logger = logging.getLogger(__name__)


@app.task
def execute_ordered_task(task_module_name, task_class_name, identity):
    module = import_module(task_module_name)
    task_class = module.__dict__[task_class_name]
    task_class.execute_next_task(identity)


class DatetimeJSONEncoder(json.JSONEncoder):
    MAGIC_PREFIX = '!@#$%^&(*:JGF&^%$#@#$%^&*^%;jgf$#$%^&UHGFRUHFTyvyuwq345'
    DATE_MASK = r'\d{4}-\d{2}-\d{2}'
    TIME_MASK = r'\d{2}:\d{2}:\d{2}'

    ISO_PATTERN = re.compile(f'{DATE_MASK}T{TIME_MASK}')
    DATE_PATTERN = re.compile(DATE_MASK)
    TIME_PATTERN = re.compile(TIME_MASK)

    def default(self, o):
        if isinstance(o, (datetime.datetime, datetime.date, datetime.time)):
            return self.MAGIC_PREFIX + o.isoformat()
        else:
            return super().default(o)

    @classmethod
    def hook_date_time(cls, original):
        for key, value in original.items():
            if isinstance(value, str):
                if value.startswith(cls.MAGIC_PREFIX):
                    encoded_value = value[len(cls.MAGIC_PREFIX):]
                    extracted_value = dateutil.parser.parse(encoded_value)

                    if cls.ISO_PATTERN.fullmatch(encoded_value):
                        original[key] = extracted_value
                    elif cls.DATE_PATTERN.fullmatch(encoded_value):
                        original[key] = extracted_value.date()
                    elif cls.TIME_PATTERN.fullmatch(encoded_value):
                        original[key] = extracted_value.time()

        return original


class OrderedTasks:
    max_retry_attempts = 13
    Model: models.Model = None
    identity_field: str = 'entity_id'
    _execute_next_task_currently_running = False

    @staticmethod
    def _get_name_and_module(callable_or_task):
        if isinstance(callable_or_task, app.Task) or isinstance(callable_or_task, celery.Task):
            return (
                callable_or_task.__class__.__name__,
                callable_or_task.__class__.__module__,
            )

        if callable(callable_or_task):
            return (
                callable_or_task.__name__,
                callable_or_task.__module__,
            )

        logger.error('Whooa, i will call string to divide by zero')
        assert False

    @staticmethod
    def _serialize(kwargs: Dict[str, Any]) -> str:
        return json.dumps(kwargs, cls=DatetimeJSONEncoder)

    @classmethod
    def _kwargs_for_model_creation(
        cls,
        identity: Any,
        kwargs: Dict[str, Any],
        module: str,
        callable: str,
    ) -> Dict[str, Any]:
        return {
            cls.identity_field: identity,
            'args': cls._serialize(kwargs),
            'module': module,
            'callable': callable,
        }

    @classmethod
    def schedule_ordered_task(cls, identity, callable_or_task, kwargs):
        # Предполагается, что задачи на исполнение будут ставиться только если транзакция успешно пройдет.
        # Короче БЛ и появление задачки должны быть в одной транзакции
        assert get_connection().in_atomic_block
        logger.debug('Scheduling new ordered task for %s', identity)
        callable_name, callable_module = cls._get_name_and_module(callable_or_task)

        cls.Model.objects.create(
            **cls._kwargs_for_model_creation(identity, kwargs, callable_module, callable_name)
        )

        # No need to schedule, in case of successful task execution
        # next execution will be scheduled at the end of execute_next_task
        # We need this to prevent endless recursion on setups where 'delay'
        # function from celery immediately executes
        if not cls._execute_next_task_currently_running:
            cls._delay_task(identity)
            logger.debug('Ordered task for %s scheduled', identity)

    @staticmethod
    def _get_task_to_execute(queue_item):
        module = import_module(queue_item.module)
        task = module.__dict__[queue_item.callable]
        return task

    @classmethod
    def _get_next_queue_item(cls, identity):
        # Всегда берем только первый в очереди объект и лочимся на нем
        try:
            queue_item = (
                cls.Model.objects
                .order_by('id')
                .filter(**{cls.identity_field: identity})
                .select_for_update()
                .first()
            )

            if not queue_item:
                logger.info('Queue item for %s not found. Probably already executed. Skip it', identity)
                return None

            if queue_item.fail_count >= cls.max_retry_attempts:
                logger.debug('Queue item for %s failed too many times', identity)
                return None

            return queue_item
        except DatabaseError:
            logger.debug('Queue item already locked by another transaction')
            return None

    @staticmethod
    def _deserialize(args: str) -> Dict[str, Any]:
        return json.loads(args, object_hook=DatetimeJSONEncoder.hook_date_time)

    @classmethod
    def _execute_task(cls, task: Any, args: str):
        task(**cls._deserialize(args))

    @classmethod
    def execute_next_task(cls, identity):
        queue_item = None

        try:
            cls._execute_next_task_currently_running = True
            with atomic():
                logger.info('Executing next %s ordered task for %s', cls.__name__, identity)

                queue_item = cls._get_next_queue_item(identity)

                if queue_item is None:
                    return

                logger.debug('Ordered task for %s locked', identity)
                task = cls._get_task_to_execute(queue_item)

                # Весь код внутри таски работает внутри транзакции. Т.е задача будет выполнена и удалена из очереди
                # только если успешно отработал код БЛ, задача удалилась из очереди и коммит прошел успешно
                cls._execute_task(task, queue_item.args)
                logger.debug('Ordered task for %s executed', identity)

                queue_item.delete()
        except Exception:
            if queue_item:
                cls.Model.objects.filter(id=queue_item.id).update(fail_count=models.F('fail_count') + 1)

                if queue_item.fail_count + 1 >= cls.max_retry_attempts:
                    logger.exception('%s ordered task failed %s times', queue_item.id, queue_item.fail_count)
                else:
                    logger.info(
                        '%s ordered task will be retried (attempt %s)',
                        queue_item.id,
                        queue_item.fail_count,
                        exc_info=True,
                    )
            raise
        finally:
            cls._execute_next_task_currently_running = False

        logger.debug('Ordered task execution committed')
        cls._delay_task(identity)
        logger.debug('Execution of next ordered task for %s scheduled', identity)

    @classmethod
    def _delay_task(cls, identity):
        execute_ordered_task.delay(cls.__module__, cls.__name__, identity)

    @classmethod
    def get_dead_tasks(cls):
        return cls.Model.objects.filter(fail_count__gte=cls.max_retry_attempts)

    @classmethod
    def retry_all(cls):
        qs = list(
            cls.Model.objects
            .filter(fail_count__lt=cls.max_retry_attempts)
            .distinct(cls.identity_field)
            .values_list(cls.identity_field, flat=True)
        )

        for identity in qs:
            cls.execute_next_task(identity)
