# coding: utf-8
import os
import os.path
import json
import tempfile
import requests
import datetime
import re
import logging
import errno
import string
import copy
import dateutil.parser
import hashlib

import sandbox.common.types.misc as ctm
from sandbox.sandboxsdk.parameters import SandboxStringParameter
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.svn import Arcadia
from sandbox.sandboxsdk.channel import channel

from sandbox.projects.yadi.FeedDeploy import YadiFeedDeploy

RANGE_VER_RE = re.compile(r'^v?(?P<left>\d+(?:\.\d+)?(?:\.\d+)?).*?-(?P<right>\d+(?:\.\d+)?(?:\.\d+)?).*$')
SINGLE_VER_RE = re.compile(r'^v?(?P<left>\d+(?:\.\d+)?(?:\.\d+)?)')
SVN_PATH = 'arcadia:/arc/trunk/arcadia/security/yadi/db@HEAD'
SVN_USER = 'zomb-sandbox-rw'
WIKI_GRID = 'https://wiki-api.yandex-team.ru/_api/frontend/product-security/yadi/vulns/db/.grid?format=json'
WIKI_KEY_MAP = {
    '100': 'title',
    '101': 'language',
    '102': 'issue',
    '103': 'module_name',
    '104': 'vulnerable_versions',
    '105': 'patched_versions'
}


class MailTo(SandboxStringParameter):
    name = 'mail_to'
    description = 'Send mail with database updates'
    default_value = 'buglloc@yandex-team.ru'


