# -*- coding: utf-8 -*-

import re
import logging
from math import sqrt
from itertools import product

import sandbox.common.types.client as ctc

from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.sandboxsdk import parameters
from sandbox.sandboxsdk.svn import Arcadia

from sandbox.projects import resource_types
from sandbox.projects.common.search import bugbanner
from sandbox.projects.common.search import queries as sq


EXPERIMENTS = [
    ('RUS', 'rusexp'),
    ('BEL', 'belexp'),
    ('UKR', 'ukrexp'),
    ('KAZ', 'kazexp'),
    ('RANK', 'rankexp'),
]

_ALWAYS_FILTERED_EXPERIMENTS = [
    'tdiexp',
    'noprune',
    'nofastrank',
]


class Queries(parameters.LastReleasedResource):
    name = 'queries_source'
    description = 'Queries'
    resource_type = [resource_types.PLAIN_TEXT_QUERIES]


def _create_exp_setup_params(group_name='Experiments setup'):
    class ExpSetup(object):
        class ExpsChanged(parameters.SandboxBoolGroupParameter):
            name = 'changed_experiments'
            description = 'Changed ranking experiments'
            choices = EXPERIMENTS
            default_value = ''
            group = group_name

        class SnipExpsChanged(parameters.SandboxBoolParameter):
            name = 'snippet_experiments_changed'
            description = 'Snippet experiments changed'
            default_value = False
            group = group_name

        params = [
            ExpsChanged,
            SnipExpsChanged,
        ]

    return ExpSetup


def _create_queries_gen_params(group_name='Queries generation setup'):
    class QueriesGen(object):
        class NumOfSearchQueries(parameters.SandboxIntegerParameter):
            name = 'number_of_search_queries'
            description = 'Number of search queries'
            default_value = 6000
            group = group_name

        class NumOfXmlSearchQueries(parameters.SandboxIntegerParameter):
            name = 'number_of_xml_search_queries'
            description = 'Number of xml search queries'
            default_value = 2000
            group = group_name

        class NumSnippetQueries(parameters.SandboxIntegerParameter):
            name = 'number_of_snippets_queries'
            description = 'Number of snippets queries'
            default_value = 2000
            group = group_name

        class QueriesWithModifiedExp(parameters.SandboxIntegerParameter):
            name = 'queries_with_modified_experiments'
            description = 'Queries with modified experiments (for each shard)'
            default_value = 1500
            group = group_name

        class QueriesWithModifiedReg(parameters.SandboxIntegerParameter):
            name = 'queries_with_modified_regions'
            description = 'Queries with modified regions (for each shard)'
            default_value = 1500
            group = group_name

        class QueriesVsRegRatio(parameters.SandboxFloatParameter):
            name = 'queries_vs_regions_ratio'
            description = 'Source queries against regions number ratio'
            default_value = 1.0
            group = group_name

        class FailOnLowAmountOfQueries(parameters.SandboxBoolParameter):
            name = 'fail_on_low_amount_of_queries'
            description = 'Fail on low amount of queries'
            default_value = True
            group = group_name

        class FilterXmlReask(parameters.SandboxBoolParameter):
            name = 'filter_xml_reask'
            description = 'Filter XML-REASK queries'
            default_value = True
            group = group_name

        params = [
            NumOfSearchQueries,
            NumOfXmlSearchQueries,
            NumSnippetQueries,
            QueriesWithModifiedExp,
            QueriesWithModifiedReg,
            QueriesVsRegRatio,
            FailOnLowAmountOfQueries,
            FilterXmlReask,
        ]

    return QueriesGen


def _get_experiments():
    return [ep[1] for ep in EXPERIMENTS]


def _is_experiment_query(query, snip_experiments_changed):
    for experiment in _get_experiments() + _ALWAYS_FILTERED_EXPERIMENTS:
        if query.find('&pron=' + experiment) != -1:
            return True

    if snip_experiments_changed and query.find('exps%3D') != -1:
        return True

    return False


exp_setup_params = _create_exp_setup_params()
queries_gen_params = _create_queries_gen_params()

_region_pattern = re.compile('(?<=relevgeo%3D)(\d+)')


