# coding: utf-8
import logging
import logging.handlers

import functools
import json
import time

import jsonschema
import sys
import socket
import os

import library.python.retry as retry_lib
from library.python.retry import RetryConf

_logger = logging.getLogger(__name__)
_logger.addHandler(logging.NullHandler())

DEFAULT_RETRY_CONF = retry_lib.DEFAULT_CONF.waiting(delay=0.15, backoff=2, jitter=0.1)


def _retry_call(retry_conf: RetryConf, retry_conf_kwargs, func, *args, **kwargs):
    return retry_lib.retry_call(func, f_args=args, f_kwargs=kwargs, conf=retry_conf, **retry_conf_kwargs)


def retry_callable(retry_conf: RetryConf, **retry_conf_kwargs):
    return functools.partial(_retry_call, retry_conf, retry_conf_kwargs)


def _wait_call(func, timeout, iteration_sleep, f_args, f_kwargs):
    start_time = time.time()
    _logger.debug("wait for %s (args %s, kwargs %s) with timeout %s and iteration_sleep %d",
                  func.__name__, f_args, f_kwargs, timeout, iteration_sleep)
    while func(*f_args, **f_kwargs):
        if timeout is not None and time.time() - start_time > timeout:
            raise Exception("Timeout reached while waiting for " + func.__name__)
        _logger.debug("%s failed, sleep until next iteration ...", func.__name__)
        time.sleep(iteration_sleep)


def negate(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        return not func(*args, **kwargs)
    return wrapper


def wait_while(timeout=None, iteration_sleep=1):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            return _wait_call(func,
                              timeout=timeout, iteration_sleep=iteration_sleep, f_args=args, f_kwargs=kwargs)
        return wrapper
    return decorator


def wait_until(timeout=None, iteration_sleep=1):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            return _wait_call(negate(func),
                              timeout=timeout, iteration_sleep=iteration_sleep, f_args=args, f_kwargs=kwargs)
        return wrapper
    return decorator


def read_file(filepath):
    with open(filepath) as f:
        return f.read()


def jdumps(data, newline=False, pretty=False, safe_serialization=True, **kwargs):
    params = {}
    if pretty:
        params.update(indent=4, separators=(",", ": "))
    if safe_serialization:
        params.update(default=lambda o: str(o))
    params.update(kwargs)

    res = json.dumps(data, **params)
    if res[-1] != "\n" and newline:
        res += "\n"
    return res


def get_json_value(data, node_path):
    path = node_path.split(".")
    node = data
    for item in path:
        node = node[item]
    return node


def set_json_value(data, node_path, value):
    path = node_path.split(".")
    node = data
    for item in path[:-1]:
        node = node[item]

    old_value = node[path[-1]]
    node[path[-1]] = value
    return old_value


def init_logger(console_logger_name=None, console_level=logging.INFO, background_handler="stderr",
                background_level=logging.DEBUG):
    logging.captureWarnings(True)
    logger = logging.getLogger()
    logger.handlers.clear()
    logger.setLevel(logging.NOTSET)

    ch = logging.NullHandler()
    if console_logger_name is not None:
        ch = logging.StreamHandler(sys.stdout)
        ch.setLevel(console_level)
        ch.addFilter(logging.Filter(name=console_logger_name))
        ch.setFormatter(logging.Formatter("%(message)s"))

    bh = logging.NullHandler()
    if background_handler is not None:
        if isinstance(background_handler, str) and background_handler == "stderr":
            bh = logging.StreamHandler(sys.stderr)
        elif isinstance(background_handler, str) and background_handler == "stdout":
            bh = logging.StreamHandler(sys.stdout)
        else:
            bh = background_handler
        bh.setLevel(background_level)
        bh.setFormatter(logging.Formatter(
            "[%(asctime)s]\t%(levelname)-8s\t" + str(os.getpid()) +
            "\t%(filename)20s:%(funcName)-20s\t%(threadName)-15s\t%(name)-15s\t%(message)s"))

    logger.addHandler(ch)
    logger.addHandler(bh)
    return logger


def tcp_ping(host, port, timeout=None) -> bool:
    s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
    if timeout is not None:
        s.settimeout(timeout)
    try:
        s.connect((host, port))
        s.shutdown(socket.SHUT_RDWR)
    except Exception as e:
        _logger.warning(f"can't connect to {host}:{port}: {type(e)} {e}")
        return False
    finally:
        s.close()
    return True


def check_json_schema_or_die(data, schema):
    _logger.debug("validating %s with schema %s", data, jdumps(schema, pretty=False))
    try:
        jsonschema.validate(data, schema)
    except jsonschema.exceptions.ValidationError as e:
        _logger.critical("bad event schema in path .%s with value %s - %s", ".".join(str(x) for x in e.path),
                         e.instance, e.message)
        raise


def json_schema_make_all_required(schema):
    if schema.get("type") == "array" and "items" in schema:
        schema = schema["items"]

    if not (schema.get("type") == "object" and "properties" in schema):
        return

    schema["required"] = list(schema["properties"].keys())

    for prop, prop_schema in schema["properties"].items():
        json_schema_make_all_required(prop_schema)
