import logging
import time

from abc import abstractmethod
from datetime import datetime, timedelta

from sandbox.common.types.task import Status

logger = logging.getLogger(__name__)
STATUSES = list((Status.Group.BREAK | Status.Group.FINISH) - {Status.DELETED, Status.STOPPED})


def get_tasks(sandbox_client, limit=100, **kwargs):
    tasks = []
    _tasks = [None]
    offset = 0
    while _tasks:
        logger.info('Search for tasks from %d to %d', offset, offset + limit)
        _tasks = sandbox_client.task.read(children=True, hidden=True, limit=limit, offset=offset, **kwargs)['items']
        tasks.extend(_tasks)
        if len(_tasks) < limit:
            break
        offset += limit
    return tasks


def get_test_type_from_tags(tags):
    test_type = 'unknown'
    for tag in tags:
        if tag.startswith('TESTENV-JOB-'):
            test_type = tag[len('TESTENV-JOB-'):].lower()
    return test_type


def is_shm(tags):
    for tag in tags:
        if tag.startswith('TESTENV-JOB-YABS_SERVER'):
            return True
    return False


def is_trunk_run(tags):
    tags_set = set(tags)
    if 'TESTENV-COMMIT-CHECK' in tags_set:
        return True

    if 'TESTENV-PRECOMMIT-CHECK' in tags_set and 'WITH-PATCH' not in tags_set:
        return True

    return False


def get_tier(tags):
    tags_set = set(tags)
    if 'TESTENV-COMMIT-CHECK' in tags_set:
        return 'baseline'

    if 'TESTENV-RESOURCE-CHECK' in tags_set:
        return 'resource'

    if 'TESTENV-PRECOMMIT-CHECK' in tags_set and 'WITHOUT-PATCH' in tags_set:
        return 'baseline'

    if 'TESTENV-PRECOMMIT-CHECK' in tags_set and 'WITH-PATCH' in tags_set:
        return 'test'

    if 'TESTENV-PRECOMMIT-CHECK' in tags_set:
        return 'cmp'

    if 'NEW_SPEC' in tags_set:
        return 'resource'

    return 'test'


def parse_updated_time(iso_string):
    try:
        return datetime.strptime(iso_string, "%Y-%m-%dT%H:%M:%S.%fZ")
    except ValueError:
        return datetime.strptime(iso_string, "%Y-%m-%dT%H:%M:%SZ")


def count_tasks(
        sandbox_rest_client,
        period=timedelta(hours=1),
        start=None,
        predicate=lambda task: True,
        key=lambda task: task['type'],
        **kwargs
):
    finish = datetime.utcnow()
    logger.debug('Start was: %s', start)
    start = datetime.fromtimestamp(start) if start else finish - period
    last_timestamp = start

    tasks = get_tasks(sandbox_rest_client, updated='{}..{}'.format(start.isoformat(), finish.isoformat()), status=STATUSES, **kwargs)

    logger.info('Got %d tasks', len(tasks))
    counter = {}
    for task in tasks:
        if not predicate(task):
            continue
        if key(task) not in counter:
            counter[key(task)] = {
                'SUCCESS': 0,
                'FAIL': 0,
            }
        if task['status'] != Status.SUCCESS:
            logger.debug('Task %s #%s has status %s', task['type'], task['id'], task['status'])
        counter[key(task)]['SUCCESS' if task['status'] == Status.SUCCESS else 'FAIL'] += 1
        last_timestamp = max(last_timestamp, parse_updated_time(task['time']['updated']) + timedelta(seconds=1))  # We don't want to count same task twice in the following run

    return counter, int(time.mktime(last_timestamp.timetuple()))


class AbstractTaskStatusCounter(object):
    def __init__(
            self,
            sandbox_rest_client,
            start=None,
            period=timedelta(hours=1),
            owner='YABS_SERVER_SANDBOX_TESTS',
            **kwargs
    ):
        self.sandbox_rest_client = sandbox_rest_client
        self.start = start
        self.period = period
        self.owner = owner
        self.kwargs = kwargs

    @abstractmethod
    def get_sensors(self, start=None):
        pass


class TaskStatusCounter(AbstractTaskStatusCounter):
    def get_sensors(self, start=None):
        counter, last_timestamp = count_tasks(
            self.sandbox_rest_client,
            period=self.period,
            start=start,
            key=lambda task: (task['type'], get_tier(task['tags'])),
            owner=self.owner,
            **self.kwargs
        )
        logger.info('Got counts: %s', counter)
        sensors = []
        for (task_type, tier), status_values in counter.items():
            for task_status, value in status_values.items():
                sensors.append({
                    'labels': {
                        'task_type': task_type,
                        'task_status': task_status,
                        'tier': tier,
                        'sensor': 'task_status_count',
                    },
                    'value': value,
                })
        return sensors, last_timestamp


