# -*- coding: utf-8 -*-
import csv
import logging
import os
import re
import sys
import yenv

from collections import (
    defaultdict,
    namedtuple,
    OrderedDict,
)
from celery.exceptions import TaskError
from datetime import datetime
from io import BytesIO, StringIO
from openpyxl import Workbook
from openpyxl.styles import Font
from openpyxl.writer.excel import ExcelWriter
from json import loads as json_loads, dumps as json_dumps
from uuid import uuid4
from pytz import UTC
from yql.api.v1.client import YqlClient
from yql.client.parameter_value_builder import YqlParameterValueBuilder as ValueBuilder
from yql.config import config
from zipfile import ZipFile, ZIP_DEFLATED
from django.conf import settings
from django.db.models import Sum
from django.utils import timezone
from django.utils.translation import (
    override as override_lang,
    ugettext as _,
)
from django_mds.client import APIError
from yt.wrapper import JsonFormat, TablePath

from events.accounts.utils import GENDER_MALE, GENDER_FEMALE, get_external_uid
from events.countme.models import AnswerCountByDate
from events.common_app.blackbox_requests import JsonBlackbox
from events.common_storages.storage import MdsClient
from events.common_storages.utils import get_mds_url
from events.common_app.directory import CachedDirectoryClient
from events.common_app.disk.client import DiskClient, DiskUploadError
from events.common_app.utils import chunks, get_user_ip_address
from events.countme.models import QuestionGroupDepth
from events.surveyme.models import (
    ProfileSurveyAnswer,
    SurveyQuestion,
    SurveyQuestionChoice,
    SurveyQuestionMatrixTitle,
    Survey,
)
from events.common_app.yt import utils as yt_utils

logger = logging.getLogger(__name__)

DATE_RANGE_BEGIN = 0
DATE_RANGE_END = 1
LABEL_DELIMITER = ' / '
FIELD_DELIMITER = ', '
ROW = 0
COL = 1
PASSPORT_CACHE_SIZE = 100

ENV_NONE = 0
ENV_INTRANET = 0b001
ENV_BUSINESS = 0b010
ENV_PERSONAL = 0b100
ENV_ALL = ENV_INTRANET | ENV_BUSINESS | ENV_PERSONAL
ILLEGAL_CHARACTERS_RE = re.compile(r'[\000-\010]|[\013-\014]|[\016-\037]|[\ufffe]')

SurveyMetadata = namedtuple('SurveyMetadata', ('survey', 'questions', 'choices', 'titles'))
Answer = namedtuple('Answer', ['answer_id', 'created_at', 'updated_at', 'data', 'source_request'])
ExportedAnswers = namedtuple('ExportedAnswers', ['content_type', 'file_name', 'stream'])


class ExportError(TaskError):
    pass


def datetime_to_string(dt):
    return dt.astimezone(UTC).strftime('%Y-%m-%dT%H:%M:%SZ')


def string_to_datetime(s):
    return UTC.localize(datetime.strptime(s, '%Y-%m-%dT%H:%M:%SZ'))


def make_label(*parts):
    return LABEL_DELIMITER.join(parts)


def get_survey_metadata(survey_id):
    survey = (
        Survey.objects.using(settings.DATABASE_ROLOCAL)
        .select_related('org')
        .get(pk=survey_id)
    )

    questions_qs = (
        SurveyQuestion.objects.using(settings.DATABASE_ROLOCAL)
        .select_related('answer_type')
        .filter(survey_id=survey_id)
        .order_by()
    )
    questions = {}
    has_choices, has_titles = False, False
    for question in questions_qs:
        if question.answer_type.slug == 'answer_choices':
            if question.param_data_source == 'survey_question_choice':
                has_choices = True
            elif question.param_data_source == 'survey_question_matrix_choice':
                has_titles = True
        questions[question.pk] = question

    choices = {}
    if has_choices:
        choices_qs = (
            SurveyQuestionChoice.objects.using(settings.DATABASE_ROLOCAL)
            .filter(survey_question__survey_id=survey_id)
            .order_by()
        )
        choices = {
            choice.pk: choice
            for choice in choices_qs
        }

    titles = {}
    if has_titles:
        titles_qs = (
            SurveyQuestionMatrixTitle.objects.using(settings.DATABASE_ROLOCAL)
            .filter(survey_question__survey_id=survey_id)
            .order_by()
        )
        titles = {
            title.pk: title
            for title in titles_qs
        }

    return SurveyMetadata(
        survey=survey,
        questions=questions,
        choices=choices,
        titles=titles,
    )


def get_table_data(cluster, path):
    client = yt_utils.get_client(cluster)
    table_path = TablePath(f'//{path}', columns=['id', 'created', 'answer'])
    for row in client.read_table(table_path, raw=True, format=JsonFormat()):
        data = json_loads(yt_utils.decode(row))
        created = string_to_datetime(data['created'])
        yield (data['id'], created, data['answer'])


def get_query_data(sql, parameters):
    config.no_color = True
    config.is_tty = True
    client = YqlClient(token=settings.YQL_TOKEN)
    request = client.query(sql)
    results = request.run(parameters=parameters)
    paths = list(results.table_paths())
    if paths:
        for (cluster, path) in paths:
            yield from get_table_data(cluster, path)
    else:
        for table in results:
            for (pk, created, answer) in table.rows:
                created = string_to_datetime(created)
                yield pk, created, answer


