# coding: utf-8
from __future__ import print_function

import json
import random
import socket

import click
import tornado.gen

import _netmon

from . import application
from . import utils
from . import rpc

from google.protobuf.json_format import MessageToJson
from infra.netmon.agent.idl import common_pb2


class IcmpBehavior(object):

    service_class = application.IcmpService

    def __init__(self):
        self.service_object = application.IcmpService()

    @property
    def port(self):
        return 0


class UdpBehavior(object):

    service_class = application.UdpService

    def __init__(self, settings):
        self._settings = settings
        self.service_object = application.UdpService()

    @property
    def port(self):
        return random.choice(self._settings.echo_port_range)


class TcpBehavior(object):

    service_class = application.TcpService

    def __init__(self, settings):
        self._settings = settings
        self.service_object = application.TcpService(
            listen_ports=self._settings.tcp_port_range
        )

    @property
    def port(self):
        return random.choice(self._settings.tcp_port_range)


def _get_behavior(protocol, settings):
    if protocol == common_pb2.ICMP:
        return IcmpBehavior()
    elif protocol == common_pb2.UDP:
        return UdpBehavior(settings)
    elif protocol == common_pb2.TCP:
        return TcpBehavior(settings)
    else:
        raise RuntimeError()


class PingCommand(application.Service):

    def __init__(self, settings, protocol, json,
                 summarize, print_successful, print_failed,
                 hosts_per_sec, packet_size, allow_mtn_vlan,
                 hostnames):
        self._settings = settings
        self._hostnames = hostnames
        self._json = json
        self._summarize = summarize
        self._print_successful = print_successful
        self._print_failed = print_failed
        self._hosts_per_sec = hosts_per_sec
        self._behavior = _get_behavior(utils.STRING_TO_PROTOCOL[protocol.upper()], settings)
        self._targets = {}

        self._settings.unistat_pusher = False
        if packet_size is not None:
            self._settings.packet_size = packet_size

        if any(utils.get_address_family(host) == socket.AF_INET for host in self._hostnames):
            self._settings.dns_resolve_ip4 = True

        self._app = application.Application()
        self._app.register(application.IfaceService(networks=settings.networks, allow_mtn_vlan=allow_mtn_vlan))
        self._app.register(self._behavior.service_object)
        if self._settings.current().noc_sla_urls:
            self._app.register(rpc.RpcClient())
        self._app.register(self)

    def _print_reports(self, future):
        result = {}
        for report in future.result():
            proto_report = utils.report_to_proto(report)
            target_addr = report.target_addr[1][0]
            target_hostname = self._targets[target_addr]

            if self._json:
                result[target_hostname] = json.loads(MessageToJson(proto_report))
            else:
                click.echo('TargetHostName: {}'.format(target_hostname))
                click.echo(repr(proto_report))

        if self._json:
            click.echo(json.dumps(result))

        self._app.stop_loop()

    def _print_summary(self, future):
        reports = future.result()

        def get_target_host(report):
            return self._targets[report.target_addr[1][0]]
        failed_hosts = [
            get_target_host(report)
            for report in reports
            if report.received == 0
        ]
        semifailed_hosts = [
            get_target_host(report)
            for report in reports
            if report.received > 0 and report.received < self._settings.packet_count
        ]
        success_hosts = [
            get_target_host(report)
            for report in reports
            if report.received == self._settings.packet_count
        ]

        if self._json:
            result = {
                "success_cnt": len(success_hosts),
                "failed_cnt": len(failed_hosts),
                "semifailed_cnt": len(semifailed_hosts)
            }
            if self._print_successful:
                result.update({
                    "success_hosts": success_hosts,
                    "semifailed_hosts": semifailed_hosts
                })
            if self._print_failed:
                result.update({
                    "failed_hosts": failed_hosts,
                    "semifailed_hosts": semifailed_hosts
                })
            click.echo(json.dumps(result))
        else:
            click.echo("{} probes sent, {} successful, {} failed, {} semifailed".format(
                len(reports),
                len(success_hosts),
                len(failed_hosts),
                len(semifailed_hosts)
            ))
            if self._print_successful:
                click.echo("Successful probe targets: {}".format(' '.join(success_hosts)))
            if self._print_failed:
                click.echo("Failed probe targets: {}".format(' '.join(failed_hosts)))
            if self._print_successful or self._print_failed:
                click.echo("Semifailed probe targets: {}".format(' '.join(semifailed_hosts)))

        self._app.stop_loop()

    @tornado.gen.coroutine
    def _send_reports(self, reports):
        ifaces = yield self._app[application.IfaceService].get_interfaces()
        def get_hostname(addr):
            for iface in ifaces:
                if iface.address == addr:
                    return iface.fqdn
            return None

        reports = [utils.report_to_proto(rep,
                                         get_hostname(rep.source_addr[1][0]),
                                         self._targets[rep.target_addr[1][0]])
                   for rep in reports]
        send_reports_wrapper = utils.suppress_http_errors(self._app[rpc.RpcClient].send_reports)

        response_codes = yield {
            url: send_reports_wrapper(
                reports,
                url
            )
            for url in self._settings.current().noc_sla_urls
        }
        for url, code in response_codes.items():
            if code != 200:
                click.echo("Sending to {} failed: response code {}".format(url, code))

    @tornado.gen.coroutine
    def _schedule_probes(self):
        source_addresses = yield self._app[application.IfaceService].get_addresses()
        target_addresses = []

        target_addresses_by_host = yield {host: self._app[application.ResolverService].try_resolve(host) for host in self._hostnames}

        for host, addresses in target_addresses_by_host.items():
            target_addresses += addresses
            for target_addr in addresses:
                self._targets[target_addr[1]] = host

        probe_type = common_pb2.NOC_SLA_PROBE if self._settings.current().noc_sla_urls else common_pb2.REGULAR_PROBE

        config_list = [
            _netmon.ProbeConfig(probe_type, source, target, self._settings.traffic_classes[0], self._settings.packet_count)
            for source, target in utils.permutate_addresses(source_addresses, target_addresses, self._behavior.port)
        ]

        delay_inc = float(utils.MS) / self._hosts_per_sec
        start_delay = 0
        for config in config_list:
            config.start_delay = int(start_delay)
            config.timeout = self._settings.packet_timeout
            config.delay = self._settings.packet_delay
            config.packet_size = self._settings.packet_size
            config.packet_ttl = self._settings.packet_ttl

            start_delay += delay_inc

        reports = yield self._app[self._behavior.service_class].schedule_checks(config_list)
        if self._settings.current().noc_sla_urls:
            yield self._send_reports(reports)

        raise tornado.gen.Return(reports)

    def start(self):
        self._app.setup_signals()
        reports = self._schedule_probes()
        if self._summarize:
            reports.add_done_callback(self._print_summary)
        else:
            reports.add_done_callback(self._print_reports)
        self._app.start_loop()
