import itertools
import logging
import os
import signal

from sandbox import sdk2

from sandbox.common.errors import TaskFailure
from sandbox.common.types.misc import NotExists
from sandbox.common.types.task import Status
from sandbox.projects.common.binary_task import LastBinaryTaskRelease, LastBinaryReleaseParameters
from sandbox.projects.modadvert import resource_types

import sandbox.common.types.notification as ctn
import sandbox.common.types.task as ctt


class YtCluster(sdk2.parameters.String):
    choices = (
        ('arnold', 'arnold'),
        ('hahn', 'hahn'),
        ('markov', 'markov'),
        ('hume', 'hume'),
        ('locke', 'locke'),
        ('freud', 'freud'),
        ('landau', 'landau'),
        ('bohr', 'bohr'),
    )


class LoggingLevel(sdk2.parameters.String):
    choices = (
        ('INFO', 'INFO'),
        ('WARN', 'WARN'),
        ('ERROR', 'ERROR'),
        ('DEBUG', 'DEBUG'),
    )


class ModadvertBaseTask(sdk2.Task):

    class Context(sdk2.Task.Context):
        subtasks = []

    class Requirements(sdk2.Task.Requirements):
        cores = 1

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Task.Parameters):
        notifications = [
            sdk2.Notification(
                statuses=[ctt.Status.FAILURE, ctt.Status.Group.BREAK],
                recipients=["modadvert-reports@yandex-team.ru"],
                transport=ctn.Transport.EMAIL,
            )
        ]

    def create_env(self):
        env = {}
        if self.scheduler > 0:
            scheduler_id_str = str(self.scheduler)
            logging.info('SANDBOX_SCHEDULER_ID=%s', scheduler_id_str)
            env['SANDBOX_SCHEDULER_ID'] = scheduler_id_str
        return env

    def validate(self):
        logging.info('starting {} validate()...'.format(self.__class__.__name__))

    def on_before_execute(self):
        logging.info('starting {} on_execute()...'.format(self.__class__.__name__))

    def on_execute_inner(self):
        pass

    def on_execute(self):
        self.validate()
        self.on_before_execute()
        self.on_execute_inner()
        self.on_after_execute()

    def on_after_execute(self):
        logging.info('finished {} on_execute()...'.format(self.__class__.__name__))

    def on_break(self, prev_status, status):
        self.grecefully_kill_child_processes()

    def on_finish(self, prev_status, status):
        self.grecefully_kill_child_processes()

    def on_before_timeout(self, seconds):
        self.grecefully_kill_child_processes()

    def grecefully_kill_child_processes(self):
        logging.info('Gracefully killing child subprocesses')
        for child_process in sdk2.helpers.ProcessRegistry:
            try:
                logging.info('Sending SIGINT to process {}'.format(child_process.pid))
                os.kill(child_process.pid, signal.SIGINT)
            except OSError:
                continue
        logging.info('Gracefully killed subprocesses')

    def run_command(self, command, log_prefix=None, env=None):
        command = map(str, command)

        if not env:
            env = self.create_env()

        if log_prefix is None:
            log_prefix = 'main'

        logging.info('Running command: {}'.format(command))
        with sdk2.helpers.ProcessRegistry, sdk2.helpers.ProcessLog(self, logger=log_prefix) as pl:
            sdk2.helpers.subprocess.check_call(command, stdout=pl.stdout, stderr=pl.stderr, env=env)
            pl.logger.info('Subprocess has finished successfully')

    def get_default_subtask_parameters(self):
        return {}

    def create_subtask(self, task, input_parameters, description=''):
        logging.info('Start {}'.format(description))
        task = sdk2.Task[task] if isinstance(task, basestring) else task
        input_parameters = dict(self.get_default_subtask_parameters(), **input_parameters)

        task_id = task(
            self,
            description=description,
            create_sub_task=True,
            **{
                key: value.id if isinstance(value, resource_types.AbstractResource) else value
                for key, value in input_parameters.iteritems() if value is not NotExists
            }
        ).enqueue().id
        self.Context.subtasks.append(task_id)
        return task_id

    def wait_all_subtasks(self):
        subtasks = list(self.find(id=self.Context.subtasks).limit(0))
        expected_statuses = Status.Group.FINISH + Status.Group.BREAK

        while any(subtask.status != Status.SUCCESS for subtask in subtasks):
            for subtask in subtasks:
                if subtask.status in expected_statuses and subtask.status != Status.SUCCESS:
                    for t in subtasks:
                        if t.status not in expected_statuses:
                            t.stop()
                    message = 'Subtask {} finished with status {}'.format(subtask.id, subtask.status)
                    raise TaskFailure(message)
            raise sdk2.WaitTask(subtasks, expected_statuses, wait_all=False)

    def get_author_token(self, token_name, vault_user=None):
        try:
            logging.info('Attempt to get token data %s for author %s', token_name, self.author)
            return sdk2.Vault.data(self.author, token_name)
        except Exception:
            if vault_user:
                logging.info('Failed to get token data %s for author. Attempt for another vault user %s', token_name, vault_user)
                return sdk2.Vault.data(vault_user, token_name)
            raise

    def untar_resource(self, resource, dst_dir=None, log_prefix=None):
        tar_data_path = sdk2.ResourceData(resource).path.as_posix()
        if log_prefix is None:
            log_prefix = 'untar_{}'.format(str(resource.type).lower())

        if dst_dir:
            self.run_command(
                ['mkdir', '-p', dst_dir],
                log_prefix=log_prefix
            )

        self.run_command(
            ['tar', '-xvf', tar_data_path] + (['-C', dst_dir] if dst_dir else []),
            log_prefix=log_prefix
        )


