import math
import logging
import functools
import collections

import infra.callisto.controllers.sdk.tier as tiers
import infra.callisto.controllers.sdk.blocks as blocks
import infra.callisto.protos.multibeta.slot_state_pb2 as slot_state_pb2


def _count_shard_number_replicas(agent_to_shard_number):
    result = collections.defaultdict(int)
    for agent, shard_number in agent_to_shard_number.items():
        result[shard_number] += 1
    return result


class _ConfigurationObserved(object):
    def __init__(self, hash_):
        self.hash = hash_

    @property
    def ready(self):
        raise NotImplementedError()

    def json(self):
        raise NotImplementedError()

    def html(self):
        raise NotImplementedError()

    def shards_progress_proto(self):
        raise NotImplementedError()

    def instances_progress_proto(self):
        raise NotImplementedError()

    def diagnostics_proto(self):
        return slot_state_pb2.Component.Diagnostics()


class _ShardlessConfObserved(_ConfigurationObserved):
    def __init__(
        self,
        required_total_ratio,
        hash_, target_count, observed_count
    ):
        super(_ShardlessConfObserved, self).__init__(hash_)
        self._required_total_ratio = required_total_ratio
        self._target_count = target_count
        self._observed_count = observed_count

    def _required_instance_count(self):
        return int(self._target_count * self._required_total_ratio)

    @property
    def ready(self):
        return self._observed_count >= self._required_instance_count()

    def json(self):
        return collections.OrderedDict(
            ready=self.ready,
            target_cnt=self._target_count,
            observed_cnt=self._observed_count
        )

    def html(self):
        return blocks.Progress(
            self._observed_count, self._target_count
        )

    def shards_progress_proto(self):
        return slot_state_pb2.Component.Progress(
            required_count=0,
            observed_count=0,
        )

    def instances_progress_proto(self):
        return slot_state_pb2.Component.Progress(
            required_count=self._required_instance_count(),
            observed_count=self._observed_count,
        )


class _ShardConfObserved(_ConfigurationObserved):
    def __init__(
        self,
        target_counts, required_total_ratio, required_replicas_ratio, allow_lack_shards_ratio,
        hash_, db_timestamp, observed_counts,
        diagnostics,
    ):
        super(_ShardConfObserved, self).__init__(hash_)
        self._target_counts = target_counts
        self._required_total_ratio = required_total_ratio
        self._required_replicas_ratio = required_replicas_ratio
        self._allow_lack_shards_ratio = allow_lack_shards_ratio

        self._db_timestamp = db_timestamp
        self._observed_counts = observed_counts
        self._diagnostics = diagnostics

    def _total_instances(self):
        return int(sum(self._target_counts.values()))

    def _total_shards(self):
        return len(self._target_counts)

    def _required_instances(self):
        return int(math.floor(self._total_instances() * self._required_total_ratio))

    def _required_shards(self):
        return int(math.floor(self._total_shards() * (1 - self._allow_lack_shards_ratio)))

    def _observed_instances(self):
        return int(sum(self._observed_counts.values()))

    def _observed_shards(self):
        seen_shards, ratio = 0, self._required_replicas_ratio
        for shard_number in self._target_counts:
            if (
                shard_number in self._observed_counts
                and self._observed_counts[shard_number] >= math.floor(self._target_counts[shard_number] * ratio)
            ):
                seen_shards += 1
        return int(seen_shards)

    def _remaining_shards(self):
        ratio = self._required_replicas_ratio
        result = {}
        for shard_number in self._target_counts:
            if (
                shard_number not in self._observed_counts
                or self._observed_counts[shard_number] < math.floor(self._target_counts[shard_number] * ratio)
            ):
                result[str(shard_number)] = self._observed_counts.get(shard_number, 0)
        return dict(result.items()[:3])

    @property
    def ready(self):
        if self._observed_instances() < self._required_instances():
            return False
        return self._observed_shards() >= self._required_shards()

    def json(self):
        return collections.OrderedDict(
            ready=self.ready,
            total_instances=self._total_instances(),
            observed_instances=self._observed_instances(),
            total_shards=self._total_shards(),
            observed_shards=self._observed_shards(),
            hash=self.hash,
            db_timestamp=self._db_timestamp,
            remaining_shards=self._remaining_shards(),
            ooms=dict(total=self._diagnostics.oom_kills.total, median=self._diagnostics.oom_kills.median),
            respawns=dict(total=self._diagnostics.respawns.total, median=self._diagnostics.respawns.median),
        )

    def html(self):
        return blocks.DoubleProgress(
            self._observed_shards(), self._total_shards(), self._observed_instances(), self._total_instances()
        )

    def shards_progress_proto(self):
        return slot_state_pb2.Component.Progress(
            required_count=self._required_shards(),
            observed_count=self._observed_shards(),
        )

    def instances_progress_proto(self):
        return slot_state_pb2.Component.Progress(
            required_count=self._required_instances(),
            observed_count=self._observed_instances(),
        )

    def diagnostics_proto(self):
        return slot_state_pb2.Component.Diagnostics(
            oom_kills=slot_state_pb2.Component.Diagnostics.OOM(
                total=self._diagnostics.oom_kills.total,
                median=self._diagnostics.oom_kills.median,
            ),
            respawns=slot_state_pb2.Component.Diagnostics.Respawns(
                total=self._diagnostics.respawns.total,
                median=self._diagnostics.respawns.median,
            ),
        )

    def __gt__(self, other):
        return self._observed_instances() > other._observed_instances()


