# -*- coding: utf-8 -*-
import logging

from collections import defaultdict, deque
from datetime import datetime, timedelta
from decimal import Decimal
from django.conf import settings
from django.db import connection, transaction
from django.utils import timezone
from json import loads as json_loads
from more_itertools import flatten, pairwise
from pytz import UTC

from events.surveyme.models import (
    ProfileSurveyAnswer,
    ProfileSurveyAnswerCounter,
)
from events.countme.models import (
    AnswerCount,
    AnswerCountByDate,
    AnswerScoresCount,
    AnswerScoresCountByDate,
    QuestionCount,
    QuestionCountByDate,
    QuestionScoresCount,
    QuestionScoresCountByDate,
    QuestionGroupDepth,
    MAX_COMPOSITE_KEY,
)

START_EPOCH = timezone.make_aware(datetime(2019, 4, 1), UTC)
MAX_STATISTICS_REBUILD_COUNT = 250
DELTA_STATISTICS_UPDATE = 5  # minute
SURVEY_LOCK_LIMIT = 100
logger = logging.getLogger(__name__)


def get_surveys(started_at, limit):
    # todo: написать тест после перевода тестирования на postgresql
    sql = '''
        select t.survey_id, t.last_modified
        from countme_answercount t
        where t.last_modified < %s
        and exists(
            select null
            from surveyme_profilesurveyanswer a
            where a.survey_id = t.survey_id
                and a.date_created >= t.last_modified
                and a.date_created > %s - interval '1 hour'
        )
        limit %s
        for update of t skip locked
    '''
    params = (started_at, started_at, limit)
    with connection.cursor() as c:
        c.execute(sql, params)
        for (survey_id, last_modified) in c.fetchall():
            yield survey_id, last_modified


class QuestionType:
    answer_type = None

    def get_value_data(self, answer_question):
        return answer_question.get('value')

    def get_scores(self, answer_question):
        return answer_question.get('scores')

    def get_question_id(self, answer_question):
        return answer_question.get('question', {}).get('id')

    def execute(self, answer_question, *args):
        return (self.get_question_id(answer_question), '', None)


class GroupType(QuestionType):
    answer_type = 'answer_group'

    def execute(self, answer_question, qstack):
        fieldset = self.get_value_data(answer_question) or []
        if fieldset:
            question_id = self.get_question_id(answer_question)
            for group_items in fieldset:
                if isinstance(group_items, dict):
                    group_items = group_items.values()
                if group_items:
                    qstack.extend(group_items)
            return question_id, None, len(fieldset)


class ShortTextType(QuestionType):
    answer_type = 'answer_short_text'

    def execute(self, answer_question, *args):
        question_id = self.get_question_id(answer_question)
        scores = self.get_scores(answer_question)
        composite_key = ''
        if scores:
            value = self.get_value_data(answer_question)
            composite_key = value[:MAX_COMPOSITE_KEY]
        return (question_id, composite_key, scores)


class BooleanType(QuestionType):
    answer_type = 'answer_boolean'

    def execute(self, answer_question, *args):
        question_id = self.get_question_id(answer_question)
        composite_key = '1' if self.get_value_data(answer_question) else '0'
        return (question_id, composite_key, None)


class ChoicesType(QuestionType):
    answer_type = 'answer_choices'

    def execute(self, answer_question, *args):
        question_id = self.get_question_id(answer_question)
        scores = self.get_scores(answer_question)
        data_source = answer_question.get('question', {}).get('options', {}).get('data_source')
        get_composite_key = self.get_composite_key_for(data_source)
        composite_keys = [
            get_composite_key(value_item)
            for value_item in self.get_value_data(answer_question) or []
        ]
        return (question_id, composite_keys, scores)

    def get_composite_key_for(self, data_source):
        if data_source == 'survey_question_matrix_choice':
            def _wrapped(value_item):
                return '%s:%s' % (
                    value_item.get('row', {}).get('key'),
                    value_item.get('col', {}).get('key'),
                )
            return _wrapped
        elif data_source == 'survey_question_choice':
            def _wrapped(value_item):
                composite_key = value_item.get('key')
                return composite_key[:MAX_COMPOSITE_KEY]
        else:
            def _wrapped(value_item):
                composite_key = value_item.get('slug') or value_item.get('key')
                return composite_key[:MAX_COMPOSITE_KEY]
        return _wrapped


