import logging

import sandbox.common.types.task as ctt
from sandbox import sdk2

from sandbox.projects.yabs.audit.executor import YabsAudit
from sandbox.projects.yabs.audit.lib import BaseAuditBinTask, Args


class YabsAuditCoordinator(BaseAuditBinTask):
    """Audit Coordinator"""
    class Parameters(BaseAuditBinTask.Parameters):
        startrek_token_secret_id = sdk2.parameters.YavSecret(
            label="Startrek token secret id",
            required=True,
            description='secret should contain keys: startrek_token',
        )

        startrek_queue = sdk2.parameters.String(
            label="Startrek queue",
            default="TEST",
            required=True,
        )

        abc_token_secret_id = sdk2.parameters.YavSecret(
            label="ABC token secret id",
            required=True,
            description='secret should contain keys: abc_token',
        )

        solomon_token_secret_id = sdk2.parameters.YavSecret(
            label="Solomon token secret id",
            required=True,
            description='secret should contain keys: solomon_token',
            default='sec-01eca9c44wtrkdjsfygbz7pdtr'
        )

    @staticmethod
    def find_audit_tasks(ids):
        if len(ids) > 0:
            return sdk2.Task.find(id=ids).limit(len(ids))
        else:
            return []

    def create_audit_task(self, task, enqueue=False):
        from yabs.stat.yabs_yt_audit.lib import find_audit_class_by_name

        audit_class = find_audit_class_by_name(task['name'])
        yql_token_secret_id = self.Parameters.yql_token_secret_id
        if audit_class.get_yql_token_secret_id() is not None:
            yql_token_secret_id = audit_class.get_yql_token_secret_id()

        sb_task = YabsAudit(
            None,
            release_version=self.Parameters.release_version,
            owner=self.owner,
            description='{name} {circuit}'.format(**task),
            priority=ctt.Priority(ctt.Priority.Class.SERVICE,
                                  ctt.Priority.Subclass.HIGH),
            yt_proxy=self.Parameters.yt_proxy,
            yt_path_prefix=self.Parameters.yt_path_prefix,
            queue_table_name=self.Parameters.queue_table_name,
            history_table_name=self.Parameters.history_table_name,
            yt_token_secret_id=self.Parameters.yt_token_secret_id,
            yql_token_secret_id=yql_token_secret_id,
            task=task,
            ch_scheme=self.Parameters.ch_scheme,
            ch_host=self.Parameters.ch_host,
            ch_port=self.Parameters.ch_port,
            ch_user=self.Parameters.ch_user,
            ch_password_secret_id=self.Parameters.ch_password_secret_id,
            ch_db=self.Parameters.ch_db,
            ch_secure=self.Parameters.ch_secure,
        )
        if enqueue:
            sb_task.enqueue()
        return sb_task.id

    def on_execute(self):
        from yabs.stat.yabs_yt_audit.lib import coordinate, list_tasks, update_task
        from yabs.stat.yabs_yt_audit.task import Task

        args = Args(self.Parameters)
        pending_tasks = list_tasks(args)
        logging.info("Pending tasks: %s", pending_tasks)

        descriptions = {
            ctt.Status.SUCCESS: 'success',
            ctt.Status.FAILURE: 'diverged',
            ctt.Status.EXCEPTION: 'failed',
            ctt.Status.TIMEOUT: 'failed',
            ctt.Status.DELETED: 'failed',
            ctt.Status.STOPPED: 'failed',
        }
        sb_tasks = self.find_audit_tasks([task['task_id'] for task in pending_tasks])
        sb_statuses = {sb_task.id: sb_task.status for sb_task in sb_tasks}
        logging.info("Sandbox statuses: %s", sb_statuses)

        for audit_task in pending_tasks:
            sb_status = sb_statuses.get(audit_task['task_id'], ctt.Status.DELETED)
            audit_task['status'] = descriptions.get(sb_status)

        tasks_for_launching = coordinate(args, [Task(**t) for t in pending_tasks if t['status']], self)

        for task in tasks_for_launching:
            task['task_id'] = self.create_audit_task(task, enqueue=True)
            logging.info('Started task: %s', task)
            if not args.debug:
                update_task(args, Task(**task))
