# coding: utf-8
import json
import logging
import os
import re
import base64
import time
import urllib2
import functools
import itertools
import subprocess
import multiprocessing
import xml.sax.saxutils as xmlutils

import requests

from sandbox.sandboxsdk import paths
from sandbox.sandboxsdk import errors
from sandbox.sandboxsdk import process
from sandbox.sandboxsdk import parameters
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.channel import channel

from sandbox.projects import resource_types
from sandbox.projects.common import utils
from sandbox.projects.common import error_handlers as eh
from sandbox.projects.common.wizard import parameters as wp
from sandbox.projects.common.wizard import resources as wr
from sandbox.projects.common.wizard import utils as wizard_utils
from sandbox.projects.common.search import components as sc
from sandbox.projects.common.search.requester import Params as requester_params


def _symbolize_log(log, binary, line_mask=re.compile(r'(?m)^    #(?P<n>\d+) (?P<addr>0x[0-9a-fA-F]+)  \(.+?\)$')):
    log = re.search(r'(?ms)^==\d+==ERROR: \w+Sanitizer:.*', log)
    log = log.group(0) if log else ''
    addrs = sorted({m.group('addr') for m in line_mask.finditer(log)})
    if not addrs:
        return log
    # odd lines - function names; even lines - file names & line numbers
    results = iter(subprocess.check_output(['addr2line', '-f', '-C', '-e', binary] + addrs).split('\n'))
    results = {addr: (name, source) for addr, name, source in itertools.izip(addrs, results, results)}

    def replace(match):
        name, source = results[match.group('addr')]
        line, number = match.group(0, 'n')
        return line if name == '??' else '    #{} {}  ({})'.format(number, name, source)
    return line_mask.sub(replace, log)


def _make_full_urls(host, default_path, extra_params, paths):
    for path in paths:
        try:
            path = path.rstrip('\r\n')
            if not path.startswith('/'):
                if path.count('\t') == 1:
                    text, lr = path.split('\t', 1)
                    path = '{}?text={}&lr={}'.format(default_path, urllib2.quote(text), lr)
                else:
                    path = default_path + '?' + path.lstrip('?')
            path += ('&' if '?' in path else '?') + (extra_params or '') + '&waitall=1'
            yield 'http://{}{}'.format(host, path)
        except Exception as e:
            yield e
            break


def _make_full_urls_prwzrd(host, default_path, extra_params, paths):
    for line in paths:
        try:
            if line.startswith('@') or not line.strip():
                continue
            req, _, rules = line.partition('$print:')
            req, s, param = req.partition('$cgi:')
            if not s:
                rules, _, param = rules.partition('$cgi:')
            req = req.strip('\n')
            path = 'http://{0}{1}?text={2}&user_request={2}'.format(host, default_path, urllib2.quote(req))
            if param.strip():
                path += '&' + ''.join(c if ord(c) < 128 else '%{:02X}'.format(ord(c)) for c in param.strip())
            yield path + '&' + (extra_params or '') + '&waitall=1'
        except Exception as e:
            yield e
            break


def _get_response(url, s=requests.Session()):
    if isinstance(url, Exception):
        raise url
    try:
        return True, s.get(url).content
    except Exception as e:
        return False, 'error=' + str(e)


def _get_apphost_response(url, init, s=requests.Session()):
    if isinstance(url, Exception):
        raise url
    try:
        return True, base64.b64encode(s.post('http://{}/'.format(url), data=base64.b64decode(init)).content)
    except Exception as e:
        return False, '!' + str(e)


class AppHostBinaryFormat(parameters.SandboxBoolParameter):
    name = "wizard_query_ah_binary_fmt"
    description = "AppHost queries are in binary+base64 format (*always true for remote servers; use apphost/tools/converter*)"
    default_value = False


class DisableBegemotExclusions(parameters.SandboxBoolParameter):
    name = "disable_begemot_exclusions"
    description = "Remove ExcludeFromBegemot from the config"
    default_value = False


class WizardAddress(parameters.SandboxStringParameter):
    name = "wizard_host"
    description = 'Remote wizard (host:port)'
    default_value = 'reqwizard.yandex.net:8891'


class ForcedKillTimeout(parameters.SandboxStringParameter):
    name = "forced_kill_timeout"
    description = 'Kill timeout (minutes)'
    default_value = '90'
    required = False


class RuntimeProvider(parameters.SandboxSelectParameter):
    name = 'runtime_provider'
    description = 'Get runtime from'
    choices = [
        ('resource', 'resource'),
        ('production', 'production')
    ]
    default_value = 'resource'
    sub_fields = {
        'resource': [wp.RuntimeData.name]
    }


