import collections
import logging
import api
import abstract
import check
import flags
from utils.cpu import cmp_cpu, cpu_count


class EventsCounter(object):
    def __init__(self):
        def _factory():
            return collections.defaultdict(int)
        self._clusters = collections.defaultdict(_factory)

    def inc(self, cluster, event):
        self._clusters[cluster][event] += 1

    def get(self, cluster, event):
        return self._clusters[cluster][event]

    def to_tuples(self):
        res = []
        for cluster in self._clusters:
            for event, value in self._clusters[cluster].items():
                res.append([cluster, event, value])
        return res


class FakeMovingAverage(object):
    def __init__(self, init):
        self._value = init

    def push(self, value):
        self._value = value

    @property
    def value(self):
        return self._value


class VmController(object):
    _logger = logging.getLogger('VmController')
    _cpu_total = cpu_count() * 100.0

    def __init__(self, configs, monitor, trigger, cpu_min, mem_min):
        super(VmController, self).__init__()
        self._configs = configs
        self._monitor = monitor
        self._trigger = trigger

        self._cpu_destroy_on = cpu_min
        self._mem_destroy_on = mem_min
        self._cpu_coeff = 1.0

        self._events_counter = EventsCounter()

    def _running(self):
        return {vm.porto_name for vm in self._monitor.list_psi()}

    def _should_run(self):
        if not flags.psi_stop_flag():
            return {config.porto_name for config in self._configs}
        return set()

    def _kill_invalid_psi(self):
        running = self._running()

        for porto_name in running - self._should_run():
            self._logger.info('destroying %s reason: not present in configs', porto_name)
            self._events_counter.inc(self._monitor[porto_name].cluster, 'destroy_not_in_config')
            api.ensure_destroyed(porto_name)
            running.discard(porto_name)

        for porto_name in running:
            if api.is_dead(porto_name) is True:
                self._logger.info('%s is dead', porto_name)
                self._events_counter.inc(self._configs[porto_name].cluster, 'destroy_dead')
                api.ensure_destroyed(porto_name)

            if self._configs[porto_name] != self._monitor[porto_name]:
                self._logger.info('destroying %s reason: invalid psi_configuration', porto_name)
                self._events_counter.inc(self._configs[porto_name].cluster, 'destroy_invalid_config')
                api.ensure_destroyed(porto_name)

            if not flags.psi_disable_custom_checks():
                status = check.check(self._configs[porto_name].cluster, porto_name)
                if status is not None:
                    self._logger.info('destroying %s reason: %s', porto_name, status)
                    self._events_counter.inc(self._configs[porto_name].cluster, 'destroy_check_{}'.format(status))
                    api.ensure_destroyed(porto_name)

    def _ensure_started(self, config, cpu_start, mem_start):
        if api.is_started(config.porto_name) is False:
            self._logger.info('%s not started', config.porto_name)
            self._events_counter.inc(config.cluster, 'launch')
            if config.mode == abstract.ContainerMode.os:
                api.launch_vm(config, cpu_start, mem_start)
            if config.mode == abstract.ContainerMode.app:
                api.launch_app(config, cpu_start, mem_start)

    def _fix_diff(self):
        self._kill_invalid_psi()

        for porto_name in self._should_run():
            if self._is_enough_resources(porto_name):
                if api.is_started(porto_name) or self._trigger.on:
                    mem, cpu = self._mem_for(porto_name), self._cpu_for(porto_name)
                    self._ensure_started(self._configs[porto_name], cpu.min, mem.min)
            else:
                logging.info('not enough resources')
                if api.is_started(porto_name):
                    api.ensure_destroyed(porto_name)

    def _cpu_for(self, porto_name):
        cpu = self._configs[porto_name].cpu.copy()
        cpu *= self._cpu_coeff
        return cpu

    def _mem_for(self, porto_name):
        return self._configs[porto_name].mem

    def _is_mem_eq(self, porto_name):
        return all((
            self._monitor[porto_name].mem['min'] == self._mem_for(porto_name).min,
            self._monitor[porto_name].mem['max'] == self._mem_for(porto_name).max
        ))

    def _is_cpu_eq(self, porto_name):
        return all((
            cmp_cpu(self._monitor[porto_name].cpu['min'], self._cpu_for(porto_name).min) == 0,
            cmp_cpu(self._monitor[porto_name].cpu['max'], self._cpu_for(porto_name).max) == 0
        ))

    def _is_enough_resources(self, porto_name):
        mem, cpu = self._mem_for(porto_name), self._cpu_for(porto_name)
        if cpu.max < self._cpu_destroy_on or mem.max < self._mem_destroy_on:
            return False
        return True

    def _apply_limits(self):
        for porto_name in self._running():
            if api.is_started(porto_name):
                mem = self._mem_for(porto_name)
                cpu = self._cpu_for(porto_name)
                if not self._is_mem_eq(porto_name):
                    api.apply_mem(porto_name, mem.min, mem.max)
                if not self._is_cpu_eq(porto_name):
                    api.apply_cpu(porto_name, cpu.min, cpu.max)

    def _calc_coeff(self):
        left = (self._trigger.high_watermark - self._trigger.value) / 100.0
        left += sum(one.cpu['usage'].value for one in self._monitor.list_psi()) / self._cpu_total
        cpu_needed = sum(one.cpu.max for one in self._configs) / self._cpu_total
        if cpu_needed != 0:
            self._cpu_coeff = max(0.0, min(1.0, left / cpu_needed))
        self._logger.debug('cpu_coeff: %s', self._cpu_coeff)

    def check(self):
        self._calc_coeff()
        self._fix_diff()
        self._apply_limits()

    def stop(self, namespace=None):
        if namespace:
            api.ensure_destroyed(namespace)
        else:
            for container in self._monitor.list_psi():
                api.ensure_destroyed(container.porto_name)

    @property
    def events_counter(self):
        return self._events_counter