class _Observed(object):
    def __init__(self):
        self._slots = {}

    def ready_hashes(self, slot_id):
        return {conf.hash for conf in self._slots.get(slot_id, {}).values() if conf.ready}

    def all_hashes(self, slot_id):
        return {conf.hash for conf in self._slots.get(slot_id, {}).values()}

    def _get_conf(self, slot_id, conf_hash):
        return self._slots.get(slot_id, {}).get(conf_hash)

    def json(self, slot_id, conf_hash):
        conf = self._get_conf(slot_id, conf_hash)
        return conf and conf.json() or {'ready': False}

    def html(self, slot_id, conf_hash):
        conf = self._get_conf(slot_id, conf_hash)
        return conf and conf.html() or blocks.Block()

    def shards_progress_proto(self, slot_id, conf_hash):
        conf = self._get_conf(slot_id, conf_hash)
        return conf and conf.shards_progress_proto()

    def instances_progress_proto(self, slot_id, conf_hash):
        conf = self._get_conf(slot_id, conf_hash)
        return conf and conf.instances_progress_proto()

    def diagnostics_proto(self, slot_id, conf_hash):
        conf = self._get_conf(slot_id, conf_hash)
        return conf and conf.diagnostics_proto()


class BasesearchObserved(_Observed):
    def __init__(
            self,
            agent_shard_number_map,
            required_total_ratio,
            required_replicas_ratio,
            allow_lack_shards_ratio=0,
    ):
        super(BasesearchObserved, self).__init__()
        self._configuration_factory_ = functools.partial(
            _ShardConfObserved,
            target_counts=_count_shard_number_replicas(agent_shard_number_map),
            required_total_ratio=required_total_ratio,
            required_replicas_ratio=required_replicas_ratio,
            allow_lack_shards_ratio=allow_lack_shards_ratio,
        )
        self._slots = {}

    def update(self, reports):
        result = collections.defaultdict(dict)
        for (slot_id, db_timestamp, hash_), slot_state in _observed_slots_state(reports).items():
            observed_conf = self._configuration_factory_(
                hash_=hash_, db_timestamp=db_timestamp,
                observed_counts=slot_state.shards,
                diagnostics=_Diagnostics(respawns=slot_state.respawns, oom_kills=slot_state.oom_kills),
            )
            if (
                hash_ not in result[slot_id]
                or observed_conf > result[slot_id][hash_]
            ):
                result[slot_id][hash_] = observed_conf
        self._slots = result


