import cStringIO
import gzip

from sandbox import sdk2
import logging
import requests
import random
from sandbox import common
import os
from sandbox.sandboxsdk import svn

from sandbox.projects.common import time_utils as tu
from sandbox.projects.common import decorators
from sandbox.projects.prs_ops import resources
from functools import partial
from sandbox.projects.common.search import bugbanner2
from sandbox.projects.release_machine.core import const as rm_const
from sandbox.projects.release_machine.components.configs.prs_ops import PrsOpsCfg
from sandbox.projects.release_machine import rm_notify as rm_notify
from sandbox.projects.common.sdk_compat import task_helper


@rm_notify.notify2()
class GetSampleForPrsOPs(bugbanner2.BugBannerTask):
    """
    Get sample for prs_ops
    """

    _RATINGS = "ratings.tsv"
    _QUERIES = "queries.tsv"
    _RATINGS_FULL = "ratings_full.tsv"
    _QUERIES_FULL = "queries_full.tsv"

    file_names = (_QUERIES_FULL, _RATINGS, _QUERIES, _RATINGS_FULL)

    res_types_files = {
        _RATINGS: resources.PRS_OPS_RATINGS,
        _RATINGS_FULL: resources.PRS_OPS_RATINGS,
        _QUERIES: resources.PRS_OPS_QUERIES,
        _QUERIES_FULL: resources.PRS_OPS_QUERIES,
    }

    _ID_POSITION = 0
    _LANG_POSITION = 3
    _REGION_POSITION = 2

    class Requirements(sdk2.Task.Requirements):
        ram = 60 * 1024
        disk_space = 40 * 1024

    class Parameters(sdk2.Task.Parameters):
        num_samples = sdk2.parameters.Integer("Number of samples", default=1000, required=True)
        seed = sdk2.parameters.Bool("Seed", default=False)
        with seed.value[True]:
            seed_num = sdk2.parameters.Integer("Seed number", default=0)

        lang_flag = sdk2.parameters.Bool("Choose countries", default=False)
        with lang_flag.value[True]:
            langs = sdk2.parameters.String("Countries (separated by comma, CAPS)", default="RU,UA,BY,KZ")
        used_for_testenv = sdk2.parameters.Bool("Mark if use for testenv update", default=False)

    @decorators.retries(3, raise_class=common.errors.TaskFailure)
    def _get_unpacked_file(self, url):
        logging.debug("trying to get {}".format(url))
        gzipped_file = requests.get(url, timeout=60)
        logging.debug("response {}".format(gzipped_file.status_code))
        return gzip.GzipFile(fileobj=cStringIO.StringIO(gzipped_file.content))

    def _get_sampled_id(self, _file, langs, num_samples, ratings):

        try:
            regions_file = os.path.abspath("regions_to_block")
            svn.Arcadia.export("arcadia:/arc/trunk/arcadia/quality/query_pool/prs_ops/lib/lrs.txt", regions_file)
            with open(regions_file, "r") as f:
                block_region = f.readlines()
                block_region = map(int, filter(None, map(str.strip, block_region)))
            logging.debug("checkout was successfull, regions: {}".format(block_region))
        except Exception:
            block_region = ['9999']
            logging.debug("checkout wasn't successfull\n")

        ratings_id = set()
        for req in ratings:
            info = self._get_info(req, lang_req=False)
            ratings_id.add(info["id"])
        ratings.seek(0)

        logging.debug("ratings id set done")

        list_id = []
        for req in _file:
            info = self._get_info(req, lang_req=True)
            if info["region"] not in block_region:
                if info["id"] in ratings_id:
                    if langs:
                        if info["lang"] in langs:
                            list_id.append(info["id"])
                    else:
                        list_id.append(info["id"])

        logging.debug("list_id from queries done")

        if len(list_id) < num_samples:
            logging.error("Not enough requests for chosen langs, only {} sampled".format(len(list_id)))
        _file.seek(0)
        return set(random.sample(list_id, min(num_samples, len(list_id))))

    @staticmethod
    def _get_info(request, lang_req):
        splitted_req = request.split("\t")
        if lang_req:
            return {
                "id": int(splitted_req[GetSampleForPrsOPs._ID_POSITION]),
                "region": int(splitted_req[GetSampleForPrsOPs._REGION_POSITION]),
                "lang": splitted_req[GetSampleForPrsOPs._LANG_POSITION],
            }
        else:
            return {"id": int(splitted_req[GetSampleForPrsOPs._ID_POSITION])}

    def _create_resources(self, files_dict, sampled_id):

        for file_name in GetSampleForPrsOPs.file_names:
            resource = GetSampleForPrsOPs.res_types_files[file_name](
                self,
                file_name,
                "{}__{}".format(tu.date_ymdhm(sep="_"), file_name),
                used_for_testenv=self.Parameters.used_for_testenv,
                used_for_testenv_big=False,
            )
            if file_name in (GetSampleForPrsOPs._QUERIES, GetSampleForPrsOPs._RATINGS_FULL):
                resource.used_for_testenv = False
            elif self.Parameters.num_samples > 100000 and self.Parameters.used_for_testenv:
                resource.used_for_testenv_big = True
                resource.used_for_testenv = False
            resource_data = sdk2.ResourceData(resource)
            with open(str(resource_data.path), "w") as out_f:
                for req in sorted(files_dict[file_name], key=partial(self._get_info, lang_req=False)):
                    if self._get_info(req, lang_req=False)["id"] in sampled_id:
                        out_f.write(req)

            resource_data.ready()

    def on_enqueue(self):
        task_helper.ctx_field_set(self, rm_const.COMPONENT_CTX_KEY, PrsOpsCfg.name)

    def on_execute(self):

        self.add_bugbanner(bugbanner2.Banners.PrsOps)

        if self.Parameters.seed:
            random.seed(self.Parameters.seed_num)

        files_dict = dict()
        for f in GetSampleForPrsOPs.file_names:
            files_dict[f] = self._get_unpacked_file(
                "https://fml.yandex-team.ru/download/dump/parse/file?latest-kosher-web&file-name={}.gz".format(f)
            )

        langs = set()
        if self.Parameters.lang_flag:
            langs = set(self.Parameters.langs.split(","))

        sampled_id = self._get_sampled_id(
            files_dict[GetSampleForPrsOPs._QUERIES_FULL],
            langs,
            self.Parameters.num_samples,
            files_dict[GetSampleForPrsOPs._RATINGS],
        )

        self._create_resources(files_dict, sampled_id)
