# coding: utf-8
from __future__ import print_function

import bisect
import random
import logging

from . import settings


class TargetIface(object):

    __slots__ = ("name", "ipv4_address", "ipv6_address")

    def __init__(self, name, ipv4_address, ipv6_address):
        self.name = name
        self.ipv4_address = ipv4_address
        self.ipv6_address = ipv6_address


class Switch(object):

    __slots__ = ("targets", )

    def __init__(self):
        self.targets = []


class Queue(object):

    __slots__ = ("switches", "switch_weights", "target_count")

    def __init__(self):
        self.switches = []
        self.switch_weights = []
        self.target_count = 0


class Datacenter(object):

    __slots__ = ("queues", "target_count")

    def __init__(self):
        self.queues = []
        self.target_count = 0


class TargetTree(object):

    __slots__ = (
        "datacenters", "source_type", "source_name", "source_switch",
        "source_queue", "source_datacenter", "_target_count",
        "_dc_map", "_queue_map", "_switch_map", "_source_object_name", "_finalized",
        "source_ipv4_address", "source_ipv6_address",
    )

    def __init__(self, source, iterator=None):
        self.datacenters = []

        self.source_type = None
        self.source_name = None
        self.source_switch = None
        self.source_queue = None
        self.source_datacenter = None

        self.source_ipv4_address = None
        self.source_ipv6_address = None

        self._target_count = 0

        self._dc_map = {}
        self._queue_map = {}
        self._switch_map = {}
        self._source_object_name = source.name

        self._finalized = False

        if iterator:
            self._from_iterable(iterator)

    def __len__(self):
        return self._target_count

    def is_valid(self):
        # tree should contain ourself and at least one another host
        return self.source_name is not None and self._target_count > 1

    def _from_iterable(self, iterator):
        for iface in iterator:
            self.push(iface)

        self.finalize()

    def push(self, iface):
        if self._finalized:
            return

        iface_name = iface.name
        datacenter_name = iface.datacenter
        queue_name = iface.queue
        switch_name = iface.switch
        network_type = iface.network_type
        ipv4_address = iface.ipv4_address
        ipv6_address = iface.ipv6_address

        dc = self._dc_map.get(datacenter_name)
        if dc is None:
            dc = self._dc_map[datacenter_name] = Datacenter()
            self.datacenters.append(dc)

        queue = self._queue_map.get((datacenter_name, queue_name))
        if queue is None:
            queue = self._queue_map[(datacenter_name, queue_name)] = Queue()
            dc.queues.append(queue)

        switch = self._switch_map.get((datacenter_name, queue_name, switch_name))
        if switch is None:
            switch = self._switch_map[(datacenter_name, queue_name, switch_name)] = Switch()
            queue.switches.append(switch)

        switch.targets.append(TargetIface(iface_name, ipv4_address, ipv6_address))

        dc.target_count += 1
        queue.target_count += 1
        self._target_count += 1

        if iface_name == self._source_object_name:
            self.source_name = iface_name
            self.source_switch = switch
            self.source_queue = queue
            self.source_datacenter = dc
            self.source_type = network_type
            self.source_ipv4_address = ipv4_address
            self.source_ipv6_address = ipv6_address

    def finalize(self):
        self._finalized = True

        del self._dc_map
        del self._queue_map
        del self._switch_map
        del self._source_object_name

        self.datacenters = tuple(self.datacenters)
        for dc in self.datacenters:
            dc.queues = tuple(dc.queues)
            for queue in dc.queues:
                queue.switches = tuple(queue.switches)
                for switch in queue.switches:
                    switch.targets = tuple(switch.targets)

                switch_weights = []
                cum_weight = 0
                for switch in queue.switches:
                    cum_weight += len(switch.targets) / float(queue.target_count)
                    switch_weights.append(cum_weight)
                switch_weights[-1] = 1.0
                queue.switch_weights = tuple(switch_weights)


class BaseSelector(object):

    def __init__(self, tree):
        self._tree = tree

    def _select_targets(self):
        raise NotImplementedError()

    @property
    def source_name(self):
        """Return name that should be used as probe source."""
        return self._tree.source_name

    @property
    def source_type(self):
        """Return probe source type"""
        return self._tree.source_type

    @property
    def source_ipv4_address(self):
        """Return ipv4 address of source"""
        return self._tree.source_ipv4_address

    @property
    def source_ipv6_address(self):
        """Return ipv6 address of source"""
        return self._tree.source_ipv6_address

    def select(self):
        """Simply return all interfaces that we should check."""
        return [x for x in self._select_targets() if x.name != self._tree.source_name]


