import datetime
import getpass
import logging
import os
import re
import shutil
import six
import tarfile
import time
import zipfile
from hashlib import sha256
from subprocess import PIPE

import requests
import sandbox.projects.sandbox.resources as sb_resources
from sandbox import sdk2
from sandbox.common.types.task import Status
from sandbox.projects.common import constants as consts
from sandbox.projects.common.arcadia import sdk as arcadiasdk
from sandbox.projects.common.vcs import aapi
from sandbox.projects.security.ReportFuzzing.ReportCollectCorpusFromYQL import ReportCollectCorpusFromYQL
from sandbox.projects.security.ReportFuzzing.common import \
    export_into_zip, chmod_exec, drop_lite_deps, get_latest_resource_path, create_dir, add_tag, rm_tags, patch_file, \
    SHARED_DIR, REVISION_TAG, MARKET_REPORT_TAG, CORPUS_TAG, INDEX_TAG, MOUNTED_NAMESPACE
from sandbox.projects.security.ReportFuzzing.resources import ReportCollectedIndex, \
    ReportCollectedCorpus, ReportFiredCorpus, CorpusFromLogs, ReportCollectedLiteIndex, MARKET_REPORT
from sandbox.projects.security.ReportFuzzing.startrek import BrokenTicket
from sandbox.sandboxsdk import process, errors
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.sdk2 import WaitTask

BIN_MAPPINGS = {
    'bin/report': 'market/report/report_bin/report_bin'
}
WORK_DIR = os.getenv('TASK_DIR', '.')
MAX_FIRED_RES = 100  # Max amount of previously fired resources to get
# Index collector:
LITE_DIR = os.path.join('market', 'report', 'lite')
TEST_PRIME_RUN_SCRIPT = os.path.join(LITE_DIR, 'test_prime.py')
LITE_RUN_SCRIPT = os.path.join(LITE_DIR, 'run.py')
MEMCACHE_PATH = 'market/pylibrary/memcached/bin'
PYTHON2 = 'python'
COLLECT_INDEX_FLAG = '--collect-index'
COLLECT_CORPUS_FLAG = '--collect-requests'
MANGLED_IDS_FLAG = '--mangled-ids'
MANGLED_IDS_JSON = 'mangled_ids.json'
REPORT_GEN_CONFIG_PATH = '/dev/shm/lite-{}/test_prime/config'.format(getpass.getuser())

logger = logging.getLogger(__name__)
MAX_MOUNT_RETRIES = 5
MOUNT_RETRY_DELAY = 60  # seconds


