import sandbox.common as sandbox_common
import sandbox.projects.common.binary_task as binary_task
import sandbox.sdk2 as sdk2
import sandbox.projects.quasar.utils.vcs as quasar_vcs

import re
import subprocess
import logging
import os
import json
import shutil
import multiprocessing
import xml.etree.ElementTree as ET


def get_master_branch(repository_config):
    for branch, branch_config in repository_config['tracked_branches'].items():
        if branch_config.get('development_branch'):
            return branch
    return None


class CreateGerritBranch(binary_task.LastBinaryTaskRelease, sdk2.Task, quasar_vcs.RepoCheckoutMixin):
    MAINFEST_NAME = 'default.xml'
    MAINFESTS_PATH = '.repo/manifests'
    REVISION_ATTR = 'revision'
    REPOSITORY_CONFIG = 'repository.json'
    COPY_MANIFEST_NAME = 'manifest.xml'

    class Parameters(binary_task.LastBinaryReleaseParameters):
        ssh_user = sdk2.parameters.String('Ssh user name')
        ssh_key_vault_owner = sdk2.parameters.String('Sandbox vault item owner for ssh key')
        ssh_key_vault_name = sdk2.parameters.String('Sandbox vault item name for ssh key')
        dry_run = sdk2.parameters.Bool('Do no commit', default=False)
        platform = sdk2.parameters.String('Platform (yandexstation_2, yandexmini_2, etc)', required=True)
        repository_config_url = sdk2.parameters.ArcadiaUrl('Path to the repository.json in arcadia (arcadia:arc/trunk/arcadia/smart_devices/platforms/$PLATFORM/firmware)')
        arcadia_release_branch_name = sdk2.parameters.String('Release branch name', default='smart_devices')
        arcadia_release_branch_number = sdk2.parameters.Integer('Release branch number', required=True)
        gerrit_release_branch_prefix = sdk2.parameters.String('Gerrit release branch prefix (e.g. "factory-" for factory branches)', default='')
        gerrit_release_branch_number = sdk2.parameters.Integer('Gerrit release branch number. If None repository config will be used to determine the next branch', default=None)
        keep_last_x_branches = sdk2.parameters.Integer('Remove all, but last X branches. If None, none branches will be removed', default=None)
        create_new_branch_force = sdk2.parameters.Bool('Create new gerrit branch even if there is no diff with previous branch', default=False)
        repository_checkout_threads = sdk2.parameters.Integer('Number of threads used in repo sync.', default=8)

        with sdk2.parameters.Output():
            revision = sdk2.parameters.String("Commited revision")
            gerrit_branch_name = sdk2.parameters.String("Gerrit branch name")

        _container = sdk2.parameters.Container(
            'lxc container with repo installed',
            default_value=quasar_vcs.Containers.REPO_CONTAINER_ID,
            platform='linux_ubuntu_18.04_bionic',
            required=True,
        )

    def setup_vcs(self):
        self.repository_user = self.Parameters.ssh_user
        self.checkout_path = os.getcwd()
        self.ssh_private_key_vault_name = self.Parameters.ssh_key_vault_name
        self.ssh_private_key_vault_owner = self.Parameters.ssh_key_vault_owner
        self.RepoParameters.current_branch = False
        self.repo_custom_manifest_path = None
        self.repository_checkout_threads = self.Parameters.repository_checkout_threads

    def on_execute(self):
        super(CreateGerritBranch, self).on_execute()
        self.setup_vcs()

        with sdk2.ssh.Key(
            self,
            key_owner=self.Parameters.ssh_key_vault_owner,
            key_name=self.Parameters.ssh_key_vault_name
        ):
            arcadia_directory = os.path.basename(str(self.Parameters.repository_config_url).rstrip('/'))
            sdk2.svn.Arcadia.checkout(
                url=self.Parameters.repository_config_url,
                path=arcadia_directory,
            )
            repository_config_filename = os.path.join(arcadia_directory, self.REPOSITORY_CONFIG)
            with open(repository_config_filename) as repository_config_file:
                repository_config = json.load(repository_config_file)

            if repository_config['vcs'] != quasar_vcs.VCS.REPO:
                sandbox_common.errors.TaskError('VCS {} is not supported'.format(repository_config['vcs']))

            # Checkout system repository
            self.repository_url = repository_config['url']
            self.repository_tag = get_master_branch(repository_config)
            self.repo_checkout()

            create_new_branch, last_gerrit_branch,  new_gerrit_branch = self.get_gerrit_release_branch_config(repository_config)

            if create_new_branch or self.Parameters.create_new_branch_force:
                self.create_gerrit_branch(repository_config['url'], self.repository_tag, new_gerrit_branch)
                self.add_branch_to_config(repository_config, new_gerrit_branch,
                                          self.Parameters.arcadia_release_branch_name,
                                          self.Parameters.arcadia_release_branch_number)
                self.copy_manifest(arcadia_directory, new_gerrit_branch)
                self.Parameters.gerrit_branch_name = new_gerrit_branch
            else:
                logging.info('No need to create new branch, use the last one: %s', last_gerrit_branch)
                self.add_mergeto_to_config(repository_config, last_gerrit_branch,
                                           self.Parameters.arcadia_release_branch_name,
                                           self.Parameters.arcadia_release_branch_number)
                self.Parameters.gerrit_branch_name = last_gerrit_branch

            self.save_repository_config(repository_config, repository_config_filename)
            if not self.Parameters.dry_run:
                self.Parameters.revision = self.commit(arcadia_directory, self.Parameters.gerrit_branch_name,
                                                       self.Parameters.arcadia_release_branch_name,
                                                       self.Parameters.arcadia_release_branch_number)

    def get_gerrit_release_branch_config(self, repository_config):
        create_new_branch = True
        last_gerrit_branch_name = ''
        new_gerrit_branch_number = 1

        # Looking for the largest branch number
        for branch, _ in repository_config['tracked_branches'].items():
            branch_number = self.get_branch_number(branch, self.Parameters.gerrit_release_branch_prefix)
            if branch_number and new_gerrit_branch_number < branch_number + 1:
                new_gerrit_branch_number = branch_number + 1
                last_gerrit_branch_name = branch

        if self.Parameters.gerrit_release_branch_number:
            new_gerrit_branch_number = self.Parameters.gerrit_release_branch_number
        elif new_gerrit_branch_number != 1:
            diff = self.find_diff(last_gerrit_branch_name)
            if not diff:
                create_new_branch = False

        return create_new_branch, last_gerrit_branch_name, self.get_release_branch_name(new_gerrit_branch_number)

    def find_diff(self, branch_name):
        diff_command = 'repo forall -p -v -c \'if [ "$(git rev-parse HEAD)" != "$(git ls-remote $REPO_REMOTE {} | cut -f1)" ]; then echo "$REPO_PROJECT"; fi\''.format(branch_name)

        logging.info('Running %s', diff_command)
        p = subprocess.Popen(diff_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
        output, error = p.communicate()
        if p.returncode != 0:
            raise sandbox_common.errors.TaskError('Command exited with code {}, {}'.format(p.returncode, error))

        return output

    def create_gerrit_branch(self, gerrit_url, master_branch, release_branch):
        self.create_release_branches(release_branch)
        self.create_manifest_release_branch(release_branch)
        self.fix_manifest_file(release_branch)
        self.push_manifest_release_branch(release_branch)

    def add_branch_to_config(self, repository_config, gerrit_release_branch, arcadia_release_branch_name, arcadia_release_branch_number):
        # Keep only the last two branches in the config
        remove_branches = self.get_removable_branches(repository_config, gerrit_release_branch)
        for branch in remove_branches:
            logging.debug('Branch %s removed from the config', branch)
            repository_config['tracked_branches'].pop(branch, None)

        # Add new branch to the config or modify mergeto for the existing
        logging.info('Adding branch %s to the config', gerrit_release_branch)
        repository_config['tracked_branches'][gerrit_release_branch] = {
            'mergeto': '{}:{}'.format(arcadia_release_branch_name, arcadia_release_branch_number)
        }

        # Set the latest ref to the config
        logging.info('Set %s system branch to the ref', gerrit_release_branch)
        repository_config['ref'] = gerrit_release_branch

    def add_mergeto_to_config(self, repository_config, gerrit_release_branch, arcadia_release_branch_name, arcadia_release_branch_number):
        logging.info('Adding release branch %s to mergeto', arcadia_release_branch_number)
        branch_config = repository_config['tracked_branches'][gerrit_release_branch]
        if branch_config.get('mergeto'):
            branch_config['mergeto'] = '{},{}'.format(branch_config['mergeto'], arcadia_release_branch_number)
        else:
            branch_config['mergeto'] = '{}:{}'.format(arcadia_release_branch_name, arcadia_release_branch_number)

    def save_repository_config(self, repository_config, repository_config_filename):
        logging.debug('Saving repository config to the file')
        with open(repository_config_filename, 'w') as repository_config_file:
            json.dump(repository_config, repository_config_file, indent=4, sort_keys=True)

    def commit(self, arcadia_directory, gerrit_branch_name, arcadia_release_branch_name, arcadia_release_branch):
        logging.info('Commiting the new repository config and the manifest')

        commit_output = sdk2.svn.Arcadia.commit(
            path=arcadia_directory,
            message='Tracking system branch {} for {}/stable-{} SKIP_CHECK'.format(
                gerrit_branch_name, arcadia_release_branch_name, arcadia_release_branch
            ),
            user=self.Parameters.ssh_user,
        )

        revision_match = re.search(r'Committed revision (\d+)', commit_output)
        return revision_match.group(1)

    def create_release_branches(self, release_branch):
        # Create the release branch for every project from the manifest using `repo forall`
        with sdk2.helpers.ProcessLog(self, logger='process_log_2.create_release_branches') as pl:
            checkout_command = 'repo forall -p -v -j{} -c "git checkout -b {}"'.format(multiprocessing.cpu_count(), release_branch)
            logging.info('Running %s', checkout_command)
            self.safe_run_command(checkout_command, stdout=pl.stdout, stderr=pl.stderr, shell=True)

            if not self.Parameters.dry_run:
                push_command = 'repo forall -p -v -j{} -c \'git push --force "$(git remote | grep -v aosp)" {}\''.format(multiprocessing.cpu_count(), release_branch)
                logging.info('Running %s', push_command)
                self.safe_run_command(push_command, stdout=pl.stdout, stderr=pl.stderr, shell=True)

    def create_manifest_release_branch(self, release_branch):
        with sdk2.helpers.ProcessLog(self, logger='process_log_3.create_manifest_release_branch') as pl:
            checkout_list = ['git', 'checkout', '-b', release_branch]
            logging.info('Running %s', ' '.join(checkout_list))
            self.safe_run_command(checkout_list, stdout=pl.stdout, stderr=pl.stderr, cwd=self.MAINFESTS_PATH)

    def fix_manifest_file(self, release_branch):
        logging.info('Fixing manifest file')
        manifest_path = os.path.join(self.MAINFESTS_PATH, self.MAINFEST_NAME)
        tree = ET.parse(manifest_path)
        root = tree.getroot()

        # Replace revision in `default` tags with release branch
        for default in root.iter('default'):
            if default.get('revision'):
                default.set('revision', release_branch)
        # Replace revision in `remote` tags with release branch
        for remote in root.iter('remote'):
            if remote.get('revision'):
                remote.set('revision', release_branch)
        # Delete revision from `project` tags
        for project in root.iter('project'):
            if project.get('revision'):
                project.attrib.pop('revision')

        tree.write(manifest_path, encoding='utf-8', xml_declaration=True)

    def push_manifest_release_branch(self, release_branch):
        with sdk2.helpers.ProcessLog(self, logger='process_log_4.push_manifest_release_branch') as pl:
            email_list = ['git', 'config',  'user.email', '"{}@yandex-team.ru"'.format(self.Parameters.ssh_user)]
            logging.info('Running %s', ' '.join(email_list))
            self.safe_run_command(email_list, stdout=pl.stdout, stderr=pl.stderr, cwd=self.MAINFESTS_PATH)

            name_list = ['git', 'config',  'user.name', '"{}"'.format(self.Parameters.ssh_user)]
            logging.info('Running %s', ' '.join(name_list))
            self.safe_run_command(name_list, stdout=pl.stdout, stderr=pl.stderr, cwd=self.MAINFESTS_PATH)

            add_list = ['git', 'add', self.MAINFEST_NAME]
            logging.info('Running %s', ' '.join(add_list))
            self.safe_run_command(add_list, stdout=pl.stdout, stderr=pl.stderr, cwd=self.MAINFESTS_PATH)

            push_list = ['git', 'commit', '-m', '"Create release branch {}"'.format(release_branch)]
            logging.info('Running %s', ' '.join(push_list))
            self.safe_run_command(push_list, stdout=pl.stdout, stderr=pl.stderr, cwd=self.MAINFESTS_PATH)

            if not self.Parameters.dry_run:
                commit_list = ['git', 'push', 'origin', release_branch]
                logging.info('Running %s', ' '.join(commit_list))
                self.safe_run_command(commit_list, stdout=pl.stdout, stderr=pl.stderr, cwd=self.MAINFESTS_PATH)

    def get_release_branch_name(self, branch_number):
        return "{}/{}{}".format(
            self.Parameters.platform,
            self.Parameters.gerrit_release_branch_prefix,
            branch_number
        )

    def get_branch_number(self, branch_name, branch_prefix=''):
        """
        Some examples of branch names:
            name='yandexstation_2/1', prefix='' -> result=1
            name='yandexstation/42', prefix='' -> result=42
            name='yandexmidi/factory-1', prefix='factory-' -> result=1
            name='yandexmidi/factory-1', prefix='' -> result=None
        """
        if branch_name.startswith(self.Parameters.platform):
            branch_number = branch_name[len(self.Parameters.platform):].lstrip('/-_')
            if branch_number.startswith(branch_prefix):
                try:
                    return int(branch_number[len(branch_prefix):])
                except ValueError as err:
                    logging.warning('Cannot parse branch number %s. Error: %s', branch_name, err)
            else:
                logging.info('Skipping branch: "{}" due to mismatching prefix: "{}")'.format(branch_name, branch_prefix))
        return None

    def get_removable_branches(self, repository_config, new_gerrit_branch):
        keep_last_x_branches = self.Parameters.keep_last_x_branches
        if not keep_last_x_branches or keep_last_x_branches < 1:
            return []
        removable = []
        new_gerrit_release_branch_number = self.get_branch_number(new_gerrit_branch, self.Parameters.gerrit_release_branch_prefix)
        if new_gerrit_release_branch_number:
            for branch, _ in repository_config['tracked_branches'].items():
                branch_number = self.get_branch_number(branch, self.Parameters.gerrit_release_branch_prefix)
                if branch_number and branch_number + keep_last_x_branches <= new_gerrit_release_branch_number:
                    removable.append(branch)
        return removable

    def copy_manifest(self, arcadia_directory, gerrit_branch_name):
        with sdk2.helpers.ProcessLog(self, logger='repo_manifest') as pl:
            repo_manifest_list = ['repo', 'manifest', '-r', '-o', self.COPY_MANIFEST_NAME]
            logging.info('Running %s', ' '.join(repo_manifest_list))
            self.safe_run_command(repo_manifest_list, stdout=pl.stdout, stderr=pl.stderr)

        arcadia_manifest_filename = os.path.join(arcadia_directory, gerrit_branch_name, self.COPY_MANIFEST_NAME)

        if not os.path.exists(arcadia_manifest_filename):
            branch_dirname = os.path.dirname(arcadia_manifest_filename)
            if not os.path.exists(branch_dirname):
                os.makedirs(branch_dirname)

        logging.info('Copying manifest to %s', arcadia_manifest_filename)
        shutil.copyfile(self.COPY_MANIFEST_NAME, arcadia_manifest_filename)
        sdk2.svn.Arcadia.add(os.path.join(arcadia_directory, gerrit_branch_name))