class YadiFeedUpdate(SandboxTask):
    type = 'YADI_FEED_UPDATE'
    dns = ctm.DnsType.DNS64
    input_parameters = (MailTo,)

    wiki_token = None
    work_dir = None
    database_dir = None

    def on_execute(self):
        if self.ctx.get('done'):
            # Task resumed after waiting
            return

        self.wiki_token = self.get_vault_data(self.owner, 'YADI_TOKEN')

        self.work_dir = self.clone_repo()
        self.database_dir = os.path.join(self.work_dir, 'yadi')
        make_sure_dir_exists(self.database_dir)
        logging.info('Repo cloned into: %s' % self.work_dir)

        self.sync_feed('nodejs', [('sc', self.fetch_sc), ('wiki', self.fetch_wiki)])
        self.sync_feed('python', [('sc', self.fetch_sc), ('wiki', self.fetch_wiki)])

        deploy = False
        commit = False
        status = Arcadia.status(self.work_dir).split('\n')
        if status:
            for st in status:
                if not st.strip():
                    continue
                commit = True

                path = st[1:].strip()
                logging.info('Commited: %s', path)
                if st[:1] == '?':
                    path = os.path.join(self.work_dir, path)
                    Arcadia.add(path)
                if 'yadi' in path:
                    deploy = True

        if commit:
            # Commit current feed
            comment = 'SKIP_CHECK Update Yadi database from %s' % datetime.datetime.now().strftime('%Y-%m-%d')
            Arcadia.commit(self.work_dir, comment, SVN_USER)
            logging.info('Commited updated advisories')

        if deploy:
            # Shedule deploy
            task = self.create_subtask(
                task_type=YadiFeedDeploy.type,
                description='%s subtask for #%d (%s)' % (YadiFeedDeploy.type, self.id, self.descr))
            self.ctx['done'] = True
            self.wait_task_completed(task, state='Waiting for feed deploy to complete')

    def sync_feed(self, lang, sources=[]):
        feed_keys, feed = self.get_advisories(lang)
        logging.info('Loaded %d %s advisories', len(feed_keys), lang)

        actual_advisories = []
        new_advisories = []
        removed_advisories = []
        for source in sources:
            s, f = source
            removed_adv, new_adv, actual_adv = self.sync_ext_feed(source=s, fetcher=f, lang=lang)
            logging.info('Synced %s advisories from %s: new %d, removed %s', lang, s, len(new_adv), len(removed_adv))
            new_advisories += new_adv
            removed_advisories += removed_adv
            actual_advisories += actual_adv

        actual_advisories_map = {}
        for issue in actual_advisories:
            actual_advisories_map[issue['id']] = issue

        new_issues = []
        for issue in new_advisories:
            if issue['id'] in feed_keys:
                # Already known about it
                continue
            feed.append(issue)
            new_issues.append(issue)

        actual_feed = []
        removed_issues = []
        for issue in feed:
            if issue['id'] not in removed_issues:
                actual = actual_advisories_map.get(issue['id'], {})
                issue['desc'] = actual.get('desc', issue['desc'])
                issue['cvss_score'] = actual.get('cvss_score', issue['cvss_score'])
                issue['patched_versions'] = actual.get('patched_versions', issue['patched_versions'])
                issue['vulnerable_versions'] = actual.get('vulnerable_versions', issue['vulnerable_versions'])
                issue['cvss_score'] = actual.get('cvss_score', issue['cvss_score'])
                issue['cvss_score'] = actual.get('cvss_score', issue['cvss_score'])

                actual_feed.append(issue)
                continue
            removed_issues.append(issue)

        self.update_advisories(lang, actual_feed)
        self.notify(lang, new_issues, removed_issues)

    def get_advisories(self, lang):
        feed_path = os.path.join(self.database_dir, '%s.json' % lang)

        yadi_feed = []
        if os.path.exists(feed_path):
            with open(feed_path, 'r') as f:
                yadi_feed = json.loads(f.read())
        yadi_keys = set([i['id'] for i in yadi_feed])
        return yadi_keys, yadi_feed

    def update_advisories(self, lang, advisories):
        feed_path = os.path.join(self.database_dir, '%s.json' % lang)
        with open(feed_path, 'w') as f:
            f.write(json.dumps(advisories, indent=2, separators=(',', ': '), sort_keys=True))

    def sync_ext_feed(self, source=None, fetcher=None, lang=None):
        raw_advisories = fetcher()
        logging.info('Fetched %s advisories: %d', source, len(raw_advisories))

        feed_path = os.path.join(self.work_dir, source, '%s.json' % lang)
        keys = []
        if not os.path.exists(feed_path):
            keys = []
        else:
            with open(feed_path, 'r') as f:
                feed = json.loads(f.read())
                keys = set([i['id'] for i in feed])

        advisories = []
        for issue in raw_advisories:
            if not lang or issue['language'] == lang:
                advisories.append(issue)
        actual_keys = set([i['id'] for i in advisories])

        new_advisories = []
        actual_advisories = []
        for issue in advisories:
            if issue['id'] in keys:
                # Already known about it
                actual_advisories.append(issue)
                continue
            new_advisories.append(issue)

        removed_advisories = []
        for key in keys:
            if key not in actual_keys:
                removed_advisories.append(key)

        make_sure_dir_exists(os.path.dirname(feed_path))
        with open(feed_path, 'w') as f:
            f.write(json.dumps(advisories, indent=2, separators=(',', ': '), sort_keys=True))

        return removed_advisories, new_advisories, actual_advisories

    def clone_repo(self):
        repo_path = tempfile.NamedTemporaryFile().name
        Arcadia.checkout(SVN_PATH, repo_path)
        return repo_path

    def nsp_to_yadi(self, issue):
        disclosed = issue.get('created_at', None)
        return {
            'id': 'nsp:%d' % issue.get('id'),
            'reference': 'https://yadi.yandex-team.ru/vulns/vuln/nsp:%d' % issue.get('id'),
            'external_references': [{
                'title': None,
                'url': 'https://nodesecurity.io/advisories/%d' % issue.get('id')
            }],
            'language': 'nodejs',
            'desc': issue.get('overview', ''),
            'title': issue.get('title'),
            'module_name': issue.get('module_name'),
            'cvss_score': issue.get('cvss_score') if issue.get('cvss_score') else 0.0,
            'patched_versions': issue.get('patched_versions', None),
            'disclosed': dateutil.parser.parse(disclosed).strftime('%Y-%m-%d') if disclosed else None,
            'vulnerable_versions': issue.get('vulnerable_versions') if issue.get('vulnerable_versions') else '*'
        }

    def fetch_nsp(self):
        offset = 0
        advisories = []
        while True:
            resp = requests.get(
                'https://api.nodesecurity.io/advisories',
                params={'offset': offset},
                headers={'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36'},
                verify=False  # Skynet, I look at you
            )
            if resp.status_code != 200:
                raise Exception('Failed to fetch NSP feed (%d): %s' % (resp.status_code, resp.content))

            data = resp.json()
            if not data.get('results'):
                break

            for issue in data.get('results', []):
                advisories.append(self.nsp_to_yadi(issue))
            offset += data.get('count', 100)

        return sorted(advisories, key=lambda k: k['id'])

    def fetch_wiki(self):
        session = requests.Session()
        session.headers['Authorization'] = 'OAuth {}'.format(self.wiki_token)
        session.headers['Content-Type'] = 'application/json'
        resp = session.get(
            WIKI_GRID,
            verify=False  # Skynet, I look at you
        )
        if resp.status_code != 200:
            raise Exception('Failed to fetch Wiki feed (%d): %s' % (resp.status_code, resp.content))

        advisories = []
        for row in resp.json().get('data', {}).get('rows', []):
            issue = {
                'id': row[0]['row_id']
            }

            for field in row:
                key = WIKI_KEY_MAP[field['__key__']]
                # Ugly hack :(
                if key == 'language':
                    issue[key] = re.sub('[^a-z]', '', field['raw'][0].lower())
                else:
                    issue[key] = field['raw']

            advisories.append(self.wiki_to_yadi(issue))

        return sorted(advisories, key=lambda k: k['id'])

    def wiki_to_yadi(self, issue):
        return {
            'id': 'ya:%s' % issue.get('id'),
            'reference': 'https://st.yandex-team.ru/%s' % issue.get('issue'),
            'external_references': [{
                'title': None,
                'url': 'https://st.yandex-team.ru/%s' % issue.get('issue')
            }],
            'language': issue.get('language'),
            'desc': '',  # TODO:?
            'title': issue.get('title'),
            'module_name': issue.get('module_name'),
            'cvss_score': 8.0,  # TODO:?
            'disclosed': None,  # TODO:?
            'patched_versions': issue.get('patched_versions', None),
            'vulnerable_versions': issue.get('vulnerable_versions') if issue.get('vulnerable_versions') else '*'
        }

    def fetch_sc(self):
        session = requests.Session()
        session.headers['Content-Type'] = 'application/json'
        resp = session.get(
            'https://www.buglloc.com/static/sc.json',
            verify=False  # Skynet, I look at you
        )
        if resp.status_code != 200:
            raise Exception('Failed to fetch Wiki feed (%d): %s' % (resp.status_code, resp.content))

        advisories = []
        for issue in resp.json().get('advisories', []):
            if issue['language'] == 'JAVA':
                issue['language'] = 'java'
            elif issue['language'] == 'PYTHON':
                issue['language'] = 'python'
            elif issue['language'] == 'JS':
                issue['language'] = 'nodejs'
            else:
                continue

            advisories.extend(self.sc_to_yadi(issue))

        return sorted(advisories, key=lambda k: k['id'])

    def sc_to_yadi(self, issue):
        result = []
        lang = issue['language']
        repos = []
        if lang == 'java':
            repos = ['maven']
        elif lang == 'python':
            repos = ['pypi']
        elif lang == 'nodejs':
            repos = ['npm']

        score = issue.get('srcclrCvssScore', 0.0)
        if not score:
            score = issue.get('nvdCvssScore', 0.0)

        disclosed = issue.get('disclosureDate', None)
        base_issue = {
            'id': 'sc:%d' % issue.get('id'),
            'reference': 'https://yadi.yandex-team.ru/vulns/vuln/sc:%d' % issue.get('id'),
            'language': lang,
            'desc': issue.get('overview', ''),
            'title': issue.get('title'),
            'cvss_score': float(score),
            'disclosed': dateutil.parser.parse(disclosed).strftime('%Y-%m-%d') if disclosed else None,
            'external_references': []
        }

        for r in issue.get('artifactLinks', []):
            if r['type'] in ('RELATED_ARTIFACT', 'FOUND_BY'):
                continue
            base_issue['external_references'].append({
                'title': r.get('title', None),
                'url': r.get('url')
            })

        base_issue['external_references'].append({
            'title': 'SourceClear',
            'url': 'https://sourceclear.com/registry/vulnerabilities/%d' % issue.get('id')
        })

        names = []
        for component in issue.get('artifactComponents', []):
            coord = component.get('coordHash', None)
            if coord:
                coord, _ = coord.split(':', 1)
                if coord not in repos:
                    continue

            name = component.get('coordOne', '')
            if component.get('coordTwo', None):
                if lang != 'java':
                    # FIX ME!
                    logging.info('Skip: %s:%s', component.get('coordOne'), component.get('coordTwo'))
                    continue
                name += ':%s' % component.get('coordTwo')

            name = name.lower()
            if name in names:
                continue
            names.append(name)

            item = copy.copy(base_issue)
            item['module_name'] = name
            issue_id = int(hashlib.sha1(name).hexdigest(), 16) % (10 ** 8)
            item['id'] += ':%d' % issue_id
            item['reference'] += ':%d' % issue_id

            affected_vers = []
            patched_vers = []
            for version in component.get('versionRanges'):
                ver_range = version.get('versionRange')
                affected_ver = []
                for v in ver_range.split(','):
                    af = clear_version(v.strip())
                    if not af:
                        logging.info('Failed to parse version: %s', v)
                        continue
                    affected_ver.append(af)
                affected_vers.append(' || '.join(affected_ver))
                patched_ver = version.get('updateToVersion')
                if patched_ver:
                    patched_vers.append('>%s' % clear_version(patched_ver))

            if not affected_vers:
                continue

            item['vulnerable_versions'] = ' || '.join(affected_vers)
            if patched_vers:
                item['patched_versions'] = ' || '.join(patched_vers)
            else:
                item['patched_versions'] = None
            result.append(item)
        return result

    def notify(self, lang, new_issues, removed_issues):
        if not new_issues and not removed_issues:
            # Nothing to notify
            return

        recipients = self.ctx.get(MailTo.name, '').split(',')
        recipients = map(string.strip, recipients)

        if len(recipients) == 0:
            return

        subj = 'Yadi DB update for %s from %s' % (lang, datetime.datetime.now().strftime('%Y-%m-%d'))
        message = []
        if new_issues:
            message.append('New vulnerabilities:')
            for issue in new_issues:
                message.append('\t- %s in %s, details: %s' % (issue['title'], issue['module_name'], issue['reference']))
            message.append('')

        if removed_issues:
            message.append('Removed vulnerabilities:')
            for issue in removed_issues:
                message.append('\t- %s in %s, details: %s' % (issue['title'], issue['module_name'], issue['reference']))
            message.append('')
        message = '\n'.join(message)
        channel.sandbox.send_email(recipients, [], subj, message)


def make_sure_dir_exists(path):
    try:
        os.makedirs(path)
    except OSError as exception:
        if exception.errno != errno.EEXIST:
            raise


def ver_padding(version):
    s = version.split('.')
    if len(s) == 1:
        return '.'.join(s + ['0', '0'])
    if len(s) == 2:
        return '.'.join(s + ['0'])
    return version


def clear_version(version):
    def range_ver(version):
        r = RANGE_VER_RE.search(version)
        if not r:
            return None

        groups = r.groupdict(None)
        return groups.get('left'), groups.get('right')

    def single_ver(version):
        r = SINGLE_VER_RE.search(version)
        if not r:
            return None

        groups = r.groupdict(None)
        return groups.get('left'), None

    ver = range_ver(version)
    if not ver:
        ver = single_ver(version)
    if not ver:
        return None

    left, right = ver
    if not left:
        return None
    if right and left != right:
        return '>=%s <=%s' % (ver_padding(left), ver_padding(right))
    return '=%s' % ver_padding(left)