class TrunkTaskStatusCounter(AbstractTaskStatusCounter):
    def get_sensors(self, start=None):
        counter, last_timestamp = count_tasks(
            self.sandbox_rest_client,
            period=self.period,
            start=start,
            predicate=lambda task: is_shm(task['tags']),
            key=lambda task: (task['type'], get_test_type_from_tags(task['tags']), get_tier(task['tags'])),
            owner=self.owner,
            **self.kwargs
        )

        logger.info('Got counts: %s', counter)
        sensors = []
        for (task_type, test_type, tier), status_values in counter.items():
            for task_status, value in status_values.items():
                sensors.append({
                    'labels': {
                        'task_type': task_type,
                        'task_status': task_status,
                        'test_type': test_type,
                        'tier': tier,
                        'sensor': 'task_status_count',
                    },
                    'value': value,
                })
        return sensors, last_timestamp


def not_dry_run(task):
    return task['input_parameters'].get('dry_run', False) is False


def get_subtask_failure_reason(sandbox_rest_client, task):
    failed_subtask_types = set()
    children = sandbox_rest_client.task.read(parent=task['id'], limit=100)['items']
    for child in children:
        if child['status'] != 'SUCCESS':
            failed_subtask_types.add(child['type'])
    if not failed_subtask_types:
        return 'main_task_failed'
    if 'YABS_SERVER_VALIDATE_AB_EXPERIMENT' in failed_subtask_types:
        return 'validation_failed'

    # build
    if 'BUILD_YABS_SERVER' in failed_subtask_types:
        return 'build_failed'

    # oneshot
    if 'EXECUTE_YT_ONESHOT' in failed_subtask_types:
        return 'get_oneshot_table_list_failed'

    # base_gen
    if 'YABS_SERVER_GENERATE_CONFIG_WITH_NEW_AB_EXPERIMENT' in failed_subtask_types:
        return 'base_generation_failed'
    if 'YABS_SERVER_MAKE_BIN_BASES' in failed_subtask_types:
        return 'base_generation_failed'

    # base_size
    if 'YABS_SERVER_DB_SIZE_AGGREGATE' in failed_subtask_types or 'YABS_SERVER_BASE_SIZE_AGGREGATE' in failed_subtask_types:
        return 'base_size_aggregate_failed'
    if 'YABS_SERVER_CHKDB_CMP' in failed_subtask_types or 'YABS_SERVER_BASE_SIZE_CMP' in failed_subtask_types:
        return 'base_size_cmp_failed'

    # ft
    if 'YABS_SERVER_B2B_FUNC_SHOOT_2' in failed_subtask_types:
        return 'ft_shoot_task_failed'
    if 'YABS_SERVER_B2B_FUNC_SHOOT_CMP' in failed_subtask_types:
        return 'ft_cmp_task_failed'

    # meta_load
    if 'YABS_SERVER_PREPARE_STAT_STUB' in failed_subtask_types:
        return 'meta_load_prepare_task_failed'
    if 'YABS_SERVER_PERFORMANCE_META_CMP' in failed_subtask_types:
        return 'meta_load_cmp_task_failed'

    # stat_load
    if 'YABS_SERVER_STAT_PERFORMANCE_PREPARE_DPLAN' in failed_subtask_types:
        return 'stat_load_prepare_task_failed'
    if 'YABS_SERVER_STAT_PERFORMANCE_BEST_CMP_2' in failed_subtask_types:
        return 'stat_load_cmp_task_failed'

    # fallback
    if len(failed_subtask_types) == 1:
        failed_task = failed_subtask_types.pop()
        return 'subtask_{}_failed'.format(failed_task)
    return 'subtasks_failed'


def get_failure_reason(sandbox_rest_client, task):
    if task['status'] == 'SUCCESS':
        return None

    if task['type'] == 'YABS_SERVER_TEST_AB_EXPERIMENT':
        ctx = sandbox_rest_client.task[task['id']].context.read()
        if ctx.get('validation_passed') is False:
            return 'validation_failed'
    if task['type'] == 'YABS_SERVER_ONE_SHOT':
        ctx = sandbox_rest_client.task[task['id']].context.read()
        if ctx.get('_failed_stage') == 'check_base_changes_after_oneshot_application':
            return 'zero_changes_in_bases'
    return get_subtask_failure_reason(sandbox_rest_client, task)


def get_failed_stage(sandbox_rest_client, task):
    if task['status'] == 'SUCCESS':
        return None
    ctx = sandbox_rest_client.task[task['id']].context.read()
    return ctx.get('_failed_stage')


class TestResultsStatusCounter(AbstractTaskStatusCounter):
    def get_sensors(self, start=None):
        counter, last_timestamp = count_tasks(
            self.sandbox_rest_client,
            period=self.period,
            start=start,
            predicate=lambda task: not_dry_run(task),
            key=lambda task: (get_failure_reason(self.sandbox_rest_client, task), get_failed_stage(self.sandbox_rest_client, task)),
            **self.kwargs
        )
        sensors = []
        for (failure_reason, failed_stage), status_values in counter.items():
            for task_status, value in status_values.items():
                labels = {
                    'task_status': task_status,
                    'sensor': 'test_result_status_count',
                }
                if failure_reason:
                    labels['failure_reason'] = failure_reason
                if failed_stage:
                    labels['failed_stage'] = failed_stage
                sensors.append({
                    'labels': labels,
                    'value': value,
                })
        return sensors, last_timestamp
