import json
import uuid
from collections import defaultdict
import logging

from django import db

from intranet.search.core.utils import generate_prefix, parse_prefix
from intranet.search.core import redis
from intranet.search.core.models import StageStatus

from .base import Storage

log = logging.getLogger(__name__)


class StageStatusStorage(Storage):
    STATUS_NEW = StageStatus.STATUS_NEW
    STATUS_DONE = StageStatus.STATUS_DONE
    STATUS_FAIL = StageStatus.STATUS_FAIL
    STATUS_CANCEL = StageStatus.STATUS_CANCEL
    STATUS_IN_PROGRESS = StageStatus.STATUS_IN_PROGRESS
    STATUS_RETRY = StageStatus.STATUS_RETRY

    id_sep = ':'
    id_schema = ['revision_id', 'indexation_id', 'stage', 'uuid']

    terminal_statuses = frozenset((STATUS_DONE, STATUS_FAIL, STATUS_CANCEL))
    nonterminal_statuses = frozenset((STATUS_NEW, STATUS_IN_PROGRESS, STATUS_RETRY))
    statuses = terminal_statuses | nonterminal_statuses

    stages = ('setup', 'walk', 'push', 'load', 'fetch',
              'create', 'content', 'store')

    def __init__(self, revision_id, indexation_id):
        self.revision_id = revision_id
        self.indexation_id = indexation_id

    # Публичный интерфейс

    def create(self, stage, app, status=STATUS_NEW):
        """Создает новый статус с заданным именем
        """
        assert app == 'global'
        assert stage in self.stages

        status_id = self._generate_id(stage)

        self.update(status_id, status)

        return status_id

    def update(self, status_id, status):
        """ Обновляет статус
        """
        assert status in self.statuses
        if not status_id:
            log.debug('Trying to set status `%s` without status_id', status)
            return status

        max_retries = 3
        retries = 0
        while True:
            try:
                StageStatus.objects.create(status_id=status_id, status=status)
            except db.DatabaseError:
                retries += 1
                if retries >= max_retries:
                    raise
            else:
                log.debug('Set status %s %s', status_id, status)
                return status

    def start(self, status_id):
        """Отметить стадию как начатую
        """
        return self.update(status_id, self.STATUS_IN_PROGRESS)

    def cancel(self, status_id):
        """Отметить стадию как отмененную
        """
        return self.update(status_id, self.STATUS_CANCEL)

    def fail(self, status_id):
        """Отметить стадию как сломавшуюся
        """
        return self.update(status_id, self.STATUS_FAIL)

    def succeed(self, status_id):
        """Отметить стадию как удачно завершенную
        """
        return self.update(status_id, self.STATUS_DONE)

    def retry(self, status_id):
        """Отметить стадию в состоянии ретрая
        """
        return self.update(status_id, self.STATUS_RETRY)

    def get_stats(self):
        """Возвращает статистику по количеству статусов для стадий
        """
        qs = self._filter()
        stream = list(qs.order_by('id').values())

        statuses = {stage_status['status_id']: stage_status for stage_status in stream}

        leave = []
        result = self.get_default_stats()

        for status_id, status in statuses.items():
            stage = self._get_stage_from_id(status_id)
            result[stage, status['status']] += 1

            # если статус не терминальный, то надо эту запись оставить для следующего тика
            if status['status'] not in self.terminal_statuses:
                log.debug('Leave status in db: `%s`.`%s` %s', stage, status['id'], status['status'])
                leave.append(status['id'])

        if stream:
            qs.filter(id__lte=stream[-1]['id']).exclude(id__in=leave).delete()

        return result

    def delete(self, stage=None):
        """Удаляет заданные стадии
        """
        stages = self._filter(stage)
        stages.delete()

    @classmethod
    def get_indexations(cls, limit=None):
        pairs = set()
        qs = StageStatus.objects.exclude(status_id__contains='push').values_list('status_id', flat=True)
        if limit:
            qs = qs.order_by('id')[:limit]

        for status_id in qs.iterator():
            prefix_bits = parse_prefix(cls.id_sep, cls.id_schema, status_id)
            indexation_id = prefix_bits.get('indexation_id', 0)
            if indexation_id != 0 and indexation_id != 'push':
                pairs.add((prefix_bits['revision_id'], prefix_bits['indexation_id']))

        return pairs

    @classmethod
    def get_default_stats(cls):
        return {(stage, status): 0 for stage in cls.stages for status in cls.statuses}

    @classmethod
    def _generate_id_prefix(cls, **kwargs):
        return generate_prefix(cls.id_sep, cls.id_schema, **kwargs)

    def _generate_id(self, stage):
        return self._generate_id_prefix(revision_id=self.revision_id, indexation_id=self.indexation_id,
                                        stage=stage, uuid=uuid.uuid1())

    @classmethod
    def _get_stage_from_id(cls, id_):
        return id_.split(cls.id_sep)[-2]

    def _filter(self, stage=None):
        prefix_bits = {'revision_id': self.revision_id, 'indexation_id': self.indexation_id}

        qs = StageStatus.objects.all()

        if isinstance(stage, str):
            prefix_bits['stage'] = stage

        prefix = self._generate_id_prefix(**prefix_bits)
        return qs.filter(status_id__startswith=prefix + self.id_sep)