def get_archived_data(survey_id, started_at=None, finished_at=None, answers_pks=None, limit=None):
    params = {
        '$survey_id': ValueBuilder.make_string(survey_id),
        '$date_started': ValueBuilder.make_string(datetime_to_string(started_at)),
        '$date_finished': ValueBuilder.make_string(datetime_to_string(finished_at)),
    }
    pks_clause = ''
    if answers_pks:
        pks_clause = 'and t.id in $pks'
        params['$pks'] = ValueBuilder.make_list(map(ValueBuilder.make_int64, answers_pks))
        limit = len(answers_pks)
    parameters = ValueBuilder.build_json_map(params)
    limit_clause = f'limit {limit}' if limit else ''
    sql = f'''
        use hahn;

        pragma yt.TmpFolder = '//home/forms/tmp';

        declare $survey_id as string;
        declare $date_started as string;
        declare $date_finished as string;
        declare $pks as list<int64>;

        $date_from = substring($date_started, 0, 10);
        $date_to = substring($date_finished, 0, 10);

        select t.id, t.created, t.answer
          from range(`//home/forms/answers/forms_biz/{yenv.type}`, $date_from, $date_to) as t
         where t.survey_id = $survey_id
           and t.created >= $date_started
           and t.created <= $date_finished
           {pks_clause}
        {limit_clause};
    '''
    logger.info('%s; params = %s', sql, parameters)
    for row in get_query_data(sql, parameters):
        pk, created, data = row
        yield pk, created, created, data, None


def get_answers_data(survey_id, started_at=None, finished_at=None, answers_pks=None, limit=None):
    fields = ('pk', 'date_created', 'date_updated', 'data', 'source_request')
    answers_qs = (
        ProfileSurveyAnswer.objects.using(settings.DATABASE_ROLOCAL)
        .filter(survey_id=survey_id)
        .order_by()
    )
    answer_count_qs = (
        AnswerCountByDate.objects.using(settings.DATABASE_ROLOCAL)
        .filter(survey_id=survey_id)
    )
    if started_at:
        answers_qs = answers_qs.filter(date_created__gte=started_at)
        answer_count_qs = answer_count_qs.filter(created__gte=started_at)
    if finished_at:
        answers_qs = answers_qs.filter(date_created__lte=finished_at)
        answer_count_qs = answer_count_qs.filter(created__lte=finished_at)
    if answers_pks:
        answers_qs = answers_qs.filter(pk__in=answers_pks)
        limit = len(answers_pks)
    if limit:
        answers_qs = answers_qs.order_by('-pk')[:limit]

    at_once = True
    if not limit or not answers_pks:
        answer_count_qs = answer_count_qs.aggregate(Sum('count'))
        at_once = (answer_count_qs['count__sum'] or 0) < 10000

    if at_once:
        for row in answers_qs.values_list(*fields):
            yield row
    else:
        pk = None
        while True:
            qs = answers_qs
            if pk is not None:
                qs = qs.filter(pk__gt=pk)
            new_pk = None
            for row in qs.order_by('pk').values_list(*fields)[:1000]:
                yield row
                new_pk = row[0]
            if new_pk is None:
                break
            pk = new_pk


def get_answers(survey, started_at=None, finished_at=None, answers_pks=None, limit=None):
    now = timezone.now()

    started_at = max(started_at or survey.date_created, survey.date_created)
    finished_at = min(finished_at or now, now)
    archived_at = survey.date_archived

    row_count = 0
    _started_at = archived_at or started_at
    for row in get_answers_data(survey.pk, _started_at, finished_at, answers_pks, limit):
        row_count += 1
        yield Answer(*row)

    # не лезем в архив если нужное количестов ответов достали из основной базы
    if limit and row_count >= limit:
        return

    if archived_at:
        if limit:
            limit -= row_count
        _finished_at = archived_at
        for row in get_archived_data(survey.pk, started_at, _finished_at, answers_pks, limit):
            yield Answer(*row)


def get_ordered_questions(questions):
    sort_key = lambda q: (q.page, q.position)
    return [
        question
        for question in sorted(questions.values(), key=sort_key)
        if not question.group_id
    ]


def get_groups(questions):
    groups = defaultdict(list)
    for question in questions.values():
        if question.group_id:
            groups[question.group_id].append(question)

    sort_key = lambda q: (q.page, q.position)
    return {
        group_id: sorted(
            children,
            key=sort_key,
        )
        for group_id, children in groups.items()
    }


def get_question_group_depth(survey_id):
    group_depth_qs = (
        QuestionGroupDepth.objects.using(settings.DATABASE_ROLOCAL)
        .filter(survey_id=survey_id)
        .values_list('question_id', 'depth')
    )
    return {
        question_id: depth
        for (question_id, depth) in group_depth_qs
    }


def get_organized_questions(questions, group_depth):
    ordered_questions = get_ordered_questions(questions)
    groups = get_groups(questions)
    for question in ordered_questions:
        if question.answer_type.slug == 'answer_group':
            # возможно серия ни разу не заполнялась,
            # но мы все равно должны вывести колонки с вопросами
            for i in range(group_depth.get(question.pk, 1)):
                for child in groups.get(question.pk, []):
                    yield i, child
        elif not question.answer_type.is_read_only:
            yield None, question


def get_ordered_choices(choices):
    sort_key = lambda c: (c.survey_question_id, c.position)
    ordered_choices = defaultdict(OrderedDict)
    for choice in sorted(choices.values(), key=sort_key):
        ordered_choices[choice.survey_question_id][str(choice.pk)] = choice.get_label()
    return ordered_choices