QUESTION_TYPES = {
    cl.answer_type: cl()
    for cl in (QuestionType, GroupType, ShortTextType, BooleanType, ChoicesType)
}


def get_stat_info(data):
    data = data or {}
    answer_data = data.get('data') or []
    if isinstance(answer_data, dict):
        answer_data = answer_data.values()
    default = QUESTION_TYPES.get(None)
    qstack = deque(answer_data)
    while qstack:
        answer_question = qstack.popleft()
        if not answer_question:
            continue
        slug = answer_question.get('question', {}).get('answer_type', {}).get('slug')
        question_type = QUESTION_TYPES.get(slug, default)
        result = question_type.execute(answer_question, qstack)
        if result is not None:
            yield (slug, *result)


class Counters:
    def __init__(self, survey_id):
        self.survey_id = survey_id
        self.answer_count = 0
        self.answer_count_by_date = defaultdict(int)
        self.answer_scores_count = defaultdict(int)
        self.answer_scores_count_by_date = defaultdict(int)
        self.question_count = defaultdict(int)
        self.question_count_by_date = defaultdict(int)
        self.question_scores_count = defaultdict(int)
        self.question_scores_count_by_date = defaultdict(int)
        self.group_depth = defaultdict(int)

    def add_answer(self, created, scores):
        self.answer_count += 1
        self.answer_count_by_date[created] += 1
        if scores is not None:
            scores = Decimal(str(round(scores, 2)))
            self.answer_scores_count[scores] += 1
            self.answer_scores_count_by_date[scores, created] += 1

    def _add_composite_key(self, question_id, composite_key, created):
        self.question_count[(question_id, composite_key)] += 1
        self.question_count_by_date[(question_id, composite_key, created)] += 1

    def add_question(self, question_id, composite_keys, created, scores):
        if isinstance(composite_keys, list):
            for composite_key in composite_keys:
                self._add_composite_key(question_id, composite_key, created)
        else:
            self._add_composite_key(question_id, composite_keys, created)
        if scores is not None:
            scores = Decimal(str(round(scores, 2)))
            self.question_scores_count[(question_id, scores)] += 1
            self.question_scores_count_by_date[(question_id, scores, created)] += 1

    def add_group_depth(self, question_id, depth):
        self.group_depth[question_id] = max(self.group_depth.get(question_id, 0), depth)

    def _save_answer_count(self, started_at):
        sql = '''
            insert into countme_answercount as a (survey_id, last_modified, count)
            values(%s, %s, %s)
            on conflict (survey_id)
            do update set count = a.count + excluded.count, last_modified = excluded.last_modified
        '''
        params = (self.survey_id, started_at, self.answer_count)
        with connection.cursor() as c:
            c.execute(sql, params)

    def _save_answer_count_by_date(self):
        if not self.answer_count_by_date:
            return

        values_pattern = '(%s, %s, %s)'
        sql = '''
            insert into countme_answercountbydate as a (survey_id, created, count)
            values{values}
            on conflict (survey_id, created)
            do update set count = a.count + excluded.count
        '''.format(
            values=', '.join([values_pattern] * len(self.answer_count_by_date))
        )
        params = tuple(flatten(
            (self.survey_id, created, count)
            for created, count in self.answer_count_by_date.items()
        ))
        with connection.cursor() as c:
            c.execute(sql, params)

    def _save_answer_scores_count(self):
        if not self.answer_scores_count:
            return

        values_pattern = '(%s, %s, %s)'
        sql = '''
            insert into countme_answerscorescount as a (survey_id, scores, count)
            values{values}
            on conflict (survey_id, scores)
            do update set count = a.count + excluded.count
        '''.format(
            values=', '.join([values_pattern] * len(self.answer_scores_count))
        )
        params = tuple(flatten(
            (self.survey_id, scores, count)
            for scores, count in self.answer_scores_count.items()
        ))
        with connection.cursor() as c:
            c.execute(sql, params)

    def _save_answer_scores_count_by_date(self):
        if not self.answer_scores_count_by_date:
            return

        values_pattern = '(%s, %s, %s, %s)'
        sql = '''
            insert into countme_answerscorescountbydate as a (survey_id, scores, created, count)
            values{values}
            on conflict (survey_id, scores, created)
            do update set count = a.count + excluded.count
        '''.format(
            values=', '.join([values_pattern] * len(self.answer_scores_count_by_date))
        )
        params = tuple(flatten(
            (self.survey_id, scores, created, count)
            for (scores, created), count in self.answer_scores_count_by_date.items()
        ))
        with connection.cursor() as c:
            c.execute(sql, params)

    def _save_question_count(self):
        if not self.question_count:
            return

        values_pattern = '(%s, %s, %s, %s)'
        sql = '''
            insert into countme_questioncount as a (survey_id, question_id, composite_key, count)
            values{values}
            on conflict (question_id, composite_key)
            do update set count = a.count + excluded.count
        '''.format(
            values=', '.join([values_pattern] * len(self.question_count))
        )
        params = tuple(flatten(
            (self.survey_id, question_id, composite_key, count)
            for (question_id, composite_key), count in self.question_count.items()
        ))
        with connection.cursor() as c:
            c.execute(sql, params)

    def _save_question_count_by_date(self):
        if not self.question_count_by_date:
            return

        values_pattern = '(%s, %s, %s, %s, %s)'
        sql = '''
            insert into countme_questioncountbydate as a (survey_id, question_id, composite_key, created, count)
            values{values}
            on conflict (question_id, composite_key, created)
            do update set count = a.count + excluded.count
        '''.format(
            values=', '.join([values_pattern] * len(self.question_count_by_date))
        )
        params = tuple(flatten(
            (self.survey_id, question_id, composite_key, created, count)
            for (question_id, composite_key, created), count in self.question_count_by_date.items()
        ))
        with connection.cursor() as c:
            c.execute(sql, params)

    def _save_question_scores_count(self):
        if not self.question_scores_count:
            return

        values_pattern = '(%s, %s, %s, %s)'
        sql = '''
            insert into countme_questionscorescount as a (survey_id, question_id, scores, count)
            values{values}
            on conflict (question_id, scores)
            do update set count = a.count + excluded.count
        '''.format(
            values=', '.join([values_pattern] * len(self.question_scores_count))
        )
        params = tuple(flatten(
            (self.survey_id, question_id, scores, count)
            for (question_id, scores), count in self.question_scores_count.items()
        ))
        with connection.cursor() as c:
            c.execute(sql, params)

    def _save_question_scores_count_by_date(self):
        if not self.question_scores_count_by_date:
            return

        values_pattern = '(%s, %s, %s, %s, %s)'
        sql = '''
            insert into countme_questionscorescountbydate as a (survey_id, question_id, scores, created, count)
            values{values}
            on conflict (question_id, scores, created)
            do update set count = a.count + excluded.count
        '''.format(
            values=', '.join([values_pattern] * len(self.question_scores_count_by_date))
        )
        params = tuple(flatten(
            (self.survey_id, question_id, scores, created, count)
            for (question_id, scores, created), count in self.question_scores_count_by_date.items()
        ))
        with connection.cursor() as c:
            c.execute(sql, params)

    def _save_group_depth(self):
        if not self.group_depth:
            return

        values_pattern = '(%s, %s, %s)'
        sql = '''
            insert into countme_questiongroupdepth as a (survey_id, question_id, depth)
            values{values}
            on conflict (question_id)
            do update set depth = greatest(a.depth, excluded.depth)
        '''.format(
            values=', '.join([values_pattern] * len(self.group_depth))
        )
        params = tuple(flatten(
            (self.survey_id, question_id, depth)
            for question_id, depth in self.group_depth.items()
        ))
        with connection.cursor() as c:
            c.execute(sql, params)

    def save(self, started_at):
        # todo: написать тест после перевода тестирования на postgresql
        self._save_answer_count(started_at)
        self._save_answer_count_by_date()
        self._save_question_count()
        self._save_question_count_by_date()
        self._save_answer_scores_count()
        self._save_answer_scores_count_by_date()
        self._save_question_scores_count()
        self._save_question_scores_count_by_date()
        self._save_group_depth()


