# coding: utf-8
from __future__ import print_function

import socket
import itertools
import logging
import contextlib

import tornado.gen
import tornado.locks
import tornado.ioloop

from infra.netmon.agent.idl import common_pb2

import _netmon

from .. import exceptions
from .. import application
from .. import utils
from .. import traceroute
from ..settings import Settings

RESOLUTION = 5


class IcmpProbeBehavior(object):

    def __init__(self, app, family, source_ip, target_ip, traffic_class, packet_count):
        self.service = app[application.IcmpService]
        self._family = family
        self._source_ip = source_ip
        self._target_ip = target_ip
        self._traffic_class = traffic_class
        self._packet_count = packet_count

    def create_probe_config(self):
        return _netmon.ProbeConfig(
            common_pb2.REGULAR_PROBE,
            _netmon.Address(self._family, self._source_ip, 0),
            _netmon.Address(self._family, self._target_ip, 0),
            self._traffic_class,
            self._packet_count
        )

    def normalize_parallel_probes(self, count):
        return 1

    def extract_addresses(self, report):
        return (
            _netmon.Address(self._family, self._source_ip, 0),
            _netmon.Address(self._family, self._target_ip, 0),
        )

    def tracert(self, filler):
        return filler.scan(self.service, 2)


class UdpProbeBehavior(object):

    def __init__(self, app, family, source_ip, target_ip, traffic_class, packet_count):
        self.service = app[application.UdpService]
        self._family = family
        self._source_ip = source_ip
        self._target_ip = target_ip
        self._target_ports = itertools.cycle(Settings.current().echo_port_range)
        self._traffic_class = traffic_class
        self._packet_count = packet_count

    def create_probe_config(self):
        return _netmon.ProbeConfig(
            common_pb2.REGULAR_PROBE,
            _netmon.Address(self._family, self._source_ip, 0),
            _netmon.Address(self._family, self._target_ip, self._target_ports.next()),
            self._traffic_class,
            self._packet_count
        )

    def normalize_parallel_probes(self, count):
        # there is no meaning to create more probes on same ports
        return min(count, Settings.current().udp_socket_count)

    def extract_addresses(self, report):
        return (
            _netmon.Address(self._family, self._source_ip, report.source_addr[1][1]),
            _netmon.Address(self._family, self._target_ip, report.target_addr[1][1])
        )

    def tracert(self, filler):
        return filler.scan(self.service, 8)


class TcpProbeBehavior(object):

    def __init__(self, app, family, source_ip, target_ip, traffic_class, packet_count):
        self.service = app[application.TcpService]
        self._family = family
        self._source_ip = source_ip
        self._target_ip = target_ip
        self._target_ports = itertools.cycle(Settings.current().tcp_port_range)
        self._traffic_class = traffic_class
        self._packet_count = packet_count

    def create_probe_config(self):
        return _netmon.ProbeConfig(
            common_pb2.REGULAR_PROBE,
            _netmon.Address(self._family, self._source_ip, 0),
            _netmon.Address(self._family, self._target_ip, self._target_ports.next()),
            self._traffic_class,
            self._packet_count
        )

    def normalize_parallel_probes(self, count):
        return 1

    def extract_addresses(self, report):
        return (
            _netmon.Address(self._family, self._source_ip, report.source_addr[1][1]),
            _netmon.Address(self._family, self._target_ip, report.target_addr[1][1])
        )

    def tracert(self, filler):
        return filler.scan(self.service, 2)


def _behavior_class(protocol):
    if protocol == common_pb2.UDP:
        return UdpProbeBehavior
    elif protocol == common_pb2.ICMP:
        return IcmpProbeBehavior
    elif protocol == common_pb2.TCP:
        return TcpProbeBehavior
    else:
        raise exceptions.ValidationError("unknown protocol specified")


def _convert_family(family):
    if family == common_pb2.INET4:
        return socket.AF_INET
    elif family == common_pb2.INET6:
        return socket.AF_INET6
    else:
        raise exceptions.ValidationError("unknown family specified")


class AggregationState(object):

    def __init__(self):
        self.now = utils.quanted_timestamp(RESOLUTION * utils.US)
        # TODO: move all to US
        self._empty = True
        self._histogram = _netmon.Histogram(Settings.current().packet_timeout * utils.MS, 3)
        self._received = 0
        self._lost = 0
        self._tos_changed = 0
        self._reports = []
        self._traceroutes = []

    def add_report(self, report):
        self._empty = False
        if report.failed or report.lost:
            self._reports.append(utils.report_to_proto(report))
        if not report.failed:
            if report.histogram is not None and report.histogram.get_total_count():
                _netmon.merge_histogram(self._histogram, report.histogram)
            self._received += report.received
            self._lost += report.lost
            self._tos_changed += report.tos_changed
        return report.failed

    def add_traceroute(self, traceroute):
        self._traceroutes.append(traceroute)

    def dump(self, result):
        result.Reports.extend(self._reports)
        result.Traceroutes.extend(self._traceroutes)

        if not self._empty:
            average, rtt25, rtt50, rtt75, rtt95, min_rtt, max_rtt = \
                utils.extract_properties_from_histogram(self._histogram)
            result.AggregatedReports.add(
                Received=self._received,
                Lost=self._lost,
                TosChanged=self._tos_changed,

                RoundTripTimeAverage=average,
                RoundTripTime25=rtt25,
                RoundTripTime50=rtt50,
                RoundTripTime75=rtt75,
                RoundTripTime95=rtt95,
                RoundTripTimeMinimum=min_rtt,
                RoundTripTimeMaximum=max_rtt,

                Generated=self.now
            )


