from sandbox.projects import resource_types

import sandbox.common.types.client as ctc
import sandbox.common.types.resource as ctr

from sandbox import common
from sandbox import sdk2

from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry

import logging
import requests
import StringIO


class CardRecommenderShardAcceptanceTest(sdk2.Task):
    """ Task for card recommender shard acceptance testing"""

    class Parameters(sdk2.Task.Parameters):
        requests_count = sdk2.parameters.Integer(
            "Count of requests for testing",
            default=10000,
            required=True
        )
        priemka_address = sdk2.parameters.String(
            "Priemka address",
            required=True,
        )
        production_address = sdk2.parameters.String(
            "Production address",
            required=True,
        )
        items_max_delta = sdk2.parameters.Integer(
            "Maximum percent of different items in priemka and production responses",
            default=20,
            required=True
        )
        relevance_max_delta = sdk2.parameters.Integer(
            "Maximum percent of average relevance diff in priemka and production responses",
            default=20,
            required=True
        )
        max_failed_requests_percent = sdk2.parameters.Integer(
            "Maximum percent of failed requests",
            default=1,
            required=True
        )
        exp_name = sdk2.parameters.String(
            "Recommender experiment name",
            required=False
        )

    class Requirements(sdk2.Task.Requirements):
        client_tags = ctc.Tag.GENERIC & (ctc.Tag.SAS | ctc.Tag.MAN | ctc.Tag.VLA)

    def parse_next_request(self, f):
        request_size = f.readline()
        if request_size == '':
            return None

        request = f.read(int(request_size))
        f.read(2) # \r\n

        stream = StringIO.StringIO(request)
        first_line = stream.readline().split(' ')
        method = first_line[0]
        url = first_line[1]
        if self.Parameters.exp_name:
            url += "&exp_name={}".format(self.Parameters.exp_name)

        headers = {}
        while True:
            header = stream.readline()[:-2] # \r\n
            if header == '':
                break
            key, value = header.split(':')
            headers[key] = value[1:]

        body = stream.read()

        return (method, url, headers, body)

    def fetch_responses(self, address, queries_file_path, requests_count):
        responses = []

        retry = Retry(
            total=3,
            backoff_factor=0.3,
        )
        session = requests.Session()
        adapter = HTTPAdapter(max_retries=retry)
        session.mount('http://', adapter)
        session.mount('https://', adapter)

        with open(queries_file_path, 'r') as queries_file:
            for i in range(requests_count):
                next_request = self.parse_next_request(queries_file)
                if next_request is None:
                    break
                method, url, headers, body = next_request

                req = requests.Request(method, address + url, data=body, headers=headers)
                prepared_request = session.prepare_request(req)
                response = session.send(prepared_request)
                if not response.ok:
                    logging.error(
                        "Error in request processing. Url: {}. Request number: {}. Code: {}. Response: {}".format(
                            address + url,
                            i,
                            response.status_code,
                            response.content
                        )
                    )
                    responses.append(None)
                else:
                    responses.append(response.json())

        return responses

    def get_response_stats(self, response):
        ids = set()
        sum_rel = 0.0

        for c in response["cards"]:
            ids.add(c["id"])
            sum_rel += c["rel"]

        return ids, sum_rel

    def compare_responses(self, left_responses, right_responses):
        if len(left_responses) != len(right_responses):
            raise common.errors.TaskFailure("Different count of responses")

        diff_items = 0
        total_items = 0
        avg_rel_left = 0.0
        avg_rel_right = 0.0

        failed_requests = 0

        total_responses = len(left_responses)

        for i in range(total_responses):
            if left_responses[i] is None or right_responses[i] is None:
                failed_requests += 1
                continue

            left_ids, left_rel = self.get_response_stats(left_responses[i])
            right_ids, right_rel = self.get_response_stats(right_responses[i])

            diff_items += len(right_ids - left_ids)
            total_items += len(right_ids)
            if len(left_ids) > 0:
                avg_rel_left += (left_rel / len(left_ids))
            if len(right_ids) > 0:
                avg_rel_right += (right_rel / len(right_ids))

        if failed_requests * 100 / total_responses > self.Parameters.max_failed_requests_percent:
            raise common.errors.TaskFailure(
                "Too many failed requests: {} of {}".format(failed_requests, total_responses)
            )

        if total_responses > 0:
            avg_rel_left /= total_responses
            avg_rel_right /= total_responses

        self.Context.diff_items = diff_items
        self.Context.total_items = total_items
        self.Context.avg_rel_left = avg_rel_left
        self.Context.avg_rel_right = avg_rel_right

    def on_execute(self):
        queries_resource = sdk2.Resource.find(
            resource_types.PLAIN_TEXT_QUERIES,
            state=ctr.State.READY,
            attrs=dict(collections_card_recommender_perf_plan=True)
        ).first()
        queries_resource_data = sdk2.ResourceData(queries_resource)

        prod_responses = self.fetch_responses(
            self.Parameters.production_address,
            str(queries_resource_data.path),
            self.Parameters.requests_count
        )
        priemka_responses = self.fetch_responses(
            self.Parameters.priemka_address,
            str(queries_resource_data.path),
            self.Parameters.requests_count
        )

        self.compare_responses(prod_responses, priemka_responses)

        if self.Context.diff_items * 1.0 / self.Context.total_items > self.Parameters.items_max_delta * 1.0 / 100:
            raise common.errors.TaskFailure(
                "Too many different items in responses: {} of {}".format(self.Context.diff_items, self.Context.total_items)
            )
        if abs(self.Context.avg_rel_right - self.Context.avg_rel_left) / self.Context.avg_rel_right > self.Parameters.relevance_max_delta * 1.0 / 100:
            raise common.errors.TaskFailure(
                "Too big difference in average relevance: prod={} priemka={}".format(self.Context.avg_rel_left, self.Context.avg_rel_right)
            )