class LeveledSelector(BaseSelector):
    """Select targets at each network level."""

    def __init__(self,
                 tree,
                 per_switch_target_count=None,
                 per_queue_target_count=None,
                 per_intra_datacenter_target_count=None,
                 per_inter_datacenter_target_count=None,
                 max_probability=None,
                 minimal_weight=None):

        super(LeveledSelector, self).__init__(tree)

        opt = settings.current()
        self._per_switch_target_count = float(
            per_switch_target_count if per_switch_target_count is not None
            else opt.per_switch_target_count)
        self._per_queue_target_count = float(
            per_queue_target_count if per_queue_target_count is not None
            else opt.per_queue_target_count)
        self._per_intra_datacenter_target_count = float(
            per_intra_datacenter_target_count if per_intra_datacenter_target_count is not None
            else opt.per_intra_datacenter_target_count)
        self._per_inter_datacenter_target_count = float(
            per_inter_datacenter_target_count if per_inter_datacenter_target_count is not None
            else opt.per_inter_datacenter_target_count)

        # sometimes we can have only one host in switch, so don't check too much
        self._max_probability = float(max(min(max_probability or opt.max_probability, 1), 0))
        self._minimal_weight = float(max(min(minimal_weight or opt.minimal_weight, 1), 0))

    def _maybe_yes(self, probability):
        # TODO: why we choose beta distribution with those alpha and beta
        probability = min(self._max_probability, probability)
        return random.betavariate(1, 1) <= probability

    def _bound_weight(self, weight):
        return self._minimal_weight + (1 - self._minimal_weight) * weight

    def _select_multiple_from_list(self, probability, target_list):
        selected = set()
        if not target_list:
            return selected
        tries, probability = divmod(probability, 1)
        for _ in xrange(int(tries)):
            selected.add(random.choice(target_list))
        if probability and self._maybe_yes(probability):
            selected.add(random.choice(target_list))
        return selected

    def _select_targets_inside_switch(self):
        """Select targets in same switch."""
        # at least one (ourself) should be in that switch
        switch_targets = self._tree.source_switch.targets
        probability = self._per_switch_target_count / len(switch_targets)
        logging.debug("Switch %s probability: %0.3f", self._tree.source_switch, probability)
        return self._select_multiple_from_list(probability, switch_targets)

    def _select_targets_inside_queue(self):
        """Select targets in each switch pair belongs to same queue."""
        source_switch = self._tree.source_switch
        probability = self._per_queue_target_count / len(source_switch.targets)
        max_target_count = float(max(
            len(switch.targets) for switch in self._tree.source_queue.switches))
        target_list = []
        for target_switch in self._tree.source_queue.switches:
            weight = self._bound_weight(len(target_switch.targets) / max_target_count)
            if source_switch == target_switch or not target_switch.targets:
                continue
            logging.debug("Queue probability (%s -> %s): %0.3f",
                          source_switch, target_switch, probability * weight)
            target_list.extend(self._select_multiple_from_list(probability * weight, target_switch.targets))
        return target_list

    def _select_multiple_from_queue(self, probability, queue):
        selected = set()

        def select_one():
            idx = bisect.bisect_left(queue.switch_weights, random.random())
            switch = queue.switches[idx]
            logging.debug("Switch %s selected from queue", switch)
            selected.add(random.choice(switch.targets))

        tries, probability = divmod(probability, 1)
        for _ in xrange(int(tries)):
            select_one()
        if probability and self._maybe_yes(probability):
            select_one()

        return selected

    def _select_targets_inside_datacenter(self):
        """Select targets in each queue pair belongs to same datacenter."""
        source_queue = self._tree.source_queue
        probability = self._per_intra_datacenter_target_count / source_queue.target_count
        max_target_count = float(max(
            queue.target_count for queue in self._tree.source_datacenter.queues))
        target_list = []
        for target_queue in self._tree.source_datacenter.queues:
            weight = self._bound_weight(target_queue.target_count / max_target_count)
            if source_queue == target_queue or not target_queue.target_count:
                continue
            logging.debug("Inside DC probability (%s -> %s): %0.3f",
                          source_queue, target_queue, probability * weight)
            target_list.extend(self._select_multiple_from_queue(probability * weight, target_queue))
        return target_list

    def _select_targets_among_datacenter(self):
        """Select targets in other datacenters."""
        source_datacenter = self._tree.source_datacenter
        probability = self._per_inter_datacenter_target_count / source_datacenter.target_count
        max_target_count = float(max(
            queue.target_count
            for dc in self._tree.datacenters
            for queue in dc.queues))
        target_list = []
        for target_datacenter in self._tree.datacenters:
            if source_datacenter == target_datacenter:
                continue
            for target_queue in target_datacenter.queues:
                # try to check bigger queues more frequent
                weight = self._bound_weight(target_queue.target_count / max_target_count)
                logging.debug("Among DC probability (%s -> %s, %s): %0.3f",
                              source_datacenter, target_datacenter, target_queue,
                              probability * weight)
                target_list.extend(self._select_multiple_from_queue(probability * weight, target_queue))
        return target_list

    def _select_targets(self):
        target_list = []
        # NB: datacenters can have different size and shape
        target_list.extend(self._select_targets_inside_switch())
        target_list.extend(self._select_targets_inside_queue())
        target_list.extend(self._select_targets_inside_datacenter())
        target_list.extend(self._select_targets_among_datacenter())
        return target_list


class PeerToPeerSelector(BaseSelector):
    """Select all targets in tree."""

    def _select_targets(self):
        target_list = []
        for datacenter in self._tree.datacenters:
            for queue in datacenter.queues:
                for switch in queue.switches:
                    target_list.extend(switch.targets)
        return target_list