def get_ordered_titles(titles):
    sort_key = lambda t: (t.survey_question_id, t.position)
    ordered_titles = defaultdict(lambda: (OrderedDict(), OrderedDict()))
    for title in sorted(titles.values(), key=sort_key):
        i = ROW if title.type == 'row' else COL
        ordered_titles[title.survey_question_id][i][str(title.pk)] = title.get_label()
    return ordered_titles


def has_quiz_questions(questions):
    for question in questions.values():
        if isinstance(question.param_quiz, dict) and question.param_quiz.get('enabled'):
            return True
    return False


class AnswerData:
    def __init__(self, answer, personal_data):
        if isinstance(answer.data, str):
            data = json_loads(answer.data or '{}')
        else:
            data = answer.data or {}
        if isinstance(answer.source_request, str):
            source_request = json_loads(answer.source_request or '{}')
        else:
            source_request = answer.source_request or {}
        self.personal_data = personal_data
        self.value_data = {}
        self.scores_data = {}
        self.base_data = {
            'id': answer.answer_id,
            'created_at': answer.created_at,
            'updated_at': answer.updated_at,
        }
        self.extra_data = {
            'ip': source_request.get('ip'),
            'uid': data.get('uid'),
            'cloud_uid': data.get('cloud_uid'),
            'yandexuid': source_request.get('cookies', {}).get('yandexuid'),
        }
        self.quiz_data = data.get('quiz') or {}
        self.init_data(data.get('data'))

    def init_data(self, answer_data):
        for question_data in self.get_list(answer_data):
            question_id = question_data.get('question', {}).get('id')

            scores_value = question_data.get('scores')
            if scores_value:
                self.scores_data[question_id] = scores_value

            self.value_data[question_id] = self.get_value_data(question_data)

    def get_list(self, answer_data):
        if isinstance(answer_data, list):
            return answer_data
        if isinstance(answer_data, dict):
            return answer_data.values()
        return []

    def get_question_id(self, question_data):
        try:
            return question_data['question']['id']
        except KeyError:
            return None

    def get_answer_slug(self, question_data):
        try:
            return question_data['question']['answer_type']['slug']
        except KeyError:
            return None

    def get_data_source(self, question_data):
        try:
            return question_data['question']['options']['data_source']
        except KeyError:
            return None

    def get_value_data(self, question_data):
        answer_slug = self.get_answer_slug(question_data)
        question_value = question_data.get('value')

        if answer_slug == 'answer_group':
            return [
                {
                    self.get_question_id(child_data): self.get_value_data(child_data)
                    for child_data in self.get_list(fieldset)
                    if child_data
                }
                for fieldset in self.get_list(question_value)
                if fieldset
            ]
        elif answer_slug == 'answer_date':
            if isinstance(question_value, dict):
                return question_value.get('begin'), question_value.get('end')
        elif answer_slug == 'answer_files':
            if isinstance(question_value, list):
                return [
                    file_item.get('path')
                    for file_item in question_value
                ]
        elif answer_slug == 'answer_choices':
            if isinstance(question_value, list):
                data_source = self.get_data_source(question_data)
                if data_source == 'survey_question_choice':
                    return {
                        choice_item.get('key')
                        for choice_item in question_value
                    }
                elif data_source == 'survey_question_matrix_choice':
                    return {
                        title_item.get('row', {}).get('key'): title_item.get('col', {}).get('key')
                        for title_item in question_value
                    }
                return [
                    choice_item.get('text')
                    for choice_item in question_value
                ]
        return question_value

    def get_value(self, getter):
        params = (getattr(self, attr_name) for attr_name in getter.attr_names)
        value = getter.get_value(*params)
        if value is None:
            value = ''
        else:
            value = str(value)
        return value


class Getter:
    group_name = None
    name = None
    env_type = ENV_NONE
    attr_names = []

    def get_value(self, *args, **kwargs):
        raise NotImplementedError

    def get_header(self):
        raise NotImplementedError


class QuestionGetter(Getter):
    group_name = 'question'
    env_type = ENV_ALL
    attr_names = ('value_data',)

    def __init__(self, question, group_index, **kwargs):
        self.name = str(question.pk)
        self.question = question
        self.group_index = group_index

    def _get_grouped(self, value_data):
        try:
            return value_data[self.question.group_id][self.group_index][self.question.pk]
        except (KeyError, IndexError):
            return None

    def _get_plain(self, value_data):
        try:
            return value_data[self.question.pk]
        except KeyError:
            return None

    def get_value(self, value_data):
        if self.group_index is not None:
            return self._get_grouped(value_data)
        return self._get_plain(value_data)

    def get_header(self):
        if self.group_index is not None:
            return '%s [%s]' % (self.get_label(), self.group_index + 1)
        return self.get_label()

    def get_label(self):
        return self.question.label


class BaseGetter(Getter):
    group_name = 'base'


class ExtraGetter(Getter):
    group_name = 'extra'


class AnswerIdGetter(BaseGetter):
    name = 'id'
    env_type = ENV_ALL
    attr_names = ('base_data',)

    def get_value(self, base_data):
        return base_data.get('id')

    def get_header(self):
        return _('ID')


class CreatedAtGetter(BaseGetter):
    name = 'date_created'
    env_type = ENV_ALL
    attr_names = ('base_data',)

    def __init__(self):
        self.default_timezone = timezone.get_default_timezone()

    def get_value(self, base_data):
        value = base_data.get('created_at')
        if isinstance(value, str):
            return value
        return value.astimezone(self.default_timezone).strftime('%Y-%m-%d %H:%M:%S')

    def get_header(self):
        return _('Время создания')