class MixQueriesExperimentsRegions(bugbanner.BugBannerTask):
    """
        **Описание**

        * Получает запросы
        * Фильтрует их
        * Для небольшого объема запросов варьирует регион и эксперимент
    """
    type = 'MIX_QUERIES_EXPERIMENTS_REGIONS'

    input_parameters = [Queries] + exp_setup_params.params + queries_gen_params.params
    execution_space = 5 * 1024  # 5 Gb
    client_tags = ctc.Tag.GENERIC & ctc.Tag.Group.LINUX

    @property
    def footer(self):
        return "<h3>Results</h3><br/>{}".format(
            self.ctx.get('result_protocol', '').replace('\n', '<br/>'))

    def on_enqueue(self):
        SandboxTask.on_enqueue(self)

        not_changed_experiments = set(_get_experiments()) - set(self.ctx.get('changed_experiments', '').split())
        if len(not_changed_experiments) == 0:
            self.ctx['queries_with_modified_experiments'] = 0

        self.ctx['out_resource_id'] = str(self.create_resource(
            self.descr,
            'plain_text_queries.txt',
            resource_types.PLAIN_TEXT_QUERIES
        ).id)

    def on_execute(self):
        self.add_bugbanner(bugbanner.Banners.WebBaseSearch)

        queries = self._load_source_queries()

        dest_queries = self._make_mixed_queries(queries)

        expected_queries = (
            self.ctx['number_of_search_queries'] +
            self.ctx['queries_with_modified_experiments'] +
            self.ctx['queries_with_modified_regions']
        )

        if len(dest_queries) < expected_queries:
            raise SandboxTaskFailureError('generated only {} queries'.format(len(dest_queries)))

        resource = self._read_resource(int(self.ctx['out_resource_id']), sync=False)

        with open(resource.abs_path(), 'w') as dest:
            dest.write('\n'.join(dest_queries))

        resource.mark_ready()

    def _load_source_queries(self):
        resource = self._read_resource(self.ctx['queries_source'])

        logging.info('loading queries')

        with open(resource.abs_path(), 'r') as f:
            queries = f.read().split("\n")

        queries = [q + '&hr=da' for q in queries]
        logging.info('num of loaded queries: %s', len(queries))

        return queries

    def _make_mixed_queries(self, queries):
        filtered_queries = self._filter_queries(queries)
        search_queries, xml_search_queries, snippets_queries = _split_queries(
            filtered_queries,
            self.ctx['number_of_search_queries'],
            self.ctx['fail_on_low_amount_of_queries'],
        )

        result_protocol = (
            '{} input queries\n'
            '{} input queries after filtering, splitted into...\n'
            '{} search queries\n'
            '{} xml search queries\n'
            '{} snippets queries\n\n'
        ).format(
            len(queries),
            len(filtered_queries),
            len(search_queries),
            len(xml_search_queries),
            len(snippets_queries)
        )

        selected_search_queries = search_queries[0:self.ctx['number_of_search_queries']]
        xml_search_queries = xml_search_queries[0:self.ctx['number_of_xml_search_queries']]
        snippets_queries = snippets_queries[0:self.ctx['number_of_snippets_queries']]

        result = selected_search_queries + xml_search_queries + snippets_queries

        result_protocol += (
            'output queries:\n'
            '{} selected search queries\n'
            '{} xml search queries\n'
            '{} snippets queries\n'
        ).format(
            len(selected_search_queries),
            len(xml_search_queries),
            len(snippets_queries)
        )
        logging.info(result_protocol)

        search_queries = search_queries[self.ctx['number_of_search_queries']:]
        logging.debug("Collect %s serach queries.", len(search_queries))
        (
            mixed_queries,
            number_of_used_queries,
            number_of_not_changed_experiments
        ) = self._make_mixed_w_experiments(search_queries)
        result += mixed_queries

        result_protocol += '{} queries were made mixing {} source queries with {} experiments\n'.format(
            len(mixed_queries), number_of_used_queries, number_of_not_changed_experiments
        )

        search_queries = search_queries[number_of_used_queries:]
        (
            mixed_queries,
            number_of_used_queries,
            number_of_used_regions
        ) = self._make_mixed_w_regions(search_queries)
        result += mixed_queries

        result_protocol += '{} queries were made mixing {} source queries with {} regions\n'.format(
            len(mixed_queries), number_of_used_queries, number_of_used_regions
        )

        result_protocol += '\n{} overall queries'.format(len(result))

        self.ctx['result_protocol'] = result_protocol

        return result

    def _filter_queries(self, queries):
        logging.info('filtering queries')
        snip_exp_changed = self.ctx['snippet_experiments_changed']
        filter_xml_reask = self.ctx['filter_xml_reask']
        filtered_queries = [q for q in queries if is_ok_query(q, snip_exp_changed, filter_xml_reask)]
        logging.info('num of filtered queries: %s', len(filtered_queries))

        return filtered_queries

    def _make_mixed_w_experiments(self, queries):
        needs_mixed_n = self.ctx['queries_with_modified_experiments']
        if needs_mixed_n == 0:
            return [], 0, 0

        not_changed_experiments = set(_get_experiments()) - set(self.ctx['changed_experiments'].split(' '))
        not_changed_experiments_n = len(not_changed_experiments)
        if not_changed_experiments_n == 0:
            raise SandboxTaskFailureError('No experiments changed')

        used_n = min(
            needs_mixed_n / not_changed_experiments_n + (1 if needs_mixed_n % not_changed_experiments_n else 0),
            len(queries)
        )
        logging.debug(
            "Use %s queries and %s experiment for mix creation", len(queries[:used_n]), len(not_changed_experiments)
        )
        mixed = _mix_queries_w_experiments(queries[:used_n], not_changed_experiments)
        if len(mixed) > needs_mixed_n:
            mixed = mixed[:needs_mixed_n]

        if len(mixed) < self.ctx['queries_with_modified_experiments']:
            raise SandboxTaskFailureError('generated only {} queries with modified experiments'.format(len(mixed)))

        return mixed, used_n, not_changed_experiments_n

    def _make_mixed_w_regions(self, queries):
        regions = self._load_all_regions()
        if not regions or len(regions) < 5000:
            raise SandboxTaskFailureError('failed to load the list of regions')

        needs_mixed_n = self.ctx['queries_with_modified_regions']
        if needs_mixed_n == 0:
            return [], 0, 0

        used_regions_n = min(
            int(sqrt(needs_mixed_n / self.ctx['queries_vs_regions_ratio'])),
            len(regions)
        )
        if used_regions_n == 0:
            raise SandboxTaskFailureError('unknown error')

        used_n = min(
            needs_mixed_n / used_regions_n + (1 if needs_mixed_n % used_regions_n else 0),
            len(queries)
        )
        mixed = _mix_queries_w_regions(queries[:used_n], regions[:used_regions_n])
        if len(mixed) > needs_mixed_n:
            mixed = mixed[:needs_mixed_n]

        if len(mixed) < self.ctx['queries_with_modified_regions']:
            raise SandboxTaskFailureError('generated only {} queries with modified regions'.format(len(mixed)))

        return mixed, used_n, used_regions_n

    def _load_all_regions(self):
        base_url = 'https://a.yandex-team.ru/arc//trunk/arcadia/kernel/relevgeo_converter/'
        doc_names = [
            'belarus.txt',
            'countries.txt',
            'earth.txt',
            'geo.c2p',
            'kazakhstan.txt',
            'reg_russia.txt',
            'russia.txt',
            'ukraine.txt',
            'usa.txt',
            'all_regions.txt'
        ]

        doc_sets = [set([]) for _ in doc_names]

        for i, doc_name in enumerate(doc_names):
            local_path = self.abs_path(doc_name)
            Arcadia.export(base_url + doc_name, local_path)
            for each_line in open(local_path):
                doc_sets[i].add(each_line.strip().split('\t')[1])

        # make unique
        for i, iset in enumerate(doc_sets):
            for jset in doc_sets[1+i:]:
                iset -= jset

        regions = []

        while len(doc_sets):
            for doc_set in doc_sets[:]:
                if len(doc_set) == 0:
                    doc_sets.remove(doc_set)
                else:
                    regions.append(doc_set.pop())

        return regions


