from collections import defaultdict
from datetime import datetime, timedelta
from itertools import chain
from threading import local
from typing import Collection

from django.db.transaction import get_connection

from staff.emission.django.emission_master.storage import MasterOrmStorage


class Controller:
    def __init__(self, storage):
        self.storage = storage
        self._context = local()

    def reset_context(self, sid=None):
        self._try_init()

        if sid:
            self._context.objects.pop(sid, None)
        else:
            self._context.objects = defaultdict(list)

    def _try_init(self):
        if not getattr(self._context, 'init', False):
            self._context.objects = defaultdict(list)
            self._context.init = True

    @property
    def cached_objects(self):
        self._try_init()
        return self._context.objects

    def __getitem__(self, key):
        if isinstance(key, slice):
            if key.start is None or key.start <= 0:
                raise KeyError(key)
            if key.stop is not None and (key.stop < key.start or key.stop <= 0):
                raise KeyError(key)

            return self.get_slice(key.start, key.stop)

        elif isinstance(key, int):
            if key <= 0:
                raise KeyError(key)

            return self.get_one(key)

        else:
            raise KeyError(key)

    def __iter__(self):
        return self.get_iterator()

    @staticmethod
    def _get_cut_date(length_in_days):
        return datetime.now() - timedelta(days=length_in_days)

    def get_iterator(self, from_id=None, max_rows=0):
        while True:
            rows_not_found = True
            for msg in self.storage.get_slice(start=from_id, max_rows=max_rows):
                from_id = msg['id'] + 1
                rows_not_found = False
                yield msg
            if rows_not_found:
                raise StopIteration

    def get_last_id(self):
        return self.storage.get_last_id()

    def get_next_id(self, current_id):
        return self.storage.get_next_id(current_id)

    def cut_outdated(self, length_in_days=7):
        date = self._get_cut_date(length_in_days)
        self.storage.cut(date)

    @staticmethod
    def _get_sid(using):
        con = get_connection(using)
        if con.savepoint_ids:
            return con.savepoint_ids[-1]
        else:
            return None

    def append(self, obj, action, using=None):
        data = self.storage.serialize_objects([obj])
        self.cached_objects[self._get_sid(using)].append((data, action))

    def commit(self):
        cached_objects = list(chain.from_iterable(self.cached_objects.values()))
        if cached_objects:
            self.storage.bulk_create(cached_objects)
        self.reset_context()

    def insert(self, msg_id, data, action):
        self.storage.insert(msg_id, data, action)

    def get_one(self, msg_id):
        return self.storage.get_one(msg_id)

    def delete_one(self, msg_id):
        self.storage.delete_one(msg_id)

    def get_slice(self, start, stop=None, max_rows=0):
        stop = self._limit_right_bound(start, stop, max_rows)

        next_id = start

        for msg in self.storage.get_slice(start, stop):
            if msg['id'] > next_id:
                for empty_msg in self._generate_empty_messages(next_id, msg['id'] - 1):
                    yield empty_msg

            next_id = msg['id'] + 1
            yield msg

        for empty_msg in self._generate_empty_messages(next_id, stop):
            yield empty_msg

    @staticmethod
    def _generate_empty_messages(start, stop):
        creation_time = str(datetime.fromtimestamp(0))

        for i in range(start, stop + 1):
            yield {'id': i, 'data': '[]', 'action': 'modify', 'creation_time': creation_time}

    def _limit_right_bound(self, start, stop, max_rows):
        last_id = self.storage.get_last_id()

        bounds = [last_id]

        if max_rows:
            bounds.append(start + max_rows - 1)

        if stop is not None:
            bounds.append(stop)

        return min(bounds)

    def get_unsent(self, count: int):
        return self.storage.get_unsent_queryset().order_by('id')[:count]

    def mark_sent(self, message_ids: Collection):
        self.storage.mark_sent(
            self.storage.get_unsent_queryset().filter(id__in=message_ids)
        )


controller = Controller(storage=MasterOrmStorage())
