# -*- coding: utf-8 -*-

import difflib
import json
import logging
import re
from sandbox import sdk2
import struct
import tarfile
import io
import StringIO
import httplib

from sandbox.projects.common import utils
from sandbox.projects.yabs.qa.resource_types import (
    YABS_REPORT_RESOURCE,
)
from sandbox.projects.vh.frontend import (
    DolbilkaDumpResult,
)
from sandbox.projects.vh.frontend.dolbilka_plan_creator import VhDolbilkaPlanCreator
from sandbox import common
from sandbox.sdk2.path import Path
from multiprocessing import Pool


class DumpParsingError(RuntimeError):
    def __init__(self, *args):
        if args:
            self.message = args[0]
        else:
            self.message = None

    def __str__(self):
        if self.message:
            return "DumpParsingError: {}".format(self.message)
        else:
            return "DumpParsingError"


class DolbilkaDumpParser(object):
    def __init__(self, dump):
        self.dump = dump

    def parse_uint32(self):
        size_byte = self.dump.read(4)
        if not size_byte:
            raise DumpParsingError("Cannot parse uint32")
        res = struct.unpack("I", size_byte)[0]
        return res

    def parse_uint64(self):
        size_byte = self.dump.read(8)
        if not size_byte:
            raise DumpParsingError("Cannot parse uint64")
        res = struct.unpack("Q", size_byte)[0]
        return res

    def parse_bytes_array(self):
        string_size = self.parse_uint32()
        if string_size == 0xffffffff:
            string_size = self.parse_uint64()

        res = self.dump.read(string_size)
        if len(res) != string_size:
            raise DumpParsingError("Cannot parse string")
        return res

    def skip(self, skip_bytes_count):
        self.dump.read(skip_bytes_count)

    def parse_response(self):
        res = {}
        res["err_code"] = self.parse_uint32()
        res["http_code"] = self.parse_uint32()
        res["request_url"] = self.parse_bytes_array()
        self.skip(8)
        self.skip(8)
        res["data_length"] = self.parse_uint64()
        res["data"] = self.parse_bytes_array()
        res["error"] = self.parse_bytes_array()
        if len(res["data"]) != res["data_length"]:
            raise RuntimeError("Incorrect data length: expected {}, got {}".format(res["data_length"], len(res["data"])))
        return res

    def parse_responses(self):
        responses = []
        while True:
            try:
                responses.append(self.parse_response())
            except DumpParsingError:
                return responses


class StringSocket(object):
    def __init__(self, response):
        self.response = response

    def makefile(self, *args, **kwargs):
        return StringIO.StringIO(self.response)


def dump_response(http_response, http_response_length):
    result = [str(http_response.status) + "  " + http_response.reason]
    headers = http_response.getheaders()
    headers.sort(key=lambda t: t[0])
    for k, v in headers:
        result.append(k + ": " + v)
    result.append("")
    result.append(http_response.read(http_response_length))
    return "\r\n".join(result)


def normalize_http_response(raw_http_response):
    socket = StringSocket(raw_http_response)
    response = httplib.HTTPResponse(socket)
    response.begin()
    return dump_response(response, len(raw_http_response))


def get_shoot_reqid(message):
    m = re.search("shoot_reqid=(\w+)", message)
    if m:
        return m.group(1)
    else:
        logging.info("cannot parse reqid from %s" % message)
        return "nonexistent"


def prettify_response(lines, flaky_fields, flaky_fields_html):
    if len(lines) == 0:
        return lines

    is_ans_size_deleted = False
    for i in range(len(lines)):
        if lines[i] == "<!DOCTYPE html>":
            del lines[i-1]
            is_ans_size_deleted = True
            break

    if not is_ans_size_deleted:
        del lines[-2]

    beginning = lines[:-1]
    content = lines[-1:]

    delete_flaky_fields(flaky_fields, flaky_fields_html, beginning)

    if is_json(content[0]):
        content = parse_response_json(content[0], flaky_fields)
    else:
        delete_flaky_fields(flaky_fields, flaky_fields_html, content)

    return beginning + content


def is_json(obj):
    try:
        json.loads(obj)
        return True
    except ValueError:
        return False