def is_ok_query(q, snip_exp_changed, filter_xml_reask):
    cgi_params = sq.parse_cgi_params(q)

    return (
        (not _is_experiment_query(q, snip_exp_changed)) and
        (_region_pattern.search(q) is not None) and
        ('info' not in cgi_params) and
        ('ms' in cgi_params and cgi_params['ms'][0] == 'proto') and
        (
            (not filter_xml_reask) or
            ('reqid' not in cgi_params) or
            '-XML-REASK' not in cgi_params['reqid'][0]
        )
    )


def _check_low_amount(queries, queries_type, min_queries):
    if (len(queries) < min_queries):
        raise SandboxTaskFailureError('Too low amount of {} queries: {}, need {}. '.format(
            queries_type,
            len(queries),
            min_queries,
        ))


def _split_queries(queries, min_search_queries, fail_on_low_amount_of_queries):
    logging.info('splitting queries')

    search_queries = []
    xml_search_queries = []
    snippets_queries = []

    for q in queries:
        cgi_params = sq.parse_cgi_params(q)
        if 'DF' in cgi_params and cgi_params['DF'][0] == 'da':
            snippets_queries.append(q)
        elif 'reqid' in cgi_params and cgi_params['reqid'][0].find('-XML') != -1:
            xml_search_queries.append(q)
        else:
            search_queries.append(q)

    total_queries = len(search_queries) + len(xml_search_queries) + len(snippets_queries)
    if len(queries) != total_queries:
        raise SandboxTaskFailureError('Queries length mismatch')

    if fail_on_low_amount_of_queries:
        _check_low_amount(search_queries, 'search', min_search_queries)
        _check_low_amount(xml_search_queries, 'xml', 10)
        _check_low_amount(snippets_queries, 'snippets', 10)

    return search_queries, xml_search_queries, snippets_queries


def _change_region(query, new_region):
    s, n = _region_pattern.subn(str(new_region), query)
    if n < 1:
        return None
    return s


def _set_experiment(query, experiment):
    return '{}&pron={}'.format(query, experiment)


def _mix_queries_w_regions(queries, regions):
    mixed = (_change_region(pair[0], pair[1]) for pair in product(queries, regions))
    return [i for i in mixed if i is not None]


def _mix_queries_w_experiments(queries, experiments):
    return [_set_experiment(pair[0], pair[1]) for pair in product(queries, experiments)]


__Task__ = MixQueriesExperimentsRegions
