# coding: utf-8
from __future__ import print_function

import collections
import contextlib
import itertools
import logging
import random
import socket
import time

import tornado.gen
import concurrent.futures

from infra.netmon.agent.idl import common_pb2 as common

import _netmon

from . import application
from . import exceptions
from . import rpc
from . import sender
from . import ticker
from . import topology
from . import transformers
from . import tasks
from . import utils
from . import _unistat_callback
from .interfaces import is_mtn_vlan
from .settings import Settings, to_timeval


class TopologyAwareIfaceService(application.IfaceService):

    provides = (application.IfaceService, )
    LOCAL_FQDNS_TTL = 3600

    def __init__(self):
        super(TopologyAwareIfaceService, self).__init__(
            allow_virtual=Settings.current().allow_virtual,
            networks=Settings.current().networks,
            allow_mtn_vlan=True
        )
        self._local_fqdns = None
        self._local_fqdns_ts = None

    @tornado.gen.coroutine
    def _filter_addresses(self, addresses):
        name_list = self._app[BuilderService].get_local_names()
        if not name_list:
            logging.warning("No local names found in topology, will bind to discovered interfaces")
            raise tornado.gen.Return(addresses)

        reference_addresses = set()
        for name, ipv4_address, ipv6_address, vlan in name_list:
            if not is_mtn_vlan(vlan):
                addrs = yield self._app[application.ResolverService].try_resolve(name)
                if Settings.current().use_topology_ips:
                    addrs = utils.coalesce_addresses(addrs, ipv4_address, ipv6_address)
                for addr_pair in addrs:
                    reference_addresses.add(addr_pair)
            else:  # don't try to resolve fake fqdns
                if ipv4_address is not None:
                    reference_addresses.add((socket.AF_INET, ipv4_address))
                if ipv6_address is not None:
                    reference_addresses.add((socket.AF_INET6, ipv6_address))

        if not reference_addresses:
            logging.warning("No local names resolved, will bind to discovered interfaces")
            raise tornado.gen.Return(addresses)

        # check if topology has mismatched interfaces
        for pair in addresses:
            if pair not in reference_addresses:
                logging.warning("Address %r missed in topology but discovered on host", pair[1])
        for pair in reference_addresses:
            if pair not in addresses:
                logging.warning("Address %r found in topology but missed in discovered interfaces", pair[1])

        raise tornado.gen.Return({pair for pair in addresses if pair in reference_addresses})

    @tornado.gen.coroutine
    def _refresh_local_fqdns(self):
        tree = yield topology.tree(self._app)

        self._local_fqdns = {
            iface.ipv6_address: iface.name
            for iface in tree.local_interfaces()
            if iface.ipv6_address is not None
        }
        self._local_fqdns.update({
            iface.ipv4_address: iface.name
            for iface in tree.local_interfaces()
            if iface.ipv4_address is not None
        })

        self._local_fqdns_ts = time.time()
        logging.debug("Updated local interface name cache")

    @tornado.gen.coroutine
    def _resolve_fqdn_by_address(self, address):
        if Settings.current().use_topology_ips:
            if (
                not self._local_fqdns or
                time.time() - self._local_fqdns_ts > self.LOCAL_FQDNS_TTL
            ):
                yield self._refresh_local_fqdns()

            if address in self._local_fqdns:
                raise tornado.gen.Return(self._local_fqdns[address])
            logging.info("Failed to resolve local address %s", address)
        else:
            fqdn = yield super(TopologyAwareIfaceService, self)._resolve_fqdn_by_address(address)
            raise tornado.gen.Return(fqdn)


