import logging
import uuid
from typing import Any

from aiopg.sa.connection import SAConnection
from ylog.context import pop_from_context, put_to_context

from intranet.trip.src.config import settings
from intranet.trip.src.db import gateways
from intranet.trip.src.lib.utils import make_hash

logger = logging.getLogger(__name__)


class UnitOfWork:
    """
    Этот класс нужен, чтобы:
      - иметь единую точку входа для всех репозиториев (DBGateway)
      - скрыть объект Connection
      - ставить и выполнять отложенные таски в redis после коммита транзакции

    Пример использования:
    uow = UnitOfWork(conn=conn, redis=redis)
    async with uow:
        await uow.trips.update(trip_id=trip_id, **fields)
        uow.add_job('trip_updated', trip_id=trip_id)

    """
    def __init__(self, conn: SAConnection, redis=None):
        self._conn = conn
        self._redis = redis

        self.bonus_cards = gateways.PersonBonusCardGateway(conn)
        self.billing_deposits = gateways.BillingDepositGateway(conn)
        self.billing_transactions = gateways.BillingTransactionGateway(conn)
        self.companies = gateways.CompanyGateway(conn)
        self.documents = gateways.PersonDocumentGateway(conn)
        self.employees = gateways.EmployeeGateway(conn)
        self.ext_persons = gateways.ExtPersonGateway(conn)
        self.holdings = gateways.HoldingGateway(conn)
        self.person_relationships = gateways.PersonRelationshipGateway(conn)
        self.person_trips = gateways.PersonTripGateway(conn)
        self.persons = gateways.PersonGateway(conn)
        self.purposes = gateways.PurposeGateway(conn)
        self.services = gateways.ServiceGateway(conn)
        self.service_providers = gateways.ServiceProviderGateway(conn)
        self.trips = gateways.TripGateway(conn)

        self._jobs: list[tuple[str, dict]] = []
        self._transaction_id = None

    async def __aenter__(self):
        self._transaction = await self._conn.begin()
        if settings.LOG_TRANSACTION_ID:
            res = await self._conn.execute('select txid_current()')
            data = await res.fetchall()
            if data and data[0]:
                self._transaction_id = data[0][0]
                put_to_context('transaction_id', self._transaction_id)
                logger.info('Starting transaction %s', self._transaction_id)
        return self

    async def __aexit__(self, exn_type, exn_value, traceback):
        if exn_type is None:
            await self._transaction.commit()
            await self._enqueue_jobs()
        else:
            await self._transaction.rollback()
        if self._transaction_id:
            pop_from_context('transaction_id')
            self._transaction_id = None

    def add_job(self, job_name: str, **kwargs: Any) -> None:
        """
        Добавляет таску, чтобы ее выполнить после коммита транзакции

        :param job_name: задача, которую нужно запланировать
        :param kwargs: аргументы, которые нужно передать в задачу
        """
        assert self._redis is not None
        self._jobs.append((job_name, kwargs))

    async def run_job(
            self,
            job_name: str,
            unique: bool = True,
            **kwargs: Any,
    ) -> None:
        """
        Запускает фоновую задачу

        :param job_name: задача, которую нужно запустить
        :param unique: блокировать запуск задачи с таким же набором аргументов
        :param kwargs: аргументы, которые нужно передать в задачу
        """
        assert self._redis is not None
        job_id = kwargs.get('_job_id', self._get_job_id(unique, **kwargs))
        kwargs['_job_id'] = f'{job_name}_{job_id}'
        await self._redis.enqueue_job(job_name, **kwargs)

    async def _enqueue_jobs(self) -> None:
        for job_name, kwargs in self._jobs:
            await self.run_job(job_name, **kwargs)
        self._jobs = []

    @staticmethod
    def _is_of_simple_type(value: Any) -> bool:
        """
        Проверяет, является ли значение целым числом, строкой
        либо списком из 1 элемента
        """
        return (
            isinstance(value, (int, str))
            or (
                isinstance(value, list)
                and len(value) == 1
                and isinstance(value[0], (int, str))
            )
        )

    def _build_job_id_from_args(self, **job_kwargs: Any) -> str:
        simple_keys = [
            key
            for key in sorted(job_kwargs.keys())
            if self._is_of_simple_type(job_kwargs[key])
        ]
        values = [str(job_kwargs[key]) for key in simple_keys]
        hashed_other_kwargs = make_hash({
            key: value
            for key, value in job_kwargs.items()
            if key not in simple_keys
        })
        values.append(hashed_other_kwargs)
        return '_'.join(values)

    def _get_job_id(self, unique: bool, **job_kwargs: Any) -> str:
        if unique:
            return self._build_job_id_from_args(**job_kwargs)
        return uuid.uuid4().hex
