# coding: utf-8
from __future__ import print_function

import collections
import contextlib
import errno
import fcntl
import hashlib
import ipaddress
import logging
import os
import signal
import socket
import sys
import time
import zlib

import tornado.gen

import _netmon

from . import interfaces
from .interfaces import BACKBONE, FASTBONE, NETWORKS  # noqa - re-exporting symbols for convenience

from infra.netmon.agent.idl import common_pb2

EXCLUDED_ADDRESSES = {
    (socket.AF_INET6, "::1"),
    (socket.AF_INET, "127.0.0.1")
}

FAMILY_TO_PROTO = {
    socket.AF_INET6: common_pb2.INET6,
    socket.AF_INET: common_pb2.INET4
}
PROTO_TO_FAMILY = {v: k for k, v in FAMILY_TO_PROTO.iteritems()}
FAMILY_TO_STRING = {
    socket.AF_INET6: "INET6",
    socket.AF_INET: "INET4"
}

PROTOCOL_TO_STRING = {
    common_pb2.ICMP: "ICMP",
    common_pb2.UDP: "UDP",
    common_pb2.TCP: "TCP"
}
STRING_TO_PROTOCOL = {v: k for k, v in PROTOCOL_TO_STRING.iteritems()}

TRAFFIC_CLASS_TO_STRING = {
    common_pb2.CS0: "CS0",
    common_pb2.CS1: "CS1",
    common_pb2.CS2: "CS2",
    common_pb2.CS3: "CS3",
    common_pb2.CS4: "CS4"
}
STRING_TO_TRAFFIC_CLASS = {v: k for k, v in TRAFFIC_CLASS_TO_STRING.iteritems()}

US = 10**6
MS = 10**3


def yield_thread():
    time.sleep(0)


# Copied from contextlib2.
class ExitStack(object):
    """Context manager for dynamic management of a stack of exit callbacks
    For example:
        with ExitStack() as stack:
            files = [stack.enter_context(open(fname)) for fname in filenames]
            # All opened files will automatically be closed at the end of
            # the with statement, even if attempts to open files later
            # in the list raise an exception
    """
    def __init__(self):
        self._exit_callbacks = collections.deque()

    def _push_cm_exit(self, cm, cm_exit):
        """Helper to correctly register callbacks to __exit__ methods"""
        def _exit_wrapper(*exc_details):
            return cm_exit(cm, *exc_details)
        _exit_wrapper.__self__ = cm
        self._push(_exit_wrapper)

    def _push(self, exit):
        """Registers a callback with the standard __exit__ method signature
        Can suppress exceptions the same way __exit__ methods can.
        Also accepts any object with an __exit__ method (registering a call
        to the method instead of the object itself)
        """
        # We use an unbound method rather than a bound method to follow
        # the standard lookup behaviour for special methods
        _cb_type = type(exit)
        try:
            exit_method = _cb_type.__exit__
        except AttributeError:
            # Not a context manager, so assume its a callable
            self._exit_callbacks.append(exit)
        else:
            self._push_cm_exit(exit, exit_method)
        return exit  # Allow use as a decorator

    def enter_context(self, cm):
        """Enters the supplied context manager
        If successful, also pushes its __exit__ method as a callback and
        returns the result of the __enter__ method.
        """
        # We look up the special methods on the type to match the with statement
        _cm_type = type(cm)
        _exit = _cm_type.__exit__
        result = _cm_type.__enter__(cm)
        self._push_cm_exit(cm, _exit)
        return result

    def __enter__(self):
        return self

    def __exit__(self, *exc_details):
        received_exc = exc_details[0] is not None

        # Callbacks are invoked in LIFO order to match the behaviour of
        # nested context managers
        suppressed_exc = False
        pending_raise = False
        while self._exit_callbacks:
            cb = self._exit_callbacks.pop()
            try:
                if cb(*exc_details):
                    suppressed_exc = True
                    pending_raise = False
                    exc_details = (None, None, None)
            except:
                new_exc_details = sys.exc_info()
                pending_raise = True
                exc_details = new_exc_details
        if pending_raise:
            exc_type, exc_value, exc_tb = exc_details
            raise exc_type, exc_value, exc_tb
        return received_exc and suppressed_exc


