import logging

from enum import Enum

from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.common.types.notification import Transport
from sandbox.common.types.task import Status
from sandbox.projects.yabs.qa.solomon.mixin import SolomonTaskMixin, SolomonTaskMixinParameters
from sandbox.projects.yabs.qa.tasks.YabsCollectTestingMetrics.task_status import (
    TrunkTaskStatusCounter,
    TaskStatusCounter,
    TestResultsStatusCounter,
)
from sandbox.projects.yabs.qa.tasks.YabsCollectTestingMetrics.spec import get_spec_resources, get_version_sensors, get_resource_age_sensors, get_ab_experiment_sensors

logger = logging.getLogger(__name__)


class TaskGroups(Enum):
    trunk_task_statuses = 'trunk_task_statuses'
    ab_experiment_task_statuses = 'ab_experiment_task_statuses'
    oneshot_task_statuses = 'oneshot_task_statuses'
    brave_tests_task_statuses = 'brave_tests_task_statuses'

    ab_experiment_test_results = 'ab_experiment_test_results'
    oneshot_test_results = 'oneshot_test_results'
    brave_tests_test_results = 'brave_tests_test_results'


def field_name(task_group):
    return 'last_timestamp_{}'.format(task_group.value)


class LastProcessedTimestamp(sdk2.Parameters):
    last_timestamp = sdk2.parameters.Integer('Updated timestamp of the last task for {tier}')


