# coding: utf-8
from __future__ import print_function

import random

import click
import tornado.gen

import _netmon

from infra.netmon.agent.idl import common_pb2

from . import application
from . import utils


class TracertFiller(object):

    def __init__(self, source_address, target_address, source=None, target=None, max_ttl=32, traffic_class=0):
        # TODO: it should be 255, but it's too big for our network
        self._max_ttl = max_ttl
        self._source_address = source_address
        self._target_address = target_address
        self._traffic_class = traffic_class
        self._traceroute = self._create_traceroute(self._source_address, self._target_address, source, target)

    @staticmethod
    def _create_traceroute(source_address, target_address, source=None, target=None):
        return common_pb2.TTracerouteInfo(
            Source=source,
            Target=target,
            SourceAddress=common_pb2.TAddress(Ip=source_address.hostname, Port=source_address.port),
            TargetAddress=common_pb2.TAddress(Ip=target_address.hostname, Port=target_address.port)
        )

    def _create_config(self, ttl):
        config = _netmon.ProbeConfig(common_pb2.REGULAR_PROBE, self._source_address, self._target_address, self._traffic_class, 3)
        config.start_delay = 0
        config.timeout = 1000
        config.delay = 100
        config.packet_size = 100
        config.packet_ttl = ttl
        return config

    def _add_hop(self, ttl, report):
        assert report.lost
        self._traceroute.Hops.add(
            TimeToLive=ttl,
            Offender=utils.address_to_proto(report.offender)
        )
        self._traceroute.Generated = report.generated

    @tornado.gen.coroutine
    def scan(self, scheduler, threads):
        for base_ttl in xrange(1, self._max_ttl + 1, threads):
            config_list = [
                self._create_config(ttl)
                for ttl in xrange(base_ttl, min(self._max_ttl + 1, base_ttl + threads))
            ]
            report_list = yield scheduler.schedule_checks(config_list)
            for idx, report in enumerate(report_list):
                if report.failed or report.received:
                    raise tornado.gen.Return(self._traceroute)
                else:
                    self._add_hop(base_ttl + idx, report)
        raise tornado.gen.Return(self._traceroute)


class IcmpBehavior(object):

    service_class = application.IcmpService
    port = 0

    def __init__(self, settings):
        self._settings = settings
        self.service_object = application.IcmpService()

    def tracert(self, scheduler, source, target):
        return TracertFiller(source, target, traffic_class=self._settings.traffic_classes[0]).scan(scheduler, 2)


class UdpBehavior(object):

    service_class = application.UdpService

    def __init__(self, settings):
        self._settings = settings
        self.service_object = application.UdpService()

    @property
    def port(self):
        return random.choice(self._settings.echo_port_range)

    def tracert(self, scheduler, source, target):
        return TracertFiller(source, target, traffic_class=self._settings.traffic_classes[0]).scan(scheduler, 8)


class TcpBehavior(object):

    service_class = application.TcpService

    def __init__(self, settings):
        self._settings = settings
        self.service_object = application.TcpService(
            listen_ports=self._settings.tcp_port_range
        )

    @property
    def port(self):
        return random.choice(self._settings.tcp_port_range)

    def tracert(self, scheduler, source, target):
        return TracertFiller(source, target, traffic_class=self._settings.traffic_classes[0]).scan(scheduler, 2)


def _get_behavior(protocol, settings):
    if protocol == common_pb2.ICMP:
        return IcmpBehavior(settings)
    elif protocol == common_pb2.UDP:
        return UdpBehavior(settings)
    elif protocol == common_pb2.TCP:
        return TcpBehavior(settings)
    else:
        raise RuntimeError()


class TracertCommand(object):

    def __init__(self, settings, protocol, hostname):
        self._settings = settings
        self._hostname = hostname
        self._behavior = _get_behavior(utils.STRING_TO_PROTOCOL[protocol.upper()], settings)

        self._app = application.Application()
        self._app.register(application.IfaceService())
        self._app.register(self._behavior.service_object)

    @tornado.gen.coroutine
    def _dispatch(self):
        source_addresses = yield self._app[application.IfaceService].get_addresses()
        target_addresses = yield self._app[application.ResolverService].try_resolve(self._hostname)

        address_combinations = utils.permutate_addresses(source_addresses, target_addresses, self._behavior.port)
        reports = yield [
            self._behavior.tracert(self._app[self._behavior.service_class], source, target)
            for source, target in address_combinations
        ]

        raise tornado.gen.Return(reports)

    def start(self):
        for info in self._app.run_sync(self._dispatch):
            if info.Hops:
                click.echo(repr(info))