def get_answers(started_at, survey_id, last_modified):
    return (
        ProfileSurveyAnswer.objects.using(settings.DATABASE_ROLOCAL)
        .filter(
            survey_id=survey_id,
            date_created__gte=last_modified,
            date_created__lt=started_at,
        )
        .order_by()
        .values_list('date_created', 'data')
    )


def calendar(start, finish):
    day = start
    while day < finish:
        yield day
        day += timedelta(days=1)
    yield finish


def get_count_estimate(survey_id):
    try:
        return ProfileSurveyAnswerCounter.objects.get(survey_id=survey_id).answers_count
    except ProfileSurveyAnswerCounter.DoesNotExist:
        return 0


def is_large_number_of_answers(started_at, survey_id, last_modified):
    delta = started_at - last_modified
    if delta.days > 7 and get_count_estimate(survey_id) > 10000:
        return True
    return False


def get_all_answers(started_at, survey_id, last_modified):
    if is_large_number_of_answers(started_at, survey_id, last_modified):
        for start, finish in pairwise(calendar(last_modified, started_at)):
            yield from get_answers(finish, survey_id, start)
    else:
        yield from get_answers(started_at, survey_id, last_modified)


def get_counters(started_at, survey_id, last_modified):
    counters = Counters(survey_id)
    # нужно на этапе инициализации данных, потом можно будет заменить на get_answers
    answers_qs = get_all_answers(started_at, survey_id, last_modified)
    for date_created, data in answers_qs:
        if not data:
            continue
        if isinstance(data, str):  # для поддержки sqlite
            data = json_loads(data)
        created = date_created.date()
        quiz = data.get('quiz') or {}
        scores = quiz.get('scores')
        counters.add_answer(created, scores)
        for slug, question_id, composite_keys, scores in get_stat_info(data):
            if slug != 'answer_group':
                counters.add_question(question_id, composite_keys, created, scores)
            else:
                counters.add_group_depth(question_id, scores)
    return counters