class ReportCollectCorpusAndIndex(sdk2.Task):
    class Requirements(sdk2.Task.Requirements):
        # Some of them from:
        # https://a.yandex-team.ru/arc/trunk/arcadia/market/report/lite/requirements.txt
        environments = [
            # Building:
            PipEnvironment('wheel', version='0.30.0'),
        ]

    class Parameters(sdk2.Parameters):

        with sdk2.parameters.Group('General') as general:
            revision = sdk2.parameters.Integer(
                'Revision (0 means latest)',
                default_value=0,
                hint=True
            )
            drop_from_resources = sdk2.parameters.Bool(
                'Do not compile report. Drop executables from sandbox resources',
                default_value=False
            )
            force_update_logs = sdk2.parameters.Bool(
                'Execute YQL task to fetch updated corpus from logs',
                default_value=False
            )
            collect_lite_index = sdk2.parameters.Bool(
                'Collect lite index',
                default_value=False
            )
            collect_huge_index = sdk2.parameters.Bool(
                'Collect huge index '
                '(ids are mapped to corresponding corpus, so not compatible with lite)',
                default_value=True
            )
            with collect_huge_index.value[True]:
                index_pattern = sdk2.parameters.String(
                    'run.py --pattern parameter'
                )
            report_broken_collection = sdk2.parameters.Bool(
                'Report broken collection into startreck', default_value=False
            )
            with report_broken_collection.value[True]:
                report_broken_queue = sdk2.parameters.String(
                    'Report broken collection into startreck queue',
                    default_value='REPORTFUZZING', required=True
                )
        with sdk2.parameters.Group('Sources and filtering') as sources:
            collect_corpus_from_logs = sdk2.parameters.Bool(
                'Collect corpus from balancer logs', default_value=True
            )
            collect_fired_corpus = sdk2.parameters.Bool(
                'Collect fired corpus on previous fuzz runs', default_value=True
            )
            collect_corpus_from_tests = sdk2.parameters.Bool(
                'Collect corpus from tests', default_value=True
            )
            corpus_ignore_re = sdk2.parameters.String(
                'Filter out requests by this regexp',
                default_value=''
            )
        with sdk2.parameters.Group('Credentials') as credentials:
            yav_tokens = sdk2.parameters.YavSecret(
                'Yav tokens'
            )
            arc_yav_token_key = sdk2.parameters.String(
                'Yav ARC_TOKEN key',
                default_value='ARC_TOKEN'
            )
            yql_yav_token_key = sdk2.parameters.String(
                'Yav YQL_TOKEN key',
                default_value='YQL_TOKEN'
            )
            with report_broken_collection.value[True]:
                startreck_yav_token_key = sdk2.parameters.String(
                    'Yav STARTRECK_TOKEN key',
                    default_value='STARTRECK_TOKEN'
                )
        _container = sdk2.parameters.Container(
            'Report fuzzing environment',
            resource_type=sb_resources.LXC_CONTAINER,
            attrs=dict(target='report-fuzzing-env-v5'),
            required=True
        )
        with sdk2.parameters.Output():
            with sdk2.parameters.Group('Output') as results_group:
                out_revision = sdk2.parameters.Integer('Revision')

    def _build_targets(self, arc_root, *targets):
        extra_params = {
            consts.DOWNLOAD_ARTIFACTS_FROM_DISTBUILD: True,
            consts.FORCE_BUILD_DEPENDS: True
        }
        arcadiasdk.do_build(
            build_system=consts.SEMI_DISTBUILD_BUILD_SYSTEM,
            build_type=consts.RELEASE_BUILD_TYPE,
            add_result=['pb2.py', '.a', '.h', 'grpc.py'],
            source_root=arc_root,
            results_dir=arc_root,
            targets=targets,
            clear_build=False,
            sandbox_token=self.server.task_token,
            **extra_params
        )
        return list(map(lambda target: os.path.join(arc_root, target), targets))

    def _build_locally(self, arc_root, *targets, **params):
        arcadiasdk.do_build(
            build_system=consts.YMAKE_BUILD_SYSTEM,
            build_type=consts.RELEASE_BUILD_TYPE,
            results_dir=arc_root,
            source_root=arc_root,
            targets=targets,
            clear_build=False,
            sandbox_token=self.server.task_token,
            **params
        )
        return arc_root

    def basic_checks(self):
        if self.Parameters.collect_corpus_from_tests and not self.Parameters.collect_huge_index:
            raise Exception('It is impossible to collect corpus from tests without huge index')

    def on_execute(self):
        if self.Parameters.revision == 0:
            self._revision = aapi.ArcadiaApi.svn_head()
        else:
            self._revision = self.Parameters.revision

        if isinstance(self._revision, six.string_types) and not self._revision.isdigit():
            self.set_info('Fuzzing does not work with hotfix branches')
            return

        self.Parameters.out_revision = self._revision

        self.corpus_hashes = set()
        self._corpus_skip_re = None
        if str(self.Parameters.corpus_ignore_re) != '':
            self._corpus_skip_re = re.compile(str(self.Parameters.corpus_ignore_re))
        self._skipped_corpus = set()
        self.basic_checks()

        logger.info('Latest revision is {}'.format(self._revision))
        rm_tags(self)
        add_tag(self, REVISION_TAG.format(self._revision))
        for _ in range(MAX_MOUNT_RETRIES):
            try:
                self.collect_corpus_and_index()
                return  # arc mount succeeded
            except AssertionError as e:
                logger.warning(e)
                time.sleep(MOUNT_RETRY_DELAY)
                logger.info('Retrying arc mount one more time')
        raise Exception('Unable to mount arc r{} after {} retries (retry delay = {})'.format(
            self._revision, MAX_MOUNT_RETRIES, MOUNT_RETRY_DELAY)
        )

    def collect_corpus_and_index(self):
        mount_params = dict()
        if self.Parameters.yav_tokens:
            arc_token = self.Parameters.yav_tokens.data()[self.Parameters.arc_yav_token_key]
            os.environ.update({'ARC_TOKEN': arc_token})
            mount_params.update({'arc_oauth_token': arc_token})

        with arcadiasdk.mount_arc_path(
                self._get_arc_url(),
                **mount_params
        ) as arc_root:
            if self.Parameters.drop_from_resources:
                drop_lite_deps(arc_root)
                self._drop_report_executables(arc_root)
                self._build_locally(
                    arc_root,
                    MEMCACHE_PATH
                )
            else:
                self._build_targets(
                    arc_root,
                    'market/report',
                    'contrib/libs/googleapis-common-protos',
                    'contrib/libs/protoc',
                    'contrib/libs/grpc',
                    MEMCACHE_PATH
                )
            if self.Parameters.collect_huge_index:
                index_dir = self._collect_index(arc_root, os.path.join(SHARED_DIR, 'index'))
                if ReportCollectCorpusAndIndex._is_valid_index(index_dir):
                    self.export_huge_index(index_dir)
                else:
                    logger.error('Invalid huge index created. Check index collection error log for details')
                    self._report_broken_collection()

            if self.Parameters.collect_lite_index:
                index_dir = self._collect_lite_index(arc_root)

            if any((
                    self.Parameters.collect_corpus_from_logs,
                    self.Parameters.collect_fired_corpus,
                    self.Parameters.collect_corpus_from_tests
            )):
                corpus_dir = self._collect_corpus(
                    arc_root, index_dir,
                    from_logs=self.Parameters.collect_corpus_from_logs,
                    previously_fired=self.Parameters.collect_fired_corpus,
                    from_tests=self.Parameters.collect_corpus_from_tests,
                )
                self.export_corpus(corpus_dir)

    def _get_startrek_session(self):
        session = requests.Session()
        token = None
        if self.Parameters.yav_tokens:
            token = self.Parameters.yav_tokens.data().get(self.Parameters.startreck_yav_token_key)
        if token is None:
            token = sdk2.Vault.data(self.owner, self.Parameters.startreck_yav_token_key)
        session.headers.update({
            'Authorization': 'OAuth {}'.format(token)
        })
        return session

    def _report_broken_collection(self):
        queue = self.Parameters.report_broken_queue
        BrokenTicket(self, queue=queue).create(self._get_startrek_session())

    @staticmethod
    def _patch_lite_dir(arc_root, output_path):
        lite_paths_file = os.path.join(arc_root, LITE_DIR, 'core', 'paths.py')
        create_dir(output_path)
        substitute_str = '_INMEMORY_ROOT_TEMPLATE = \'{}\''.format(output_path)
        patch_file(lite_paths_file, '_INMEMORY_ROOT_TEMPLATE = \'.+\'', substitute_str)
        return os.path.join(output_path, 'test_prime')

    def _gen_corpus_out_of_requests_file(self, all_requests_file, output_dir):
        create_dir(output_dir)
        time_mark = datetime.datetime.now().strftime("%d.%m.%y")
        place = 'tests'
        with open(all_requests_file) as reqs_f:
            for i, req in enumerate(reqs_f):
                data = 'GET {} HTTP/1.1\nHost: localhost\n\n'.format(req.strip())
                data_hash = sha256(data.encode('utf-8')).hexdigest()
                if data_hash in self.corpus_hashes:
                    continue
                if self._corpus_skip_re is not None and self._corpus_skip_re.search(data):
                    self._skipped_corpus.add('{}:{}'.format(all_requests_file, i))
                    continue
                self.corpus_hashes.add(data_hash)
                f_name = '{}-{}-{}.cov'.format(time_mark, place, data_hash)
                with open(os.path.join(output_dir, f_name), 'w') as out_f:
                    out_f.write(data)
        return output_dir

    def _get_arc_url(self):
        return 'arcadia-arc:/#r{}'.format(self._revision)

    def _collect_from_logs(self, write_dir, corpus_dir, resource):
        logs_corpus_zipfile = str(sdk2.ResourceData(resource).path)
        with zipfile.ZipFile(logs_corpus_zipfile) as zip:
            corpus_folder_name = os.path.dirname(zip.infolist()[0].filename)
            zip.extractall(path=os.path.join(write_dir))
        extracted_corpus_dir = os.path.join(write_dir, corpus_folder_name)
        self._collect_unique_corpus_items(extracted_corpus_dir, corpus_dir)

    def _collect_corpus(
            self,
            arc_root,
            index_dir,
            from_logs=True,
            previously_fired=True,
            from_tests=True,
            workdir=WORK_DIR,
            tmp='tmp'
    ):
        """
        Collects resulting corpus into one folder
        :param from_logs: Get corpus items from balancer logs (collected by REPORT_COLLECT_CORPUS task)
        :param previously_fired: Corpus collected from fuzzer previous findings
        :param from_tests: Corpus collected from tests via python script
        :return: Path to folder, containing collected corpus
        """
        index_dir = os.path.abspath(index_dir)
        tmp = os.path.abspath(tmp)
        logger.info('Collecting corpus ...')
        corpus_dir = create_dir(workdir, 'corpus')
        tmp_dir = create_dir(tmp)

        from_logs_update = False
        if from_logs:
            logger.info('Collecting corpus from logs')
            write_dir = create_dir(tmp_dir, 'logs_corpus')
            resource = sdk2.Resource.find(
                type=CorpusFromLogs, state='READY'
            ).order(-sdk2.Resource.id).first()
            if resource is None or self.Parameters.force_update_logs:
                from_logs_update = True
                child = ReportCollectCorpusFromYQL(
                    self,
                    description='[Autorun:{}] Collect corpus from YQL logs'.format(self.id),
                    owner=self.owner,
                    yql_yav_token=self.Parameters.yav_tokens,
                    yav_token_key=self.Parameters.yql_yav_token_key
                ).enqueue()
            else:
                self._collect_from_logs(write_dir, corpus_dir, resource)

        if previously_fired:
            logger.info('Collecting previously fired corpus')
            write_dir = create_dir(tmp_dir, 'fired_corpus')
            resources = sdk2.Resource.find(
                type=ReportFiredCorpus, state='READY'
            ).order(-sdk2.Resource.id).limit(MAX_FIRED_RES)
            for res in resources:
                fired_corp_filepath = str(sdk2.ResourceData(res).path)
                logger.info('Processing resource {}'.format(fired_corp_filepath))
                with zipfile.ZipFile(fired_corp_filepath) as zip:
                    if len(zip.infolist()) == 0:
                        logger.info('Fired corpus {} is empty'.format(fired_corp_filepath))
                        continue
                    elif os.path.isdir(zip.infolist()[0].filename):
                        corpus_folder_name = os.path.dirname(zip.infolist()[0].filename)
                    else:
                        corpus_folder_name = ''
                    zip.extractall(path=write_dir)

                extracted_corpus_dir = os.path.join(write_dir, corpus_folder_name)
                self._collect_unique_corpus_items(extracted_corpus_dir, corpus_dir)

        if from_tests:
            logger.info('Collecting corpus from tests')
            write_dir = create_dir(tmp_dir, 'tests_corpus')
            all_req = os.path.join(write_dir, 'all_requests.txt')
            cmd = [
                PYTHON2, os.path.join(arc_root, LITE_RUN_SCRIPT), COLLECT_CORPUS_FLAG, all_req,
                MANGLED_IDS_FLAG, os.path.join(index_dir, MANGLED_IDS_JSON)
            ]
            if self.Parameters.index_pattern:
                cmd += ['--pattern', self.Parameters.index_pattern]
            with open(os.path.join(workdir, 'collect_corpus_log.txt'), 'w') as corpus_log_file:
                process.run_process(
                    cmd, log_prefix='corpus_logs', work_dir=workdir,
                    outputs_to_one_file=False, stdout=corpus_log_file, wait=True, check=True
                )

            self._gen_corpus_out_of_requests_file(all_req, corpus_dir)

        if from_logs_update:
            WaitTask(
                [child.id],
                list(Status.Group.FINISH + Status.Group.BREAK),
                wait_all=True,
                timeout=60 * 60  # seconds
            )
            resource = sdk2.Resource.find(
                type=CorpusFromLogs, state='READY'
            ).order(-sdk2.Resource.id).first()
            write_dir = create_dir(tmp_dir, 'logs_corpus')
            self._collect_from_logs(write_dir, corpus_dir, resource)

        logger.info('Collecting corpus DONE')
        return corpus_dir

    @staticmethod
    def _is_valid_index(index_path):
        if not os.path.exists(index_path):
            logger.error('Path not exists {}'.format(index_path))
            return False
        must_contain = (
            'mangled_ids.json',
            os.path.join('report_meta', 'config', 'report_config.conf')
        )
        for item in must_contain:
            if not os.path.exists(os.path.join(index_path, item)):
                logger.error('No {} found'.format(os.path.join(index_path, item)))
                return False
        return True

    def _collect_index(self, arc_root, into_dir, workdir=WORK_DIR):
        create_dir(into_dir)
        into_dir = os.path.abspath(into_dir)
        logger.info('Collecting index into {}'.format(into_dir))
        cmd = [
            PYTHON2, os.path.join(arc_root, LITE_RUN_SCRIPT), COLLECT_INDEX_FLAG, into_dir
        ]
        if self.Parameters.index_pattern:
            cmd += ['--pattern', self.Parameters.index_pattern]
        with open(os.path.join(workdir, 'collect_index_log.txt'), 'w') as index_log_file:
            process.run_process(
                cmd, log_prefix='index_logs', work_dir=workdir,
                outputs_to_one_file=False, stdout=index_log_file,
                wait=True, check=True
            )
            logger.info('Collected index {}: {}'.format(workdir, os.listdir(into_dir)))
        return into_dir

    def _collect_lite_index(self, arc_root, workdir=WORK_DIR):
        logger.info('Collecting lite index')
        cmd = (
            PYTHON2, os.path.join(arc_root, TEST_PRIME_RUN_SCRIPT), '-s'
        )
        lite_index = ReportCollectedLiteIndex(
            self,
            'Lite index',
            'lite_index_r{}'.format(self._revision),
            revision=self._revision
        )
        lite_index_res_data = sdk2.ResourceData(lite_index)
        lite_index_path = str(lite_index_res_data.path)
        default_outp_path = ReportCollectCorpusAndIndex._patch_lite_dir(arc_root, lite_index_path)
        with open(os.path.join(workdir, 'collect_lite_index_log.txt'), 'w') as index_log_file:
            proc = process.run_process(
                cmd, stdin=PIPE, log_prefix='lite_index_logs', work_dir=workdir,
                outputs_to_one_file=False, stdout=index_log_file,
                wait=False
            )
            proc.stdin.write(b'\n\n')
            proc.stdin.close()
            proc.communicate()
            if proc.returncode != 0:
                msg = 'Lite index collection "{0}" died with exit code {1}'
                raise errors.SandboxSubprocessError(
                    message=msg.format(proc.saved_cmd, proc.returncode),
                    cmd_string=proc.saved_cmd,
                    returncode=proc.returncode,
                    stderr_path=proc.stderr_path_filename,
                    stdout_path=proc.stdout_path_filename,
                    stderr_full_path=proc.stderr_path,
                    stdout_full_path=proc.stdout_path
                )
        logger.info('Collected lite index {}: {}'.format(default_outp_path, os.listdir(default_outp_path)))
        report_config = os.path.join(default_outp_path, 'report_meta', 'config', 'report_config.conf')
        logger.info('Patching PidDir in {} (mounted ns: {})'.format(report_config, MOUNTED_NAMESPACE))
        patch_file(report_config, r'PidDir\s+.+', 'PidDir {}'.format(MOUNTED_NAMESPACE))
        patch_file(report_config, r'\bRunningFlagFilePath\b.*', 'RunningFlagFilePath')
        lite_index_res_data.ready()
        return lite_index_path

    def _collect_unique_corpus_items(self, collect_from, accumulate_into):
        for item in os.listdir(collect_from):
            if os.path.isdir(item):
                self._collect_unique_corpus_items(os.path.join(collect_from, item), accumulate_into)
                continue
            item_filepath = os.path.join(collect_from, item)
            with open(item_filepath, 'rb') as f:
                bin_data = f.read()
            if self._corpus_skip_re is not None and self._corpus_skip_re.search(bin_data):
                self._skipped_corpus.add(item_filepath)
                continue  # Skipping this corpus item

            item_hash = sha256(bin_data).hexdigest()
            if item_hash not in self.corpus_hashes:
                self.corpus_hashes.add(item_hash)
                shutil.copy(item_filepath, accumulate_into)
        logger.info('Unique corpus items count = {}'.format(len(self.corpus_hashes)))

    def _drop_report_executables(self, arc_root, tmp_dir='execs'):
        create_dir(tmp_dir)
        logger.info('Dropping executables ...')
        report_execs, res_id = self._get_resource(MARKET_REPORT)
        if report_execs is None:
            raise Exception('No {} resource found with the revision {}'.format(MARKET_REPORT, self._revision))
        task_tags = self.server.task[self.id].read()["tags"]
        task_tags += [MARKET_REPORT_TAG.format(res_id)]
        self.server.task[self.id].update({"tags": task_tags})
        with tarfile.open(report_execs) as tar:
            tar.extractall(path=tmp_dir)
        for f_from, f_to in BIN_MAPPINGS.items():
            src = os.path.join(tmp_dir, f_from)
            dst = os.path.join(arc_root, f_to)
            shutil.copy(src, dst)
            chmod_exec(dst)
        logger.info('Dropping report executable DONE')

    def _get_resource(self, res_type):
        return get_latest_resource_path(res_type, attrs={'svn_revision': self._revision})

    def export_corpus(self, filepath, workdir=WORK_DIR):
        corpus_filepath, n = export_into_zip(filepath, 'corpus_r{}.zip'.format(self._revision), workdir=workdir)
        res = ReportCollectedCorpus(
            self,
            'Corpus ({} items)'.format(n),
            corpus_filepath,
            revision=self._revision,
            minimized=False,
            from_logs=self.Parameters.collect_corpus_from_logs,
            from_tests=self.Parameters.collect_corpus_from_tests,
            from_fuzz_runs=self.Parameters.collect_fired_corpus
        )
        add_tag(self, CORPUS_TAG.format(res.id))
        sdk2.ResourceData(res).ready()

    def export_huge_index(self, filepath, workdir=WORK_DIR):
        index_filepath, _ = export_into_zip(filepath, 'index_r{}.zip'.format(self._revision), workdir=workdir)
        res = ReportCollectedIndex(
            self,
            'Index',
            index_filepath,
            revision=self._revision
        )
        add_tag(self, INDEX_TAG.format(res.id))
        sdk2.ResourceData(res).ready()