class ModadvertBaseBinaryTask(ModadvertBaseTask, LastBinaryTaskRelease):
    """
    Base class for SandBox binary tasks. See more about binary tasks: https://wiki.yandex-team.ru/sandbox/tasks/binary/
    On execute this task automatically runs last released binary for this task.
    """

    class Parameters(ModadvertBaseTask.Parameters):
        ext_params = LastBinaryReleaseParameters()

    def on_save(self):
        LastBinaryTaskRelease.on_save(self)
        ModadvertBaseTask.on_save(self)

    def on_execute(self):
        LastBinaryTaskRelease.on_execute(self)
        ModadvertBaseTask.on_execute(self)


class ModadvertBaseYtTask(ModadvertBaseTask):

    class Parameters(ModadvertBaseTask.Parameters):
        with sdk2.parameters.Group('Credentials') as credentials_group:
            vault_user = sdk2.parameters.String('Vault user', default='MODADVERT')
            tokens = sdk2.parameters.Dict('Tokens', default={'YT_TOKEN': 'yt-token'})

        with sdk2.parameters.Group('Clusters') as clusters_group:
            yt_proxy_url = YtCluster('YT master cluster', default='arnold')
            yt_worker_proxy_url = YtCluster('YT master cluster for dynamic tables', default='markov')

    def get_default_subtask_parameters(self):
        return dict(
            super(ModadvertBaseYtTask, self).get_default_subtask_parameters(),
            vault_user=self.Parameters.vault_user,
            tokens=self.Parameters.tokens,
            yt_proxy_url=self.Parameters.yt_proxy_url
        )

    def get_yt_token(self):
        return sdk2.Vault.data(self.Parameters.vault_user, self.Parameters.tokens['YT_TOKEN'])

    def create_env(self):
        env = super(ModadvertBaseYtTask, self).create_env()
        for key, vault_name in self.Parameters.tokens.iteritems():
            env[key] = sdk2.Vault.data(self.Parameters.vault_user, vault_name)
        env['YT_USER'] = self.Parameters.vault_user
        return env


class ModadvertBaseRunBinaryTask(ModadvertBaseYtTask):

    resource_name = None

    def create_command(self):
        return []

    def get_latest_resource(self, resource_name):
        logging.info('Attempt to find resource with name: {}'.format(resource_name))
        return sdk2.Resource.find(
            resource_type=resource_types.YA_PACKAGE,
            state='READY',
            attr_name='resource_name',
            attr_value=resource_name,
            order='-id'
        ).first()

    def load_latest_resource(self, resource_name):
        resource = self.get_latest_resource(resource_name)
        tarball_path = sdk2.ResourceData(resource).path.as_posix()
        self.run_command(['tar', '-xvf', tarball_path], log_prefix='untar')

    def on_execute_inner(self):
        self.load_latest_resource(self.resource_name)
        self.run_command(self.create_command(), log_prefix='main')


