from resources import get_binary_path
from binary import Binary
import logging
from itertools import chain
import os
import stat

from sandbox import sdk2
import sandbox.common.types.task as ctt


class PCodeSlowStatisticsCoordinator(sdk2.Task):
    """PCode Slow Statistics Coordinator"""
    class Parameters(sdk2.Task.Parameters):
        proxy = sdk2.parameters.String(
            label="Cluster for data",
            required=True,
            default='hahn'
        )

        contour = sdk2.parameters.String(
            label="Contour",
            required=True,
            default='0'
        )

        yt_token_vault_name = sdk2.parameters.String(
            label="YT token vault name",
            required=True,
        )

        yql_token_vault_name = sdk2.parameters.String(
            label="YQL token vault name",
            required=True,
        )

        st_token_vault_name = sdk2.parameters.String(
            label="StarTrek token vault name",
            required=True,
        )

    # noinspection PyMethodMayBeStatic
    def _find_tasks(self, task_type_name, pending_uuids):
        for uuid in pending_uuids:
            task = sdk2.Task.find(
                tags=uuid.upper(),
                task_type=sdk2.Task[task_type_name],
            ).order(-sdk2.Task.id).first()
            if task is not None:
                yield task

    def _get_operation_uuids_by_statuses(self, pending_uuids):
        tasks = chain(
            self._find_tasks('YABS_YT_UPSERTER', pending_uuids),
            self._find_tasks('RUN_YQL_2', pending_uuids)
        )

        def get_uuid_and_status(task):
            task_status = task.status
            try:
                tags = self.server.task[task.id].read()['tags']
                return tags[-1].lower(), task_status
            except Exception as e:
                logging.warning('failed to get uuid %s ' % str(e))
                return None, task_status

        uuids_by_status = {}

        for uuid, status in map(get_uuid_and_status, tasks):
            if uuid is not None:
                uuids_by_status.setdefault(status, []).append(uuid)

        return uuids_by_status

    # noinspection PyMethodMayBeStatic
    def _get_failed_operations(self, uuids_by_status):
        failed_statuses = (ctt.Status.STOPPED, ctt.Status.FAILURE, ctt.Status.EXCEPTION)
        return chain.from_iterable(uuids_by_status.get(status, []) for status in failed_statuses)

    # noinspection PyMethodMayBeStatic
    def _get_finished_operations(self, uuids_by_status):
        finished_statuses = (ctt.Status.SUCCESS, )
        return chain.from_iterable(uuids_by_status.get(status, []) for status in finished_statuses)

    # noinspection PyShadowingBuiltins,PyShadowingNames
    def on_execute(self):
        binary_path = get_binary_path()
        st = os.stat(binary_path)
        os.chmod(binary_path, st.st_mode | stat.S_IEXEC)

        proxy = self.Parameters.proxy
        contour = self.Parameters.contour
        yt_token = sdk2.Vault.data(self.Parameters.yt_token_vault_name)
        st_token = sdk2.Vault.data(self.Parameters.st_token_vault_name)
        # yql_token = sdk2.Vault.data(self.Parameters.yql_token_vault_name)

        binary = Binary(binary_path, yt_token, proxy, contour)

        pending_operations = binary.list()

        uuids_by_status = self._get_operation_uuids_by_statuses(pending_operations)

        failed_operations = list(self._get_failed_operations(uuids_by_status))
        finished_operations = list(self._get_finished_operations(uuids_by_status))

        pinged_operations = list(chain.from_iterable(uuids_by_status.values()))

        unlisted_operations = [
            operation for operation in pending_operations
            if operation not in pinged_operations
        ]

        statuses_teardown = repr(uuids_by_status) + '\n---------\n' + repr(unlisted_operations)
        self.server.task[self.id] = dict(
            description=statuses_teardown)

        actions_and_observations, time_literal = binary.coordinate(
            failed_operations + unlisted_operations, finished_operations, st_token)

        not_spaces_time_literal = time_literal.replace(' ', '_')

        actions = map(
            lambda action_and_observation: action_and_observation[0],
            filter(
                lambda action_and_observation: action_and_observation[0] is not None,
                actions_and_observations
            )
        )

        contour_tag = 'CONTOUR_%s' % contour
        cluster_tag = 'CLUSTER_%s' % proxy

        aliases = list(set(map(lambda a: a['alias'], actions)))
        self.server.task[self.id] = dict(
            tags=([not_spaces_time_literal, 'PCODE_SLOW_STATISTICS', contour_tag, cluster_tag] + aliases))

        observations = map(
            lambda action_and_observation: action_and_observation[1],
            actions_and_observations
        )

        def get_observation_description(observation):
            return '\n'.join('{}: {}'.format(key, value) for key, value in observation.items())

        description = '\n----------\n'.join(
            chain((statuses_teardown, ), map(get_observation_description, observations))
        )
        self.server.task[self.id] = dict(
            description=description)

        for action in actions:
            task_to_execute = None
            if action["type"] == "MR":
                task_to_execute = sdk2.Task['RUN_YQL_2'](
                    None, owner=self.owner,
                    description='',
                    priority=ctt.Priority(
                        ctt.Priority.Class.SERVICE, ctt.Priority.Subclass.HIGH),
                    query=action['query'], trace_query=True
                )
            elif action['type'] == 'Upsert':
                task_to_execute = sdk2.Task['YABS_YT_UPSERTER'](
                    None, owner=self.owner,
                    yt_token_vault_name=self.Parameters.yt_token_vault_name,
                    description='',
                    priority=ctt.Priority(
                        ctt.Priority.Class.SERVICE, ctt.Priority.Subclass.HIGH),
                    key_columns=action['key_columns'], value_columns=action['value_columns'],
                    output_table=action['destination_table'], input_table=action['source_table'],
                    cluster_name=proxy
                )
            self.server.task[task_to_execute.id] = dict(
                tags=["PCODE_SLOW_STATISTICS", contour_tag, cluster_tag, action['alias'], action['uuid']]
            )
            task_to_execute.enqueue()


__TASK__ = PCodeSlowStatisticsCoordinator
