# coding: utf-8

import collections
import json
import logging
import os

import requests
from requests.packages import urllib3

from sandbox import sdk2
from sandbox.common.types import client as ctc
from sandbox.common.types import notification as ctn
from sandbox.common.types import task as ctt
from sandbox.projects.common import link_builder as lb
from sandbox.projects.common import task_env
from sandbox.projects.common.arcadia import sdk as arcadiasdk
from sandbox.projects.common.vcs import arc
from sandbox.projects.market.report.common import helpers


logger = logging.getLogger(__name__)

TASKS_MAX_NUMBER = 900


def make_html_list(*args):
    items = [
        '<li>{}</li>'.format(i)
        for i in args
    ]
    return '<ul>{}</ul>'.format(''.join(items))


def create_retry_session(retries=5, backoff_factor=0.3, status_forcelist=(500, 502, 504)):
    session = requests.Session()
    retry = urllib3.util.retry.Retry(
        total=retries,
        read=retries,
        connect=retries,
        backoff_factor=backoff_factor,
        status_forcelist=status_forcelist,
    )
    adapter = requests.adapters.HTTPAdapter(max_retries=retry)
    session.mount('http://', adapter)
    session.mount('https://', adapter)
    return session


def contains_common_prefix(prefix, paths):
    for path in paths:
        if prefix == os.path.commonprefix([prefix, path]):
            return True
    return False


def get_all_active_tasks(task_type, author=None, owner=None):
    status = ctt.Status.Group.QUEUE + ctt.Status.Group.EXECUTE + ctt.Status.Group.WAIT
    return sdk2.Task.find(
        task_type=task_type, author=author, owner=owner, status=status, children=False
    ).limit(TASKS_MAX_NUMBER)


def get_ci_context_value(context, *args):
    context_part = context or {}
    paths = collections.deque(args)
    while len(paths) != 1:
        path = paths.popleft()
        context_part = context_part.get(path, {})
    return context_part.get(paths[0])


class TsumClient(object):
    def __init__(self, token):
        self._api_url = 'https://tsum.yandex-team.ru/api'
        self._session = create_retry_session()
        self._token = token

    def get_active_releases(self, project_id):
        resp = self._send_request('/projects/{}/active-releases'.format(project_id))
        return resp.json()

    def _send_request(self, handle):
        url = self._api_url + handle
        headers = {'Authorization': 'OAuth {}'.format(self._token)}
        logger.debug('Send request, url: %s', url)
        resp = self._session.get(url, headers=headers)
        resp.raise_for_status()
        return resp


