# -*- coding: utf-8 -*-

import json
import logging
import os
import requests

from sandbox import common
from sandbox import sdk2
from sandbox.sandboxsdk.environments import PipEnvironment
import sandbox.projects.common.binary_task as binary_task

# https://github.yandex-team.ru/morda/main/blob/dev/packages/ether-cms/src/typings/resources

logger = logging.getLogger('json_schema_validator')
logger.setLevel(logging.INFO)

SUPPORTED_HTTP_METHODS = {"GET", "POST"}
REQUESTS_FILENAME = "requests"
HANDLERS_DIRNAME = "handlers"
STRUCTURES_DIRNAME = "structures"
SUCCESS_RESPONSES_DIRNAME = "responses"
FAILED_RESPONSES_DIRNAME = "failed_responses"
ERRORS_DIRNAME = "errors"
RESOURCE_PREFIX = "resfs/file/sandbox/projects/vh/frontend/json_schema_validator"


def remove_prefix(text, prefix):
    if text.startswith(prefix):
        return text[len(prefix):]
    return text


def create_dir_by_path(path, cur_dir):
    for add_dir in path.split('/'):
        cur_dir = os.path.join(cur_dir, add_dir)
        if not os.path.exists(cur_dir):
            os.makedirs(cur_dir)
    return cur_dir


def read_schema(file):
    with open(file) as fd:
        return json.load(fd)


def read_payload(dir, filename):
    with open(os.path.join(dir, filename)) as fd:
        return fd.read()


def read_schemas(dir, key_prefix):
    result = {}
    logger.debug("Read schemas from dir {}".format(dir))
    for filename in os.listdir(dir):
        file = os.path.join(dir, filename)
        if filename.endswith(".jsonschema") and os.path.isfile(file):
            result[key_prefix + filename] = read_schema(file)
    return result


def merge_dicts(lhs, rhs):
    result = lhs.copy()
    result.update(rhs)
    return result


class NamedHttpRequest(object):
    def __init__(self, handler, idx, raw_request, payload_reader):
        self.handler = handler
        self.idx = idx

        try:
            parts = raw_request.strip().split('\t')
            if parts[0].startswith("expect_404_code"):
                self.schema_name = "404"
            else:
                self.schema_name = "{}__{}".format(handler, parts[0])

            if parts[1] not in SUPPORTED_HTTP_METHODS:
                raise RuntimeError("Unsupported method: {}".format(parts[1]))
            self.method = parts[1]

            if self.method == "GET":
                if len(parts) < 3 or len(parts) > 4:
                    raise RuntimeError("GET request must have 3 or 4 parts, actual number of parts is {}".format(len(parts)))
            elif self.method == "POST":
                if len(parts) != 5:
                    raise RuntimeError("POST request must have 5 parts, actual number of parts is {}".format(len(parts)))

            self.uri = parts[2]
            self.headers = {}
            if len(parts) >= 4:
                for header in parts[3].split('\\n'):
                    key, value = header.split(": ", 1)
                    self.headers[key] = value

            self.body = None
            if len(parts) >= 5:
                self.body = payload_reader(parts[4])
        except Exception as e:
            error = "Invalid request format for {}".format(NamedHttpRequest.pretty_name(handler, idx))
            error += "; Caused by: {}".format(e)
            raise RuntimeError(error)

    def __str__(self):
        return "[{}] {} {} - {}".format(self.get_pretty_name(), self.method, self.uri, self.headers)

    def get_pretty_name(self):
        return NamedHttpRequest.pretty_name(self.handler, self.idx)

    def get_url(self, host, port):
        return "{}:{}{}".format(host, port, self.uri)

    @staticmethod
    def pretty_name(handler, idx):
        return "{}:{}".format(handler, idx)


class NamedHttpResponse(object):
    def __init__(self, handler, idx, schema_name, raw_response):
        self.handler = handler
        self.idx = idx
        self.schema_name = schema_name
        self.status_code = raw_response.status_code if raw_response is not None else 0
        self.content = raw_response.content if raw_response is not None else ""

    def __str__(self):
        return "[{}] - code:{} - {}".format(self.get_pretty_name(), self.status_code, self.content)

    def get_pretty_name(self):
        return "{}:{}".format(self.handler, self.idx)

    def get_json_content(self):
        return json.loads(self.content)


