"""
    Here is a class for usage with sdk2 tasks
"""
import itertools
import logging

from sandbox import common
from sandbox import sdk2
from sandbox.sandboxsdk import paths

import sandbox.projects.websearch.middlesearch.resources as ms_resources
from . import requester_core
from . import requester
from sandbox.projects import resource_types
from sandbox.projects.common import file_utils as fu
from sandbox.projects.common import link_builder as lb
from sandbox.projects.common import utils
from sandbox.projects.common.search import bugbanner2
from sandbox.projects.common.search.response.diff import protodiff as prt
from sandbox.projects.common.differ import printers

RESPONSE_SEP = "\n" + "-" * 100 + "{}\n"


class RequesterTask(bugbanner2.BugBannerTask):
    class Parameters(sdk2.Task.Parameters):
        with sdk2.parameters.Group('Requester options') as requester_options:
            requests = sdk2.parameters.Resource(
                "Requests",
                resource_type=[
                    resource_types.PLAIN_TEXT_QUERIES,
                    ms_resources.WebMiddlesearchPlainTextQueries,
                    ms_resources.WebMiddlesearchApphostRequests,
                ],
                required=True,
            )
            queries_limit = sdk2.parameters.Integer(
                "Limit number of used requests (0 = all)",
                required=False,
                default=0,
            )
            recheck_n = sdk2.parameters.Integer(
                "Recheck response N times (0 - ask once)",
                required=False,
                default=0,
            )
            request_timeout = sdk2.parameters.Integer(
                "Request timeout (valgrind crutch, default = 60sec)",
                required=False,
                default=60,
            )
            max_retries = sdk2.parameters.Integer(
                "Retry every query N times (0 - ask once)",
                required=False,
                default=0,
            )
            n_workers = sdk2.parameters.Integer(
                "Number of parallel processes used (0 = ncpu)",
                required=False,
                default=0,
            )

        with sdk2.parameters.Group('Response check options') as response_check_options:
            max_empty_responses_rate = sdk2.parameters.Float(
                "Maximum allowed empty responses rate",
                required=False,
                default=0,
            )
            min_single_response_size = sdk2.parameters.Integer(
                "Minimum size of single response (bytes, 0 = don't check)",
                required=False,
                default=0,
            )

    def query_transformer(self, base_url):
        return None

    def iter_requests(self):
        requests_path = str(sdk2.ResourceData(self.Parameters.requests).path)
        req_iter = fu.read_line_by_line(requests_path)
        q_limit = self.Parameters.queries_limit
        if self.Parameters.recheck_n:
            # SEARCH-4579
            # repeat same query n_repeats times, and then compare responses on them to achieve stability checking
            n_repeats = self.Parameters.recheck_n + 1
            q_limit *= n_repeats
            req_iter = itertools.chain.from_iterable(itertools.repeat(x, n_repeats) for x in req_iter)
        return itertools.islice(req_iter, q_limit or None)

    def iter_responses(self, queries_iterable, target_port, n_workers=None):
        logging.info("Iter responses")
        queries_data_iterator = itertools.izip(
            queries_iterable,
            itertools.repeat(self.Parameters.request_timeout),
            itertools.repeat(self.Parameters.max_retries),
            itertools.repeat(self.query_transformer("http://localhost:{}".format(target_port)))
        )
        # for response in requester_core.response_yielder_sequential(queries_data_iterator):
        n_workers = n_workers or self.Parameters.n_workers
        for response in requester_core.response_yielder(n_workers, queries_data_iterator):
            yield response

    def save_responses(self, target_port, save_to_path):
        self.prepare_to_save()
        n_repeats = self.Parameters.recheck_n + 1
        if n_repeats > 1:
            unstable_diff_path = self._unstable_diff_path()
            paths.make_folder(unstable_diff_path)
            printer = printers.PrinterToHtml(
                unstable_diff_path,
                write_compact_diff=False,
                pair_head_template="response {obj_index}",
                obj_head_template="query"
            )
            protodiff = self.create_differ(printer)
        with utils.TimeCounter("SAVING RESPONSES"):
            with open(save_to_path, "w") as out_f:
                iter_req = self.iter_requests()
                for r_num, req, r_status, r_data in self.iter_responses(iter_req, target_port):
                    if not r_status and self.stop_on_error:
                        raise common.errors.TaskFailure("Failed to query #{}: {}".format(r_num, r_data))
                    r_data = self.post_process_response(r_num, req, r_status, r_data)
                    if n_repeats > 1:
                        if r_num % n_repeats == 0:
                            benchmark = r_data
                        else:
                            self.check_responses_stability(protodiff, benchmark, r_data, req)
                    self.write_responses(out_f, r_data, r_num)

    def check_responses_stability(self, protodiff, benchmark, checking, req):
        if protodiff is None:
            raise common.errors.TaskFailure("Response differ is not set up")
        protodiff.compare_single_pair(benchmark, checking, title=self.parse_request_for_diff(req))
        if protodiff.get_diff_count() > 0:
            protodiff._finalize_printer()
            res = resource_types.BASESEARCH_RESPONSES_COMPARE_RESULT(
                self, "Unstable diff", self._unstable_diff_path()
            )
            sdk2.ResourceData(res).ready()
            self.set_info("Unstable diff: {}".format(lb.resource_link(res.id)), do_escape=False)
            raise common.errors.TaskError("Search is not stable")

    def create_differ(self):
        """Returns sandbox.projects.common.differ.differ.DifferBase instance or None"""
        return None

    def parse_request_for_diff(self, req):
        return req

    def write_responses(self, out_f, r_data, r_num):
        """Writes the content r_data of response #r_num to out_f file."""
        pass

    def prepare_to_save(self):
        pass

    @property
    def stop_on_error(self):
        return True

    def post_process_response(self, r_num, req, r_status, r_data):
        return r_data

    def _unstable_diff_path(self):
        return "unstable_diff"


