# coding: utf-8
from __future__ import print_function

import itertools
import logging

import tornado.gen

import _netmon

from . import ticker
from . import application
from . import utils

from infra.netmon.agent.idl import common_pb2

FREQUENCY = 1


class LoopbackBenchmark(application.Service):

    def __init__(self, settings):
        self._settings = settings

        self._initialized = False
        self._listen_ports = self._settings.echo_port_range
        self._tcp_ports = self._settings.tcp_port_range
        self._addresses = []

        self._app = application.Application()
        if self._settings.stats_port is not None:
            self._app.register(application.StatService(port=self._settings.stats_port))
        self._app.register(application.EchoService(listen_ports=self._listen_ports))
        self._app.register(application.UdpService())
        self._app.register(application.TcpService(listen_ports=self._tcp_ports))
        self._app.register(self)

        self._loop = ticker.LoopingCall("generator_loop", self._schedule_probes, FREQUENCY)

    def _gen_configs(self, port_range):
        config_list = [
            _netmon.ProbeConfig(
                common_pb2.REGULAR_PROBE,
                _netmon.Address(family, addr, 0),
                _netmon.Address(family, addr, port),
                tc,
                self._settings.packet_count
            )
            for port in port_range
            for family, addr in self._addresses
            for tc in self._settings.traffic_classes
        ]

        for config in config_list:
            config.timeout = self._settings.packet_timeout
            config.delay = self._settings.packet_delay
            config.start_delay = self._settings.probe_start_delay
            config.packet_size = self._settings.packet_size
            config.packet_ttl = self._settings.packet_ttl

        return config_list

    @tornado.gen.coroutine
    def _schedule_probes(self):
        iface_service = self._app[application.IfaceService]
        echo_service = self._app[application.EchoService]
        udp_service = self._app[application.UdpService]
        tcp_service = self._app[application.TcpService]

        if not self._initialized:
            interfaces = yield iface_service.get_interfaces()
            self._addresses = [
                (x.family, x.address) for x in interfaces
            ]
            yield echo_service.on_address_sync(self._addresses)
            yield udp_service.on_address_sync(self._addresses)
            yield tcp_service.on_address_sync(self._addresses)
            self._initialized = True

        udp_reports, tcp_reports = yield [
            udp_service.schedule_checks(self._gen_configs(self._listen_ports)),
            tcp_service.schedule_checks(self._gen_configs(self._tcp_ports))
        ]

        for report in udp_reports + tcp_reports:
            if report.failed:
                logging.error("probe failed: %s, %s", report.error, utils.report_to_proto(report))

    @tornado.gen.coroutine
    def cancel(self):
        yield self._loop.cancel()

    def start(self):
        self._app.start()


class PeerToPeerBuilder(object):

    def __init__(self, resolver, targets, local_hostname, local_addresses, listen_ports, packet_count, traffic_classes):
        self._targets = set(targets)
        self._local_hostname = local_hostname
        self._local_addresses = set(address for _, address in local_addresses)
        self._listen_ports = listen_ports
        self._resolver = resolver
        self._traffic_classes = traffic_classes
        self._packet_count = packet_count

    def _create_address_map(self, l):
        return {family: address for family, address in l
                if not utils.is_address_excluded(family, address)}

    @tornado.gen.coroutine
    def _generate_pairs(self, ident):
        addresses = yield {
            hostname: self._resolver.try_resolve(ident(hostname))
            for hostname in itertools.chain((self._local_hostname, ), self._targets)
        }

        source_addresses = self._create_address_map((family, address) for family, address in (
            addresses.get(ident(self._local_hostname), ())
        ) if address in self._local_addresses)

        pairs = []
        for hostname in self._targets:
            target_addresses = self._create_address_map(
                addresses.get(ident(hostname), ())
            )

            for family in source_addresses.viewkeys() & target_addresses.viewkeys():
                pairs.append((family, source_addresses[family], target_addresses[family]))

        raise tornado.gen.Return(pairs)

    def _generate_configs(self, pairs):
        config_list = [
            _netmon.ProbeConfig(
                common_pb2.REGULAR_PROBE,
                _netmon.Address(family, source, 0),
                _netmon.Address(family, target, port),
                tc,
                self._packet_count
            )
            for port in self._listen_ports
            for family, source, target in pairs
            if source in self._local_addresses
            for tc in self._traffic_classes
        ]
        return config_list

    @tornado.gen.coroutine
    def generate(self):
        config_list = []
        config_list.extend(self._generate_configs((
            yield self._generate_pairs(lambda s: s)
        )))
        config_list.extend(self._generate_configs((
            yield self._generate_pairs(lambda s: "fb-%s" % s)
        )))
        raise tornado.gen.Return(config_list)


class PeerToPeerBenchmark(application.Service):

    def __init__(self, settings, targets):
        self._settings = settings
        self._targets = targets
        self._listen_ports = self._settings.echo_port_range
        self._tcp_ports = self._settings.tcp_port_range

        self._loop = ticker.LoopingCall("generator_loop", self._schedule_probes, FREQUENCY)

        self._app = application.Application()
        if self._settings.stats_port is not None:
            self._app.register(application.StatService(port=self._settings.stats_port))
        self._app.register(application.EchoService(listen_ports=self._listen_ports))
        self._app.register(application.UdpService())
        self._app.register(application.IfaceService(True))
        self._app.register(application.TcpService(listen_ports=self._tcp_ports))
        self._app.register(self)

    def _adjust_configs(self, config_list):
        for config in config_list:
            config.timeout = self._settings.packet_timeout
            config.delay = self._settings.packet_delay
            config.start_delay = self._settings.probe_start_delay
            config.packet_size = self._settings.packet_size
            config.packet_ttl = self._settings.packet_ttl

        return config_list

    @tornado.gen.coroutine
    def _schedule_probes(self):
        iface_service = self._app[application.IfaceService]
        echo_service = self._app[application.EchoService]
        udp_service = self._app[application.UdpService]
        tcp_service = self._app[application.TcpService]

        udp_builder = PeerToPeerBuilder(
            self._app[application.ResolverService],
            self._targets,
            self._settings.current().hostname,
            (yield iface_service.get_addresses()),
            echo_service.listen_ports,
            self._settings.packet_count,
            self._settings.traffic_classes
        )

        tcp_builder = PeerToPeerBuilder(
            self._app[application.ResolverService],
            self._targets,
            self._settings.current().hostname,
            (yield iface_service.get_addresses()),
            self._tcp_ports,
            self._settings.packet_count,
            self._settings.traffic_classes
        )

        udp_configs, tcp_configs = yield [
            udp_builder.generate(),
            tcp_builder.generate()
        ]

        udp_reports, tcp_reports = yield [
            udp_service.schedule_checks(self._adjust_configs(udp_configs)),
            tcp_service.schedule_checks(self._adjust_configs(tcp_configs))
        ]

        for report in udp_reports + tcp_reports:
            if report.failed:
                logging.error("probe failed: %s, %s", report.error, utils.report_to_proto(report))

    @tornado.gen.coroutine
    def cancel(self):
        yield self._loop.cancel()

    def start(self):
        self._app.start()
