# -*- coding: utf-8 -*-
import logging
import os
import pytz
import yenv

from collections import namedtuple
from django.conf import settings
from django.db import transaction, connections
from django.utils import timezone
from json import loads as json_loads
from yt.wrapper import YtClient, TablePath

from events.celery_app import app
from events.common_app.utils import chunks
from events.surveyme.models import (
    AnswerExportYtStatus,
    ProfileSurveyAnswer,
)


logger = logging.getLogger(__name__)

LOCK_ANSWERS_LIMIT = 10000
YT_UPLOAD_LIMIT = 2000
DB_UPDATE_LIMIT = 50

Answer = namedtuple('Answer', 'answer_id survey_id created lang ip uid yandexuid answer_data'.split())


def group_to_dict(answer_question):
    slug = answer_question.get('question', {}).get('answer_type', {}).get('slug')
    if slug == 'answer_group':
        value_data = [
            to_dict(fieldset)
            for fieldset in answer_question.get('value') or []
            if fieldset
        ]
        answer_question['value'] = value_data
    return answer_question


def to_dict(data):
    if isinstance(data, list):
        data = {
            it.get('question', {}).get('slug'): group_to_dict(it)
            for it in data
            if it
        }
    return data


def convert_answer_data(answer_data):
    answer_data['data'] = to_dict(answer_data.get('data', []))
    return answer_data


def _get_answer(row):
    answer_id, survey_id, created, source_request, answer_data = row
    if not answer_data:
        return None
    if isinstance(answer_data, str):
        answer_data = json_loads(answer_data)
    if isinstance(source_request, str):
        source_request = json_loads(source_request)
    elif source_request is None:
        source_request = {}
    lang = source_request.get('lang')
    ip = answer_data.pop('ip', None) or source_request.get('ip')
    uid = answer_data.pop('uid', None)
    yandexuid = answer_data.pop('yandexuid', None) or source_request.get('cookies', {}).get('yandexuid')
    return Answer(
        answer_id=answer_id,
        survey_id=str(survey_id),
        created=created,
        lang=lang,
        ip=ip,
        uid=uid,
        yandexuid=yandexuid,
        answer_data=convert_answer_data(answer_data),
    )


def get_answer(obj):
    if isinstance(obj, tuple):
        return _get_answer(obj)
    elif isinstance(obj, ProfileSurveyAnswer):
        return _get_answer((obj.pk, obj.survey_id, obj.date_created, obj.source_request, obj.data))


def get_table_path(survey_id, created, app_type=None):
    home_path = '//home/forms/answers'
    app_type = app_type or settings.APP_TYPE

    if settings.IS_BUSINESS_SITE:
        parts = (created.date().isoformat(),)
    else:
        parts = (survey_id, 'data')

    return '/'.join([home_path, app_type, yenv.type, *parts])


def get_answers_for_update(export_status, limit_size, started_at):
    ordering = 'date_created' if settings.IS_BUSINESS_SITE else 'survey_id'
    answers_qs = (
        ProfileSurveyAnswer.objects.select_for_update()
        .filter(
            export_yt_status__exported=export_status,
            date_created__lt=started_at,
        )
        .order_by(ordering)
        .values_list('pk', 'survey_id', 'date_created', 'source_request', 'data')
        [:limit_size]
    )
    for row in answers_qs:
        answer = get_answer(row)
        if answer:
            yield answer


def get_answers_skip_locked(export_status, limit_size, started_at):
    ordering = '' if settings.IS_BUSINESS_SITE else 'order by t.survey_id'
    sql = '''
        select st.answer_id, t.survey_id, t.date_created, t.source_request, t.data
        from surveyme_answerexportytstatus st
        join surveyme_profilesurveyanswer t on t.id = st.answer_id
        where st.exported = %s
            and t.date_created < %s
        {ordering}
        limit %s
        for update of st skip locked
    '''.format(ordering=ordering)
    connection_master = connections[settings.DATABASE_DEFAULT]
    c = connection_master.cursor()
    params = (export_status, started_at, limit_size)
    c.execute(sql, params)
    for row in c.fetchall():
        answer = get_answer(row)
        if answer:
            yield answer


