import collections
import logging
from datetime import datetime

from django.db import connection
from django.db.models import Count
from django.forms import model_to_dict

from intranet.search.core.models import Indexation, PushInstance, PushRecord, StageStatus
from intranet.search.core.storages.base import Storage
from intranet.search.core.storages.indexation import indexation_model_to_dict
from intranet.search.core.storages.revision import rev_model_to_dict

log = logging.getLogger(__name__)


def instance_model_to_dict(obj, dump_related=False):
    res = model_to_dict(obj, exclude=['status'])
    res['status'] = obj.get_status_display()
    res['created_at'] = obj.created_at
    if dump_related:
        res['indexation'] = indexation_model_to_dict(obj.indexation) if obj.indexation else None
        res['revision'] = rev_model_to_dict(obj.revision) if obj.revision else None
    return res


def push_model_to_dict(obj):
    res = model_to_dict(obj, exclude=['meta', 'status'])
    res['status'] = obj.get_status_display()
    res['meta'] = obj.meta
    res['date'] = obj.date
    res['organization'] = model_to_dict(obj.organization) if obj.organization_id else {}
    return res


class PushStorage(Storage):
    STATUS_NEW = PushRecord.STATUS_NEW
    STATUS_FAIL = PushRecord.STATUS_FAIL
    STATUS_DONE = PushRecord.STATUS_DONE
    STATUS_CANCEL = PushRecord.STATUS_CANCEL
    STATUS_RETRIED = PushRecord.STATUS_RETRY
    STATUS_KNOWN_FAIL = PushRecord.STATUS_KNOWN_FAIL
    STATUSES = dict(PushRecord.STATUSES)

    indexation_push_status_map = {
        Indexation.STATUS_NEW: STATUS_NEW,
        Indexation.STATUS_DONE: STATUS_DONE,
        Indexation.STATUS_STOP: STATUS_DONE,
        Indexation.STATUS_FAIL: STATUS_FAIL,
    }

    stage_push_status_map = {
        StageStatus.STATUS_NEW: STATUS_NEW,
        StageStatus.STATUS_IN_PROGRESS: STATUS_NEW,
        StageStatus.STATUS_RETRY: STATUS_NEW,
        StageStatus.STATUS_DONE: STATUS_DONE,
        StageStatus.STATUS_FAIL: STATUS_FAIL,
        StageStatus.STATUS_CANCEL: STATUS_CANCEL,
    }

    update_sql = '''
WITH t as (
    SELECT rec.id as id,
    CASE
        WHEN MIN(inst.status) = 1 THEN 1  -- new
        WHEN MIN(inst.status) = 2 THEN 2  -- fail
        WHEN MIN(inst.status) = 6 THEN 6  -- known_fail
        WHEN MIN(inst.status) = 3 THEN 3  -- done
        WHEN MIN(inst.status) = 4 THEN 4  -- cancel
        ELSE null
    END AS status, MAX(inst.updated_at) as end_time
    FROM isearch_pushinstance AS inst
    INNER JOIN isearch_pushrecord AS rec ON (inst.push_id=rec.id)
    WHERE rec.status IN (1,5) and rec.id >= %s and rec.id < %s
    GROUP BY rec.id
    HAVING MIN(inst.status) > 1
)
update isearch_pushrecord
set status=t.status, end_time=t.end_time
from t
where isearch_pushrecord.id=t.id and t.status is not null;
-- returning isearch_pushrecord.id
'''

    def create(self, search, index, type_='unknown', meta=None, comment=None,
               object_id=None, organization_id=None):
        rec = PushRecord.objects.create(
            search=search,
            index=index,
            type=type_,
            meta=meta,
            comment=comment or '',
            object_id=object_id,
            organization_id=organization_id,
            start_time=datetime.now()
        )
        log.info('Push record created: %s, %s', rec.id, push_model_to_dict(rec))
        return rec.id

    def get(self, id_):
        try:
            return push_model_to_dict(PushRecord.objects.get(id=id_))
        except KeyError:
            return None

    def update(self, id_, **kwargs):
        log.info('Push record updated: %s, %s', id_, kwargs)
        return PushRecord.objects.filter(id=id_).update(**kwargs)

    def kill(self, id_, cancel=False, comment=None):
        """
        Помечает пуш как завершённый (не влияет на индексаторы)
        :param cancel: если True, пуш помечается как отменённый, иначе - как зафейленный
        :param comment: комментарий, который будет добавлен к пушу
        """
        status = self.STATUS_CANCEL if cancel else self.STATUS_FAIL
        old_comment = self.get(id_)['comment']
        if comment:
            comment = f'[{status} comment] {comment}'
            if old_comment:
                comment = '\n'.join([old_comment, comment])
        else:
            comment = old_comment
        self.set_instance_status(id_, status)
        self.update(id_, status=status, comment=comment)

    def stats(self, since=None, organization_id=None):
        data = PushRecord.objects.all()

        if since:
            data = data.filter(start_time__gte=since)
        if organization_id:
            data = data.filter(organization_id=organization_id)

        data = data.values_list('search', 'status').annotate(Count('pk'))

        res = collections.defaultdict(dict)
        for search, status, count in data:
            status = self.STATUSES[status]
            res[search][status] = count

        return res

    def delete_older_than(self, old_threshold, interval=100000):
        pushes_qs = PushRecord.objects.exclude(status=PushRecord.STATUS_NEW)
        instances_qs = PushInstance.objects.exclude(push__status=PushRecord.STATUS_NEW)

        first = pushes_qs.filter(start_time__lte=old_threshold).first()
        if not first:
            return

        last = pushes_qs.filter(start_time__lte=old_threshold).last()
        from_id = first.id
        last_id = last.id

        while from_id < last_id:
            till_id = min(from_id + interval, last_id)

            log.info('Delete PushInstances from %s to %s', from_id, till_id)
            qs = instances_qs.filter(push_id__gte=from_id, push_id__lte=till_id)
            qs._raw_delete(qs.db)

            log.info('Delete PushRecords from %s to %s', from_id, till_id)
            qs = pushes_qs.filter(id__gte=from_id, id__lte=till_id)
            qs._raw_delete(qs.db)

            from_id = till_id

    def get_by_search_count(self, add_status=False, **filters):
        query = PushRecord.objects.all()
        if filters:
            query = query.filter(**filters)

        values = ['search', 'index']
        if add_status:
            values.append('status')
        query = query.values(*values).annotate(count=Count('id'))

        result = {}
        for obj in query:
            key = tuple(obj[k] for k in values)
            result[key] = obj['count']

        return result

    def get_push_status_by_indexation(self, indexation_status):
        return self.indexation_push_status_map[indexation_status]

    def get_push_status_by_stage(self, stage_status):
        return self.stage_push_status_map[stage_status]

    def set_instance_status(self, id_, status, revision_id=None, indexation_id=None):
        log.info('Update push=%s, push_status=%s, revision_id=%s, indexation_id=%s',
                 id_, status, revision_id, indexation_id)

        query = PushInstance.objects.filter(push_id=id_)
        if indexation_id:
            query = query.filter(indexation_id=indexation_id)
        if revision_id:
            query = query.filter(revision_id=revision_id)

        updated = query.update(status=status, updated_at=datetime.now())

        if updated == 0:
            if not (revision_id or indexation_id):
                log.warning('Push %s has no instances to update!', id_)
                return

            log.debug('Create new instance for push %s', id_)
            PushInstance.objects.create(
                push_id=id_,
                revision_id=revision_id,
                indexation_id=indexation_id,
                status=status,
            )

    def get_pushes_for_retry(self, max_retries=3):
        """ Возвращает список пушей, которые нужно перезапустить
        """
        pushes_to_retry = (
            PushRecord.objects
            .filter(status=PushStorage.STATUS_FAIL, retries__lt=max_retries)
            .values_list('id', flat=True)
        )
        return pushes_to_retry

    def get_instances_for_retry(self, push_id, only_failed=True):
        """ Возвращает список инстансов, которые нужно поретраить
        """
        instances_to_retry = (
            PushInstance.objects
            .filter(push=push_id)
            .exclude(status=PushStorage.STATUS_RETRIED)
            .select_related('revision', 'indexation')
        )
        if only_failed:
            instances_to_retry = instances_to_retry.filter(status=PushStorage.STATUS_FAIL)

        return [instance_model_to_dict(obj, dump_related=True) for obj in instances_to_retry]

    def update_statuses(self, interval=10000):
        """ Обновление статусов незавершенных пушей"""

        pushes_to_update = (
            PushRecord.objects.filter(
                status__in=(self.STATUS_NEW, self.STATUS_RETRIED),
                instances__status__in=(self.STATUS_DONE, self.STATUS_FAIL, self.STATUS_CANCEL)
            )
        )

        first = pushes_to_update.order_by('id').first()
        if not first:
            return

        last = pushes_to_update.order_by('-id').first()
        from_id = first.id
        last_id = last.id

        while from_id <= last_id:
            till_id = min(from_id + interval, last_id) + 1

            log.info('Updating statuses of PushInstances from %s to %s', from_id, till_id)
            qs_deleted = pushes_to_update.filter(id__gte=from_id, id__lt=till_id)
            qs_deleted.filter(organization__deleted=True).update(status=self.STATUS_KNOWN_FAIL)

            with connection.cursor() as cursor:
                cursor.execute(self.update_sql, [from_id, till_id])

            from_id = till_id

    def fail_staled_instances(self, stale_threshold):
        updated_count = (
            PushInstance.objects
            .filter(status=self.STATUS_NEW, indexation_id__isnull=True,
                    updated_at__lte=stale_threshold)
            .update(status=self.STATUS_FAIL)
        )
        log.info('Fail %s staled push instances, threshold=%s', updated_count, stale_threshold)

        pushes_updated_count = (
            PushRecord.objects
            .filter(status=self.STATUS_NEW, instances__isnull=True,
                    start_time__lte=stale_threshold)
            .update(status=self.STATUS_KNOWN_FAIL, end_time=datetime.now())
        )
        log.info('Fail %s staled pushes, threshold=%s', pushes_updated_count, stale_threshold)
