import ast
import logging
import os
import re
from typing import List
from sandbox import sdk2
from sandbox.common.errors import TaskFailure
import sandbox.common.types.client as ctc
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt

from sandbox.projects.mobile_apps.teamcity_sandbox_runner.runner_stage import TeamcitySandboxRunnerStage
from sandbox.projects.mobile_apps.teamcity_sandbox_runner.utils import vcs
from sandbox.projects.mobile_apps.teamcity_sandbox_runner.utils.artifacts_processor import ArtifactsProcessor
from sandbox.projects.mobile_apps.teamcity_sandbox_runner.utils.parameters import TeamcitySandboxRunnerUpperLevelParameters
from sandbox.projects.mobile_apps.teamcity_sandbox_runner.utils.string_preparer import StringPreparer


_logger = logging.getLogger('runner')

REQUIREMENTS = {
    'XLARGE': 16,
    'LARGE': 8,
    'MEDIUM': 4,
    'SMALL': 1}

SUBTASK_DESCRIPTION = '<span class=\'teamcity teamcity_branch status status_unknown\'>{}</span>{}'
STAGE_DEFAULT_CLIENT_TAG = "USER_MONOREPO"


class TeamcitySandboxRunnerParameters(TeamcitySandboxRunnerUpperLevelParameters):
    dependent_templates = sdk2.parameters.List(
        'Dependent templates',
        value_type=sdk2.parameters.String, )
    with sdk2.parameters.Group('Teamcity sandbox runner artifacts parameters') as tsr_artifacts_params:
        use_parent_resources = sdk2.parameters.Bool(
            'Use resources from parent task',
            default=False, )
        with use_parent_resources.value[True]:
            teamcity_artifacts_resource = sdk2.parameters.ParentResource(
                'Teamcity artifacts resource', )
            teamcity_messages_resource = sdk2.parameters.ParentResource(
                'Teamcity messages resource', )