class SearchRequesterTask(RequesterTask):
    """
    Search-specific requester task:
    uses custom request transformer and assumes that responses have Protobuf format.
    """
    class Parameters(RequesterTask.Parameters):
        with sdk2.parameters.Group('Query transform options') as query_transform_options:
            make_binary = sdk2.parameters.Bool("Get responses in binary format", default=True)
            additional_cgi = sdk2.parameters.String(
                "Add custom cgi params (ex.: &p1=v1&p2=v2)",
                default="",
                required=False,
            )
            stabilize = sdk2.parameters.Bool("Stabilize responses", default=True)
            need_dbgrlv = sdk2.parameters.Bool("Get debug relevance props", default=True)
            get_all_factors = sdk2.parameters.Bool("Get all factors", default=False)
            use_dcfm = sdk2.parameters.Bool("Use DCFM (SEARCH-841)", default=False)
            # https://st.yandex-team.ru/SEARCH-4373#1509019888000
            log_src_groupings = sdk2.parameters.Bool("Enable logging srcgroupings", default=False)
            collection = sdk2.parameters.String(
                "Search collections (yandsearch by default)",
                default="yandsearch",
                required=False,
            )

    def query_transformer(self, base_url):
        return requester_core.CgiQueryTransformer(
            base_url=base_url,
            make_binary=self.Parameters.make_binary,
            additional_cgi=self.Parameters.additional_cgi,
            stabilize=self.Parameters.stabilize,
            need_dbgrlv=self.Parameters.need_dbgrlv,
            log_src_groupings=self.Parameters.log_src_groupings,
            get_all_factors=self.Parameters.get_all_factors,
            use_dcfm=self.Parameters.use_dcfm,
            collection=self.Parameters.collection,
        )

    def create_differ(self, output_printer):
        return prt.Protodiff(output_printer, skip_fields=self.skip_fields_in_stability_check())

    @staticmethod
    def skip_fields_in_stability_check():
        """Default: projects/common/search/response/diff/protodiff.py/_DEFAULT_SKIP_FIELDS"""
        return None

    def write_responses(self, out_f, r_data, r_num):
        if self.Parameters.make_binary:
            requester.write_binary_data(out_f, r_data)
        else:
            out_f.write(r_data.strip("\n\r"))
            out_f.write(RESPONSE_SEP.format(r_num + 1))