class UpdatedAtGetter(ExtraGetter):
    name = 'date_updated'
    env_type = ENV_ALL
    attr_names = ('base_data',)

    def __init__(self):
        self.default_timezone = timezone.get_default_timezone()

    def get_value(self, base_data):
        value = base_data.get('updated_at')
        if isinstance(value, str):
            return value
        return value.astimezone(self.default_timezone).strftime('%Y-%m-%d %H:%M:%S')

    def get_header(self):
        return _('Время изменения')


class AnswerIpGetter(ExtraGetter):
    name = 'ip'
    env_type = ENV_INTRANET
    attr_names = ('extra_data',)

    def get_value(self, extra_data):
        return extra_data.get('ip')

    def get_header(self):
        return _('IP')


class AnswerUidGetter(ExtraGetter):
    name = 'uid'
    env_type = ENV_INTRANET
    attr_names = ('extra_data',)

    def get_value(self, extra_data):
        return extra_data.get('uid')

    def get_header(self):
        return _('UID')


class AnswerYandexUidGetter(ExtraGetter):
    name = 'yandexuid'
    env_type = ENV_INTRANET
    attr_names = ('extra_data',)

    def get_value(self, extra_data):
        return extra_data.get('yandexuid')

    def get_header(self):
        return _('YandexUID')


class QuestionScoresGetter(QuestionGetter):
    name = 'scores'
    env_type = ENV_ALL
    attr_names = ('scores_data',)

    def __init__(self, question, *args):
        self.question = question

    def get_value(self, scores_data):
        return scores_data.get(self.question.pk)

    def get_header(self):
        return make_label(self.question.label, _('Баллы'))


class SurveyScoresGetter(BaseGetter):
    name = 'survey_scores'
    env_type = ENV_ALL
    attr_names = ('quiz_data',)

    def get_value(self, quiz_data):
        return quiz_data.get('scores')

    def get_header(self):
        return _('Набрано баллов')


class TotalScoresGetter(BaseGetter):
    name = 'total_scores'
    env_type = ENV_ALL
    attr_names = ('quiz_data',)

    def get_value(self, quiz_data):
        return quiz_data.get('total_scores')

    def get_header(self):
        return _('Всего баллов')


class ChoiceGetter(QuestionGetter):
    def __init__(self, question, group_index, choices):
        super().__init__(question, group_index)
        self.choices = choices

    def get_value(self, value_data):
        value = super().get_value(value_data)
        if value and isinstance(value, set):
            value, *tail = value
            return self.choices.get(value)


class MultipleChoiceGetter(QuestionGetter):
    def __init__(self, question, group_index, pk, label):
        super().__init__(question, group_index)
        self.pk = pk
        self.label = label

    def get_value(self, value_data):
        value = super().get_value(value_data)
        if isinstance(value, set):
            if self.pk in value:
                return self.label

    def get_label(self):
        return make_label(self.question.label, self.label)


class TitleGetter(QuestionGetter):
    def __init__(self, question, group_index, row_pk, row_label, cols):
        super().__init__(question, group_index)
        self.row_pk = row_pk
        self.row_label = row_label
        self.cols = cols

    def get_value(self, value_data):
        value = super().get_value(value_data)
        if isinstance(value, dict):
            col_pk = value.get(self.row_pk)
            if col_pk:
                return self.cols.get(col_pk)

    def get_label(self):
        return make_label(self.question.label, self.row_label)


class DataSourceGetter(QuestionGetter):
    def get_value(self, value_data):
        value = super().get_value(value_data)
        if isinstance(value, list):
            return FIELD_DELIMITER.join(value)


class BooleanGetter(QuestionGetter):
    def __init__(self, question, group_index):
        super().__init__(question, group_index)
        self.yes = _('Yes')
        self.no = _('No')

    def get_value(self, value_data):
        value = super().get_value(value_data)
        if value is False:
            return self.no
        elif value is True:
            return self.yes


class FilesGetter(QuestionGetter):
    def __init__(self, *args, **kwargs):
        self.uploader = kwargs.pop('uploader', None)
        self.upload_files = kwargs.pop('upload_files', False)
        support_uploading = getattr(self.uploader, 'support_uploading', False)
        self.downloader = MdsDownloader() if support_uploading and self.upload_files else None
        self.tld = kwargs.pop('tld', settings.DEFAULT_TLD)
        super().__init__(*args, **kwargs)

    def upload(self, path):
        if self.upload_files and self.downloader and self.uploader:
            data = self.downloader.download_file(path)
            if data:
                file_name = path.split('/')[-1]
                return self.uploader.upload_file(file_name, data)

    def get_value(self, value_data):
        value = super().get_value(value_data)
        if isinstance(value, list):
            urls = []
            for path in value:
                href = self.upload(path)
                if not href:
                    href = get_mds_url(path, self.tld)
                urls.append(href)
            return FIELD_DELIMITER.join(urls)


class DateRangeGetter(QuestionGetter):
    def __init__(self, question, group_index, date_range):
        super().__init__(question, group_index)
        self.date_range = date_range
        self.label = _('Начало') if self.date_range == DATE_RANGE_BEGIN else _('Конец')

    def get_value(self, value_data):
        value = super().get_value(value_data)
        if isinstance(value, tuple):
            return value[self.date_range]

    def get_label(self):
        return make_label(self.question.label, self.label)