class ModadvertBaseRunSupermoderation(ModadvertBaseYtTask):
    class Parameters(ModadvertBaseYtTask.Parameters):
        with ModadvertBaseYtTask.Parameters.credentials_group() as credentials_group:
            tokens = sdk2.parameters.Dict(
                'Tokens',
                default={
                    'YT_TOKEN': 'yt-token',
                    'MODADVERT_CLIENT_ID': 'modadvert-client-id',
                    'MODADVERT_CLIENT_SECRET': 'modadvert-client-secret',
                }
            )

        with ModadvertBaseYtTask.Parameters.clusters_group() as clusters_group:
            automatic_worker_cluster_detection = sdk2.parameters.Bool('Automatically detect worker cluster', default=True)
            with automatic_worker_cluster_detection.value[False]:
                yt_worker_proxy_url = YtCluster('YT worker cluster', default='arnold')

        with sdk2.parameters.Group('Config') as config_group:
            config = sdk2.parameters.Resource(
                'Resource with supermoderation config',
                resource_type=resource_types.SUPERMODERATION_CONFIG_TYPES,
            )

            service_confs = sdk2.parameters.Resource(
                'Resource with supermoderation service config',
                resource_type=resource_types.SUPERMODERATION_SERVICE_CONFIG_TYPES,
            )

        with sdk2.parameters.Group('Launch parameters') as launch_group:
            iterations = sdk2.parameters.Integer('Number of iterations', required=False, default=10)
            sleep_seconds = sdk2.parameters.Integer('Sleep seconds', required=False, default=1)
            testing = sdk2.parameters.Bool('Testing mode', default=False, required=False)
            logging_level = LoggingLevel('Logging level', default='INFO')
            profile_performance = sdk2.parameters.Bool('Profile performance', default=False, required=False)
            cmd_options = sdk2.parameters.Dict('Binary options', default={})
            cmd_flags = sdk2.parameters.List('Binary flags', default=[])

    def get_base_cmd(self):
        return [self.Parameters.binaries_resource.entry_point]

    def create_command(self):
        cmd = self.get_base_cmd() + [
            '--yt-proxy-url', self.Parameters.yt_proxy_url,
            '--conf', './config.yaml',
            '--service-conf-dir', self.service_confs_dir,
            '--iterations', self.Parameters.iterations,
            '--sleep-seconds', self.Parameters.sleep_seconds,
            '--logging-level', self.Parameters.logging_level
        ]

        if not self.Parameters.automatic_worker_cluster_detection:
            cmd.extend(['--yt-worker-proxy-url', self.Parameters.yt_worker_proxy_url])

        if self.Parameters.testing:
            cmd.append('--testing')

        if self.Parameters.profile_performance:
            profile_stats_dir = './profile_stats'
            resource_types.MODADVERT_CPROFILE_STATS(self, 'cProfile stats', profile_stats_dir)
            cmd.extend(['--profile-stats-dir', profile_stats_dir])

        for pair in self.Parameters.cmd_options.iteritems():
            cmd.extend(pair)

        for item in self.Parameters.cmd_flags:
            cmd.append(item)

        return cmd

    def on_before_execute(self):
        super(ModadvertBaseRunSupermoderation, self).on_before_execute()
        self.service_confs_dir = 'service_confs'
        self.untar_resource(self.Parameters.config)
        self.untar_resource(self.Parameters.service_confs, dst_dir=self.service_confs_dir)
        self.untar_resource(self.Parameters.binaries_resource)

    def on_execute_inner(self):
        self.run_command(self.create_command(), log_prefix='main')


def merge_configs(*configs):
    if len(configs) > 1 and all(isinstance(config, dict) for config in configs):
        keys = set()
        for config in configs:
            keys.update(config.keys())
        result_config = {}
        for key in keys:
            result_config[key] = merge_configs(*(config[key] for config in configs if key in config))
        return result_config
    return configs[0]


def chunk_iterator(sequence, size):
    """ Copy-paste from modadvert/libs/utils/iterator because of inability to import from SB tasks """
    if size <= 0:
        raise ValueError('positive `size` expected but {} found'.format(size))
    iterator = iter(sequence)
    return iter(lambda: list(itertools.islice(iterator, size)), [])