class TeamcitySandboxRunner(sdk2.Task):

    # arcadia mounts to task_dir/arcadia, folders are exported to task_dir/export
    arc_export_prefix = 'export'
    arc_mount_prefix = 'arcadia'

    class Requirements(sdk2.Requirements):
        cores = 1
        disk_space = 1 * 1024  # 1 GB
        ram = 1 * 1024  # 1 GB
        client_tags = 'USER_MONOREPO'

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(TeamcitySandboxRunnerParameters):
        pass

    def on_save(self):
        if self.Parameters.release_type == 'custom':
            return
        if self.Parameters.release_type == 'none':
            self.Requirements.tasks_resource = None
        else:
            attrs = {"target": "teamcity_sandbox_runner/bin",
                     "release": self.Parameters.release_type,
                     "tasks_bundle": "TEAMCITY_SANDBOX_RUNNER",
                     }
            if self.Parameters.version != "":
                attrs['version'] = self.Parameters.version
            binary_resource = sdk2.service_resources.SandboxTasksBinary.find(
                owner="MOBDEVTOOLS",
                attrs=attrs
            ).first()
            if not binary_resource:
                raise TaskFailure("Could not find binary task")
            self.Requirements.tasks_resource = binary_resource.id

    def _add_counter_to_extra_env_impl(self, client, counter_path):
        _logger.info('Try to get build counter.')
        with client.Transaction():
            client.lock(counter_path, waitable=True, wait_for=60)
            counter_value = client.get(counter_path)
            self.Context.build_number = counter_value
            client.set(counter_path, counter_value + 1)
        _logger.info('Build counter successfully received. Value: {}'.format(counter_value))

    def _add_counter_to_extra_env(self):
        # creating BUILD_NUMBER
        build_counter = self.Context.config.get('config', {}).get('build_counter', None)
        if build_counter:
            # all import outside sandbox should be hidden
            import yt.wrapper as yt
            import library.python.retry as retry

            token = sdk2.Vault.data(self.owner, 'yt_token')
            yt_config = {
                'proxy': {'url': 'locke'},
                'token': token}
            client = yt.YtClient(config=yt_config)

            build_counter = build_counter.split(':')
            build_counter = ['mobdevtools'] + build_counter if len(build_counter) == 1 else build_counter
            counter_path = '//home/{}/build_counters/{}'.format(*build_counter)

            retry.retry_call(
                self._add_counter_to_extra_env_impl,
                (client, counter_path),
                conf=retry.RetryConf().waiting(delay=2., backoff=2., jitter=1., limit=3600.).upto(minutes=15.))

    def _delete_stage_contexts_duplicates(self):
        self.Context.depends_on = remove_duplicates(self.Context.depends_on)
        self.Context.waiting = remove_duplicates(self.Context.waiting)
        self.Context.fail_fast_tasks = remove_duplicates(self.Context.fail_fast_tasks)
        self.Context.not_running_stages = remove_duplicates(self.Context.not_running_stages)
        _logger.debug('Context.depends_on: {}'.format(self.Context.depends_on))
        _logger.debug('Context.waiting: {}'.format(self.Context.waiting))
        _logger.debug('Context.fail_fast_tasks: {}'.format(self.Context.fail_fast_tasks))
        _logger.debug('Context.not_running_stages: {}'.format(self.Context.not_running_stages))

    def _process_config(self):
        _logger.info('Begin to process config.')
        config = self.Context.config
        self.Context.config_name = config.get('config', {}).get('name', 'unknown_name')
        self.Context.runner_version = config.get('config', {}).get('runner_version', None)

        for stage_name, stage_params in config.get('stages', {}).iteritems():
            config['stages'][stage_name].setdefault('artifacts', {})
            for test in stage_params.get('junit', []):
                config['stages'][stage_name]['artifacts']['+{}'.format(os.path.join(test, '*'))] = os.path.join('junit/{}'.format(stage_name), test)
                config['stages'][stage_name]['artifacts']['-{}'.format(os.path.join(test, '*', 'binary'))] = None
            if stage_params.get('fail_fast', False):
                self.Context.fail_fast_tasks.append(stage_name)
            self.Context.not_running_stages.append(stage_name)
            if stage_name not in self.Context.depends_on:
                self.Context.depends_on[stage_name] = []
            stage_depends_on = []
            for cmd in config['stages'].get(stage_name, {}).get('cmd', []):
                for param in re.findall(r'%internal\..*?%', cmd):
                    extracted_internal_stage = param[10:-1].split(':')[0]
                    _logger.debug('Extracted stage name from cmd: {}'.format(extracted_internal_stage))
                    if extracted_internal_stage not in config.get('stages', {}).keys():
                        raise TaskFailure('Extracted stage name {} from cmd hasn\'t been found in stages. Check cmd for typo'.format(extracted_internal_stage))
                    stage_depends_on.append(extracted_internal_stage)
            for depends_on in stage_params.get('depends_on', []):
                self.Context.depends_on[stage_name].append(depends_on)
                if depends_on not in self.Context.waiting:
                    self.Context.waiting[depends_on] = []
                self.Context.waiting[depends_on].append(stage_name)
            for depends_stage_name in stage_depends_on:
                if depends_stage_name not in self.Context.waiting:
                    self.Context.waiting[depends_stage_name] = []
                self.Context.depends_on[stage_name].append(depends_stage_name)
                self.Context.waiting[depends_stage_name].append(stage_name)

        self._delete_stage_contexts_duplicates()
        _logger.info('Config processing successfully finished.')

    def _get_stage_work_dir_path(self, stage_params):
        work_dir = stage_params.get('work_dir', '')
        if work_dir == './' or work_dir == '.':
            work_dir = ''
        if self.Parameters.repo_url == 'arcadia':
            if stage_params.get('arc_exported_paths', '') != '':
                _repo_path = self.arc_export_prefix
            else:
                _repo_path = self.arc_mount_prefix
        else:
            _repo_path = ''
        _logger.debug('set repo_path to {}'.format(_repo_path))
        stage_work_dir_path = os.path.join(_repo_path, work_dir)
        _logger.debug("_get_stage_work_dir_path returned {}".format(stage_work_dir_path))
        return stage_work_dir_path

    def _expand_stage_artifacts(self, stage_params):
        full_artifacts = {}
        full_work_dir = self._get_stage_work_dir_path(stage_params)
        for key, value in stage_params.get('artifacts', {}).iteritems():
            sign, artifact_path = key[0], key[1:]
            full_artifact_path = os.path.join(full_work_dir, artifact_path)
            full_artifacts[sign + full_artifact_path] = value
        return full_artifacts

    def _expand_stage_internal_artifacts(self, stage_params):
        full_work_dir = self._get_stage_work_dir_path(stage_params)
        return {
            os.path.join(full_work_dir, key): value for key, value in
            stage_params.get('internal_artifacts', {}).iteritems()
        }

    @staticmethod
    def _set_sandbox_environment(stage_params):
        environments = {}
        if stage_params.get('xcode'):
            _logger.info('Xcode found with version {}.'.format(stage_params.get('xcode')))
            environments['xcode'] = stage_params.get('xcode')
        if stage_params.get('android-sdk'):
            _logger.info('android-sdk found with version {}.'.format(stage_params.get('android-sdk')))
            environments['android-sdk'] = stage_params.get('android-sdk')
        if stage_params.get('rvm+ruby'):
            _logger.info('RVM+RUBY found with version {}.'.format(stage_params.get('rvm+ruby')))
            environments['rvm+ruby'] = stage_params.get('rvm+ruby')
        if stage_params.get('jdk'):
            _logger.info('JDK found with version {}.'.format(stage_params.get('jdk')))
            environments['jdk'] = stage_params.get('jdk')
        if stage_params.get('certs'):
            environments['certs'] = stage_params.get('certs')
        return environments

    def _prepare_stage_params(self, stage_name, stage_params):
        stage_params['name'] = '{}:{}'.format(self.Context.config_name, stage_name)
        stage_params['ssh_key'] = str(self.Parameters.ssh_key)
        stage_params['repo_url'] = self.Parameters.repo_url
        stage_params['branch'] = self.Parameters.branch
        stage_params['commit'] = self.Parameters.commit
        stage_params.setdefault('kill_timeout', 30 * 60)
        stage_params.setdefault('secrets', {})
        for key, value in ast.literal_eval(self.Parameters.secrets).iteritems():
            stage_params['secrets'][key] = value
        stage_params.setdefault('env', {})
        for key, value in ast.literal_eval(self.Parameters.env).iteritems():
            stage_params['env'][key] = value
        if self.Context.build_number:
            stage_params['env']['BUILD_NUMBER'] = self.Context.build_number
            _logger.debug("Set BUILD_NUMBER = {}".format(self.Context.build_number))
        stage_params['env']['BUILD_BRANCH'] = self.Parameters.branch
        stage_params['internal_artifacts'] = self._expand_stage_internal_artifacts(stage_params)
        stage_params['dependency_files'] = self.Context.sub_tasks_ids
        if self.Parameters.release_type != 'custom':
            stage_params['release_type'] = self.Parameters.release_type
        stage_params['arc_export_prefix'] = self.arc_export_prefix
        _stage_artifacts = self._expand_stage_artifacts(stage_params)
        self.Context.artifacts[stage_name] = _stage_artifacts
        stage_params['artifacts'] = _stage_artifacts.keys()
        stage_params['version'] = self.Context.runner_version
        stage_params['sandbox_environments'] = self._set_sandbox_environment(stage_params)
        if stage_params.get('ios'):
            stage_params['env']['COCOAPODS_DISABLE_STATS'] = 'true'
            stage_params['env']['FASTLANE_DISABLE_COLORS'] = 1
        if self.Parameters.teamcity_build_id:
            _logger.info('set stage parameter teamcity_build_id {}.'.format(self.Parameters.teamcity_build_id))
            stage_params['teamcity_build_id'] = self.Parameters.teamcity_build_id
        # if teamcity parameter is True -> override config value.
        # if teamcity parameter is False -> use config value
        if self.Parameters.force_clean_build:
            stage_params['force_clean_build'] = self.Parameters.force_clean_build
        stage_params['config_hash'] = self.Context.config_hash
        return stage_params

    def _set_subtask_requirements(self, build_sub_task, stage_params):
        build_sub_task.Requirements.dns = ctm.DnsType.DNS64

        if stage_params.get('ios'):
            if self.Parameters.mac_agent:
                build_sub_task.Requirements.host = self.Parameters.mac_agent
                build_sub_task.save()
                _logger.debug("Set mac host requirement {}".format(self.Parameters.mac_agent))
                return
        else:
            if self.Parameters.linux_agent:
                build_sub_task.Requirements.host = self.Parameters.linux_agent
                build_sub_task.save()
                _logger.debug("Set linux host requirement {}".format(self.Parameters.linux_agent))
                return

        if stage_params.get('client_tags') is None:
            build_sub_task.Requirements.client_tags = STAGE_DEFAULT_CLIENT_TAG
        else:
            build_sub_task.Requirements.client_tags = stage_params.get('client_tags')
        if stage_params.get('ios'):
            build_sub_task.Requirements.client_tags &= ctc.Tag.MOBILE_MONOREPO
        else:
            build_sub_task.Requirements.client_tags &= ctc.Tag.Group.LINUX
            if stage_params.get('multislot') is not None:
                _logger.debug("stage_params.get('multislot') is {} ".format(stage_params.get('multislot')))
                if stage_params.get('multislot') in set(REQUIREMENTS.keys()):
                    cores = REQUIREMENTS[stage_params.get('multislot')]
                    build_sub_task.Requirements.cores = cores
                    build_sub_task.Requirements.ram = (cores * 4 - 2) * 1024
                else:
                    raise TaskFailure('Multislot type {} not found.'.format(stage_params.get('multislot')))
            else:
                # captures the entire machine for exclusive use
                build_sub_task.Requirements.cores = 17
                build_sub_task.Requirements.ram = 31 * 1024
                _logger.debug("Capture machine for exclusive use")
        _logger.debug("Set requirements: {}, {}, {}".format(
            build_sub_task.Requirements.client_tags,
            build_sub_task.Requirements.cores,
            build_sub_task.Requirements.ram))
        build_sub_task.save()

    def _enqueue_stage(self, stage_name, stage_params):
        _logger.info('Try to enqueue subtask for {} stage.'.format(stage_name))

        stage_params = self._prepare_stage_params(stage_name, stage_params)
        _logger.info('subtask params: {}'.format(stage_params))
        build_sub_task = TeamcitySandboxRunnerStage(
            self,
            description=SUBTASK_DESCRIPTION.format(stage_name, self.Parameters.description),
            **stage_params
        )
        self._set_subtask_requirements(build_sub_task, stage_params)
        build_sub_task.enqueue()

        self.Context.sub_tasks_ids[stage_name] = build_sub_task.id
        self.Context.not_running_stages.remove(stage_name)
        self.Context.running_tasks.append(build_sub_task.id)
        _logger.info('Task {} for stage: {} enqueued'.format(build_sub_task.id, stage_name))

    def _stop_task(self, task_id):
        _logger.info('Try to stop task {}.'.format(task_id))
        task = sdk2.Task[task_id]
        try:
            task.stop()
        except Exception as e:
            _logger.exception(e)
        _logger.info('Stopped task {}.'.format(task_id))

    def _stop_all_task(self, failed_stage):
        _logger.info('Try to stop all running tasks.')
        for task_id in self.Context.running_tasks:
            self._stop_task(task_id)
        self.Context.not_running_stages = []
        self.Context._failed_messages.append('Task failed on force stage "{}".'.format(failed_stage))
        _logger.info('All running tasks are stopped.')

    def _finalize(self):
        _logger.info('Finalize task.')
        failed_message = prepare_failed_message(
            failed_task_ids=self.Context.failed_tasks,
            failed_messages=self.Context._failed_messages
        )
        if len(failed_message) > 0:
            raise TaskFailure(failed_message)

    def _check_runner_version(self):
        error_message = 'runner_version required. more information in <a href="https://docs.yandex-team.ru/teamcity-sandbox-runner/#versioning_policy">documentation</a>.'
        if self.Parameters.version == '':
            raise TaskFailure(error_message)

    def _start_artifacts_processor(self):
        env = ast.literal_eval(self.Parameters.env)
        string_preparer = StringPreparer(env)
        artifacts_processor = ArtifactsProcessor(self, stage_task=False, string_preparer=string_preparer)
        artifacts_processor.start_processing()

    def _process_running_tasks(self):
        # mechanism for running dependent tasks sequentially
        for stage_name, task_id in self.Context.sub_tasks_ids.iteritems():
            if task_id in self.Context.running_tasks and self.server.task[task_id].read()["status"] in list(ctt.Status.Group.FINISH | ctt.Status.Group.BREAK):
                self.Context.running_tasks.remove(task_id)
                for waiting_stage in self.Context.waiting.get(stage_name, []):
                    _logger.debug('Delete stage {} from {}\'s depends list'.format(stage_name, waiting_stage))
                    self.Context.depends_on[waiting_stage].remove(stage_name)
                if self.server.task[task_id].read()["status"] not in ctt.Status.Group.SUCCEED:
                    self.Context.failed_tasks.append(task_id)
                    if stage_name in self.Context.fail_fast_tasks:
                        self._stop_all_task(stage_name)

    def _run_tasks(self):
        for stage_name, stage_params in self.Context.config.get('stages', {}).iteritems():
            if len(self.Context.depends_on[stage_name]) == 0 and stage_name in self.Context.not_running_stages:
                self._enqueue_stage(stage_name, stage_params)

    def _prelaunch_preparation(self):
        self.Context.launch_step = 0
        self.Context.failed_tasks = []
        self.Context.running_tasks = []
        self.Context.fail_fast_tasks = []
        self.Context.not_running_stages = []
        self.Context.depends_on = dict()
        self.Context.waiting = dict()
        self.Context._failed_messages = []
        self.Context.config, self.Context.config_hash, self.Context._repo_path = vcs.prepare_repository_and_read_config(
            self,
            self.Parameters.ssh_key,
            self.Parameters.repo_url,
            self.Parameters.branch,
            self.Parameters.commit,
            self.Parameters.config_from_repository,
            self.Parameters.config_path,
            self.Parameters.config,
            False,
            self.Parameters.dependent_templates
        )
        self._process_config()
        self._add_counter_to_extra_env()
        self.Context.sub_tasks_ids = dict()
        self.Context.artifacts = dict()

    def on_execute(self):
        _logger.debug("self.Parameters: {}".format(vars(self.Parameters)))
        self._check_runner_version()

        with self.memoize_stage.prelaunch_preparation:
            self._prelaunch_preparation()

        if len(self.Context.not_running_stages) > 0 or len(self.Context.running_tasks) > 0:
            self._process_running_tasks()
            self.Context.launch_step += 1

            with self.memoize_stage['run_builders_{}_step'.format(self.Context.launch_step)]:
                self._run_tasks()
            with self.memoize_stage['wait_builders_{}_step'.format(self.Context.launch_step)]:
                raise sdk2.WaitTask(self.Context.running_tasks, ctt.Status.Group.FINISH | ctt.Status.Group.BREAK, wait_all=False)

        with self.memoize_stage['enqueue_artifacts_processor']:
            self._start_artifacts_processor()

        self._finalize()

    def _stop_child_process(self):
        _logger.info("Try to stop child tasks")
        for stage_name, task_id in self.Context.sub_tasks_ids.iteritems():
            if task_id in self.Context.running_tasks:
                self._stop_task(task_id)
                self.Context.failed_tasks.append(task_id)
        _logger.info("Stopped child tasks")

    def on_terminate(self):
        self._stop_child_process()
        self._finalize()

    def on_before_timeout(self, seconds):
        self._stop_child_process()
        self._finalize()


def prepare_failed_message(failed_task_ids, failed_messages):
    # type: (List[int], List[str]) -> str
    failed_message = ''
    if len(failed_messages) > 0:
        failed_message = '\n'.join(failed_messages)
    if len(failed_task_ids) > 0:
        links = ['<a href="https://sandbox.yandex-team.ru/task/{id}/view">{id}</a>'.format(id=task_id)
                 for task_id in failed_task_ids]
        concatenated_links = ', '.join(links)
        failed_message = '{}\nError in tasks [{}].'.format(failed_message, concatenated_links)
    return failed_message


def remove_duplicates(obj):
    if isinstance(obj, dict):
        for key, value in obj.items():
            obj[key] = list(set(value))
    if isinstance(obj, list):
        obj = list(set(obj))
    return obj