class PaymentGetter(QuestionGetter):
    attr_names = ('value_data', 'base_data')

    def get_value(self, value_data, base_data):
        value = super().get_value(value_data)
        if isinstance(value, dict):
            return _('Оплата заказа %(order)s на сумму %(amount)s руб') % {
                'order': base_data.get('id'),
                'amount': value.get('amount'),
            }


class PersonalDataGetter(Getter):
    group_name = 'personal'
    attr_names = ('personal_data', 'extra_data')
    env_type = ENV_INTRANET | ENV_BUSINESS

    def get_value(self, personal_data, extra_data):
        uid = extra_data['uid']
        cloud_uid = extra_data.get('cloud_uid')
        if uid or cloud_uid:
            value = personal_data.get_value(uid or cloud_uid)
            if isinstance(value, dict):
                return value.get(self.name)


class ParamNameGetter(PersonalDataGetter):
    name = 'param_name'

    def get_header(self):
        return _('Имя')


class ParamSurnameGetter(PersonalDataGetter):
    name = 'param_surname'

    def get_header(self):
        return _('Фамилия')


class ParamPatronymicGetter(PersonalDataGetter):
    name = 'param_patronymic'
    env_type = ENV_BUSINESS

    def get_header(self):
        return _('Отчество')


class ParamBirthDateGetter(PersonalDataGetter):
    name = 'param_birthdate'

    def get_header(self):
        return _('Дата рождения')


class ParamGenderGetter(PersonalDataGetter):
    name = 'param_gender'

    def get_value(self, *args, **kwargs):
        value = super().get_value(*args, **kwargs)
        if value == GENDER_MALE:
            return _('Мужской')
        elif value == GENDER_FEMALE:
            return _('Женский')

    def get_header(self):
        return _('Пол')


class ParamLoginGetter(PersonalDataGetter):
    name = 'yandex_username'

    def get_header(self):
        return _('Логин')


class ParamEmailGetter(PersonalDataGetter):
    name = 'param_subscribed_email'

    def get_header(self):
        return _('Email')


class ParamPhoneGetter(PersonalDataGetter):
    name = 'param_phone'
    env_type = ENV_BUSINESS

    def get_header(self):
        return _('Телефон')


class ParamPositionGetter(PersonalDataGetter):
    name = 'param_position'
    env_type = ENV_BUSINESS

    def get_header(self):
        return _('Должность')


class ParamJobPlaceGetter(PersonalDataGetter):
    name = 'param_job_place'
    env_type = ENV_BUSINESS

    def get_header(self):
        return _('Место работы')


def has_files_questions(questions):
    for question in questions.values():
        if question.answer_type.slug == 'answer_files':
            return True
    return False


class AnswerMetadata:
    def __init__(self, survey_id, uploader=None, upload_files=False, tld=None):
        self.survey_id = survey_id
        self.uploader = uploader
        self.upload_files = upload_files
        self.tld = tld or settings.DEFAULT_TLD

    def init_data(self):
        self.survey_metadata = get_survey_metadata(self.survey_id)
        self.choices = get_ordered_choices(self.survey_metadata.choices)
        self.titles = get_ordered_titles(self.survey_metadata.titles)
        self.group_depth = get_question_group_depth(self.survey_id)
        self.questions = get_organized_questions(self.survey_metadata.questions, self.group_depth)
        self.has_quiz_questions = has_quiz_questions(self.survey_metadata.questions)
        self.dir_id = getattr(self.survey_metadata.survey.org, 'dir_id', None)

    def get_question_getters(self):
        for i, question in self.questions:
            if question.answer_type.slug == 'answer_choices':
                if question.param_data_source == 'survey_question_choice':
                    choices = self.choices.get(question.pk)
                    if choices:
                        if question.param_is_allow_multiple_choice:
                            for pk, label in choices.items():
                                yield MultipleChoiceGetter(question, i, pk=pk, label=label)
                        else:
                            yield ChoiceGetter(question, i, choices=choices)
                elif question.param_data_source == 'survey_question_matrix_choice':
                    rows, cols = self.titles.get(question.pk, (OrderedDict(), OrderedDict()))
                    for pk, label in rows.items():
                        yield TitleGetter(question, i, row_pk=pk, row_label=label, cols=cols)
                else:
                    yield DataSourceGetter(question, i)
            elif question.answer_type.slug == 'answer_boolean':
                yield BooleanGetter(question, i)
            elif question.answer_type.slug == 'answer_date':
                if question.param_date_field_type == 'daterange':
                    yield DateRangeGetter(question, i, date_range=DATE_RANGE_BEGIN)
                    yield DateRangeGetter(question, i, date_range=DATE_RANGE_END)
                else:
                    yield QuestionGetter(question, i)
            elif question.answer_type.slug == 'answer_files':
                yield FilesGetter(question, i, uploader=self.uploader, upload_files=self.upload_files, tld=self.tld)
            elif question.answer_type.slug == 'answer_payment':
                yield PaymentGetter(question, i)
            else:
                yield QuestionGetter(question, i)
            if isinstance(question.param_quiz, dict) and question.param_quiz.get('enabled'):
                yield QuestionScoresGetter(question)

        if self.has_quiz_questions:
            yield SurveyScoresGetter()
            yield TotalScoresGetter()

    def get_base_getters(self):
        yield AnswerIdGetter()
        yield CreatedAtGetter()
        yield UpdatedAtGetter()

    def get_extra_getters(self):
        yield AnswerUidGetter()
        yield AnswerYandexUidGetter()
        yield AnswerIpGetter()

    def get_personal_getters(self):
        yield ParamNameGetter()
        yield ParamSurnameGetter()
        yield ParamPatronymicGetter()
        yield ParamBirthDateGetter()
        yield ParamGenderGetter()
        yield ParamLoginGetter()
        yield ParamEmailGetter()
        yield ParamPhoneGetter()
        yield ParamPositionGetter()
        yield ParamJobPlaceGetter()

    def get_getters(self):
        yield from self.get_base_getters()
        yield from self.get_question_getters()
        yield from self.get_extra_getters()
        yield from self.get_personal_getters()

    def filter_getters(self, questions_pks, columns):
        env_type = self.get_env_type()
        for getter in self.get_getters():
            if getter.env_type & env_type:
                if getter.group_name == 'base':
                    yield getter
                elif getter.group_name == 'question':
                    if not questions_pks or getter.question.pk in questions_pks:
                        yield getter
                elif getter.group_name in ('personal', 'extra'):
                    if columns and getter.name in columns:
                        yield getter

    def get_env_type(self):
        if settings.IS_BUSINESS_SITE:
            if self.dir_id:
                return ENV_BUSINESS
            return ENV_PERSONAL
        return ENV_INTRANET

    def get_personal_data(self, getters):
        has_personal_getters = any(
            getter.group_name == 'personal'
            for getter in getters
        )
        if not has_personal_getters:
            return EmptyPersonalData()
        if settings.IS_BUSINESS_SITE:
            if self.dir_id:
                return DirPersonalData(self.dir_id)
            return EmptyPersonalData()
        return PassportPersonalData()

    def get_file_name(self, format):
        export_date = timezone.now().strftime('%Y-%m-%d')
        survey_name = re.sub(r'[^\w \.-]+', '', self.survey_metadata.survey.get_name(), re.U)
        return '%s %s.%s' % (export_date, survey_name[:60], format)