def parse_response_json(line, blacklist):
    response = json.loads(line)
    response = replace_flaky_fields(response, blacklist)
    return json.dumps(response,
                      sort_keys=True, indent=4, separators=(',', ': '),
                      ensure_ascii=False, encoding='utf8').splitlines()


def replace_flaky_fields(obj, blacklist):
    if not isinstance(obj, dict):
        return obj

    modified_obj = {}
    for key, value in obj.iteritems():
        if key in blacklist:
            modified_obj[key] = 'FAKE_FIELD'
            continue

        modified_value = value
        if isinstance(value, dict):
            modified_value = replace_flaky_fields(value, blacklist)
        elif isinstance(value, list):
            modified_value = list()
            for item in value:
                modified_value.append(replace_flaky_fields(item, blacklist))
        modified_obj[key] = modified_value

    return modified_obj


def delete_flaky_fields(flaky_fields, flaky_fields_html, lines):
    for i in range(len(lines)):
        for field in flaky_fields:
            if field in flaky_fields_html:
                field_start = lines[i].find(field)
                while field_start != -1:
                    str_begin = lines[i][:field_start]
                    field_end = lines[i].find(',', field_start)
                    if field_end != -1 and field_end + 1 < len(lines[i]):
                        lines[i] = str_begin + lines[i][field_end + 1:]
                    else:
                        lines[i] = str_begin
                    field_start = lines[i].find(field)

            else:
                field_start = lines[i].find(field)
                while field_start != -1 and lines[i][field_start - 1] in "'\"":
                    str_begin = lines[i][:field_start]
                    field_end = lines[i].find(',', field_start)
                    if field_end != -1 and field_end + 1 < len(lines[i]):
                        lines[i] = str_begin + lines[i][field_end + 1:]
                    else:
                        lines[i] = str_begin
                    field_start = lines[i].find(field)


def calculate_diff(item):
    try:
        item_id, stable_request, test_request, stable_response, test_response = item

        elem = {
            "id": item_id,
            "request": "GET {stable_request}\r\n{test_request}\r\n{headers}\r\n\r\n".format(
                stable_request=stable_request["request"],
                test_request=test_request["request"],
                headers=test_request["headers"],
            ),
            "search": {
                "stable_code": str(stable_response["http_code"]),
                "test_code": str(test_response["http_code"]),
                "handler": test_request["handler"],
            },
        }

        stable_lines = normalize_http_response(stable_response["data"]).splitlines()
        test_lines = normalize_http_response(test_response["data"]).splitlines()

        # FIXME: Maybe separate json fields and other (ex. headers)
        flaky_fields = ['apphost-reqid', 'update_time', 'expirationTimestamp',
                              'signature', 'watchSessionId', 'reqid', 'search-reqid', 'req_id',
                              'content_url', 'ugc_token', 'ugc_not_interested_token', 'current_time',
                              'likes', 'dislikes', 'rating_kp', 'subscribers_count',
                              'start_position', 'x-content-version', 'x-processing-host', 'x-yandex-req-id', 'cuid',
                              'url', 'logging_url']
        flaky_fields_html = ['x-content-version', 'x-processing-host', 'x-yandex-req-id', 'cuid']

        elem["stable_ans"] = '\n'.join(stable_lines)
        elem["test_ans"] = '\n'.join(test_lines)
        stable_lines = prettify_response(stable_lines, flaky_fields, flaky_fields_html)
        test_lines = prettify_response(test_lines, flaky_fields, flaky_fields_html)

        try:
            unified_diff = "\n".join(
                difflib.unified_diff(
                    stable_lines,
                    test_lines,
                    fromfile="stable",
                    tofile="test",
                    lineterm="",
                )
            )
        except Exception:
            unified_diff = "bad content"

        if len(unified_diff) > 0:
            elem["status"] = "failed"
            elem["diff"] = unified_diff
        else:
            elem["status"] = "passed"
            elem["diff"] = ""

        return elem
    except Exception:
        # print stack trace if exception raised
        logging.exception("Can't calculate diff")
        raise


