import collections
import logging

import infra.callisto.controllers.sdk as sdk
import infra.callisto.controllers.utils.funcs as funcs
import controllerv2
import skip_list


def make_controller(
    yt_observer,
    instance_provider,
    **builder_args
):
    return InnerBuildController(
        yt_observer,
        controllerv2.BuildController(yt_observer.name, yt_observer.tier, instance_provider, **builder_args),
    )


class FailuresNotification(sdk.notify.Notification):
    FailedTasks = 'failed-tasks'
    FailedTotal = 'failed-total'

    def __init__(self, tier, failed_tasks, failed_total, generation):
        self._tier = tier
        self._failed_tasks = failed_tasks
        self._failed_total = failed_total
        self._generation = generation

    @property
    def level(self):
        failed_tasks_ratio = self._failed_tasks / float(self._tier.shards_count)
        failed_total_ratio = self._failed_total / float(self._tier.shards_count)

        if self._tier.shards_count < 10:
            if failed_tasks_ratio > 0.75 and failed_total_ratio > 3.0:
                return sdk.notify.NotifyLevels.ERROR
            if failed_tasks_ratio > 0.5 and failed_total_ratio > 2.0:
                return sdk.notify.NotifyLevels.WARNING
            if failed_tasks_ratio > 0.3 and failed_total_ratio > 1.0:
                return sdk.notify.NotifyLevels.INFO
            return sdk.notify.NotifyLevels.IDLE
        else:
            if failed_tasks_ratio > 0.05 and failed_total_ratio > 0.075:
                return sdk.notify.NotifyLevels.ERROR
            if failed_tasks_ratio > 0.01 and failed_total_ratio > 0.02:
                return sdk.notify.NotifyLevels.WARNING
            if failed_tasks_ratio > 0.005 and failed_total_ratio > 0.01:
                return sdk.notify.NotifyLevels.INFO
            return sdk.notify.NotifyLevels.IDLE

    @property
    def message(self):
        header = '[tier: {}, ts: {}]'.format(self._tier.name, self._generation)
        text = 'failed tasks: {}, failed total: {}'.format(self._failed_tasks, self._failed_total)
        return '{}: {}'.format(header, text)


class TaskState(sdk.notify.ValueNotification):
    name = 'task-state'