def get_answers(export_status, limit_size, started_at):
    connection_master = connections[settings.DATABASE_DEFAULT]
    if connection_master.vendor == 'postgresql':
        yield from get_answers_skip_locked(export_status, limit_size, started_at)
    else:
        yield from get_answers_for_update(export_status, limit_size, started_at)


def datetime_to_string(dt):
    naive = timezone.make_naive(dt, pytz.UTC)
    return naive.strftime('%Y-%m-%dT%H:%M:%SZ')


class DeployAnswers:
    def __init__(self):
        self.client = self._get_client()

    def _get_client(self):
        config = {
            'token': settings.YT_TOKEN,
        }
        return YtClient('hahn', config=config)

    def _create_folder(self, folder_path):
        exists = self.client.exists(folder_path)
        if not exists:
            self.client.mkdir(folder_path, recursive=True)

    def _create_table(self, table_path):
        params = {
            'recursive': True,
            'ignore_existing': True,
            'attributes': {
                'schema': self._get_schema(),
            },
        }
        self.client.create('table', table_path, **params)

    def _get_schema(self):
        return [
            {
                'name': 'id',
                'type': 'int64',
                'required': True,
            },
            {
                'name': 'created',
                'type': 'string',
                'required': True,
            },
            {
                'name': 'survey_id',
                'type': 'string',
                'required': True,
            },
            {
                'name': 'lang',
                'type': 'string',
            },
            {
                'name': 'uid',
                'type': 'string',
            },
            {
                'name': 'yandexuid',
                'type': 'string',
            },
            {
                'name': 'ip',
                'type': 'string',
            },
            {
                'name': 'answer',
                'type': 'any',
            },
        ]

    def _write_table(self, answers, table_path):
        data = [
            {
                'id': answer.answer_id,
                'created': datetime_to_string(answer.created),
                'survey_id': answer.survey_id,
                'uid': answer.uid,
                'lang': answer.lang,
                'yandexuid': answer.yandexuid,
                'ip': answer.ip,
                'answer': answer.answer_data,
            }
            for answer in answers
        ]
        logger.info('writing to the table %s ... %s rows', table_path, len(data))
        self.client.write_table(TablePath(table_path, append=True), data)

    def write_data(self, answers, table_path):
        folder_path = os.path.dirname(table_path)
        self._create_folder(folder_path)
        self._create_table(table_path)
        self._write_table(answers, table_path)


def change_status(answers, status):
    for chunk in chunks(answers, DB_UPDATE_LIMIT):
        pks = (answer.answer_id for answer in chunk)
        answers_qs = AnswerExportYtStatus.objects.filter(pk__in=pks)
        answers_qs.update(exported=status)


def deploy_answers(answers, table_path, status):
    deploy = DeployAnswers()
    deploy.write_data(answers, table_path)
    change_status(answers, status)


def export_answers(old_export_status, new_export_status, limit_size, started_at):
    cnt = 0
    table_path, answers = None, []
    for answer in get_answers(old_export_status, limit_size, started_at):
        cnt += 1
        if table_path is None:
            table_path = get_table_path(answer.survey_id, answer.created)
            answers = [answer]
        else:
            new_table_path = get_table_path(answer.survey_id, answer.created)
            if table_path != new_table_path:
                deploy_answers(answers, table_path, new_export_status)
                answers = [answer]
                table_path = new_table_path
            else:
                if len(answers) >= YT_UPLOAD_LIMIT:
                    deploy_answers(answers, table_path, new_export_status)
                    answers = []
                answers.append(answer)
    if answers:
        deploy_answers(answers, table_path, new_export_status)
    return cnt


@app.task(bind=True, ignore_result=True, soft_time_limit=6*3600, time_limit=7*3600)
def export_all_answers(self):
    started_at = timezone.now()
    while True:
        with transaction.atomic():
            cnt = export_answers(False, True, LOCK_ANSWERS_LIMIT, started_at)
            if cnt == 0:
                break