class LocalStageStatusStorage(StageStatusStorage):
    key_sep = ':'
    key_prefix = 'stage-status'
    key_schema = ['key_prefix', 'revision_id', 'indexation_id']

    id_sep = '.'
    id_schema = ['stage', 'uuid']

    def __init__(self, revision_id, indexation_id):
        super().__init__(revision_id, indexation_id)

        self.global_storage = StageStatusStorage(revision_id, indexation_id)
        self.key = generate_prefix(self.key_sep, self.key_schema,
                                   key_prefix=self.key_prefix, revision_id=self.revision_id,
                                   indexation_id=self.indexation_id)

    def create(self, stage, app, status=StageStatusStorage.STATUS_NEW):
        if app == 'global':
            id_ = self.global_storage.create(stage, app, status)
        else:
            assert app == 'local'
            assert stage in self.stages

            # нужно зашить стадию в id, чтобы всегда её можно было определить
            id_ = self._generate_id(stage)
            self._push(id=id_, status=status)

        log.debug('Create stage %s, %s, %s, %s', stage, status, app, id_)
        return id_

    def update(self, status_id, status):
        assert status in self.statuses
        if not status_id:
            log.debug('Trying to set status `%s` without status_id', status)
            return status

        log.debug('Update %s: %s', status_id, status)

        if self._is_global_id(status_id):
            return self.global_storage.update(status_id, status)
        else:
            self._push(id=status_id, status=status)
            return status

    def get_stats(self, with_global=False):
        client = redis.get_client()
        if with_global:
            result = self.global_storage.get_stats()
        else:
            result = self.get_default_stats()
        log.debug('Get stats with_global=%s: %s', with_global, result)
        statuses = defaultdict(dict)

        stream = client.lrange(self.key, 0, -1)
        for stage_status in stream:
            stage_status = self._decode(stage_status)

            statuses[stage_status['id']].update(stage_status)

        client.ltrim(self.key, len(stream), -1)
        leave = []

        for stage_status in statuses.values():
            log.debug('Got stage_status: %s', stage_status)
            stage = self._get_stage_from_id(stage_status['id'])
            status = stage_status['status']

            result[stage, status] += 1

            # если статус не терминальный, то надо эту запись оставить для следующего тика
            if status not in self.terminal_statuses:
                log.debug('Leave status in redis: `%s`.`%s` %s', stage, stage_status['id'], status)
                leave.append(self._encode(stage_status))

        # вставляем оставшиеся записи в начало списка
        if leave:
            log.debug('Push non terminal statuses back. key: %s, leave count: %s',
                      self.key, len(leave))
            client.lpush(self.key, *leave)
        elif client.llen(self.key) == 0:
            # если ничего не осталось - удаляем ключ. Любая следующая вставка автоматически
            # его добавит, а если вставок больше не будет - то мы не будем обрабатывать статистику
            # для лишней индексации.
            log.debug('Delete stats key: %s', self.key)
            client.delete(self.key)
        log.debug('Update stats: %s', result)
        return result

    def delete(self, stage=None):
        log.debug('Start statuses cleanup. Key: %s', self.key)
        redis.get_client().delete(self.key)
        self.global_storage.delete(stage)

    @classmethod
    def get_indexations(cls, with_global=True, with_global_limit=None):
        """ Все пары (ревизия, индексация) которые есть в redis """
        if with_global:
            pairs = StageStatusStorage.get_indexations(with_global_limit)
        else:
            pairs = set()

        key = generate_prefix(cls.key_sep, cls.key_schema,
                              key_prefix=cls.key_prefix, revision_id='*', indexation_id='*')

        for value in redis.get_client().scan_iter(key):
            prefix_bits = parse_prefix(cls.key_sep, cls.key_schema, value)
            if prefix_bits.get('indexation_id', 0) != 0:  # 0 - специальный id для пушей
                pairs.add((int(prefix_bits['revision_id']), int(prefix_bits['indexation_id'])))

        return pairs

    def _push(self, **kwargs):
        data = self._encode(kwargs)
        redis.get_client().rpush(self.key, data)

    def _generate_id(self, stage):
        return self._generate_id_prefix(stage=stage, uuid=uuid.uuid1())

    @classmethod
    def _is_global_id(cls, id_):
        """ Проверяет ключ из redis или базы """
        # id_.startswith('stage-status') нужно на переходный период в релизе,
        # потому что иначе статусы индексаций, начавшихся до релиза попадают в mysql
        return cls.id_sep not in id_ and not id_.startswith('stage-status')

    def _encode(self, data):
        return json.dumps(data)

    def _decode(self, raw_data):
        return json.loads(raw_data)