class BuilderService(application.AppMixin, application.Service):

    def __init__(self):
        self._loop = None
        if Settings.current().protocols and Settings.current().max_targets:
            # let's distribute index rebuilding over cluster because it's CPU intensive
            self._loop = ticker.LoopingCall("builder", self._refresh_indexes, 3600, round_by_interval=True)
            self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
        self._local_names = []
        self._selectors = []

    def get_local_names(self):
        return self._local_names

    def get_selectors(self):
        return self._selectors

    @tornado.gen.coroutine
    def _ensure_ready(self, group_futures):
        hosts = []
        wait_iterator = tornado.gen.WaitIterator(*group_futures)

        while not wait_iterator.done():
            try:
                group = yield wait_iterator.next()
            except exceptions.GroupNotReady as exc:
                logging.warning("%s", exc)
            else:
                hosts += [group]

        raise tornado.gen.Return(hosts)

    @tornado.gen.coroutine
    def _refresh_indexes(self):
        # full topology is relatively big so don't store it in memory,
        # only use parts that we need

        tstart = time.time()

        try:
            tree = yield topology.tree(self._app)
        except Exception as exc:
            logging.warning("Can't load topology: %s", exc)
            return

        if not tree.local_interfaces():
            logging.warning("Can't find %s in topology", Settings.current().hostname)
            return

        self._local_names = [(iface.name, iface.ipv4_address, iface.ipv6_address, iface.vlan)
                             for iface in tree.local_interfaces()]

        current_settings = Settings.current()

        tloaded = time.time()

        logging.info("Topology loading took %.3f seconds", tloaded - tstart)

        selectors = []

        if current_settings.vlans or current_settings.vrfs:
            sel = yield self._thread_pool.submit(
                transformers.VlanAwareTransfomer(
                    current_settings.vlans,
                    current_settings.vrfs
                ).transform,
                tree
            )

            selectors.extend(sel)

        if current_settings.groups:
            host_groups = yield self._ensure_ready(
                [
                    self._app[rpc.RpcClient].expand_group(name)
                    for name in Settings.current().groups
                ]
            )

            if any(host_groups):
                sel = yield self._thread_pool.submit(
                    transformers.GroupAwareTransfomer([set(group) for group in host_groups]).transform,
                    tree
                )

                selectors.extend(sel)

        # generate automatic targets if:
        # - no other selectors found and scheduled probes are disabled
        # - explicitly enabled through command-line
        if ((not selectors and not self._app[ScheduledProbesService].enabled) or Settings.current().automatic_targets):
            sel = yield self._thread_pool.submit(
                transformers.NetworkAwareTransfomer().transform,
                tree
            )

            selectors.extend(sel)

        tbuilt = time.time()

        logging.info("Tree building took %.3f seconds", tbuilt - tloaded)
        logging.info("Index refresh took %.3f seconds", tbuilt - tstart)

        self._selectors = selectors

    @tornado.gen.coroutine
    def cancel(self):
        if self._loop is not None:
            yield self._loop.cancel()
            self._thread_pool.shutdown(wait=False)


class ScheduledProbesService(application.AppMixin, application.Service):
    Target = collections.namedtuple("Target", ("hostname", "ip"))
    Schedule = collections.namedtuple("Schedule", ("targets",
                                                   "traffic_classes",
                                                   "packet_count",
                                                   "family",
                                                   "vlan"  # may be None
                                                   ))

    def __init__(self):
        self._loop = None
        if Settings.current().noc_sla_urls:
            self._loop = ticker.LoopingCall("scheduled_probes", self._refresh_schedule, 60, round_by_interval=True)
        self._type = common.NOC_SLA_PROBE
        self._protocols = []
        self._schedules = []

    @tornado.gen.coroutine
    def _refresh_schedule(self):
        try:
            response = yield self._app[rpc.RpcClient].scheduled_probes(self._type)
        except exceptions.BackendUrlsNotReady:
            logging.warning("Failed to refresh probe schedule: backend urls aren't initialized yet")
            raise
        # dirty hack for checking of schedule emptiness
        if response.ListFields():
            protocols = [p for p in response.Protocols]
            if not protocols:
                logging.warning("No protocols in probe schedule, enable UDP by default")
                protocols = [common.UDP]

            schedules = self._extract_schedules(response)

            ttl = int(response.Ttl)
            # don't request schedule too often
            if ttl < 60:
                logging.warning("Probe schedule ttl %d secs is too small, increase it to 60 secs", ttl)
                ttl = 60

            subschedule_to_str = lambda schedule: \
                "({} targets, {} packets, family {}, {}tcs {})".format(
                    len(schedule.targets),
                    schedule.packet_count,
                    utils.FAMILY_TO_STRING[schedule.family],
                    "vlan {}, ".format(schedule.vlan) if schedule.vlan is not None else "",
                    ", ".join(utils.TRAFFIC_CLASS_TO_STRING[tc] for tc in schedule.traffic_classes)
                )

            logging.info("New probe schedule fetched from server, protocols: %s, ttl %d secs, schedules: [%s]",
                         ", ".join(utils.PROTOCOL_TO_STRING[p] for p in protocols),
                         ttl,
                         ", ".join(subschedule_to_str(s) for s in schedules))
            for schedule in schedules:
                logging.debug("Probe schedule targets: %s", ", ".join([target.hostname for target in schedule.targets]))

            self._protocols = protocols
            self._schedules = schedules

            raise tornado.gen.Return(ttl)

    def _extract_schedules(self, proto):
        target_limit = Settings.current().max_scheduled_targets
        total_targets = sum(len(x.Targets) for x in proto.Subschedules)
        if total_targets > target_limit:
            logging.warning("Too many targets (%d) in new probe schedule, reduce them to %d", total_targets, target_limit)

        schedules = []
        for subschedule_proto in proto.Subschedules:
            subschedule = self._extract_subschedule(subschedule_proto, target_limit)
            if subschedule.targets and subschedule.packet_count > 0:
                target_limit -= len(subschedule.targets)
                schedules.append(subschedule)
            if target_limit == 0:
                break

        return schedules

    def _extract_subschedule(self, proto, target_limit):
        traffic_classes = [tc for tc in proto.TrafficClasses]
        if not traffic_classes:
            logging.warning("No traffic classes in probe schedule, enable CS0 by default")
            traffic_classes = [common.CS0]

        limit = min(target_limit, len(proto.Targets))
        targets = [self.Target(x.Hostname, x.Ip) for x in proto.Targets[:limit]]

        return self.Schedule(targets=targets,
                             traffic_classes=traffic_classes,
                             packet_count=int(proto.PacketCount),
                             family=utils.PROTO_TO_FAMILY[proto.Family],
                             vlan=proto.Vlan if proto.HasField("Vlan") else None)

    @property
    def enabled(self):
        return self._loop is not None

    def __nonzero__(self):
        return bool(self._protocols and any(s.targets for s in self._schedules))

    @property
    def protocols(self):
        return self._protocols

    @property
    def schedules(self):
        return self._schedules

    @tornado.gen.coroutine
    def cancel(self):
        if self._loop is not None:
            yield self._loop.cancel()