class ProbeAccumulator(object):

    def __init__(self, result, on_changed=None):
        self._result = result
        self._on_changed = on_changed
        self._current_state = None

    def dump(self):
        if self._current_state is not None:
            self._current_state.dump(self._result)
            self._current_state = None
        if self._on_changed is not None:
            self._on_changed()

    def _get_actual_state(self):
        if self._current_state is None:
            self._current_state = AggregationState()
        elif utils.quanted_timestamp(RESOLUTION * utils.US) != self._current_state.now:
            self.dump()
            self._current_state = AggregationState()
        return self._current_state

    def add_report(self, report):
        return self._get_actual_state().add_report(report)

    def add_traceroute(self, report):
        return self._get_actual_state().add_traceroute(report)


class ProbePlanner(object):

    def __init__(self, app, arguments, result, on_changed=None):
        self._app = app
        self._settings = Settings.current()

        self._arguments = arguments
        self._result = result
        self._probes_in_parallel = self._arguments.ProbesInParallel or self._settings.udp_socket_count
        self._duration = self._arguments.Duration or 60
        self._on_changed = on_changed

        self._behavior = _behavior_class(self._arguments.Protocol)(
            app=self._app,
            family=_convert_family(self._arguments.Family),
            source_ip=self._arguments.SourceAddress.Ip,
            target_ip=self._arguments.TargetAddress.Ip,
            traffic_class=self._arguments.TrafficClass,
            packet_count=self._settings.diagnostic_packet_count
        )

        self._loop = tornado.ioloop.IOLoop.current()

        self._probes_to_delay = self._behavior.normalize_parallel_probes(self._probes_in_parallel)
        self._probe_semaphore = tornado.locks.Semaphore(self._probes_to_delay)

        self._tracert_semaphore = tornado.locks.Semaphore(self._probes_to_delay)
        self._tracert_always = True

        self._stop_condition = tornado.locks.Condition()
        self._running_tasks = 0

        self._failed = False
        self._exception = None

        self._accumulator = ProbeAccumulator(self._result, on_changed=self._on_changed)

    @tornado.gen.coroutine
    def _schedule_task(self, semaphore, func, *args, **kwargs):
        self._running_tasks += 1
        yield semaphore.acquire()
        self._loop.add_callback(func, *args, **kwargs)

    @contextlib.contextmanager
    def _task_context(self, semaphore):
        try:
            yield
        except Exception as exc:
            logging.exception("Diagnostic failed: %s", exc)
            self._failed = True
            self._exception = exc
        finally:
            semaphore.release()
            assert self._running_tasks
            self._running_tasks -= 1
            self._stop_condition.notify()

    def _create_probe_config(self):
        config = self._behavior.create_probe_config()
        config.timeout = self._settings.packet_timeout
        config.delay = self._settings.diagnostic_packet_delay
        if self._probes_to_delay:
            self._probes_to_delay -= 1
            config.start_delay = RESOLUTION * utils.MS
        else:
            config.start_delay = 0
        config.packet_size = self._arguments.PacketSize or self._settings.diagnostic_packet_size
        config.packet_ttl = self._settings.diagnostic_packet_ttl
        config.histogram = True
        return config

    @tornado.gen.coroutine
    def _schedule_probe(self):
        with self._task_context(self._probe_semaphore):
            [report] = yield self._behavior.service.schedule_checks([self._create_probe_config()])
            failed = self._accumulator.add_report(report)
            self._failed = self._failed or failed

        if not self._failed and (report.lost or self._tracert_always):
            self._tracert_always = False
            # control flow should jump right into function body, so no race should exists
            yield self._schedule_task(self._tracert_semaphore, self._schedule_tracert, report)

    @tornado.gen.coroutine
    def _schedule_tracert(self, report):
        with self._task_context(self._tracert_semaphore):
            source_address, target_address = self._behavior.extract_addresses(report)
            filler = traceroute.TracertFiller(
                source_address=source_address,
                target_address=target_address,
                traffic_class=self._arguments.TrafficClass
            )
            self._accumulator.add_traceroute((yield self._behavior.tracert(filler)))

    @tornado.gen.coroutine
    def run(self):
        deadline = self._loop.time() + self._duration
        while self._loop.time() < deadline and not self._failed:
            yield self._schedule_task(self._probe_semaphore, self._schedule_probe)

        while self._running_tasks:
            yield self._stop_condition.wait()
        self._accumulator.dump()

        if self._exception is not None:
            raise self._exception

    @property
    def success(self):
        return not self._failed