class YabsCollectTestingMetrics(SolomonTaskMixin, sdk2.Task):
    '''Collect yabs testing (aka ShM) metrics'''

    class Requirements(sdk2.Task.Requirements):
        cores = 1
        ram = 4096

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Parameters):
        solomon_parameters = SolomonTaskMixinParameters()

        with sdk2.parameters.Output:
            trunk_task_statuses_last_timestamp = sdk2.parameters.Integer('Updated timestamp of the last trunk updated task')
            ab_experiment_task_statuses_last_timestamp = sdk2.parameters.Integer('Updated timestamp of the last ab_experiment updated task')
            oneshot_task_statuses_last_timestamp = sdk2.parameters.Integer('Updated timestamp of the last oneshot updated task')

            for task_group in TaskGroups:
                sdk2.helpers.set_parameter(
                    '{}_last_timestamp'.format(task_group.value),
                    LastProcessedTimestamp(suffix='_' + task_group.value, label_substs={'tier': task_group.value})
                )

    def on_create(self):
        super(YabsCollectTestingMetrics, self).on_create()
        if self.Context.copy_of:
            self.Parameters.notifications = [
                notification for notification in self.Parameters.notifications
                if notification.transport != Transport.JUGGLER
            ]

    def publish_results(self):
        for task_group in TaskGroups:
            with self.memoize_stage['set_output_{}'.format(task_group.value)]:
                param_name = field_name(task_group)
                last_timestamp = getattr(self.Context, param_name)
                if last_timestamp:
                    setattr(self.Parameters, param_name, last_timestamp)

    def on_success(self, *args, **kwargs):
        super(YabsCollectTestingMetrics, self).on_success(*args, **kwargs)
        self.publish_results()

    def on_failure(self, *args, **kwargs):
        super(YabsCollectTestingMetrics, self).on_failure(*args, **kwargs)
        self.publish_results()

    def on_break(self, *args, **kwargs):
        super(YabsCollectTestingMetrics, self).on_break(*args, **kwargs)
        self.publish_results()

    def get_last_timestamp(self):
        last_timestamp = {}
        if self.Context.copy_of:
            scheduler = 489327
            logger.debug('Task is a copy, will search for tasks in scheduler %s', scheduler)
        else:
            scheduler = self.scheduler
        logger.debug('Will search for tasks in scheduler %s', scheduler)
        if scheduler:
            last_task = YabsCollectTestingMetrics.find(scheduler=scheduler, status=Status.Group.FINISH, order='-id').limit(1).first()
            if last_task:
                logger.debug('Found task %s', last_task.id)
                for task_group in TaskGroups:
                    last_timestamp[task_group] = getattr(last_task.Parameters, field_name(task_group))

        return last_timestamp

    def on_execute(self):
        failures = []
        last_timestamp = self.get_last_timestamp()
        logger.debug('last timestamps: %s', last_timestamp)

        params = {
            TaskGroups.trunk_task_statuses: (
                TrunkTaskStatusCounter(self.server, author='robot-testenv'),
                'trunk',
                'task_statuses',
            ),
            TaskGroups.oneshot_task_statuses: (
                TaskStatusCounter(self.server, tags=['ONESHOT-TEST', 'CONTENT-SYSTEM-SETTINGS-CHANGE-TEST'], all_tags=False),
                'oneshot',
                'task_statuses',
            ),
            TaskGroups.ab_experiment_task_statuses: (
                TaskStatusCounter(self.server, tags=['AB-EXPERIMENT-TEST']),
                'ab_experiment',
                'task_statuses',
            ),
            TaskGroups.brave_tests_task_statuses: (
                TaskStatusCounter(
                    self.server,
                    tags=[
                        'TESTENV-JOB-YABS_SERVER_BRAVE_TESTS_FT',
                        'TESTENV-JOB-YABS_SERVER_BRAVE_TESTS_FT_SAAS',
                        'TESTENV-JOB-YABS_SERVER_BRAVE_TESTS_PERFORMANCE_META',
                        'TESTENV-JOB-YABS_SERVER_BRAVE_TESTS_PERFORMANCE',
                        'TESTENV-JOB-YABS_SERVER_BRAVE_TESTS_SANITIZE',
                    ],
                    all_tags=False,
                ),
                'brave_tests',
                'task_statuses',
            ),

            TaskGroups.ab_experiment_test_results: (
                TestResultsStatusCounter(self.server, type='YABS_SERVER_TEST_AB_EXPERIMENT'),
                'ab_experiment',
                'test_results',
            ),
            TaskGroups.oneshot_test_results: (
                TestResultsStatusCounter(self.server, type='YABS_SERVER_ONE_SHOT'),
                'oneshot',
                'test_results',
            ),
            TaskGroups.brave_tests_test_results: (
                TestResultsStatusCounter(self.server, type='YABS_SERVER_RUN_BRAVE_TESTS'),
                'brave_tests',
                'test_results',
            ),
        }

        for task_group, (sensor_getter, cluster, service) in params.items():
            try:
                sensors, timestamp = sensor_getter.get_sensors(start=last_timestamp.get(task_group))
                self.solomon_push_client.add(sensors, cluster=cluster, service=service)
                setattr(self.Context, field_name(task_group), timestamp)
            except Exception:
                logger.error('Failed to get %s', task_group.value, exc_info=True)
                failures.append(task_group.value)
                setattr(self.Context, field_name(task_group), last_timestamp[task_group])

        nanny_token = None
        if self.Parameters.solomon_token_yav_secret_id:
            secret = sdk2.yav.Secret(self.Parameters.solomon_token_yav_secret_id, self.Parameters.solomon_token_yav_secret_version)
            nanny_token = secret.data()['nanny_token']

        try:
            resources_from_spec, server_resources_from_spec = get_spec_resources()

            self.solomon_push_client.add(get_version_sensors(self.server, server_resources_from_spec, nanny_token=nanny_token), cluster='oneshot', service='spec')
            self.solomon_push_client.add(get_resource_age_sensors(resources_from_spec), cluster='oneshot', service='spec')
        except Exception:
            logger.error('Failed to get oneshot resource metrics', exc_info=True)
            failures.append('oneshot_spec')

        try:
            self.solomon_push_client.add(get_ab_experiment_sensors(self.server, nanny_token), cluster='ab_experiment', service='spec')
        except Exception:
            logger.error('Failed to get ab_experiment resource metrics', exc_info=True)
            failures.append('ab_experiment_spec')

        if failures:
            raise TaskFailure('Failed to get {} metrics'.format(', '.join(failures)))