class HandlerSuiteFactory(object):
    def __init__(self, structures_dir, handlers_dir):
        self.structures = read_schemas(structures_dir, "structures__")
        self.handlers_dir = handlers_dir

    def for_handler(self, handler):
        from jsonschema import Draft4Validator, RefResolver
        handler_dir = os.path.join(self.handlers_dir, handler)
        if not os.path.isdir(handler_dir):
            raise RuntimeError("Dir {} does not exist".format(handler_dir))

        tests = read_schemas(handler_dir, "{}__".format(handler))
        schemastore = merge_dicts(tests, self.structures)
        resolver = RefResolver("", "", schemastore)

        testcases = {}
        for testname, schema in tests.iteritems():
            Draft4Validator.check_schema(schema)
            testcases[testname] = Draft4Validator(schema, resolver=resolver)

        return HandlerSuite(handler, testcases)


class HandlerSuite(object):
    def __init__(self, handler, testcases):
        self.handler = handler
        self.testcases = testcases

    def validate_schema(self, testname, json_value):
        validator = self.testcases[testname]
        errors_found = []
        for error in validator.iter_errors(json_value):
            errors_found.append(error)
        return errors_found


class VhFrontendJsonSchemaValidator(binary_task.LastBinaryTaskRelease, sdk2.Task):
    class Requirements(sdk2.Task.Requirements):
        environments = (
            PipEnvironment('jsonschema', '2.5.1'),
        )

    class Parameters(sdk2.Task.Parameters):
        host = sdk2.parameters.String(
            "Host to shoot at",
            default="http://hamster.yandex.net",
            required=True
        )
        port = sdk2.parameters.String(
            "Port to shoot at",
            default=80,
            required=True
        )
        ext_params = binary_task.LastBinaryReleaseParameters()
        with sdk2.parameters.Output():
            # FIXME: add list of failed validates

            is_schema_valid = sdk2.parameters.Bool(
                "is schema success"
            )

    def read_requests_for_handler(self, handler_name, dir):
        logger.debug("Read requests from dir {}".format(dir))
        with open(os.path.join(dir, REQUESTS_FILENAME)) as f:
            for index, raw_request in enumerate(f, 1):
                yield NamedHttpRequest(handler_name, index, raw_request, lambda filename: read_payload(dir, filename))

    def read_requests(self):
        result = []
        for filename in os.listdir(self.handlers_dir):
            subdir = os.path.join(self.handlers_dir, filename)
            if os.path.isdir(subdir):
                result.extend(self.read_requests_for_handler(filename, subdir))
        return result

    def send_request_with_retry(self, request, host, port):
        from time import sleep
        from requests import HTTPError
        response = None
        for _ in range(10):
            try:
                url = request.get_url(host, port)
                if request.method == "GET":
                    response = requests.get(url, headers=request.headers)
                elif request.method == "POST":
                    response = requests.post(url, headers=request.headers, data=request.body)
                else:
                    raise RuntimeError("Request method {} is not supported".format(request.method))
                response.raise_for_status()
                break
            except HTTPError as e:
                logger.info(e.response)
                if (e.response and e.response.status_code and e.response.status_code // 100 == 4):
                    break
                sleep(2)
        return NamedHttpResponse(request.handler, request.idx, request.schema_name, response)

    def normalize_scheme(self, host):
        if not host.startswith('http') and not host.startswith('https'):
            host = 'http://' + host
        return host

    def get_responses(self, named_requests):
        named_responses = []
        for request in named_requests:
            host = self.normalize_scheme(self.Parameters.host)
            port = self.Parameters.port
            response = self.send_request_with_retry(request, host, port)
            named_responses.append(response)
        return named_responses

    def dump_failed_response(self, response, error_message):
        self.log_path(FAILED_RESPONSES_DIRNAME, response.get_pretty_name()).write_bytes(response.content)
        self.log_path(ERRORS_DIRNAME, response.get_pretty_name()).write_bytes(error_message)

    def dump_success_response(self, response):
        self.log_path(SUCCESS_RESPONSES_DIRNAME, response.get_pretty_name()).write_bytes(response.content)

    def validate_responses(self, suite_factory, named_responses):
        validated_statuses = {}
        error_count = 0
        for current_response in named_responses:
            if current_response.schema_name == "404" and current_response.status_code == 404:
                continue
            if current_response.status_code != 200:
                validated_statuses[current_response.get_pretty_name()] = \
                    "Status code is {}, not 200".format(current_response.status_code)
                error_count += 1
                continue

            try:
                suite = suite_factory.for_handler(current_response.handler)
                errors = suite.validate_schema(current_response.schema_name, current_response.get_json_content())
                if len(errors) == 0:
                    validated_statuses[current_response.get_pretty_name()] = "OK"
                    self.dump_success_response(current_response)
                else:
                    error_message = ""
                    for error in errors:
                        error_message += str(error.message) + "\n\n" + str(error.context) + ";\n\n"
                    validated_statuses[current_response.get_pretty_name()] = "Response is invalid (by schema), see log1/errors"
                    self.dump_failed_response(current_response, "Response is invalid by schema:\n" + error_message)
                    error_count += 1
            except Exception as e:
                logger.exception('Something wrong')
                validated_statuses[current_response.get_pretty_name()] = "Something wrong, see log1/errors"
                self.dump_failed_response(current_response, "Error\n" + e.message)
                error_count += 1

        return validated_statuses, error_count

    def create_dir_from_resource(self, create_dir):
        from library.python import resource
        dir_path = os.path.join(create_dir)

        prefix = os.path.join(RESOURCE_PREFIX, create_dir)
        for key in resource.iterkeys(prefix=prefix):
            filename = os.path.basename(key)
            dirname = remove_prefix(str(os.path.dirname(key)), str(prefix))
            cur_dir = create_dir_by_path(dirname, dir_path)
            with open(os.path.join(cur_dir, filename), 'wb') as f:
                f.write(resource.find(key))

    def on_execute(self):
        logger.info("host: {}, port: {}".format(self.normalize_scheme(self.Parameters.host), self.Parameters.port))

        self.handlers_dir = os.path.join(HANDLERS_DIRNAME)
        self.create_dir_from_resource(HANDLERS_DIRNAME)
        self.structures_dir = os.path.join(STRUCTURES_DIRNAME)
        self.create_dir_from_resource(STRUCTURES_DIRNAME)
        logger.info("handlers_dir: {}".format(self.handlers_dir))
        logger.info("structures_dir: {}".format(self.structures_dir))
        self.log_path(FAILED_RESPONSES_DIRNAME).mkdir(parents=True)
        logger.info("failed_responses_dir: {}".format(self.log_path(FAILED_RESPONSES_DIRNAME).name))
        self.log_path(ERRORS_DIRNAME).mkdir(parents=True)
        logger.info("errors_dir: {}".format(self.log_path(ERRORS_DIRNAME).name))
        self.log_path(SUCCESS_RESPONSES_DIRNAME).mkdir(parents=True)
        logger.info("success_responses_dir: {}".format(self.log_path(SUCCESS_RESPONSES_DIRNAME).name))

        named_requests = self.read_requests()
        logger.info("named_requests:\n {}".format("\n".join(map(str, named_requests))))

        suite_factory = HandlerSuiteFactory(self.structures_dir, self.handlers_dir)
        named_responses = self.get_responses(named_requests)
        logger.debug("named_responses:\n {}".format("\n".join(map(str, named_responses))))

        validation_result, error_count = self.validate_responses(suite_factory, named_responses)
        logger.info("validation_result:\n {}".format(validation_result))
        if error_count != 0:
            raise common.errors.TaskFailure("Failed {}/{}. Validation error: {}".format(
                error_count, len(named_responses), validation_result))