class BasePersonalData:
    def __init__(self):
        self.cache = {}

    def get_value(self, uid):
        return self.cache.get(uid)

    def prepare(self, answers):
        pass


class EmptyPersonalData(BasePersonalData):
    def get_value(self, uid):
        return None


class DirPersonalData(BasePersonalData):
    def __init__(self, dir_id):
        if not self.is_portal(dir_id):
            self.init_cache(dir_id)

    def is_portal(self, dir_id):
        client = CachedDirectoryClient()
        org_data = client.get_organization(dir_id, fields='organization_type')
        return org_data.get('organization_type') == 'portal'

    def init_cache(self, dir_id):
        fields = 'cloud_uid,name,birthday,gender,contacts,nickname,department.name,position'
        client = CachedDirectoryClient()
        self.cache = {}
        for it in client.get_users(dir_id, fields=fields):
            uid = str(it['id'])
            user_data = self._get_user_data(it)
            self.cache[uid] = user_data
            cloud_uid = it.get('cloud_uid')
            if cloud_uid:
                self.cache[cloud_uid] = user_data

    def _get_user_data(self, user_data):
        name = user_data.get('name') or {}
        department = user_data.get('department') or {}
        contacts = user_data.get('contacts') or {}
        return {
            'param_name': name.get('first'),
            'param_surname': name.get('last'),
            'param_patronymic': name.get('middle'),
            'param_birthdate': user_data.get('birthday'),
            'param_gender': self._get_gender(user_data.get('gender')),
            'yandex_username': user_data.get('nickname'),
            'param_subscribed_email': self._get_email(contacts),
            'param_phone': self._get_phone(contacts),
            'param_position': user_data.get('position'),
            'param_job_place': department.get('name'),
        }

    def _get_gender(self, gender):
        if gender == 'male':
            return GENDER_MALE
        if gender == 'female':
            return GENDER_FEMALE

    def _get_email(self, contacts):
        email = None
        if isinstance(contacts, list):
            for contact in contacts:
                if contact['type'] == 'email':
                    email = contact['value']
                    if contact['main']:
                        return email
        return email

    def _get_phone(self, contacts):
        if isinstance(contacts, list):
            for contact in contacts:
                if contact['type'] == 'phone':
                    return contact['value']


class PassportCache:
    def __init__(self, maxsize=PASSPORT_CACHE_SIZE):
        self.maxsize = maxsize
        self.cache = OrderedDict()

    def has_key(self, key):
        if key not in self.cache:
            return False
        self.cache.move_to_end(key)
        return True

    def get(self, key):
        return self.cache.get(key)

    def set(self, key, value):
        self.cache[key] = value
        if len(self.cache) > self.maxsize:
            oldest = next(iter(self.cache))
            del self.cache[oldest]

    def update(self, cache):
        for key, value in cache.items():
            self.set(key, value)


class PassportPersonalData(BasePersonalData):
    def __init__(self):
        self.cache = PassportCache()
        kwargs = (
            settings.INTERNAL_SITE_BLACKBOX_KWARGS
            if settings.APP_TYPE == 'forms_int'
            else settings.EXTERNAL_SITE_BLACKBOX_KWARGS
        )
        self.bb = JsonBlackbox(**kwargs)
        self.params = {
            'dbfields': [bb_name for bb_name, _ in settings.YAUTH_PASSPORT_FIELDS],
            'emails': 'getdefault',
            'userip': get_user_ip_address(),
        }

    def _get_email(self, address_list):
        email = None
        for address in address_list or []:
            email = address['address']
            if address['default']:
                return email
        return email

    def get_passport_data(self, pp):
        dbfields = pp.get('dbfields') or {}
        fio = dbfields.get('account_info.fio.uid', '').split()
        return {
            'param_name': fio[-1] if fio else '',
            'param_surname': fio[0] if fio else '',
            'param_birthdate': dbfields.get('account_info.birth_date.uid'),
            'param_gender': dbfields.get('account_info.sex.uid'),
            'yandex_username': pp.get('login'),
            'param_subscribed_email': self._get_email(pp.get('address-list')),
        }

    def prepare(self, answers):
        has_value = lambda value: value
        get_user_uid = lambda answer: answer.data.get('uid') if answer.data else None
        not_available_uids = ','.join({
            uid
            for uid in filter(has_value, map(get_user_uid, answers))
            if not self.cache.has_key(uid)  # noqa
        })
        if not_available_uids:
            response = self.bb.userinfo(uid=not_available_uids, **self.params)
            personal_data = {
                pp['id']: self.get_passport_data(pp)
                for pp in response.get('users') or []
            }
            self.cache.update(personal_data)


