# -*- coding: utf-8 -*-
import logging
import os
import threading
import time

import sandbox.common.types.task as ctt
import sandbox.sdk2 as sdk2
from sandbox.sandboxsdk.errors import SandboxTaskFailureError

from sandbox.projects.inventori.common import InventoryRunTaskTemplate
from sandbox.projects.inventori.common import binary_task
from sandbox.projects.inventori.common import resources
from sandbox.projects.inventori.common.resources import RunMode
from sandbox.projects.inventori.common.utils import report_status


class ReleaseParameters(sdk2.Parameters):
    ext_params = binary_task.binary_release_parameters_list(stable=True)


class RunInventoriTasks(InventoryRunTaskTemplate.InventoriRunTaskTemplateVersioned):
    SANDBOX_BASE_API_URL = 'https://sandbox.yandex-team.ru/api/v1.0'
    FINISHED_STATUSES = ctt.Status.Group.FINISH | ctt.Status.Group.BREAK

    class Requirements(sdk2.Requirements):
        """
        Small requirements as most of tasks are simply do yt or yql calls
        """
        disk_space = 2 * 1024  # 2 GiB

        # Requirements for multislot agents
        cores = 1  # < 16
        ram = 512  # 512 MiB < 64 GiB

        class Caches(sdk2.Requirements.Caches):
            pass  # Do not use any shared caches (required for running on multislot agent)

    class Parameters(InventoryRunTaskTemplate.get_run_params(
        resources.INVENTORI_TASK_INITIALIZERS,
        base_class=InventoryRunTaskTemplate.get_run_params(resources.InventoriBaseTaskParams,
                                                           base_class=ReleaseParameters)
    )):
        with sdk2.parameters.Output(reset_on_restart=False):
            result_tables = sdk2.parameters.List('Result tables')

    def get_default_task_name(self):
        if self.has_task_type():
            return self.task_name

    @property
    def exec_timestamp(self):
        return int(time.mktime(self.task_params.exec_time.timetuple()))

    def _report_status(self, status, yt_cluster=None):
        """
        :type staus: int
        :param status: 0 - start, -1 - fail, 1 - successfully finished
        :type status: int
        :type yt_cluster: str | list[str]
        :param yt_cluster: hahn | arnold
        """
        if not self.task_params.send_status_to_solomon:
            return
        yt_cluster = yt_cluster or self.task_params.yt_yql_cluster
        report_status(self, self.exec_timestamp, yt_cluster, status)

    def on_save(self):
        super(RunInventoriTasks, self).on_save()

        if not self.Parameters.task_name and self.has_task_type():
            self.Parameters.task_name = self.task_name

        if self.Parameters.task_name:
            self.append_tag(self.Parameters.task_name)

        if self.Parameters.run_mode == RunMode.SIMPLE:  # noqa
            self.append_tag(self.Parameters.yt_yql_cluster)  # noqa

        self.Parameters.push_tasks_resource = True

    def _ensure_child_tasks_finished(self, child_task_ids, raise_on_unsuccessful=False):
        unsuccessful_tasks = []
        self.logger.info('Process child tasks %s', child_task_ids)
        if not child_task_ids:
            return
        for task_id in child_task_ids:
            task = sdk2.Task[task_id]
            self.logger.info('Child %s status was: %s', task_id, task.status)
            if task.status != ctt.Status.SUCCESS:
                unsuccessful_tasks.append(task_id)
            if task.status not in self.FINISHED_STATUSES:
                task.stop()
        if raise_on_unsuccessful and unsuccessful_tasks:
            raise SandboxTaskFailureError('Not all task are successfully finished: {}'.format(
                ','.join(map(str, unsuccessful_tasks))
            ))

    def _get_cluster_resolver(self):
        from inventori.pylibs.utils.cluster_resolver import ClusterResolver

        return ClusterResolver(
            yt_token=self.task_params.yt_token,
            infra_token=self.task_params.oauth_token,
        )

    def spawn_child_tasks_for_clusters(self, clusters):
        common_params = dict(list(self.Parameters))  # noqa
        common_params['run_mode'] = RunMode.SIMPLE
        common_params['is_custom_run_mode_for_this_task'] = 'False'

        logging.debug('pass calculated defaults %s', self.calculated_sandbox_param_defaults)
        common_params.update(self.calculated_sandbox_param_defaults)

        child_tasks = []

        logging.debug('common_params %s', common_params)

        for cluster in clusters:
            common_params['yt_yql_cluster'] = cluster

            created_task = RunInventoriTasks(
                self,
                description='Child of https://sandbox.yandex-team.ru/task/{}'.format(self.id),
                owner=self.owner,
                kill_timeout=self.Parameters.kill_timeout + (2 * 60),  # noqa
                **common_params
            ).save().enqueue()
            logging.info('Created sub task https://sandbox.yandex-team.ru/task/%s', created_task.id)
            child_tasks.append(created_task)

        for n, task in enumerate(child_tasks):
            logging.info('#%s child task id: %s', n, task.id)
        logging.info('Start waiting now task above for status: %s ...', self.FINISHED_STATUSES)

        return child_tasks

    def _waiting_tm_tasks(self, cluster_replication, seconds_to_sleep=7):
        logging.info('_waiting_tm_tasks')
        summary_sleeping_time = 0

        while cluster_replication.is_copy_tasks_still_running():
            summary_sleeping_time += seconds_to_sleep
            time.sleep(seconds_to_sleep)

        return summary_sleeping_time

    def _do_transfer(self, tables, master_cluster, slave_cluster):
        from inventori.pylibs.utils.cluster_replication import ClusterReplication

        cluster_replication = ClusterReplication(yt_token=self.task_params.yt_token)

        tm_tasks = cluster_replication.tm_copy_tables(
            master_cluster=master_cluster, slave_cluster=slave_cluster,
            tables=tables)

        self.set_info('Start next tm_tasks:{}'.format(
            ''.join(
                '\n  * {0} <a href="https://transfer-manager.yt.yandex-team.ru/task?id={1}">{1}</a>'.format(
                    table, task_id)
                for table, task_id in tm_tasks.items()
            ),
        ), do_escape=False)

        if self.task_params.is_waiting_for_tm_tasks:
            summary_sleeping_time = self._waiting_tm_tasks(cluster_replication)
            self._report_status(1, yt_cluster=slave_cluster)
            self.set_info('Waited TM tasks for {seconds} seconds\n'.format(
                seconds=summary_sleeping_time,
                statuses='\n  * '.join(
                    '{0} <a href="https://transfer-manager.yt.yandex-team.ru/task?id={1}">{1}</a> - {2}'.format(
                        table, task_id, status)
                    for (table, task_id), status in zip(tm_tasks.items(), cluster_replication.get_final_task_statuses())
                ),
            ), do_escape=False)

    def on_execute(self):
        self.logger.info('on_execute() %s %s', threading.currentThread(), os.getpid())

        with self.memoize_stage.initialize():
            available_clusters = self._get_cluster_resolver().find_available_clusters()

            if self.task_params.run_mode == RunMode.SIMPLE:
                if self.task_params.yt_yql_cluster not in available_clusters:
                    self.set_info("Cluster {cluster} is not available. I won't run anything".format(
                        cluster=self.task_params.yt_yql_cluster))
                else:
                    self._report_status(0)

                    super(RunInventoriTasks, self).on_execute()

                    self._report_status(1)
                return

            if not available_clusters:
                self.set_info('<b>There is no available clusters!</b>', do_escape=False)
                return

            if self.task_params.run_mode == RunMode.REPLICATION:
                master_cluster, slave_cluster = self._get_cluster_resolver().get_master_slave_cluster_pair(
                    self.task_params.using_tables
                )
                self.Context.master_cluster, self.Context.slave_cluster = master_cluster, slave_cluster
                self.append_tag(master_cluster)
                self.Context.save()  # noqa

                self._report_status(0, yt_cluster=master_cluster)
                if slave_cluster and self.task_params.is_waiting_for_tm_tasks:
                    self._report_status(0, yt_cluster=slave_cluster)

                super(RunInventoriTasks, self).on_execute()

                self._report_status(1, yt_cluster=master_cluster)
                self.Context.master_cluster_been_processed = True
                self.Context.save()  # noqa

                if slave_cluster:
                    self._do_transfer(
                        (self.Context.result_tables or []) + self.Parameters.output_tables,
                        master_cluster, slave_cluster,
                    )
                    self._report_status(1, yt_cluster=slave_cluster)
                else:
                    self.set_info('Slave cluster is unavailable')

                return
            if self.task_params.run_mode == RunMode.ASYNC:
                child_tasks = self.spawn_child_tasks_for_clusters(available_clusters)
                self.Context.child_task_ids = [t.id for t in child_tasks]
                self.Context.save()  # noqa
                raise sdk2.WaitTask(
                    child_tasks,
                    ctt.Status.Group.FINISH | ctt.Status.Group.BREAK,
                    wait_all=True,
                    timeout=self.Parameters.kill_timeout,  # noqa
                )

        with self.memoize_stage.finalize():
            if self.task_params.run_mode == RunMode.ASYNC:
                self._ensure_child_tasks_finished(self.Context.child_task_ids, raise_on_unsuccessful=True)

    def _on_stop(self):
        self.logger.info('_on_stop')
        if self.task_params.run_mode == RunMode.ASYNC.value:
            self.logger.info('async _on_stop')
            self._ensure_child_tasks_finished(self.Context.child_task_ids)
        else:
            self.logger.info('regular _on_stop')
            if self.task_params.run_mode == RunMode.SIMPLE:
                self._report_status(-1)
            if self.task_params.run_mode == RunMode.REPLICATION:
                if self.Context.master_cluster and not self.Context.master_cluster_been_processed:
                    self._report_status(-1, self.Context.master_cluster)
                if self.Context.slave_cluster:
                    self._report_status(-1, self.Context.slave_cluster)
            super(RunInventoriTasks, self)._on_stop()

