import logging
import datetime
import functools
import collections

import infra.callisto.controllers.utils.entities as entities
import infra.callisto.controllers.sdk as sdk
import infra.callisto.controllers.sdk.blocks as blocks
import infra.callisto.controllers.sdk.notify as notify
import infra.callisto.controllers.utils.funcs as funcs
import infra.callisto.controllers.slots.state as slots_state
import report


class Intl2Groups(notify.ValueNotification):
    name = 'intl2-groups'


class Intl2Instances(notify.ValueNotification):
    name = 'intl2-instances'


class IntL2Controller(sdk.Controller):
    reports_alive_threshold = datetime.timedelta(minutes=4)

    @property
    def id(self):
        return 'intl2_' + self._contour.name

    @property
    def tags(self):
        return {'a_itype_intl2'}

    def __init__(self, contour, agents, max_switching=3, unsafe=False):
        super(IntL2Controller, self).__init__()
        self._contour = contour
        self._agents = agents
        self._agents_targets = {agent: slots_state.empty_mapping(contour) for agent in agents}
        self._target = slots_state.empty_mapping(contour)
        self._max_switching = max_switching
        self._unsafe = unsafe
        self._reports = {}

    def notifications(self):
        timestamps = collections.defaultdict(int)
        for target in self._target.values():
            if target:
                timestamps[target.timestamp] += 1
        targets = [
            Intl2Groups(
                value=cnt,
                labels=dict(generation=ts, status='in-target'),
            )
            for ts, cnt in timestamps.items()
        ]
        return targets + [
            Intl2Instances(value=len(self._agents_match_target()), labels=dict(status='match-target')),
            Intl2Instances(value=len(self._switching_agents()), labels=dict(status='switching')),
        ]

    def set_context(self, context):
        if context is not None:
            self._target = slots_state.mapping_from_dict(context['target'])
            prev_targets = agents_targets_from_json(context['state'])
            self._agents_targets = {
                agent: prev_targets.get(agent) or slots_state.empty_mapping(self._contour)
                for agent in self._agents
            }

    def execute(self):
        empty_map = slots_state.empty_mapping(self._contour)
        for agent, target_mapping in self._agents_targets.items():
            if target_mapping == empty_map and self._target != empty_map:
                logging.info('set intl2(%s) ts because current is empty', agent.instance)
                self._agents_targets[agent] = self._target.copy()

    def update(self, reports):
        convert_func = functools.partial(report.convert_report_to_intl2_report, self._contour)
        reports = funcs.imap_ignoring_exceptions(convert_func, _filter_reports(reports, self._agents))
        self._reports = {rep.agent: rep for rep in reports}

    def update_state(self, slots_states):
        _log.debug('update targets')
        target = _generate_target(self._contour, slots_states)

        for slot, old, new in _diff_targets(self._target, target):
            _log.info('Slot %s ts: %s -> %s', slot, old, new)

        self._target = target
        self._reconfigure()
        self._check()

    def gencfg(self):
        configs = {}
        for agent, target_mapping in self._agents_targets.items():
            configs[agent.host, agent.port] = {}
            for slot in target_mapping:
                configs[agent.host, agent.port][str(slot)] = target_mapping[slot].json()
        return configs

    def get_context(self):
        return {
            'target': slots_state.mapping_to_dict(self._target),
            'state': agents_targets_to_json(self._agents_targets),
        }

    def unused_slots(self):
        if self._unsafe:
            return set(self._contour.slots)
        res = reduce(set.difference, (intl2.used_slots for intl2 in self._state.values()), set(self._contour.slots))
        for slot, target in self._target.items():
            if target:
                res.discard(slot)
        _log.debug('unused slots: %s', res)
        return res

    def _check(self):
        if not self._unsafe:
            self._check_not_too_many_reconfigured()
            self._check_not_null_mapping()
            self._check_enough_alive_instances()
        else:
            _log.warn('Unsafe mode: skipping all checks')

    def _reconfigure(self):
        for agent, intl2 in funcs.shuffled(self._state.items()):
            if sum(1 for intl2_ in self._state.values() if intl2_.switching) < self._max_switching or intl2.switching:
                if self._agents_targets[agent] != self._target:
                    self._agents_targets[agent] = self._target.copy()
                    _log.debug('set intl2(%s) ts', agent.instance)

    def _check_not_too_many_reconfigured(self):
        state = self._state
        cnt = 0
        for agent, intl2 in funcs.shuffled(state.items()):
            if self._agents_targets[agent] != state[agent].observed_state:
                cnt += 1
        if cnt > self._max_switching:
            raise RuntimeError('too many reconfiguring')

    def _check_not_null_mapping(self):
        if self._target == slots_state.empty_mapping(self._contour):
            raise RuntimeError('empty mapping')

    def _check_enough_alive_instances(self):
        assert len(self._reports) > 0.75 * len(self._agents)

    def _switching_agents(self):
        lst = []
        for agent, intl2 in self._state.iteritems():
            if intl2.switching:
                lst.append(agent.instance)
        return lst

    @property
    def _state(self):
        return {
            agent_: _IntL2(self._agents_targets[agent_], report_.state)
            for agent_, report_ in self._reports.items()
        }

    def json_view(self):
        return {
            'slots': {
                slot: state.json() for slot, state in self._target.items()
            },
            'switching': self._switching_agents(),
        }

    def html_view(self):
        return blocks.Intl2View(
            target_state=blocks.Intl2Target({slot: state.json() for slot, state in self._target.items()}),
            progress=blocks.Progress(
                done=len(self._agents_match_target()),
                total=len(self._agents)
            ),
        )

    def _agents_match_target(self):
        return filter(
            lambda agent: (
                self._agents_targets[agent] == self._target
                and agent in self._state
                and not self._state[agent].switching
            ),
            self._agents
        )

    def observed_slot_usage(self, slot_name):
        observed_cnt, not_observed_cnt = 0, 0
        for agent in self._reports:
            if self._agents_targets[agent][slot_name]:
                observed_cnt += 1
            else:
                not_observed_cnt += 1
        return observed_cnt / float(observed_cnt + not_observed_cnt)