# Code from arcadia (sky install)
# https://a.yandex-team.ru/arc/trunk/arcadia/skynet/packages/build_v2/install/skyinstall.py#L502
def detect_hostname():
    """Return host's fqdn or raise an Exception if not possible."""

    hostname = socket.gethostname()

    if sys.platform == 'cygwin' and '.' not in hostname:
        hostname = socket.getfqdn(hostname).split()[0]

    ips = set()

    not_found_errs = (
        getattr(socket, 'EAI_ADDRFAMILY', None),
        getattr(socket, 'EAI_NONAME', None),
        getattr(socket, 'EAI_NODATA', None),
    )

    try:
        addrinfo = socket.getaddrinfo(hostname, 0, 0, socket.SOCK_STREAM, 0)
    except (socket.gaierror, socket.herror) as ex:
        if ex.errno in not_found_errs:
            # Unable to resolve ourselves
            return hostname
        raise
    else:
        for ipinfo in addrinfo:
            ips.add(ipinfo[4][0])

        fqdns = set()

        for ip in ips:
            try:
                fqdn = socket.gethostbyaddr(ip)[0]
                if sys.platform == 'cygwin':
                    # on cygwin gethostbyaddr can return all DNS prefixes as an address, so we cut it here
                    fqdn = fqdn.split()[0]
                fqdns.add(fqdn)
            except (socket.gaierror, socket.herror) as ex:
                if ex.errno in not_found_errs:
                    continue
                raise

        fqdns = list(fqdns)

        if hostname in fqdns:
            # Found hostname in fqdns
            return hostname
        elif len(fqdns) == 1 and fqdns[0].startswith(hostname):
            # Got only 1 fqdn
            return fqdns[0]
        else:
            # Got many fqdns, dont know how to choose one
            return hostname


def get_address_family(address):
    try:
        addr = ipaddress.ip_address(unicode(address))
        if addr.version == 6:
            return socket.AF_INET6
        elif addr.version == 4:
            return socket.AF_INET
    except ValueError:
        return None


def is_address_excluded(family, name):
    if (family, name) in EXCLUDED_ADDRESSES:
        return True

    if family == socket.AF_INET6 and name.startswith("fe80::1%"):
        return True

    return False


def is_same_network(address1, address2):
    addr1 = interfaces.HostIPv6Address(unicode(address1))
    addr2 = interfaces.HostIPv6Address(unicode(address2))
    return ((addr1.is_backbone and addr2.is_backbone) or
            (addr1.is_fastbone and addr2.is_fastbone) or
            (not addr1.is_fastbone and not addr1.is_backbone and not addr2.is_fastbone and not addr2.is_backbone))


HistogramProperties = collections.namedtuple("HistogramProperties", (
    "average", "rtt25", "rtt50", "rtt75", "rtt95", "min_rtt", "max_rtt"
))


def extract_properties_from_histogram(hist):
    if hist is not None and hist.get_total_count():
        return HistogramProperties(
            average=hist.get_mean(),
            rtt25=hist.get_value_at_percentile(25),
            rtt50=hist.get_value_at_percentile(50),
            rtt75=hist.get_value_at_percentile(75),
            rtt95=hist.get_value_at_percentile(95),
            min_rtt=hist.get_min(),
            max_rtt=hist.get_max()
        )
    else:
        return HistogramProperties(None, None, None, None, None, None, None)


def address_to_proto(address=None):
    return common_pb2.TAddress(
        Ip=address[1][0],
        Port=address[1][1]
    ) if address is not None else None


def report_to_proto(report, source=None, target=None):
    _, rtt25, rtt50, rtt75, rtt95, min_rtt, max_rtt = \
        extract_properties_from_histogram(report.histogram)
    return common_pb2.TProbeReport(
        Type=report.type,

        Source=source,
        Target=target,

        Family=FAMILY_TO_PROTO[report.family],
        Protocol=report.protocol,

        SourceAddress=address_to_proto(report.source_addr),
        TargetAddress=address_to_proto(report.target_addr),

        Failed=report.failed,
        Error=report.error,

        Received=report.received,
        Congested=report.congested,
        Corrupted=report.corrupted,
        Lost=report.lost,
        TosChanged=report.tos_changed,

        RoundTripTimeAverage=report.average,
        RoundTripTime25=rtt25,
        RoundTripTime50=rtt50,
        RoundTripTime75=rtt75,
        RoundTripTime95=rtt95,
        RoundTripTimeMinimum=min_rtt,
        RoundTripTimeMaximum=max_rtt,

        Truncated=report.truncated,

        Offender=address_to_proto(report.offender),

        Generated=report.generated,
        TypeOfService=report.type_of_service
    )


def round_by_interval(interval, started, data=None):
    deadline = started - started % interval + interval
    if data is not None:
        deadline += int(hashlib.md5(data).hexdigest(), 16) % interval
    return deadline


def should_be_updated(interval, started, data=None):
    # Some keys will be updated less often than others.
    return round_by_interval(interval, started, data) < time.time()