class Formatter:
    def __init__(self):
        self._stream = BytesIO()

    def write(self, data):
        if isinstance(data, str):
            data = data.encode()
        self._stream.write(data)

    def get_stream(self):
        return self._stream


class JsonFormatter(Formatter):
    content_type = 'application/json'

    def __init__(self, header):
        super().__init__()
        self.first = True
        self.header = header

    def __enter__(self):
        self.write('[')
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.write('\n]\n')

    def writerow(self, values):
        if not self.first:
            self.write(',')
        self.write('\n  ')
        self.write(json_dumps(list(zip(self.header, values)), ensure_ascii=False))
        self.first = False


class CsvFormatter(Formatter):
    content_type = 'text/csv'

    def __init__(self, header):
        super().__init__()
        self.st = StringIO()
        self.csv_writer = csv.writer(self.st)
        self.header = header

    def __enter__(self):
        self._writeheader(self.header)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        pass

    def _writeheader(self, values):
        self.writerow(values)

    def writerow(self, values):
        self.csv_writer.writerow(values)
        self.write(self.st.getvalue())
        self.st.seek(0)
        self.st.truncate(0)


class XlsxFormatter(Formatter):
    content_type = 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'

    def __init__(self, header):
        super().__init__()
        self.wb, self.ws = None, None
        self.header = header

    def __enter__(self):
        self.wb = Workbook(write_only=True)
        self.ws = self.wb.create_sheet()
        self._writeheader(self.header)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        with ZipFile(self.get_stream(), 'w', ZIP_DEFLATED, allowZip64=True) as archive:
            writer = ExcelWriter(self.wb, archive)
            writer.write_data()

    def _writeheader(self, values):
        row = self.ws.row_dimensions[1]
        row.font = Font(name=row.font.name, bold=True)
        self.writerow(values)

    def writerow(self, values):
        self.ws.append(self.sanitize(it).encode() for it in values)

    def sanitize(self, s):
        return ILLEGAL_CHARACTERS_RE.sub('', s)


def get_formatter(format, header):
    if format == 'csv':
        return CsvFormatter(header)
    if format == 'xlsx':
        return XlsxFormatter(header)
    if format == 'json':
        return JsonFormatter(header)


def get_uploader(upload, user_uid, survey_id):
    if upload == 'mds':
        return MdsUploader()
    if upload == 'disk':
        if not settings.IS_BUSINESS_SITE:
            user_uid = get_external_uid(user_uid)
            if not user_uid:
                return MdsUploader()
            client = DiskClient(user_uid)
            if not client.isdir(settings.DISK_NDA_FOLDER):
                return MdsUploader()
        return DiskUploader(user_uid, survey_id)
    if upload == 'file':
        return FileUploader()
    if upload == 'console':
        return ConsoleUploader()


class MdsDownloader:
    def __init__(self):
        self.client = MdsClient()

    def download_file(self, path):
        try:
            return self.client.get(path)
        except APIError:
            logger.warn('File %s not found', path)


class BaseUploader:
    support_uploading = False

    def upload_file(self, file_name, data):
        pass

    def upload_report(self, file_name, content_type, data):
        pass

    def mkdir(self, folder_path):
        pass

    def listdir(self, folder_path):
        raise StopIteration


class MdsUploader(BaseUploader):
    expire = '3d'  # удалять через 3 дня

    def __init__(self):
        self.client = MdsClient()

    def upload_report(self, file_name, content_type, data):
        path = self.client.upload(data, uuid4().hex, expire=self.expire)
        logger.debug('MDS %s/get-%s/%s', settings.MDS_PUBLIC_URL, settings.MDS_NAMESPACE, path)
        return {
            'file_name': file_name,
            'path': path,
            'content_type': content_type,
            'status_code': 200,
        }


class DiskUploader(BaseUploader):
    support_uploading = True

    def __init__(self, user_uid, survey_id):
        self.client = DiskClient(user_uid)
        self.survey_id = survey_id
        self.folders = set()
        self.files = defaultdict(set)

    def get_files_folder(self):
        return os.path.join(self.get_report_folder(), 'Files')

    def get_report_folder(self):
        return os.path.join(settings.DISK_FOLDER_PATH, str(self.survey_id))

    def mkdir(self, folder_path):
        self.client.mkdir(folder_path)

    def listdir(self, folder_path):
        yield from self.client.listdir(folder_path)

    def init_new_folder(self, folder_path, check_files):
        self.mkdir(folder_path)
        self.folders.add(folder_path)
        if check_files:
            self.files[folder_path] = set(self.listdir(folder_path))

    def upload_file(self, file_path, data, check_files=True):
        file_name = os.path.basename(file_path)
        folder_path = os.path.dirname(file_path)

        if not folder_path:
            folder_path = self.get_files_folder()

        if folder_path not in self.folders:
            self.init_new_folder(folder_path, check_files)

        file_path = os.path.join(folder_path, file_name)
        if not check_files or file_name not in self.files.get(folder_path, []):
            logger.info('Upload %s', file_path)
            try:
                self.client.upload(file_path, data)
            except DiskUploadError as e:
                raise ExportError(e.message) from e

        return self.client.get_interface_url(file_path)

    def upload_report(self, file_name, content_type, data):
        file_path = os.path.join(self.get_report_folder(), file_name)
        file_path, exists = self.client.check_file_name(file_path)
        href = self.upload_file(file_path, data, check_files=False)

        if content_type == XlsxFormatter.content_type:
            href = self.client.get_edit_url(file_path)

        return {
            'file_name': file_name,
            'path': href,
            'content_type': content_type,
            'status_code': 302,
        }


