from collections import defaultdict
from threading import local
from typing import Collection
from itertools import chain

from django.db.transaction import get_connection
from django.conf import settings

from staff.emission.logbroker.storage import LogbrokerOrmStorage


class Controller:
    def __init__(self, storage):
        self.storage = storage
        self._context = local()

    def reset_context(self, sid=None):
        self._try_init()

        if sid:
            self._context.objects.pop(sid, None)
        else:
            self._context.objects = defaultdict(list)

    def _try_init(self):
        if not getattr(self._context, 'init', False):
            self._context.objects = defaultdict(list)
            self._context.init = True

    @property
    def cached_objects(self):
        self._try_init()
        return self._context.objects

    @staticmethod
    def _get_sid(using):
        con = get_connection(using)
        if con.savepoint_ids:
            return con.savepoint_ids[-1]
        else:
            return None

    def append(self, obj, action, using=None):
        data = self.storage.serialize_objects([obj])
        self.cached_objects[self._get_sid(using)].append((data, action))

    def get_unsent(self, count: int):
        return self.storage.get_unsent_queryset().order_by('id')[:count]

    def mark_sent(self, message_ids: Collection):
        self.storage.mark_sent(
            self.storage.get_unsent_queryset().filter(id__in=message_ids)
        )

    def commit(self):
        cached_objects = list(chain.from_iterable(self.cached_objects.values()))
        if cached_objects:
            self.storage.bulk_create(cached_objects)
        self.reset_context()


controller = Controller(
    storage=LogbrokerOrmStorage(transaction_wait_delta=settings.LOG_BROKER_TRANSACTION_WAIT_DELTA),
)