class MarketReportTrunkGuard(sdk2.Task):
    class Requirements(task_env.TinyRequirements):
        client_tags = (
            (ctc.Tag.MULTISLOT | ctc.Tag.GENERIC)
            & (ctc.Tag.LINUX_TRUSTY | ctc.Tag.LINUX_XENIAL | ctc.Tag.LINUX_BIONIC)
        )

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = None

        with sdk2.parameters.Group('Input parameters'):
            tsum_project_id = sdk2.parameters.String('Tsum project ID', default='report')
            tsum_pipeline_ids = sdk2.parameters.List(
                'List of tsum pipeline IDs where you want to control number of commits',
                value_type=sdk2.parameters.String,
                default=['report-cd-pipeline', 'report-meta-pipeline']
            )
            controlled_paths = sdk2.parameters.List(
                'List of paths to search commits', value_type=sdk2.parameters.String, required=True)
            max_active_release_num = sdk2.parameters.Integer('Max number of active releases', default=3)
            allowed_commit_num = sdk2.parameters.Integer('Number of allowed commits in controlled paths', default=5)
            wait_time = sdk2.parameters.Integer('Wait time between check attempts (seconds)', default=300)

        with sdk2.parameters.RadioGroup('Check mode') as check_mode:
            check_mode.values.strict = check_mode.Value('Conditions must be met for all pipelines', default=True)
            check_mode.values.weak = check_mode.Value('Conditions must be met for at least one of pipelines')

        with sdk2.parameters.Group('Auth tokens'):
            tsum_token = sdk2.parameters.YavSecret(
                label='Yav Tsum token', description='TSUM_TOKEN from Sandbox Vault is used by default')
            arc_token = sdk2.parameters.YavSecret(
                label='Yav ARC token', description='ARC_TOKEN from Sandbox Vault is used by default')

        with sdk2.parameters.Group('Debug options'):
            debug_options = sdk2.parameters.Bool('Enable debug options', default=False)
            with debug_options.value[True]:
                ci_context = sdk2.parameters.String('CI context')
                author = sdk2.parameters.String('Author for task searching')

        with sdk2.parameters.Output():
            pull_request_id = sdk2.parameters.Integer('Pull request id')

        notifications = [
            sdk2.Notification(
                statuses=[ctt.Status.FAILURE, ctt.Status.EXCEPTION, ctt.Status.TIMEOUT],
                recipients=['bmrph'],
                transport=ctn.Transport.TELEGRAM
            )
        ]

    def on_execute(self):
        with self.memoize_stage.prepare:
            self._validate_input_params()

        self._reasons = []
        self._check_fn = all if self.Parameters.check_mode == 'strict' else any
        self._arc_token = helpers.get_arc_token(self)
        self._arc_client = arc.Arc(arc_oauth_token=self._arc_token)
        self._ci_context = self.Context.__CI_CONTEXT or json.loads(self.Parameters.ci_context)
        self._set_pull_request_id()

        try:
            if not self._is_commit_allowed():
                msg = 'wait {}s'.format(self.Parameters.wait_time)
                self._print_reasons(msg)
                raise sdk2.WaitTime(self.Parameters.wait_time)
        except sdk2.WaitTime:
            raise
        except sdk2.WaitTask:
            raise
        except Exception as e:
            logger.exception(e)
            self.set_info('WARN: Exception has occurred: {}. Wait {}s'.format(e, self.Parameters.wait_time))
            raise sdk2.WaitTime(self.Parameters.wait_time)

        self.set_info('INFO: Commits are allowed')

    def _validate_input_params(self):
        assert self.Parameters.tsum_project_id, 'tsum_project_id is empty'
        assert self.Parameters.tsum_pipeline_ids, 'tsum_pipeline_ids is empty'
        assert self.Parameters.controlled_paths, 'controlled_paths is empty'
        assert self.Parameters.max_active_release_num > 0, 'max_active_release_num must be > 0'
        assert self.Parameters.allowed_commit_num > 0, 'allowed_commit_num must be > 0'
        assert self.Parameters.wait_time > 0, 'wait_time musrt be > 0'

    def _set_pull_request_id(self):
        with self.memoize_stage.set_pr_id:
            pull_request_id = get_ci_context_value(self._ci_context, 'launch_pull_request_info', 'pull_request', 'id')
            if not pull_request_id:
                raise ValueError('Pull request id not found in CI context')
            self.Parameters.pull_request_id = pull_request_id
        logger.info('Pull request id: %s', self.Parameters.pull_request_id)

    def _is_commit_allowed(self):
        with self.memoize_stage.check_controlled_paths:
            if not self._commit_contains_controlled_path():
                return True
        self._check_pull_request_queue()
        active_releases = self._get_active_releases()
        return self._check_active_release_num(active_releases) and self._check_free_revisions_num(active_releases)

    def _commit_contains_controlled_path(self):
        changed_files = self._get_changed_files()
        for controlled_path in self.Parameters.controlled_paths:
            if contains_common_prefix(controlled_path, changed_files):
                logger.info('Commit contains controlled path, %s', controlled_path)
                return True
        logger.info("Commit doesn't contain controlled path")
        return False

    def _get_changed_files(self):
        vcs_info = get_ci_context_value(self._ci_context, 'launch_pull_request_info', 'vcs_info')
        logger.debug('vcs_info: %s', vcs_info)
        feature_branch = vcs_info.get('feature_branch')
        if not feature_branch:
            path = [self.Parameters.controlled_paths[0]]
            logger.warn('Feature branch not found, set changed files as %s', path)
            return path
        with arcadiasdk.mount_arc_path(
            'arcadia-arc:/#{}'.format(feature_branch),
            use_arc_instead_of_aapi=True,
            arc_oauth_token=self._arc_token
        ) as arcadia_path:
            commits = self._arc_client.log(
                arcadia_path,
                start_commit=vcs_info['upstream_revision_hash'],
                end_commit=vcs_info['feature_revision_hash'],
                as_dict=True,
                name_only=True
            )

            changed_files = set()
            for commit in commits:
                paths = [
                    name['path']
                    for name in commit.get('names')
                    if 'path' in name
                ]
                changed_files.update(paths)
            logger.info('Changed files of branch %s: %s', feature_branch, ', '.join(changed_files))
            return changed_files

    def _check_pull_request_queue(self):
        tasks = get_all_active_tasks(self.type, owner=self.owner, author=self.Parameters.author or self.author)
        # Skip all task where pr id is not yet set
        tasks = [
            task
            for task in tasks
            if task.Parameters.pull_request_id
        ]
        tasks.sort(key=lambda t: t.Parameters.pull_request_id)
        pr_ids = [task.Parameters.pull_request_id for task in tasks]
        logger.info('Pull requests waiting in queue: %s', pr_ids)
        if len(pr_ids) == 0 or self.Parameters.pull_request_id == pr_ids[0]:
            return
        # Commit is not allowed
        index = pr_ids.index(self.Parameters.pull_request_id)
        reason = '[] pull request id {} is not the first in the queue, {} items before'.format(
            self.Parameters.pull_request_id, index)
        self._reasons.append(reason)

        task = tasks[index - 1]  # task with previous pr
        msg = 'wait task {} with timeout {}s'.format(lb.task_link(task.id), self.Parameters.wait_time)
        self._print_reasons(msg)

        # Subscribe to task
        status = ctt.Status.Group.FINISH + ctt.Status.Group.BREAK
        raise sdk2.WaitTask(task, status, timeout=self.Parameters.wait_time)

    def _get_active_releases(self):
        token = helpers.get_token(self.Parameters.tsum_token, self.owner, vault_item='TSUM_TOKEN')
        tsum_client = TsumClient(token)
        active_releases = collections.defaultdict(list)
        for release in tsum_client.get_active_releases(self.Parameters.tsum_project_id):
            pipe_id = release['pipeId']
            if pipe_id not in self.Parameters.tsum_pipeline_ids:
                continue
            active_releases[pipe_id].append(release)
        for pipe_id, releases in active_releases.iteritems():
            releases.sort(key=lambda x: x['createdDate'], reverse=True)
            logger.info('Pipeline %s active releases: %s', pipe_id, ', '.join([r['id'] for r in releases]))
        return active_releases

    def _check_active_release_num(self, active_releases):
        reasons = []
        check_results = []
        for pipe_id, releases in active_releases.iteritems():
            result = len(releases) < self.Parameters.max_active_release_num
            if not result:
                reason = '[{}] too many active releases: {}'.format(pipe_id, len(releases))
                reasons.append(reason)
            check_results.append(result)
        result = self._check_fn(check_results)
        if not result:
            self._reasons.extend(reasons)
        return result

    def _check_free_revisions_num(self, active_releases):
        base_revisions = [
            int(releases[0]['revision'])
            for releases in active_releases.itervalues()
        ]
        logger.info('Base revisions of all pipelines: %s', base_revisions)
        revisions = self._get_revisions(min(base_revisions))
        revisions.update(base_revisions)
        revisions = sorted(list(revisions))

        reasons = []
        check_results = []
        for pipe_id, releases in active_releases.iteritems():
            base_revision = int(releases[0]['revision'])
            base_index = revisions.index(base_revision)
            free_revision_num = len(revisions) - (base_index + 1)
            logger.info('Pipeline %s free revisions: %s', pipe_id, revisions[base_index + 1:])
            result = free_revision_num < self.Parameters.allowed_commit_num
            if not result:
                reason = '[{}] too many commits not included in any release: {}'.format(pipe_id, free_revision_num)
                reasons.append(reason)
            check_results.append(result)
        result = self._check_fn(check_results)
        if not result:
            self._reasons.extend(reasons)
        return result

    def _get_revisions(self, base_revision):
        with arcadiasdk.mount_arc_path(
            'arcadia-arc:/#trunk',
            use_arc_instead_of_aapi=True,
            arc_oauth_token=self._arc_token
        ) as arcadia_path:
            all_revisions = set()
            for path in self.Parameters.controlled_paths:
                revisions = self._get_revisions_by_path(arcadia_path, base_revision, path)
                logger.debug('Revisions for path %s (total: %s): %s', path, len(revisions), revisions)
                all_revisions.update(set(revisions))
            return all_revisions

    def _get_revisions_by_path(self, arcadia_path, base_revision, path):
        start_commit = 'r{}'.format(base_revision)
        commits = self._arc_client.log(
            arcadia_path, path=path, start_commit=start_commit, end_commit='HEAD', as_dict=True)
        return [
            commit['revision']
            for commit in commits
        ]

    def _print_reasons(self, message):
        msg = 'INFO: Any extra commits are disabled, {}<br/>Reason:{}'.format(message, make_html_list(*self._reasons))
        self.set_info(msg, do_escape=False)