def update_counters(started_at):
    cnt = 0
    with transaction.atomic():
        for (survey_id, last_modified) in get_surveys(started_at, limit=SURVEY_LOCK_LIMIT):
            cnt += 1
            counters = get_counters(started_at, survey_id, last_modified)
            counters.save(started_at)

            if counters.answer_count > 0:
                logger.info('Updated survey (%s) counters for %s answers', survey_id, counters.answer_count)
    return cnt


def delete_survey_counters(survey_id):
    AnswerCount.objects.filter(survey_id=survey_id).delete()
    AnswerCountByDate.objects.filter(survey_id=survey_id).delete()
    AnswerScoresCount.objects.filter(survey_id=survey_id).delete()
    AnswerScoresCountByDate.objects.filter(survey_id=survey_id).delete()
    QuestionCount.objects.filter(survey_id=survey_id).delete()
    QuestionCountByDate.objects.filter(survey_id=survey_id).delete()
    QuestionScoresCount.objects.filter(survey_id=survey_id).delete()
    QuestionScoresCountByDate.objects.filter(survey_id=survey_id).delete()
    QuestionGroupDepth.objects.filter(survey_id=survey_id).delete()


def init_survey_counters(survey_id):
    AnswerCount.objects.create(survey_id=survey_id, count=0, last_modified=START_EPOCH)


def update_survey_counters(survey_id, started_at):
    counters = get_counters(started_at, survey_id, START_EPOCH)
    counters.save(started_at)

    if counters.answer_count > 0:
        logger.info('Updated survey (%s) counters for %s answers', survey_id, counters.answer_count)


def rebuild_counters(survey_id):
    with transaction.atomic():
        delete_survey_counters(survey_id)
        init_survey_counters(survey_id)
        update_survey_counters(survey_id, timezone.now())


def check_if_need_rebuild_counters(survey):
    if (
        not survey.date_archived
        and (
            survey.answercount is None
            or survey.answercount.count < MAX_STATISTICS_REBUILD_COUNT
        )
    ):
        last_created = (
            ProfileSurveyAnswer.objects.using(settings.DATABASE_ROLOCAL)
            .filter(survey=survey)
            .values_list('date_created', flat=True)
            .order_by('pk')
            .last()
        )
        if last_created:
            return timezone.now() - last_created < timedelta(minutes=DELTA_STATISTICS_UPDATE)
    return False