class ResponseProvider(parameters.SandboxSelectParameter):
    name = "wizard_type"
    description = 'Fetch responses from'
    default_value = 'local'
    choices = [
        ('local Wizard instance', 'local'),
        ('remote Wizard instance', 'remote'),
    ]
    sub_fields = {
        'local': [wp.Binary.name, wp.Config.name, DisableBegemotExclusions.name, wp.Shard.name, wp.RuntimeData.name, wp.CacheEnable.name, RuntimeProvider.name],
        'remote': [WizardAddress.name],
    }


class BadStats(parameters.SandboxBoolParameter):
    name = 'bad_stats'
    description = 'Fail if action=stat is broken'
    default_value = False
    required = False


class GetWizardResponses(SandboxTask):
    """Save Wizard responses into a resource while monitoring its memory usage."""
    type = 'GET_WIZARD_RESPONSES'
    required_ram = 50 * 1024
    execution_space = 300 * 1024
    input_parameters = (
        ResponseProvider,
        RuntimeProvider,
        WizardAddress,
        wp.Binary,
        wp.Config,
        DisableBegemotExclusions,
        wp.Shard,
        wp.RuntimeData,
        wp.CacheEnable,
        wp.UseAppHost,
        wp.Queries,
        wp.ExtraQueryParams,
        AppHostBinaryFormat,
        requester_params.WorkersCount,
        BadStats,
        ForcedKillTimeout,
    )

    client_tags = wizard_utils.ALL_SANDBOX_HOSTS_TAGS

    @property
    def footer(self):
        if 'asan_trace' in self.ctx:
            return [{"content": "<pre>{}</pre>".format(xmlutils.escape(self.ctx['asan_trace']))}]

    def on_enqueue(self):
        SandboxTask.on_enqueue(self)
        wizard_utils.setup_hosts_sdk1(self)
        self.ctx['out_resource_id'] = self.create_resource(
            self.descr,
            self.abs_path('wizard_responses.txt'),
            resource_types.WIZARD_RESPONSES_RESULT,
            arch='any'
        ).id
        wizard_utils.on_enqueue(self)

    def _fetch(self, urls, f=_get_response):
        pool = multiprocessing.Pool(processes=utils.get_or_default(self.ctx, requester_params.WorkersCount), maxtasksperchild=10000)
        try:
            with open(self.abs_path('wizard_responses.txt'), 'w') as output:
                for num, (status, data) in enumerate(pool.imap(f, urls, 10)):
                    output.write(data.strip('\r\n') + '\n')
                    if num % 100 == 0:
                        logging.debug('-- %s responses', num)
        finally:
            pool.close()
            pool.join()

    def _fetch_ah(self, url, inits):
        host, _, port = url.partition(':')
        return self._fetch(inits, functools.partial(_get_apphost_response, '{}:{}'.format(host, int(port) + 1)))

    @staticmethod
    def _get_yasm_stat(port, timeout=20):
        full_url = "http://localhost:{}/wizard?action=stat".format(port)
        logging.info("Fetching url {} with timeout={}".format(full_url, timeout))
        response = urllib2.urlopen(full_url, timeout=timeout)
        return dict(json.loads(response.read()))

    def get_last_released_runtime(self):
        """
        Fetches last released stable WIZARD_RUNTIME_PACKAGE resource
        :return: last released into 'stable' synced WIZARD_RUNTIME_PACKAGE
        """
        try:
            last_released_id = channel.sandbox.list_releases(
                resource_type=resource_types.WIZARD_RUNTIME_PACKAGE,
                release_status='stable',
                limit=1)[0].resources[0].id
            logging.info('Fetching WIZARD_RUNTIME_DATA from {}'.format(last_released_id))
            return self.sync_resource(last_released_id)
        except IndexError:
            return None

    def get_wizard_runtime_data(self):
        if self.ctx.get(RuntimeProvider.name) == 'production':
            runtime_data = self.get_last_released_runtime()
            if runtime_data is None:
                raise errors.SandboxTaskFailureError(
                    'Could not get wizard runtime data from last released sandbox resource'
                )
            return runtime_data
        else:
            if self.ctx[wp.RuntimeData.name]:
                return self.sync_resource(self.ctx[wp.RuntimeData.name])
            return None

    def _fetch_stat(self, port, output, limit=1000):
        unknown_signals = []
        check = 0
        with open(output, 'wb') as fd:
            stats = self.get_yasm_stat(port)
            while stats and check < limit:
                time.sleep(5)
                # Check all signals (BEGEMOT-254)
                for k in stats:
                    if not re.search(r"_([hmetxnvc0-9]{4}|hgram|max|summ)$", k):
                        unknown_signals.append(k)
                fd.write('[{}]  {}\n'.format(time.ctime(), stats))
                stats = self.get_yasm_stat(port)
                check += 1
            if not stats and utils.get_or_default(self.ctx, BadStats):
                eh.check_failed('action=stat answer is empty.')
        self.set_info("Perform {} requests. action=stat answer contains signals.".format(check))
        for s in set(unknown_signals):
            self.set_info("Unknown signal: {}".format(s))

    def on_execute(self):
        os.environ["MKL_CBWR"] = "COMPATIBLE"
        # If the task is running >90 mins, it is most likely hung in process_pool.join (BEGEMOT-526).
        self.ctx['kill_timeout'] = int(utils.get_or_default(self.ctx, ForcedKillTimeout)) * 60
        open('wizard_responses.txt', 'w').close()
        extra_params = utils.get_or_default(self.ctx, wp.ExtraQueryParams).lstrip('&')
        make_urls = _make_full_urls
        if channel.sandbox.get_resource(self.ctx[wp.Queries.name]).type == str(wr.WizardQueries):
            make_urls = _make_full_urls_prwzrd
        queries = open(self.sync_resource(self.ctx[wp.Queries.name]))

        if utils.get_or_default(self.ctx, ResponseProvider) == 'remote':
            if utils.get_or_default(self.ctx, wp.UseAppHost):
                return self._fetch_ah(utils.get_or_default(self.ctx, WizardAddress), queries)
            return self._fetch(make_urls(utils.get_or_default(self.ctx, WizardAddress), '/wizard', extra_params, queries))

        binary = self.sync_resource(self.ctx[wp.Binary.name])
        runtime_data = self.get_wizard_runtime_data()

        config = self.sync_resource(self.ctx[wp.Config.name])
        if self.ctx.get(DisableBegemotExclusions.name):
            with open(config) as inf, open('_patched_config.cfg', 'w') as outf:
                for line in inf:
                    if not line.strip().startswith('ExcludeFromBegemot:'):
                        outf.write(line)
            config = '_patched_config.cfg'

        wizard = sc.get_wizard(
            binary,
            config,
            self.sync_resource(self.ctx[wp.Shard.name]),
            runtime_data,
            cache=self.ctx[wp.CacheEnable.name],
            eventlog_path=os.path.join(str(self.log_path()), 'wizard.evlog')
        )

        if utils.get_or_default(self.ctx, wp.UseAppHost) and not utils.get_or_default(self.ctx, AppHostBinaryFormat):
            return self._fetch_ah_stdio(wizard, queries)

        try:
            with wizard:
                stat_res = self.create_resource(
                    'action=stat logs',
                    self.abs_path('stats.txt'),
                    resource_types.TASK_LOGS,
                    arch='any'
                )
                stats_proc = multiprocessing.Process(
                    target=self._fetch_stat, args=(wizard.get_port(), str(stat_res.path))
                )
                stats_proc.start()
                if utils.get_or_default(self.ctx, wp.UseAppHost):
                    self._fetch_ah('localhost:' + str(wizard.get_port()), queries)
                else:
                    self._fetch(make_urls('localhost:' + str(wizard.get_port()), '/wizard', extra_params, queries))
                stats_proc.terminate()
        except errors.SandboxSubprocessError:
            try:
                # May have failed due to `--sanitize=address`. The resulting log
                # is unreadable, however, because `llvm-symbolizer` is not guaranteed to exist.
                with open(os.path.join(paths.get_logs_folder(), 'run_wizard.out.txt')) as fd:
                    symbolized = _symbolize_log(fd.read(), binary)
                resource = self.create_resource(
                    'sanitizer log | {}'.format(self.descr),
                    'sanitizer.err',
                    resource_types.OTHER_RESOURCE,
                    arch='any'
                )
                with open('sanitizer.err', 'w') as log:
                    log.write(symbolized)
                self.mark_resource_ready(resource.id)
                self.ctx['asan_trace'] = symbolized[:2048] + ('...' if len(symbolized) > 2048 else '')
                raise errors.SandboxTaskFailureError('sanitizer found an error; see resource #{}'.format(resource.id))
            except IOError:
                pass
            raise

    def _fetch_ah_stdio(self, component, queries):
        argv = [component.binary] + component._get_run_cmd(component._generate_config()) + ['--apphost-stdin']
        with open('ah-req-tmp.txt', 'w') as out:
            for query in queries:
                # TODO: rwr=off
                out.write(query.strip('\r\n') + '\n')
        with open('ah-req-tmp.txt', 'r') as input, open('wizard_responses.txt', 'w') as output, open('wizard_stderr.txt', 'w') as err:
            process.run_process(argv, wait=True, outputs_to_one_file=False, stdin=input, stdout=output, stderr=err, log_prefix='wizard')


__Task__ = GetWizardResponses