class FileUploader(BaseUploader):
    def upload_report(self, file_name, content_type, data):
        with open(file_name, 'wb') as f:
            f.write(data)
        return {
            'file_name': file_name,
            'path': os.path.realpath(file_name),
            'content_type': content_type,
            'status_code': 200,
        }


class ConsoleUploader(BaseUploader):
    def upload_report(self, file_name, content_type, data):
        sys.stdout.write(data.decode(errors='replace'))
        return {}


def get_exported_answers_stream(survey_id,
                                started_at=None,
                                finished_at=None,
                                answers_pks=None,
                                questions_pks=None,
                                limit=None,
                                columns=None,
                                format='csv',
                                uploader=None,
                                upload_files=False,
                                tld=None):
    answer_metadata = AnswerMetadata(survey_id, uploader, upload_files, tld)
    answer_metadata.init_data()
    getters = list(answer_metadata.filter_getters(questions_pks, columns))
    personal_data = None

    header = [
        getter.get_header()
        for getter in getters
    ]
    with get_formatter(format, header) as f:
        answers = get_answers(
            survey=answer_metadata.survey_metadata.survey,
            started_at=started_at,
            finished_at=finished_at,
            answers_pks=answers_pks,
            limit=limit,
        )

        for chunk in chunks(answers, size=100):
            chunk = list(chunk)
            if not personal_data:
                personal_data = answer_metadata.get_personal_data(getters)
            personal_data.prepare(chunk)
            for answer in chunk:
                answer_data = AnswerData(answer, personal_data)
                f.writerow(
                    answer_data.get_value(getter)
                    for getter in getters
                )
    return ExportedAnswers(
        content_type=f.content_type,
        file_name=answer_metadata.get_file_name(format),
        stream=f.get_stream(),
    )


def _export_answers(survey_id,
                    user_uid=None,
                    started_at=None,
                    finished_at=None,
                    answers_pks=None,
                    questions_pks=None,
                    limit=None,
                    columns=None,
                    format='csv',
                    upload='console',
                    upload_files=False,
                    tld=None):
    """Экспорт ответов

    Parameters:
    survey_id:    (int|str):   код формы
    user_uid      (str):       uid пользователя, выполняющего запрос
    started_at    (datetime):  начало диапазона для экспорта ответов
    finished_at   (datetime):  конец диапазона для экспорта ответов
    answers_pks   (list[int]): список ответов для экспорта
    questions_pks (list[int]): список вопросов для экспорта, по-умолчанию выгружаются все вопросы
    limit         (int):       ограничить количиство экспортируемых ответов,
                               включает обратную сортировку по answer_id
    columns       (list[str]): список полей для экспорта, по-умолчанию поля не выгружаются
                               (param_name, param_surname, param_patronymic, param_gender,
                               param_phone, param_subscribed_email, param_birthdate,
                               param_position, param_job_place, ip, uid, yandexuid, date_updated)
    format        (str):       формат для экспорта (csv, xlsx, json)
    upload        (str):       получатель результата экспорта (mds, disk, file, console)
    upload_files  (bool):      выгружать файлы на яндекс-диск, только для upload=disk
    tld           (str):       для формирования ссылки на скачивание файла

    Returns:
    iterator: возвращает итератор с данными ответов, пригодных для экспорта
              в плоскую таблицу, первая строка содержит список колонок
    """
    uploader = get_uploader(upload, user_uid, survey_id)

    response = get_exported_answers_stream(
        survey_id,
        started_at, finished_at,
        answers_pks, questions_pks, limit,
        columns, format,
        uploader=uploader, upload_files=upload_files,
        tld=tld,
    )

    result = uploader.upload_report(response.file_name, response.content_type, response.stream.getvalue())
    result['survey_id'] = survey_id
    return result


def export_answers(*args, **kwargs):
    try:
        return _export_answers(*args, **kwargs)
    except ExportError:
        raise
    except Exception as e:
        raise ExportError('Неожиданная ошибка') from e


def main():
    # from dateutil.parser import parse as parse_datetime
    yenv.type = 'production'
    params = dict(
        survey_id='6252bbb3b1f76021e04d2435',
        # started_at=parse_datetime('2020-08-11T09:47:05Z'),
        # finished_at=parse_datetime('2021-03-11T06:14:00Z'),
        # answers_pks=[112360584, 112360568, 112360567],
        # limit=10,
        format='csv',
        upload='file',
    )
    with override_lang('ru'):
        export_answers(**params)


if __name__ == '__main__':
    from django_tools_log_context.profiler import execution_profiler
    with execution_profiler('main', 'main', 0):
        main()
