# -*- coding: utf-8 -*-

import contextlib
import logging
import shutil
import tempfile
import filecmp
import os
import os.path
import random
import requests
import signal
import socket
import tarfile
import time
import yatest.common
from yatest.common import network, ExecutionTimeoutError


class PythonTest(object):
    port = 0

    @classmethod
    def logger(cls):
        return logging.getLogger("test_logger")

    @classmethod
    def tar_paths(cls):
        raise NotImplementedError("Subclass should implement tar_paths() classmethod")

    @classmethod
    def jar_paths(cls):
        return []

    @classmethod
    def java_class_name(cls):
        raise NotImplementedError("Subclass should implement java_class_name() classmethod")

    @classmethod
    def java_args(cls):
        raise NotImplementedError("Subclass should implement java_args() classmethod")

    @classmethod
    def jvm_args(cls):
        return ['-Xmx2g']

    @classmethod
    def program_args(cls):
        return []

    @classmethod
    def get_server_start_timeout(cls):
        return 180

    @classmethod
    def get_server_shutdown_timeout(cls):
        return 20

    @classmethod
    def find_free_port(cls):
        with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
            s.bind(('', 0))
            return s.getsockname()[1]

    @classmethod
    def untar(cls, tar_path, dir_path):
        os.mkdir(dir_path)
        with contextlib.closing(tarfile.open(tar_path, 'r')) as tf:
            tf.extractall(dir_path)

    @classmethod
    def execute_java(cls, class_name, args, program_args):
        classpath_dir = 'classpath' + str(random.random())
        for tar_path in cls.tar_paths():
            cls.untar(yatest.common.binary_path(tar_path), classpath_dir)

        cls.untar('jdk.tar', 'jdk')

        cmd = [
            './jdk/bin/java',
            "-cp", "".join([d + "/*:" for d in [classpath_dir] + cls.jar_paths()]),
            "-Djava.library.path=" + classpath_dir,
            ] + cls.jvm_args() + args + [class_name] + program_args
        cls.logger().info("execute_java: " + str(cmd))
        return yatest.common.execute(cmd, wait=False)

    @classmethod
    def get_server_port(cls):
        if not cls.port:
            cls.port = cls.find_free_port()
        return cls.port

    @classmethod
    def get_base_url(cls):
        cls.base_url = "http://127.0.0.1:{}/".format(cls.get_server_port())
        return cls.base_url

    @classmethod
    def setup_class(cls):
        cls.pm = network.PortManager()
        cls.server = cls.execute_java(cls.java_class_name(), cls.java_args(), cls.program_args())
        cls.wait_for_connection(cls.get_server_start_timeout())

    @classmethod
    def teardown_class(cls):
        cls.pm.release()
        # в самом конце, проверяем, что процесс умирает по TERM-сигналу
        os.kill(cls.server.process.pid, signal.SIGTERM)
        try:
            cls.server.wait(timeout=cls.get_server_shutdown_timeout())
        except ExecutionTimeoutError as e:
            raise Exception("Seems like java can't shutdown gracefully on SIGTERM: "+str(e))
        finally:
            if cls.server.running:
                cls.server.kill()

    @classmethod
    def wait_for_connection(cls, timeout):
        cls.logger().info("wait for connection")
        start_time = time.time()
        while True:
            if not cls.server.running:
                raise Exception("server is down")
            try:
                requests.get(cls.get_base_url(), allow_redirects=False)
                cls.logger().info("connection acquired")
                return
            except requests.ConnectionError as e:
                if time.time() < start_time + timeout:
                    time.sleep(1)
                else:
                    raise e


class TestModelGeneratorResult(PythonTest):
    tmpdir = tempfile.mkdtemp(prefix="result")

    @classmethod
    def tar_paths(cls):
        return ["direct/libs/model-generator/libs-model-generator.tar"]

    @classmethod
    def java_class_name(cls):
        return "ru.yandex.direct.model.generator.Tool"

    @classmethod
    def setup_class(cls):
        cls.server = cls.execute_java(cls.java_class_name(), cls.java_args(), cls.program_args()).wait()

    @classmethod
    def teardown_class(cls):
        shutil.rmtree(cls.tmpdir)

    @classmethod
    def java_args(cls):
        return []

    @classmethod
    def additional_args(cls):
        return []

    @classmethod
    def program_args(cls):
        return [cls.conf_dir, "-o", cls.tmpdir] + cls.additional_args()


def compare_dirs(dir1, dir2):
    dirs_cmp = filecmp.dircmp(dir1, dir2, [".DS_Store"])

    if len(dirs_cmp.left_only) > 0:
        return False, "Directory differences are found in " + dir1 + ": " + str(dirs_cmp.left_only)
    elif len(dirs_cmp.right_only) > 0:
        return False, "Directory differences are found in " + dir1 + ": " + str(dirs_cmp.right_only)
    elif len(dirs_cmp.funny_files) > 0:
        return False, "Can't compare directories " + dir1 + ": " + str(dirs_cmp.funny_files)

    (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False)
    if len(mismatch) > 0:
        return False, "File differences are found in " + dir1 + ": " + str(mismatch)
    elif len(errors) > 0:
        return False, "Errors occurred while comparing files " + dir1 + ": " + str(errors)

    for common_dir in dirs_cmp.common_dirs:
        new_dir1 = os.path.join(dir1, common_dir)
        new_dir2 = os.path.join(dir2, common_dir)
        result, message = compare_dirs(new_dir1, new_dir2)
        if not result:
            return False, message

    return True, "Directories are equal"