class _IntL2(object):
    def __init__(self, target_state, observed_state):
        self.target_state = target_state
        self.observed_state = observed_state

    @property
    def switching(self):
        if any(self.target_state.values()):
            return self.target_state != self.observed_state
        return False

    @property
    def used_slots(self):
        a = {slot for slot, state in self.observed_state.iteritems() if state}
        b = {slot for slot, state in self.target_state.iteritems() if state}
        return set.union(a, b)


def agents_targets_to_json(agents_targets):
    return [
        {'instance': agent.instance, 'target': slots_state.mapping_to_dict(target)}
        for agent, target in agents_targets.items()
    ]


def agents_targets_from_json(agents_targets):
    return {
        entities.Agent(*dct['instance'].split(':')): slots_state.mapping_from_dict(dct['target'])
        for dct in agents_targets
    }


def _generate_target(tier, slots_states):
    target = slots_state.empty_mapping(tier)
    for slot_states in _states_per_group(slots_states).values():
        slot, state = _select_slot_state(slot_states)
        target[slot] = state
    return target


def _states_per_group(slots_states):
    used = collections.defaultdict(set)
    for slot, state in slots_states.items():
        if state:
            used[state.group_number].add((slot, state))
    return used


def _select_slot_state(slots_states):
    slot, state = sorted(slots_states, key=lambda (slot_, state_): (-state_.timestamp, slot_))[0]
    return slot, state


def _diff_targets(old, new):
    res = []
    for slot in old:
        if new[slot] != old[slot]:
            res.append((slot, old[slot], new[slot]))
    return res


def _filter_reports(reports, known_agents):
    return (rep for agent, rep in reports.iteritems() if agent in known_agents)


_log = logging.getLogger(__name__)