class InnerBuildController(sdk.Controller):
    @property
    def path(self):
        return 'builder_{}'.format(self.yt_observer.name)

    def __init__(self, yt_observer, builder):
        super(InnerBuildController, self).__init__()
        self.yt_observer = yt_observer
        self.builder = builder
        self.register(self.builder)

        self.generations = []

        self._log = logging.getLogger('builder-{}'.format(self.tier.name))
        self._max_shard_versions_allowed = 2

        self.add_handler('/build_progress', self.view_progress)

    def execute(self):
        generations = [g for g in self.generations if g not in skip_list.SKIP_LIST]
        if generations:
            self._build_new(self.tier.list_shards(generations[0]))
            self._remove_not_in(generations[:self._max_shard_versions_allowed])
            self._stop_not_in(generations[:self._max_shard_versions_allowed-1])

    def update(self, reports):
        generations = self.yt_observer.get_last_generations(self._max_shard_versions_allowed)
        if generations != self.generations:
            self._log.info(
                'New generations: %s -> %s',
                _human_readable_generations(self.generations),
                _human_readable_generations(generations)
            )
            self.generations = generations

    def notifications(self):
        lst = []
        if self.generations:
            fails_count = self._fails_count_if_not_built()
            for ts, state in self.generations_status().items():
                if ts != self.generations[0]:
                    continue
                lst.append(FailuresNotification(
                    tier=self.tier,
                    failed_tasks=fails_count[ts][FailuresNotification.FailedTasks],
                    failed_total=fails_count[ts][FailuresNotification.FailedTotal],
                    generation=ts,
                ))
            counts = _eval_counts_by_namespace(self.builder.tasks_state())
            for namespace, counts in counts.iteritems():
                for status in _STATUSES:
                    lst.append(TaskState(
                        value=counts[status],
                        labels=dict(namespace=namespace, status=status)
                    ))
        return lst

    def _shards_state(self):
        return {
            sdk.tier.parse_shard(state.task.resource_name): state
            for task_id, state in self.builder.tasks_state().iteritems()
        }

    def _build_new(self, shards):
        building = set()
        for shard, state in self._shards_state().iteritems():
            if state.mode == 'build':
                building.add(shard)

        shards = set(shards) - building

        tasks = [
            self.yt_observer.make_task(shard)
            for shard in shards
        ]
        for task in tasks:
            self.builder.build(task)

    def _remove_not_in(self, generations):
        to_remove = []
        for shard, state in self._shards_state().iteritems():
            if (
                (shard.timestamp not in generations or shard.timestamp in skip_list.SKIP_LIST)
                and state.mode != 'remove'
            ):
                to_remove.append(state.task.task_id)

        for task_id in to_remove:
            self.builder.remove(task_id)

    def _stop_not_in(self, generations):
        to_stop = []
        for shard, state in self._shards_state().iteritems():
            if shard.timestamp not in generations and state.mode == 'build':
                to_stop.append(state.task.task_id)

        for task_id in to_stop:
            self.builder.keep(task_id)

    def generations_status(self):
        ts = collections.defaultdict(lambda: dict(BUILD=0, DONE=0, IDLE=0, FAILURE=0, none=0))
        for shard, state in self._shards_state().iteritems():
            ts[shard.timestamp][state.observed.status] += 1
        return dict(ts)

    def _fails_count_if_not_built(self):
        ts = collections.defaultdict(lambda: {
            FailuresNotification.FailedTotal: 0, FailuresNotification.FailedTasks: 0
        })
        for shard, state in self._shards_state().iteritems():
            if not state.observed.prepared:
                ts[shard.timestamp][FailuresNotification.FailedTotal] += state.observed.total_fails_count
                ts[shard.timestamp][FailuresNotification.FailedTasks] += 1 if state.observed.total_fails_count > 0 else 0
        return ts

    def prepared(self):
        return {
            shard for shard, state in self._shards_state().iteritems()
            if state.observed.prepared
        }

    def timestamps(self):
        return {
            sdk.tier.parse_shard(state.task.resource_name).timestamp
            for state in self.builder.tasks_state().itervalues()
        }

    @property
    def tier(self):
        return self.yt_observer.tier

    def json_view(self):
        return {
            'current_deploy_state': self.generations[0] if self.generations else None,
            'current_deploy_states': self.generations,
            'full_state': self.generations_status(),
        }

    def html_view(self):
        gen_status = self.generations_status()
        href_list = [
            sdk.blocks.Href(
                'progress',
                sdk.request.absolute_path(sdk.request.ctrl_path(self)) + '?handler=/build_progress&viewer=1'
            ),
            sdk.blocks.Href('yt-root', self.yt_observer.source.href),
        ]
        if self.yt_observer.namespace_prefix is not None:
            href_list.append(sdk.blocks.Href('tracker', self.yt_observer.tracker_href))
        return sdk.blocks.BuilderView(
            bars=[
                sdk.blocks.Progress(gen_status[gen]['DONE'], sum(gen_status[gen].values()), timestamp=gen)
                for gen in sorted(gen_status, reverse=True)
            ],
            name=self.tier.name,
            href_list=sdk.blocks.HrefList(href_list),
        )

    def __str__(self):
        return 'Inner({})'.format(self.builder)

    def build_progress(self, tier_name):
        assert tier_name == self.tier.name
        return {
            task_state.task.resource_name: task_state.observed.json()
            for task_state in self.builder.tasks_state().itervalues()
        }

    @sdk.request.add_viewer('build')
    def view_progress(self):
        timestamp_shard_state = collections.defaultdict(list)
        for shard, state in self._shards_state().iteritems():
            shard_state = dict(state.observed.json(), shard=shard.fullname)
            timestamp_shard_state[shard.timestamp].append(shard_state)
        return timestamp_shard_state


def _human_readable_generations(generations):
    return [funcs.timestamp_to_yt_state(ts) for ts in generations]


_STATUSES = ['prepared', 'building', 'idle', 'failed', 'dead', 'none']


def _status(observed):
    for status in _STATUSES:
        if getattr(observed, status, None):
            return status
    return 'none'


def _eval_counts_by_namespace(tasks_state):
    counts = collections.defaultdict(lambda: {key: int() for key in _STATUSES})

    for task_state in tasks_state.itervalues():
        namespace = task_state.task.namespace
        observed = task_state.observed
        if namespace:
            _update_counts(observed, counts[namespace])

    return counts


def _update_counts(resource_state, counts):
    counts[_status(resource_state)] += 1
