# -*- coding: utf-8 -*-
import os
import sys
import time
import base64
import hashlib
import tempfile
import subprocess
import unicodedata
import json
import re
from urllib.parse import urlparse
from logging import getLogger
import string
import traceback
import datetime

import requests
import psutil
import dns.resolver
import dns.reversename

from .notification.errors import catch_error

from tvm2 import TVM2
from ticket_parser2_py3 import BlackboxClientId

import xml.etree.ElementTree as etree
from lxml.etree import XMLSyntaxError

from django.conf import settings


logger = getLogger(__name__)


def get_oauth_token(code):
    try:
        req = {
            'grant_type': 'authorization_code',
            'code': code,
            'client_id': settings.OAUTH_CLIENT_ID,
            'client_secret': settings.OAUTH_CLIENT_SECRET
        }
        r = requests.post(settings.OAUTH_URL, data=req)
        resp = r.json()
    except Exception:
        catch_error()
        return ''
    else:
        return resp.get('access_token', '')
    return ''


def prepare_skynet_resource(resource_id, report_path):
    try:
        process = subprocess.Popen([settings.SKYNET_BIN, 'files', "--json", resource_id],
                                   stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    except Exception:
        catch_error()
        return ''

    sky_out = ''
    # should not take long
    while process.returncode is None:
        time.sleep(1)
        process.poll()
        sky_out += process.stdout.read()

    if not sky_out:
        return ''

    try:
        res_info = json.loads(sky_out)
    except Exception:
        catch_error()
        return ''

    if not isinstance(res_info, list):
        return ''

    for item in res_info:
        if item.get('type', '') not in ['file']:
            continue
        filename = item.get('name', '')
        work_dir = tempfile.mkdtemp(suffix='sky_get_')
        try:
            process = subprocess.Popen([settings.SKYNET_BIN, 'get', '-d', work_dir, resource_id],
                                       stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        except Exception:
            catch_error()
            return ''

        # should not take long, but never know :(
        while process.returncode is None:
            time.sleep(1)
            process.poll()

        res_fd = tempfile.NamedTemporaryFile(dir=report_path, delete=False)
        res_fd.close()
        os.rename(os.path.join(work_dir, filename), res_fd.name)
        os.rmdir(work_dir)
        return res_fd.name
    return ''


def upload_to_elliptics(k, v, public=False, ttl=0):
    url = settings.ELLIPTICS_PROXY.get('write') + k
    try:
        if ttl:
            data = requests.post(url, data=v, params={'expire': '{ttl:d}d'.format(ttl=ttl)},
                                 headers={'Authorization': settings.ELLIPTICS_AUTH})
        else:
            data = requests.post(url, data=v, headers={'Authorization': settings.ELLIPTICS_AUTH})
        tree = etree.fromstring(data.text)
        key = tree.get('key')
        if key is None:
            return ''
        if public:
            return settings.ELLIPTICS_PROXY.get('public_read') + key
        return settings.ELLIPTICS_PROXY.get('read') + key
    except Exception:
        catch_error()
        pass
    return ''


def read_from_mds(k):
    #XXX: this is ugly
    if k.startswith(settings.ELLIPTICS_PROXY.get('read')):
        url = k
    else:
        url = settings.ELLIPTICS_PROXY.get('read') + k
    try:
        r = requests.get(url, headers={'Authorization': settings.ELLIPTICS_AUTH})
        return r.content
    except Exception:
        catch_error()
        pass
    return ''


def join_by_domain(urls):
    domains = dict()
    for item in urls:
        urp = urlparse(item)
        key = '://'.join([urp.scheme, urp.netloc])
        data = item.replace(key, '').strip()
        if domains.get(key):
            domains[key].append(data)
        else:
            domains[key] = []
            domains[key].append(data)
    return domains


def killpstree(pid):
    parent = psutil.Process(pid)
    for child in parent.children(True):
        if child.is_running():
            child.kill()
    parent.kill()


def is_l7_any(hostname):
    ip, has_ipv6 = do_resolve(hostname)
    if not ip:
        return False
    ptr = get_ptr(ip)
    return ptr == 'any.yandex.ru'


def is_s3(url):
    up = urlparse(url)
    hostname = up.netloc
    if ':' in hostname:
        hostname, port = hostname.split(':', 1)
    ip, has_ipv6 = do_resolve(hostname)
    if not ip:
        return False
    ptr = get_ptr(ip)
    return ptr == 's3.yandex.net'


def is_hamster(hostname):
    for tld in settings.PASSPORT_TLDS:
        if hostname.endswith('.'.join(['.hamster.yandex', tld])):
            return True
    return False


def is_l7(url):
    up = urlparse(url)
    hostname = up.netloc
    if ':' in hostname:
        hostname, port = hostname.split(':', 1)
    L7 = []
    for domain in settings.PASSPORT_TLDS:
        L7.append('.'.join(['yandex', domain]))
        L7.append('.'.join(['www.yandex', domain]))
        L7.append('.'.join(['hamster.yandex', domain]))
    real_hostname = resolve_cname(hostname)
    if real_hostname and real_hostname != hostname:
        return any([real_hostname in L7, is_hamster(real_hostname)])
    return any([hostname in L7, is_hamster(hostname)])


def get_abc_id_by_slug(slug):
    headers = {'Authorization': 'OAuth ' + settings.OAUTH_TOKEN_FOR_STATIC}
    try:
        r = requests.get(settings.ABC_URL + '/api/v4/services/',
                         params={'slug': slug, 'page_size': 1},
                         headers=headers)
        response = r.json()
    except Exception:
        catch_error()
        return 0
    else:
        if response.get('results'):
            return response['results'][0].get('id', 0)
    return 0


def get_st_queue_by_abc_id(abc_id):
    headers = {'Authorization': 'OAuth ' + settings.OAUTH_TOKEN_FOR_STATIC}
    try:
        r = requests.get(settings.ABC_URL + '/api/v4/services/contacts/',
                         params={'service': abc_id, 'contact_type': 4},
                         headers=headers)
        response = r.json()
    except Exception:
        catch_error()
        return ''
    if not response.get('results'):
        return ''
    return response['results'][0].get('content')


def get_abc_members(abc_id):
    users = []
    headers = {'Authorization': 'OAuth {}'.format(settings.OAUTH_TOKEN_FOR_STATIC)}
    try:
        r = requests.get(settings.ABC_URL + '/api/v4/services/members/',
                         params={'role__scope': ['administration', 'development',
                                                 'testing', 'devops'],
                                 'service': abc_id},
                         headers=headers)
        response = r.json()
    except Exception:
        catch_error()
        return users
    else:
        if not response.get('results'):
            return users
        for item in response.get('results', []):
            login = item.get('person', {}).get('login')
            if not login or login in users:
                continue
            users.append(login)
    return users


def get_abc_id_by_url(url):
    return get_abc_id_by_slug(url.split('/')[-1])


def get_st_queues_from_abc_by_username(username, cookies):
    res = requests.get(settings.ABC_URL + '/v2/projects', params={
        'is_deleted': 'false',
        '_query': "team.person.login=='%s'" % username,
        '_fields': 'queue_links,team.role'
    }, cookies=cookies)
    queues = set()
    try:
        data = res.json()
    except Exception:
        catch_error()
        return []
    else:
        for item in data.get('result', []):
            if not item.get('queue_links'):
                continue
            for queue in item.get('queue_links', []):
                queues.add(queue.get('queue'))
    return [q for q in list(queues) if q]


def get_host_resp(fqdn):
    r = requests.get(settings.GOLEM_HOST_RESP, params={'hostname': fqdn})
    if r.status_code != 200 or not r.text:
        return []
    return [x for x in [x.strip() for x in [x for x in r.text.split(',') if x]] if x]


def do_resolve(hostname, depth=0):
    has_ipv6 = False
    dns_ans = None
    if depth > 5:
        return None, has_ipv6
    # try ipv6 first
    try:
        dns_ans = dns.resolver.query(hostname, 'AAAA')
    except dns.resolver.NoAnswer:
        pass
    except Exception:
        return None, has_ipv6
    else:
        has_ipv6 = True
    if not dns_ans:
        try:
            dns_ans = dns.resolver.query(hostname, 'A')
        except dns.resolver.NoAnswer:
            return None, has_ipv6
        except Exception:
            return None, has_ipv6
    CNAME = 5
    for rec in dns_ans.response.answer:
        if rec.rdtype == CNAME:
            return do_resolve(str(rec).split()[-1].strip('.'), depth+1)
        return str(rec).split()[-1].strip('.'), has_ipv6


def resolve_cname(hostname, depth=0):
    if depth > 5:
        return None
    try:
        dns_ans = dns.resolver.query(hostname, 'CNAME')
    except dns.resolver.NoAnswer:
        if depth > 0:
            return hostname
        return None
    except Exception as e:
        return None
    for rec in dns_ans.response.answer:
        if rec.rdtype == 5:
            return resolve_cname(str(rec).split()[-1].strip('.'), depth + 1)
        return hostname


def get_admins_from_golem_by_uri(uri):
    real_hostname = ''
    urp = urlparse(uri)
    hostname = urp.netloc
    if ':' in hostname:
        hostname, port = hostname.split(':')
    real_hostname = resolve_cname(hostname)
    if real_hostname and real_hostname != hostname:
        return get_host_resp(real_hostname)
    return get_host_resp(hostname)


def is_resolvable(url):
    up = urlparse(url)
    hostname = up.netloc
    if ':' in hostname:
        hostname, port = hostname.split(':', 1)
    ip, has_ipv6 = do_resolve(hostname)
    return ip is not None


def has_ipv6(url):
    up = urlparse(url)
    hostname = up.netloc
    if ':' in hostname:
        hostname, port = hostname.split(':', 1)
    ip, has_ipv6 = do_resolve(hostname)
    return has_ipv6


def slugify_target_url(url):
    def norm_path(x):
        if '.' in x:
            return ''
        return x

    pr = urlparse(url)
    netloc = pr.netloc.lower()
    if ':' in netloc:
        netloc = netloc.split(':')[0]
    splnetloc = netloc.split('.')
    m = 0
    if splnetloc[0] in ['www', 'm', 'n']:
        m = 1
    # some weird logic for per-pull domains
    if len(splnetloc) > 3 and re.match(r'(.*)-(\d+)$', splnetloc[0]):
        match = re.search(r'(.*)-(\d+)$', splnetloc[0])
        splnetloc[0] = match.group(1) + '-*'
    if is_l7_any(netloc):
        return 'any.yandex/'
    if 'yandex' in netloc:
        for tld in settings.PASSPORT_TLDS:
            ctld = tld.split('.')
            if len(ctld) > 1 and len(splnetloc) > 2 and splnetloc[-1] == ctld[1] and splnetloc[-2] == ctld[0]:
                netloc = '.'.join(splnetloc[m:-2])
                break
            if splnetloc[-1] == tld:
                netloc = '.'.join(splnetloc[m:-1])
                break
    else:
       netloc = '.'.join(splnetloc[m:])
    path = pr.path
    splpath = [x for x in path.split('/') if x and norm_path(x)]
    if is_l7(netloc):
        netloc = 'yandex'
    if netloc in ['yandex', 'l7test.yandex', 'hamster.yandex'] and len(splpath):
        return '{}/{}'.format(netloc, '/'.join(splpath[:1]))
    return '{}/'.format(netloc)


def resolve_sandbox_resource(resource_id):
    try:
        r = requests.get(settings.SANDBOX_RESOURCE_API, params={'id': resource_id, 'limit': 1},
                         headers={'Accept': 'application/json; charset=utf-8'})
        if r.status_code != 200 or not r.text:
            return {}
        result = r.json()
    except Exception:
        catch_error()
        return {}
    else:
        if result.get('items', []):
            return result.get('items', [])[0]
    return {}


def get_latest_aggregate(aggregate_uid):
    try:
        r = requests.get(settings.SANDBOX_RESOURCE_API, params={'type': 'MOLLY_REQS', 'state': 'READY',
                                                                'attrs': json.dumps({'target_id': aggregate_uid}),
                                                                'limit': 1},
                         headers={'Accept': 'application/json; charset=utf-8'})
        if r.status_code != 200 or not r.text:
            return {}
        result = r.json()
    except Exception:
        catch_error()
        return {}
    else:
        if result.get('items', []):
            return result.get('items', [])[0]
    return {}


def parse_molly_json(input_path):
    with open(input_path) as fd:
        content = json.load(fd)
        for item in content:
            yield item


def parse_balancer_log(input_path):
    with open(input_path) as fd:
        mline = fd.readline()
        while mline:
            rec = dict()
            rec['headers'] = []
            mdata = json.loads(mline)
            for k in list(mdata.keys()):
                rec['method'] = 'GET'
                rec['proto'] = 'HTTP/1.0'
                rec['body'] = ''
                if k == 'path':
                    pu = urlparse(mdata[k])
                    rec['path'] = pu.path
                    if pu.query:
                        rec['rawquery'] = pu.query
                if k == 'headers':
                    for h in mdata[k].split('~'):
                        kv = h.split(':', 1)
                        if kv[0].startswith('X-'):
                            continue
                        if len(kv) < 2:
                            continue
                        rec['headers'].append({'Name': kv[0], 'Value': kv[1].strip()})
            yield rec
            mline = fd.readline()


def parse_balancer_log_tmp(input_path):
    with open(input_path) as fd:
        content = json.load(fd)
        for mdata in content:
            rec = dict()
            rec['headers'] = []
            for k in list(mdata.keys()):
                rec['method'] = 'GET'
                rec['proto'] = 'HTTP/1.0'
                rec['body'] = ''
                if k == 'path':
                    pu = urlparse(mdata[k])
                    rec['path'] = pu.path
                    if pu.query:
                        rec['rawquery'] = pu.query
                if k == 'headers':
                    for h in mdata[k].split('~'):
                        kv = h.split(':', 1)
                        if kv[0].startswith('X-'):
                            continue
                        if len(kv) < 2:
                            continue
                        rec['headers'].append({'Name': kv[0], 'Value': kv[1].strip()})
            yield rec


class BurpParser(object):
    """
        TODO: Support protocol (http|https) and port extraction.
        Now it only works with http and 80.
    """
    requests = []
    parsing_request = False
    current_is_base64 = False

    def start(self, tag, attrib):
        """
        <request base64="true"><![CDATA[R0VUI...4zDQoNCg==]]></request>
        or
        <request base64="false"><![CDATA[GET /molly/ HTTP/1.1
        Host: moth
        ...
        ]]></request>
        """
        if tag == 'request':
            self.parsing_request = True

            if not 'base64' in attrib:
                # Invalid file?
                return

            use_base64 = attrib['base64']
            if use_base64.lower() == 'true':
                self.current_is_base64 = True
            else:
                self.current_is_base64 = False

    def data(self, data):
        if self.parsing_request:
            if not self.current_is_base64:
                request_text = data
                head, postdata = request_text.split('\n\n', 1)
            else:
                request_text_b64 = data
                request_text = base64.b64decode(request_text_b64)
                head, postdata = request_text.split('\r\n\r\n', 1)

            rec = dict()
            i = 0
            for l in head.split('\n'):
                if i == 0:
                    method, uri, proto = l.split(' ', 3)
                    rec['method'] = method.strip()
                    rec['proto'] = proto.strip()
                    pu = urlparse(uri)
                    rec['path'] = pu.path
                    if pu.query:
                        rec['rawquery'] = pu.query
                    rec['headers'] = []
                    rec['body'] = postdata
                else:
                    kv = l.split(':', 1)
                    if kv[0].startswith('X-'):
                        continue
                    if len(kv) < 2:
                        continue
                    rec['headers'].append({'Name': kv[0], 'Value': kv[1].strip()})
                i += 1

            self.requests.append(rec)

    def end(self, tag):
        if tag == 'request':
            self.parsing_request = False

    def close(self):
        return self.requests


def parse_burp_xml_log(input_path):
    with open(input_path) as fd:
        xp = BurpParser()
        parser = etree.XMLParser(target=xp)

        try:
            burp_requests = etree.fromstring(fd.read(), parser)
        except XMLSyntaxError:
            return []
        return burp_requests


# XXX: this is ugly
def pg_wrap(inp, maxlen=0):
    if maxlen > 0:
        try:
            #list(map(unicodedata.name, inp.decode('utf-8', errors='replace')))
            #inp = inp.decode('utf-8', errors='replace').encode('utf-8')
            return inp[:maxlen]
        except Exception:
            return base64.b64encode(inp[:(maxlen*3/4)])
    else:
        try:
            #list(map(unicodedata.name, inp.decode('utf-8', errors='replace')))
            #inp = inp.decode('utf-8', errors='replace').encode('utf-8')
            return inp
        except Exception:
            return base64.b64encode(inp)


def pg_unwrap(inp):
    try:
        return base64.b64decode(inp)
    except Exception:
        return inp


def log_vuln_to_splunk(vuln):
    vuln_details = dict()
    vuln_details['scan_uid'] = vuln.scan.uid
    vuln_details['event_type'] = 'vulnerability'
    vuln_details['url'] = vuln.scan.scan_url
    vuln_details['is_prod'] = vuln.scan.is_prod
    vuln_details['type'] = vuln.vuln_type.name
    vuln_details['human_name'] = vuln.vuln_type.human_name
    vuln_details['severity'] = vuln.severity
    vuln_details['human_severity'] = vuln.yseverity
    vuln_details['triaged'] = vuln.is_triaged
    if vuln_details['triaged']:
        vuln_details['ticket'] = []
        for ticket in vuln.tracker_tickets.all():
            vuln_details['ticket'].append(ticket.ticket_id)
    vuln_details['target'] = vuln.scan.target.name
    vuln_details['false_positive'] = vuln.is_false_positive

    requests_url = "https://%s/services/collector" % settings.HEC_HOST
    post_data = {
        "event": vuln_details
    }
    data = json.dumps(post_data).encode('utf8')
    headers = {'Authorization': "Splunk {}".format(settings.HEC_TOKEN) }
    try:
        req = requests.post(requests_url, data=data, headers=headers, verify=settings.CA_FILE)
        response_json = json.loads(str(req.text))

        if "text" in response_json:
            if response_json["text"] != "Success":
                logger.error("Error sending request")

    except Exception:
        catch_error()
        logger.error("Unexpected error:" + sys.exc_info()[0])


def log_scan_to_splunk(scan):
    event_details = dict()
    event_details['scan_uid'] = scan.uid
    event_details['event_type'] = 'scan'
    event_details['user'] = scan.user.username
    event_details['url'] = scan.scan_url
    event_details['is_prod'] = scan.is_prod
    event_details['target'] = scan.target.name

    requests_url = "https://%s/services/collector" % settings.HEC_HOST
    post_data = {
        "event": event_details
    }
    data = json.dumps(post_data).encode('utf8')
    headers = {'Authorization': "Splunk {}".format(settings.HEC_TOKEN) }
    try:
        req = requests.post(requests_url, data=data, headers=headers, verify=settings.CA_FILE)
        response_json = json.loads(str(req.text))

        if "text" in response_json:
            if response_json["text"] != "Success":
                logger.error("Error sending request")

    except Exception:
        catch_error()
        logger.error("Unexpected error:" + str(sys.exc_info()[:2]))


def mask_credentials(input):
    return 'sha256:{}'.format(hashlib.sha256(input).hexdigest())


def get_cgroup(host):
    try:
        r = requests.get('https://c.yandex-team.ru/api-cached/hosts2groups/%s' % host)
    except Exception:
        return host
    else:
        if r.status_code == 404:
            cname = resolve_cname(host)
            if cname and cname != host:
                return get_cgroup(cname)
            return host

    c_groups = r.text.split('\n')
    c_groups = [x for x in c_groups if x]
    if not c_groups:
        return host

    for g in c_groups:
        g = g.strip()
        if not g:
            continue
        if 'prod' in g:
            return g
        if 'stable' in g:
            return g
    return c_groups[0].strip()


def get_tvm_service_ticket(client_id, client_secret, dst_id):
    tvm2_client = TVM2(
        client_id=client_id,
        secret=client_secret,
        blackbox_client=BlackboxClientId.ProdYateam,
        destinations=[dst_id],
    )
    service_tickets = tvm2_client.get_service_tickets(dst_id)
    service_ticket = service_tickets.get(dst_id)
    if not service_ticket:
        return ''
    return service_ticket


def get_ptr(ip):
    dns_ans = None
    addr = dns.reversename.from_address(ip)
    # try ipv6 first
    try:
        dns_ans = dns.resolver.query(addr, 'PTR')
    except dns.resolver.NoAnswer:
        pass
    except Exception as e:
        return None
    else:
        return str(dns_ans[0]).strip('.')


def get_abc_scopes(username, filter_scopes=[]):
    user_scopes = []
    headers = {'Authorization': 'OAuth ' + settings.OAUTH_TOKEN_FOR_STATIC}
    try:
        r = requests.get(settings.ABC_URL + '/api/v4/services/members/',
                         params={'page_size': 500, 'person__login': username},
                         headers=headers)
        resp = r.json()
    except Exception:
        catch_error()
        pass
    else:
        for item in resp.get('results'):
            svc = item.get('service', {})
            scope_slug = item.get('role', {}).get('scope', {}).get('slug')
            abc_id = item.get('service', {}).get('id')
            user_scopes.append((abc_id, scope_slug))
        if filter_scopes:
            return [x for x in user_scopes if x[1] in filter_scopes]
    return user_scopes


def clean_antirobot_header(val):
    res = ''
    for c in val:
        if c in string.ascii_lowercase + '-':
            res += c
    return res


def quote_commands(words):
    def quote_word(s):
        return "'" + s.replace("'", "'\"'\"'") + "'"
    return " ".join(map(quote_word, words))


def join_envs(envs):
    return ";".join(x + "=" + envs[x] for x in list(envs.keys()))


def parse_datetime_from_startrack(datestr):
    """
    Convert startrack time to unixtime
    param: 'datestr' like '2021-03-30T07:00:30.123+0000'
    """
    # python2.7 not support UTC offset therefore cut off it
    datestr = datestr[:-5]
    return datetime.datetime.strptime(datestr, '%Y-%m-%dT%H:%M:%S.%f')