class AgentService(application.Service):

    Pair = collections.namedtuple("Pair", ("family",
                                           "source", "source_addr",
                                           "target", "target_addr",
                                           "traffic_class", "packet_count"))

    def __init__(self, settings, start_delay=0, start_echo=True):
        self._settings = settings

        self._app = application.Application()

        ticker_interval = self._settings.check_interval
        if self._settings.uniform_probes:
            # In uniform mode AgentService itself generates randomized
            # intervals and passes them to probe configs. Ticker waits for probes
            # to finish and immediately schedules the next iteration.
            ticker_interval = min(ticker_interval)
        # delay probe scheduling for some time to not interfere with builder
        self._loop = ticker.LoopingCall(
            "agent_loop", self._one_shot, ticker_interval,
            start_delay=start_delay
        )

        if self._settings.stats_port is not None:
            self._app.register(application.StatService(port=self._settings.stats_port))

        self._app.register(application.BackendMaintainerService())

        if start_echo:
            self._app.register(application.EchoService(listen_ports=self._settings.echo_port_range))
        if common.UDP in self._settings.protocols:
            self._app.register(application.UdpService())
        if common.ICMP in self._settings.protocols:
            self._app.register(application.IcmpService())
        if common.TCP in self._settings.protocols:
            self._app.register(application.TcpService(listen_ports=self._settings.tcp_port_range))
        if self._settings.link_poller:
            self._app.register(application.LinkService())

        self._app.register(rpc.RpcClient())
        self._app.register(ScheduledProbesService())
        self._app.register(BuilderService())
        if self._settings.networks is not None:
            self._app.register(TopologyAwareIfaceService())
            self._app.register(tasks.TaskExecutor())
            self._app.register(tasks.TaskDispatcher())
        else:
            self._app.register(application.IfaceService(allow_virtual=self._settings.allow_virtual, networks=None))
        self._app.register(sender.SenderService())
        self._app.register(self)

    @tornado.gen.coroutine
    def _process_probes(self, pairs, service, config_list, protocol, smoothing_interval=None):
        if not config_list:
            raise tornado.gen.Return([])

        for config in config_list:
            config.timeout = self._settings.packet_timeout
            config.packet_size = self._settings.packet_size
            config.packet_ttl = self._settings.packet_ttl

            if smoothing_interval:
                config.delay = smoothing_interval / config.packet_count
                config.start_delay = config.delay
            else:
                config.delay = self._settings.packet_delay
                config.start_delay = self._settings.probe_start_delay

        reports = yield service.schedule_checks(config_list)
        raise tornado.gen.Return([
            utils.report_to_proto(report, pair.source, pair.target)
            for pair, report in zip(pairs, reports)
        ])

    def _process_icmp_probes(self, resolved_pairs, probe_type, smoothing_interval=None):
        configs = [
            _netmon.ProbeConfig(
                probe_type,
                _netmon.Address(pair.family, pair.source_addr, 0),
                _netmon.Address(pair.family, pair.target_addr, 0),
                pair.traffic_class,
                pair.packet_count
            )
            for pair in resolved_pairs
        ]

        return self._process_probes(
            resolved_pairs, self._app[application.IcmpService], configs, common.ICMP, smoothing_interval
        )

    def _process_udp_probes(self, resolved_pairs, probe_type, smoothing_interval=None):
        target_ports = self._settings.echo_port_range
        configs = [
            _netmon.ProbeConfig(
                probe_type,
                _netmon.Address(pair.family, pair.source_addr, 0),
                _netmon.Address(pair.family, pair.target_addr, random.choice(target_ports)),
                pair.traffic_class,
                pair.packet_count
            )
            for pair in resolved_pairs
        ]

        return self._process_probes(
            resolved_pairs, self._app[application.UdpService], configs, common.UDP, smoothing_interval
        )

    def _process_tcp_probes(self, resolved_pairs, probe_type, smoothing_interval=None):
        target_ports = self._settings.tcp_port_range
        configs = [
            _netmon.ProbeConfig(
                probe_type,
                _netmon.Address(pair.family, pair.source_addr, 0),
                _netmon.Address(pair.family, pair.target_addr, random.choice(target_ports)),
                pair.traffic_class,
                pair.packet_count
            )
            for pair in resolved_pairs
        ]

        return self._process_probes(
            resolved_pairs, self._app[application.TcpService], configs, common.TCP, smoothing_interval
        )

    @tornado.gen.coroutine
    def _resolve_hostname(self, hostname, ipv4_address=None, ipv6_address=None, ignore_errors=True):
        addresses = yield self._app[application.ResolverService].try_resolve(hostname, ignore_errors)
        if self._settings.use_topology_ips:
            addresses = utils.coalesce_addresses(addresses, ipv4_address, ipv6_address)
        raise tornado.gen.Return(addresses)

    def _check_iface_dscp_compatibility(self, iface, tc):
        if tc == 0:
            return (iface.backbone6 or
                    iface.family == socket.AF_INET or
                    is_mtn_vlan(iface.vlan))
        else:
            if iface.backbone6:
                network_type = utils.BACKBONE
            elif iface.fastbone6:
                network_type = utils.FASTBONE
            else:
                network_type = None
            return transformers.check_dscp_compatibility(network_type, tc)

    @tornado.gen.coroutine
    def _spawn_probes(self):
        selector_list = self._app[BuilderService].get_selectors()
        scheduled_probes = self._app[ScheduledProbesService]
        if not selector_list and not scheduled_probes:
            logging.warning("No selectors and scheduled probes found, do nothing")
            raise tornado.gen.Return([])

        resolved_hostnames = {}
        if self._settings.packet_count > 0 and self._settings.max_targets > 0:
            targets = list({
                (selector, target)
                for selector in selector_list
                for target in selector.select()
            })

            if len(targets) > self._settings.max_targets:
                targets = random.sample(targets, self._settings.max_targets)

            for selector, target in targets:
                if selector.source_name not in resolved_hostnames:
                    resolved_hostnames[selector.source_name] = yield self._resolve_hostname(
                        selector.source_name, selector.source_ipv4_address, selector.source_ipv6_address)

                if target.name not in resolved_hostnames:
                    resolved_hostnames[target.name] = yield self._resolve_hostname(
                        target.name, target.ipv4_address, target.ipv6_address)

            triplets = [
                (selector.source_name, target.name, tc)
                for selector, target in targets
                for tc in self._settings.traffic_classes
                if transformers.check_dscp_compatibility(selector.source_type, tc)
            ]

            resolved_pairs = [
                self.Pair(
                    source_family,
                    source, source_addr,
                    target, target_addr,
                    tc if source_family == socket.AF_INET6 else 0,
                    self._settings.packet_count
                )
                for source, target, tc in triplets
                for source_family, source_addr in resolved_hostnames[source]
                for target_family, target_addr in resolved_hostnames[target]
                if source_family == target_family
            ]
        else:
            resolved_pairs = []

        scheduled_pairs = []
        interfaces = yield self._app[application.IfaceService].get_interfaces()
        for iface in interfaces:
            for schedule in scheduled_probes.schedules:
                if iface.family != schedule.family:
                    continue

                if (
                    (schedule.vlan is not None and schedule.vlan != iface.vlan) or
                    (schedule.vlan is None and is_mtn_vlan(iface.vlan))
                ):
                    continue

                compatible_tcs = [tc for tc in schedule.traffic_classes
                                  if self._check_iface_dscp_compatibility(iface, tc)]
                if not compatible_tcs:
                    continue

                for target in schedule.targets:
                    if (
                        target.hostname.startswith('vlan688@') or
                        target.hostname.startswith('vlan788@')
                    ):
                        # don't try to resolve fake hostname, use ip from schedule
                        target_addrs = [target.ip]
                    else:
                        # resolve hostnames from schedule
                        if target.hostname not in resolved_hostnames:
                            resolved_hostnames[target.hostname] = yield self._resolve_hostname(
                                target.hostname,
                                target.ip if iface.family == socket.AF_INET else None,
                                target.ip if iface.family == socket.AF_INET6 else None,
                                ignore_errors=False
                            )

                        resolved_target_addrs = resolved_hostnames[target.hostname]
                        if resolved_target_addrs is None:
                            # DNS error, try to use ip from schedule
                            target_addrs = [target.ip]
                        else:
                            target_addrs = [
                                addr
                                for family, addr in resolved_target_addrs
                                if family == iface.family
                            ]
                            if target_addrs:
                                if target.ip not in target_addrs:
                                    _unistat_callback.push_signal(_unistat_callback.ScheduleOutdatedIps, 1.0)
                                    logging.warning("Outdated address in schedule for host %s: actual %s, got %s",
                                                    target.hostname, target_addrs, target.ip)
                            elif iface.family == socket.AF_INET and self._settings.ignore_ipv4_dns_fails:
                                target_addrs = [target.ip]
                            else:
                                logging.info("DNS: host %s (family %s) not found", target.hostname, iface.family)
                                target_addrs = []

                    for tc in compatible_tcs:
                        scheduled_pairs.extend([
                            self.Pair(
                                iface.family,
                                iface.fqdn, iface.address,
                                target.hostname, target_addr,
                                tc,
                                schedule.packet_count
                            )
                            for target_addr in target_addrs
                        ])

        if not resolved_pairs and not scheduled_pairs:
            logging.warning("No meaning pairs found, do nothing")
            raise tornado.gen.Return([])

        logging.debug("Will check %d random + %d scheduled pairs", len(resolved_pairs), len(scheduled_pairs))

        smoothing_interval = None
        if self._settings.uniform_probes:
            smoothing_interval = to_timeval(random.randint(*self._settings.check_interval)) - self._settings.packet_timeout
            if smoothing_interval < to_timeval(10):
                logging.error("Smoothing interval {} is too short, ignoring uniform mode".format(smoothing_interval))
                smoothing_interval = None

        processors = []
        if common.ICMP in self._settings.protocols:
            processors.append(self._process_icmp_probes(resolved_pairs, common.REGULAR_PROBE, smoothing_interval))
            if common.ICMP in scheduled_probes.protocols:
                processors.append(self._process_icmp_probes(scheduled_pairs, common.NOC_SLA_PROBE, smoothing_interval))
        if common.UDP in self._settings.protocols:
            processors.append(self._process_udp_probes(resolved_pairs, common.REGULAR_PROBE, smoothing_interval))
            if common.UDP in scheduled_probes.protocols:
                processors.append(self._process_udp_probes(scheduled_pairs, common.NOC_SLA_PROBE, smoothing_interval))
        if common.TCP in self._settings.protocols:
            processors.append(self._process_tcp_probes(resolved_pairs, common.REGULAR_PROBE, smoothing_interval))
            if common.TCP in scheduled_probes.protocols:
                processors.append(self._process_tcp_probes(scheduled_pairs, common.NOC_SLA_PROBE, smoothing_interval))
        report_list = list(itertools.chain.from_iterable((yield processors)))

        raise tornado.gen.Return(report_list)

    @tornado.gen.coroutine
    def _one_shot(self):
        report_list = yield self._spawn_probes()

        if self._settings.link_poller:
            link_reports = yield self._app[application.LinkService].report_status()
            link_reports_proto = [utils.report_to_proto(rep, self._settings.hostname, self._settings.hostname) for rep in link_reports]
            report_list = link_reports_proto + report_list

        for report in report_list:
            logging.debug("New report finished: \n%s", report)

        self._app[sender.SenderService].enqueue(report_list)

    @contextlib.contextmanager
    def create_context(self, app):
        with utils.create_pid(self._settings.pid_path):
            yield

    @tornado.gen.coroutine
    def cancel(self):
        yield self._loop.cancel()

    def start(self):
        self._app.start()

    def run_once(self):
        self._app.prepare()
        self._app.run_sync(self._loop.wait)