@contextlib.contextmanager
def do_backup(path):
    backup_path = path.parent.joinpath("{}.old".format(path.name))
    try:
        path.rename(backup_path)
    except OSError:
        exists = False
    else:
        exists = True
    try:
        yield
    except:
        if exists:
            backup_path.rename(path)
        raise


def kill_process(pid, signum=signal.SIGKILL):
    try:
        os.kill(pid, signum)
    except OSError as exc:
        if exc.errno != errno.ESRCH:
            raise
        else:
            return False
    else:
        return True


def should_file_be_updated(interval, path, data=None):
    try:
        stat = path.stat()
    except OSError:
        return True
    else:
        return not stat.st_size or should_be_updated(interval, stat.st_mtime, data)


def acquire_file_lock(stream):
    fd = stream.fileno()
    fcntl.fcntl(fd, fcntl.F_SETFD, fcntl.fcntl(fd, fcntl.F_GETFD) | fcntl.FD_CLOEXEC)
    try:
        fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
    except IOError as err:
        if err.errno == errno.EWOULDBLOCK:
            return False
        else:
            raise
    else:
        return True


def create_pid(lock_path):
    lock_file = lock_path.open(mode="a+")
    lock_file.seek(0)
    if not acquire_file_lock(lock_file):
        raise Exception(
            "Can't grab exclusive lock '{0}', already acquired pid: {1}".format(lock_path, repr(lock_file.read()))
        )
    lock_file.truncate(0)
    lock_file.write(unicode(os.getpid()))
    lock_file.flush()
    return lock_file


def timestamp():
    """Returning timestamp in US."""
    return int(time.time() * US)


def quanted_timestamp(resolution, ts=None):
    """Round current timestamp up to resolution, returned timestamp will be in US."""
    ts = timestamp() if ts is None else ts
    return ts - ts % resolution


def permutate_addresses(source_addresses, target_addresses, port):
    return [
        (
            _netmon.Address(source_family, source_address, 0),
            _netmon.Address(target_family, target_address, port)
        )
        for source_family, source_address in source_addresses
        for target_family, target_address in target_addresses
        if (
            source_family == target_family
            and not is_address_excluded(source_family, source_address)
            and not is_address_excluded(target_family, target_address)
            and (source_family == socket.AF_INET or is_same_network(source_address, target_address))
        )
    ]


def coalesce_addresses(addresses, ipv4_address, ipv6_address):
    """If addresses already have v4/v6 address, keeps them,
    otherwise uses repsective parameter if it is non-empty.
    """
    result = []
    families = {family: addr for family, addr in addresses}
    if socket.AF_INET in families:
        result.append((socket.AF_INET, families[socket.AF_INET]))
    elif ipv4_address:
        result.append((socket.AF_INET, ipv4_address))

    if socket.AF_INET6 in families:
        result.append((socket.AF_INET6, families[socket.AF_INET6]))
    elif ipv6_address:
        result.append((socket.AF_INET6, ipv6_address))
    return result


def suppress_http_errors(func):
    """Creates a wrapper around coroutine func that suppresses all exceptions,
    returns 200 on success, error code on HTTPError, -1 on other errors.
    """
    @tornado.gen.coroutine
    def wrapper(*args, **kwargs):
        try:
            yield func(*args, **kwargs)
        except Exception as exc:
            abbreviate = lambda s: s[:100] + '...' if len(s) > 100 else s
            logging.warning(
                "Failed call to %s with args %s, kwargs %s: %s",
                func,
                [abbreviate(str(arg)) for arg in args],
                {abbreviate(str(k)): abbreviate(str(v)) for k, v in kwargs.iteritems()},
                exc
            )
            if isinstance(exc, tornado.httpclient.HTTPError) and (exc.code == 429 or exc.code >= 500):
                raise tornado.gen.Return(exc.code)
            else:
                raise tornado.gen.Return(-1)
        raise tornado.gen.Return(200)
    return wrapper


@tornado.gen.coroutine
def make_protobuf_request(request, response, url, headers=None, connect_timeout=5, request_timeout=15):
    if headers is None:
        headers = dict()

    headers.update({
        "Content-Type": "application/x-coded-protobuf",
        "Content-Encoding": "gzip"
        })
    reply = yield tornado.httpclient.AsyncHTTPClient().fetch(
        url,
        method='POST',
        connect_timeout=connect_timeout,
        request_timeout=request_timeout,
        body=zlib.compress(request.SerializeToString()),
        decompress_response=True,
        headers=headers
    )
    response.ParseFromString(reply.body)
