import time
import logging
import datetime

from sandbox import sdk2
from sandbox.common.utils import get_task_link
from sandbox.common.types.task import Status
from sandbox.common.types.notification import Transport
from sandbox.projects.common.yabs.graphite import Graphite, one_min_metric

from killer import get_hanging_tasks_by_status, kill_tasks


def push_metrics(results, backend='graphite'):
    metrics = []
    for task_type, kill_results in results.items():
        metric = one_min_metric('tasks_killed', task_type, hostname='sbyt_watchdog')
        for status, killed in kill_results.items():
            metrics.append({'name': metric(status), 'timestamp': int(time.time()), 'value': len(killed)})

    if backend == 'graphite':
        metrics_backend = Graphite()
    else:
        raise NotImplementedError(
            'We\'re old school guys and we\'re not gonna use dat hipster Solomon or whatever else')

    metrics_backend.send(metrics)


def enqueue_murders(thresholds):
    for task_type, ttl_by_status in thresholds.items():
        for status, ttl in ttl_by_status.items():
            yield (task_type, status, ttl)


class TaskKillerThresholds(sdk2.parameters.JSON):
    required = True

    @classmethod
    def cast(cls, value):
        for task_type, status_thresholds in value.items():
            if isinstance(status_thresholds, dict):
                for status, threshold in status_thresholds.items():
                    if not hasattr(Status, status):
                        raise Exception('Invalid status name: `%s`' % status)
                    if not isinstance(threshold, int):
                        raise Exception('Invalid `%s.%s` value type: `%s`, must be `int`' % (task_type, status, type(threshold)))
            else:
                raise Exception('Invalid `%s` value type: `%s`, must be `dict`' % (task_type, type(status_thresholds)))

        return value


class YabsTaskKiller(sdk2.Task):
    """ Kills hanging tasks """
    class Requirements(sdk2.Requirements):
        cores = 1

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Task.Parameters):
        max_restarts = 1
        description = "Killing hanging tasks: %s" % time.strftime("%Y-%m-%d %H:%M:%S")

        notify_on_success = sdk2.parameters.Bool("Send report in case of successfull run", default=False)

        max_kills_per_invocation = sdk2.parameters.Integer("Total maximum kills per invocation", default=1000)

        thresholds = TaskKillerThresholds(
            'Thresholds',
            description='JSON string with thresholds config. Example: {"TASK_TYPE": {"STATUS": 100500, ...}}. All thresholds must be in seconds')

    def generate_report(self, results):
        report_lines = []
        for task_type, kill_results in results.items():
            report_lines.append('Killed `%s` tasks:' % task_type)
            for status, killed in kill_results.items():
                links = ['<a href="%s">%s</a>' % (get_task_link(task_id), task_id) for task_id in killed]
                report_lines.append('  From `%s` status (total %d): [%s]' % (status, len(killed), ', '.join(links)))
        return "\n".join(report_lines)

    def notify_watchers(self, info):
        recipients = sum((list(n.recipients) for n in self.Parameters.notifications), [])
        logging.info("Will send emails to [%s]", ", ".join(recipients))
        self.server.notification(
            subject='[YaBS] TaskKiller report',
            body=info,
            recipients=sum((list(n.recipients) for n in self.Parameters.notifications), []),
            transport=Transport.EMAIL,
            urgent=False
        )

    def confess(self, results, kill_counter):
        info = self.generate_report(results)
        self.set_info("<pre>%s</pre>" % info, do_escape=False)

        push_metrics(results)

        if kill_counter > 0 and self.Parameters.notify_on_success:
            self.notify_watchers(info)

    def on_execute(self):
        thresholds = self.Parameters.thresholds
        max_kills_per_invocation = self.Parameters.max_kills_per_invocation
        results = {task_type: {} for task_type in thresholds.keys()}
        kill_counter = 0
        overall_failed_counter = 0

        for task_type, status, ttl in enqueue_murders(thresholds):
            results[task_type][status] = []
            for tasks_batch in get_hanging_tasks_by_status(self.server, task_type, status, ttl):
                logging.info('Found %d old `%s` tasks in `%s` status updated more than %s from now',
                             len(tasks_batch), task_type, status, str(datetime.timedelta(seconds=ttl)))
                task_ids = [t['id'] for t in tasks_batch]

                logging.info('About to delete tasks [%s]', ', '.join(map(unicode, task_ids)))
                succeeded, failed = kill_tasks(self.server, task_ids)

                logging.info('Successfully deleted: [%s]', ', '.join(map(unicode, succeeded)))
                if len(failed) > 0:
                    overall_failed_counter += len(failed)
                    for _id, _info in failed.items():
                        logging.warning('Cannot delete task %s, status: `%s`, reason: `%s`',
                                        get_task_link(_id), _info["status"], _info["message"])

                results[task_type][status] += succeeded
                kill_counter += len(succeeded)
                time.sleep(3)  # not to overload sandbox-api
                if kill_counter > max_kills_per_invocation:
                    break

        self.confess(results, kill_counter)
        if overall_failed_counter > 0:
            raise Exception("Failed to kill {} task(s), see logs for more details".format(overall_failed_counter))