class VhFrontendCountDiff(sdk2.Task):
    """
        Count diff for VH frontend Back2Back
    """
    class Parameters(sdk2.Task.Parameters):
        stable_responses = sdk2.parameters.Resource(
            "Stable stand dump results",
            resource_type=DolbilkaDumpResult,
            required=True,
        )
        test_responses = sdk2.parameters.Resource(
            "Test stand dump results",
            resource_type=DolbilkaDumpResult,
            required=True,
        )
        stable_plan_creator = sdk2.parameters.Task(
            "Task that created stable plan",
            task_type=VhDolbilkaPlanCreator,
            required=True,
        )
        test_plan_creator = sdk2.parameters.Task(
            "Task that created test plan",
            task_type=VhDolbilkaPlanCreator,
            required=True,
        )

        with sdk2.parameters.Output():
            is_test_success = sdk2.parameters.Bool(
                "zero diff"
            )
        # flaky_fields = sdk2.parameters.List(
        #     'Fields which will be unique for every request (eg. "apphost-reqid)"',
        #     default=['"apphost-reqid"', 'X-Content-Version', '"update_time"', '"expirationTimestamp"',
        #              '"signature"', '"watchSessionId"'],
        # )

    def on_execute(self):
        diff = self.get_shoot_diff()

        tests_number, failed_test_number = self.save_diff(diff)
        self.copy_report_resource()
        self.add_report_link(tests_number, failed_test_number)

        if failed_test_number == 0:
            self.Parameters.is_test_success = True
        else:
            self.Parameters.is_test_success = False

    def add_report_link(self, tests_number, failed_test_number):
        logging.info("logpath %s" % self.log_path().name)

        report_link = common.utils.get_task_link(
            "{task_id}/{log_folder}/{report_file}".format(
                task_id=self.id,
                log_folder=self.log_path().name,
                report_file="index.html",
            )
        ).replace(
            "://sandbox",
            "://proxy.sandbox"
        )

        self.set_info(
            "".join(
                [
                    "<a href={report} target=\"_blank\">Отчет</a>".format(report=report_link),
                    "<br>Failures: {fails}".format(fails=failed_test_number),
                    "<br>Tests: {tests}".format(tests=tests_number),
                ]
            ),
            do_escape=False,
        )

    def copy_report_resource(self):
        report_resource_path = Path(
            utils.sync_last_stable_resource(YABS_REPORT_RESOURCE)
        ).resolve()
        logging.info("report_resource_path = {report_resource_path}".format(report_resource_path=report_resource_path))
        with tarfile.open(str(report_resource_path)) as tar:
            current_path = self.log_path()
            for member in tar.getmembers():
                if member.isreg():
                    if "build" in member.name:
                        member.name = Path(member.name).name
                        tar.extract(member, str(current_path / "static" / "build"))
                    elif "static" in member.name:
                        member.name = Path(member.name).name
                        tar.extract(member, str(current_path / "static"))
                    elif member.name.endswith("main.standalone.chunked.html"):
                        member.name = Path(member.name).name
                        tar.extract(member, str(current_path))
                        current_path.joinpath("main.standalone.chunked.html").rename(current_path / "index.html")

    def match_request_with_response(self, requests_list, responses_list):
        reqid_to_request = {}
        for request in requests_list:
            reqid_to_request[get_shoot_reqid(request["request"])] = request

        stable_requests, stable_responses = [], []
        for response in responses_list:
            reqid = get_shoot_reqid(response["request_url"])
            request = reqid_to_request.get(reqid)
            if request is None:
                logging.info("duplicated response shoot_reqid={}".format(reqid))
            else:
                stable_requests.append(request)
                stable_responses.append(response)
                del reqid_to_request[reqid]

        for reqid in reqid_to_request:
            logging.info("request shoot_reqid={} doesn't have response".format(reqid))

        return stable_requests, stable_responses

    def get_shoot_diff(self):
        stable_responses = self.get_responses(self.Parameters.stable_responses)
        test_responses = self.get_responses(self.Parameters.test_responses)
        stable_requests = self.get_requests(self.Parameters.stable_plan_creator)
        test_requests = self.get_requests(self.Parameters.test_plan_creator)

        stable_requests, stable_responses = self.match_request_with_response(stable_requests, stable_responses)
        test_requests, test_responses = self.match_request_with_response(test_requests, test_responses)

        logging.info(
            "Responses from stable {responses}/{requests}".format(
                responses=len(stable_responses),
                requests=len(stable_requests),
            )
        )
        logging.info(
            "Responses from test {responses}/{requests}".format(
                responses=len(test_responses),
                requests=len(test_requests),
            )
        )

        pool = Pool(processes=30)

        shot_count = len(stable_responses)
        diff_list = pool.map_async(
            calculate_diff,
            zip(range(shot_count), stable_requests, test_requests, stable_responses, test_responses),
        )

        return list(diff_list.get())

    def get_requests(self, plan_creator):
        requests_resource = sdk2.Resource[plan_creator.Parameters.requests]
        requests_resource_path = sdk2.ResourceData(requests_resource).path

        with requests_resource_path.open() as f:
            requests = []
            for line in f:
                # logging.info(line)
                time_delay, request, headers = line.strip().split("\t")
                request_parts = re.split("\/|\?", request)
                handler = request_parts[2] if len(request_parts) > 2 else ""
                requests.append(
                    {
                        "request": request,
                        "headers": "\r\n".join(headers.split("\\n\\n")),
                        "handler": handler,
                    },
                )
        return requests

    def get_responses(self, response_resource):
        response_resource_path = sdk2.ResourceData(response_resource).path

        with response_resource_path.open("rb") as f:
            parser = DolbilkaDumpParser(f)
            responses = parser.parse_responses()
        return responses

    def save_diff(self, diff):
        search = {}
        self.log_path("tests").mkdir(parents=True)
        self.log_path("diff").symlink_to(self.log_path("tests"))
        self.log_path("logs_full_answers").mkdir(parents=True)

        result = []
        failed_test_number = 0
        for elem in diff:
            for key in elem["search"].keys():
                search.setdefault(key, set()).add(elem["search"][key])
            result.append({
                "status": elem["status"],
                "search": elem["search"],
                "name": str(elem["id"]),
                "id": int(elem["id"]),
                "diffLinesCount": 1,
            })
            if elem["status"] == "failed":
                failed_test_number += 1

            try:
                self.log_path("diff", "{elem_id}.json".format(elem_id=elem["id"])).write_bytes(
                    json.dumps(
                        {
                            "diff": elem["diff"],
                            "handler": elem["search"]["handler"],
                            "id": elem["id"],
                            "name": str(elem["id"]),
                            "request": elem["request"],
                            "search": elem["search"],
                            "status": elem["status"],
                            "stable_ans": elem["stable_ans"],
                            "test_ans": elem["test_ans"]
                        },
                    )
                )
            except Exception:
                logging.info("Error in parsing response on: {request}".format(request=elem["request"]))

            try:
                path = str(self.log_path("logs_full_answers", "{elem_id}.json".format(elem_id=elem["id"])))
                logging.info(path)
                with io.open(path, 'w', encoding='utf-8') as f:
                    f.write(json.dumps({{
                        "diff": elem["diff"],
                        "handler": elem["search"]["handler"],
                        "id": elem["id"],
                        "name": str(elem["id"]),
                        "request": elem["request"],
                        "search": elem["search"],
                        "status": elem["status"],
                        "stable_ans": elem["stable_ans"],
                        "test_ans": elem["test_ans"]
                    }}, ensure_ascii=False))

            except Exception:
                logging.info("Error in test_ans: {request}".format(request=elem["test_ans"]))
                self.log_path("diff", "{elem_id}.json".format(elem_id=elem["id"])).write_bytes(
                    json.dumps(
                        {
                            "diff": elem["diff"],
                            "handler": elem["search"]["handler"],
                            "id": elem["id"],
                            "name": str(elem["id"]),
                            "request": elem["request"],
                            "search": elem["search"],
                            "status": elem["status"],
                            "stable_ans": elem["stable_ans"],
                            "test_ans": elem["test_ans"]
                        },
                    )
                )

        for tag in search.keys():
            search[tag] = list(search[tag])
        tests_number = len(result)
        self.log_path("report.json").write_bytes(
            json.dumps(
                {
                    "search": search,
                    "meta": [
                        {
                            "value": tests_number,
                            "title": "Tests",
                        }, {
                            "value": failed_test_number,
                            "title": "Failures",
                        },
                    ],
                    "results": result,
                },
            )
        )

        return tests_number, failed_test_number