class BasesearchYpObserved(BasesearchObserved):
    def __init__(
            self,
            tier, replication,
            required_total_ratio,
            required_replicas_ratio,
            allow_lack_shards_ratio=0,
    ):
        super(BasesearchObserved, self).__init__()
        self._configuration_factory_ = functools.partial(
            _ShardConfObserved,
            target_counts=dict.fromkeys(tier.list_shard_numbers(), replication),
            required_total_ratio=required_total_ratio,
            required_replicas_ratio=required_replicas_ratio,
            allow_lack_shards_ratio=allow_lack_shards_ratio,
        )
        self._slots = {}


class MmetaObserved(_Observed):
    def __init__(
            self,
            slots_agents,
            required_total_ratio,
    ):
        super(MmetaObserved, self).__init__()
        self._configuration_factory_ = functools.partial(
            _ShardConfObserved,
            required_total_ratio=required_total_ratio,
            required_replicas_ratio=required_total_ratio,
            allow_lack_shards_ratio=0,
        )
        self._slots_agents = slots_agents

    def update(self, reports):
        result = collections.defaultdict(dict)
        for (slot_id, db_timestamp, hash_), slot_state in _observed_slots_state(reports).items():
            target_counts = _count_shard_number_replicas({agent: (0, 0) for agent in self._slots_agents.get(slot_id, [])})
            observed_conf = self._configuration_factory_(
                hash_=hash_, db_timestamp=db_timestamp,
                target_counts=target_counts, observed_counts=slot_state.shards,
                diagnostics=_Diagnostics(respawns=slot_state.respawns, oom_kills=slot_state.oom_kills),
            )
            if (
                hash_ not in result[slot_id]
                or observed_conf > result[slot_id][hash_]
            ):
                result[slot_id][hash_] = observed_conf
        self._slots = result


class IntObserved(_Observed):
    def __init__(
            self,
            slots_agents,
            required_total_ratio,
    ):
        super(IntObserved, self).__init__()
        self._configuration_factory_ = functools.partial(
            _ShardlessConfObserved,
            required_total_ratio=required_total_ratio,
        )
        self._slots_agents = slots_agents

    def update(self, reports):
        result = collections.defaultdict(dict)
        for (slot_id, hash_), observed_count in _observed_slots_state_int(reports).items():
            result[slot_id][hash_] = self._configuration_factory_(
                target_count=len(self._slots_agents[slot_id]),
                observed_count=observed_count,
                hash_=hash_,
            )
        self._slots = result


def _iter_slots_configurations(reports):
    failed = 0
    for report in reports.values():
        try:
            for slot_id, configurations in report['slots'].items():
                for conf in configurations:
                    yield slot_id, conf
        except (TypeError, ValueError, KeyError):
            failed += 1
    if failed:
        logging.info('failed to parse %s/%s reports', failed, len(reports))


class _Diagnostics(object):
    class _Counter(object):
        def __init__(self, counts):
            self.total = sum(counts)
            self.median = sorted(counts)[len(counts) / 2]

    def __init__(self, respawns, oom_kills):
        self.respawns = self._Counter(respawns)
        self.oom_kills = self._Counter(oom_kills)


class _SlotState(object):
    def __init__(self):
        self.shards = collections.defaultdict(int)
        self.respawns = []
        self.oom_kills = []


def _observed_slots_state(reports):
    slots_state = collections.defaultdict(_SlotState)

    for slot_id, conf in _iter_slots_configurations(reports):
        if conf.get('shard') and conf.get('conf_hash'):
            shard = tiers.parse_shard(conf['shard'].split('/')[-1])
            shard_number = shard.group_number, shard.shard_number

            slot_state = slots_state[slot_id, shard.timestamp, int(conf['conf_hash'])]
            slot_state.respawns.append(int(conf.get('respawn_count', 0)))
            slot_state.oom_kills.append(int(conf.get('oom_kills', 0)))
            if conf.get('running'):
                slot_state.shards[shard_number] += 1

    return slots_state


def _observed_slots_state_int(reports):
    slots_state = collections.defaultdict(int)
    for slot_id, conf in _iter_slots_configurations(reports):
        if conf.get('conf_hash') and conf.get('running'):
            slots_state[slot_id, int(conf['conf_hash'])] += 1
    return slots_state
