import datetime
import logging
import os
import time

import requests
from sandbox import sdk2
from sandbox.common.errors import TaskFailure, VaultError
from sandbox.common.types import task as ctt

from .ci import get_tests_count, create_session, get_tests_finish_timestamp, get_ci_check_url


logger = logging.getLogger(__name__)


CI_SYNC_DELAY = datetime.timedelta(minutes=10)
CI_TOKEN_DEFAULT_VAULT_KEY = 'yabs-cs-sb-ci-token'

CHECK_TESTS_STATUS_DEFAULT_INTERVAL = datetime.timedelta(minutes=30)
TASK_TIMEOUT = datetime.timedelta(days=3)

YABS_SERVER_TEST_PATH = os.path.join('yabs', 'server', 'test', 'ft')
YABS_SERVER_CI_PROJECT_ID = 291


class NoData(Exception):
    pass


def get_ci_token():
    try:
        return sdk2.Vault.data(CI_TOKEN_DEFAULT_VAULT_KEY)
    except VaultError:
        logger.error('Failed to get "%s" token from vault: %s', CI_TOKEN_DEFAULT_VAULT_KEY)
        return None


def get_failed_tests_count(revision, project_id, session):
    tests_count = get_tests_count(revision, project_id, session)
    logger.debug('tests_count: %s', tests_count)

    try:
        tests_count = tests_count[0]
    except IndexError:
        raise NoData('No data for revision {}'.format(revision))

    return {
        test_type: tests_count.get('fails_' + test_type, 0)
        for test_type
        in ['broken_deps', 'regular']
    }


class YabsServerCheckFT(sdk2.Task):
    """Check yabs functional tests"""
    name = 'YABS_SERVER_CHECK_FT'

    class Parameters(sdk2.Task.Parameters):
        max_restarts = 0
        kill_timeout = 20 * 60  # 20 min

        revision = sdk2.parameters.Integer('Revision', required=True)
        test_path = sdk2.parameters.String('Test path', default=YABS_SERVER_TEST_PATH)
        ci_project_id = sdk2.parameters.Integer('CI project id', default=YABS_SERVER_CI_PROJECT_ID)
        time_to_wait = sdk2.parameters.Integer('Time to wait between checks (seconds)',
                                               default=CHECK_TESTS_STATUS_DEFAULT_INTERVAL.total_seconds())
        with sdk2.parameters.Group("Other parameters") as other_group:
            ci_token_vault_key = sdk2.parameters.String(label='CI token vault key',
                                                        default=CI_TOKEN_DEFAULT_VAULT_KEY)

    class Requirements(sdk2.Requirements):
        cores = 1
        ram = 1024
        semaphores = ctt.Semaphores(
            acquires=[
                ctt.Semaphores.Acquire(name='yabs_server_check_ft', weight=1),
            ],
            release=(ctt.Status.Group.BREAK, ctt.Status.Group.FINISH, ctt.Status.Group.WAIT)
        )

        class Caches(sdk2.Requirements.Caches):
            pass

    def on_execute(self):
        with self.memoize_stage.first_run(commit_on_entrance=False):
            self.Context.task_run_end_timestamp = int(time.time()) + TASK_TIMEOUT.total_seconds()

            ci_check_url = get_ci_check_url(revision=self.Parameters.revision, test_path=self.Parameters.test_path)
            self.set_info(
                ('Checking revision <a href={url} target="_blank">{revision}</a>, '
                 'last status is available at <a href="https://ci.yandex-team.ru/project/{project_id}" target="_blank">CI project page</a>.').format(
                    url=ci_check_url,
                    revision=self.Parameters.revision,
                    project_id=self.Parameters.ci_project_id,
                ),
                do_escape=False
            )

        if time.time() > self.Context.task_run_end_timestamp:
            raise TaskFailure('TIMEOUT: task is running more than {}'.format(TASK_TIMEOUT))

        ci_token = get_ci_token()
        session = create_session(ci_token)

        with self.memoize_stage.wait_tests_finish(commit_on_entrance=False, commit_on_wait=False):
            test_finish_timestamp = None
            try:
                test_finish_timestamp = get_tests_finish_timestamp(revision=self.Parameters.revision, session=session)
            except requests.HTTPError as e:
                logger.warn('CI responded with "%s"', e)
                if e.response.status_code == requests.codes.not_found:
                    logger.info('Tests are still not running.')

            if test_finish_timestamp is None:
                logger.info('Tests not finished. Wait for %d seconds before next check', self.Parameters.time_to_wait)
                raise sdk2.WaitTime(self.Parameters.time_to_wait)

            logger.debug('test_finish_timestamp: %s', str(test_finish_timestamp))
            self.Context.test_finish_timestamp = test_finish_timestamp

        with self.memoize_stage.wait_ci_sync(commit_on_entrance=False):
            utc_test_finish_datetime = datetime.datetime.utcfromtimestamp(self.Context.test_finish_timestamp)
            logger.debug('utc_test_finish_datetime: %s', str(utc_test_finish_datetime))
            utc_now = datetime.datetime.utcnow()
            logger.debug('utc_now: %s', str(utc_now))

            time_passed_after_test_finish = utc_now - utc_test_finish_datetime
            logger.debug('time_passed_after_test_finish: %s', str(time_passed_after_test_finish))
            if time_passed_after_test_finish < CI_SYNC_DELAY:
                time_to_wait = CI_SYNC_DELAY - time_passed_after_test_finish
                logger.debug('time_to_wait: %s', str(time_to_wait))
                raise sdk2.WaitTime(time_to_wait.total_seconds())

        try:
            failed_tests_count = get_failed_tests_count(revision=self.Parameters.revision,
                                                        project_id=self.Parameters.ci_project_id,
                                                        session=session)
        except NoData:
            logger.warn('No data for revision {revision}'.format(revision=self.Parameters.revision))
            raise sdk2.WaitTime(self.Parameters.time_to_wait)

        logger.info('Failed tests count: %s', failed_tests_count)

        if sum(failed_tests_count.values()):
            raise TaskFailure(
                'Some tests failed on revision {revision}: {failed_tests_count}'
                .format(
                    revision=self.Parameters.revision,
                    failed_tests_count=failed_tests_count
                )
            )

        self.set_info('All tests passed')
