# -*- coding:utf-8 -*-

import io
import os
import re
import sys
import json
import time
import logging
import tarfile
import multiprocessing

from sandbox.sandboxsdk import errors
from sandbox.sandboxsdk import parameters
from sandbox.sandboxsdk import paths
from sandbox.sandboxsdk import task
from sandbox.sandboxsdk.svn import Arcadia

from sandbox.projects import resource_types
from sandbox.projects.common import utils
from sandbox.projects.common.wizard import parameters as wp
from sandbox.projects.common.wizard import printwizard as printwzrd
from sandbox.projects.common.wizard import utils as wizard_utils
from sandbox.projects.common.search import components as sc
from sandbox.projects.common.build import parameters as bp


try:
    _time = time.monotonic
except AttributeError:
    _time = time.time


class Wizard(sc.Wizard):
    def __enter__(self):
        start = _time()
        ret = sc.Wizard.__enter__(self)
        self.startup_time = _time() - start
        return ret

    def __exit__(self, type, value, traceback):
        if self.process:
            code = self.process.poll()
            if code is not None:
                raise errors.SandboxTaskFailureError('wizard died with exit code {}'.format(code))
        start = _time()
        try:
            ret = sc.Wizard.__exit__(self, type, value, traceback)
        except Exception as e:
            ret = 1
            logging.warn('Exception during wizard shutdown: %s' % e)
        self.shutdown_time = _time() - start
        return ret


def _component(name, work_dir, binary, config, shard, runtime, cache, eventlog_path):
    port = 8891
    return Wizard(work_dir, binary, port, sc.sconf.SearchConfig, config, shard, runtime, cache, eventlog_path)


def _geo_component(name, work_dir, binary, config, shard, runtime, cache, eventlog_path):
    dirname = os.path.join(shard, 'GeoAddr')
    known_configs = ['geoaddr.cfg', 'config.pb.txt']
    for fname in known_configs:
        config_fname = os.path.join(dirname, fname)
        if os.path.exists(config_fname):
            break
    else:
        raise errors.SandboxTaskFailureError('Failed to find geosearch config in GeoAddr, tried %s' % ', '.join(known_configs))
    with open(config_fname) as f:
        config_text = f.read()
        if '#PrintAllSubmatches' not in config_text and 'PrintAllSubmatches: false' not in config_text:
            raise errors.SandboxTaskFailureError('cannot run `geo`: PrintAllSubmatches is enabled in %s' % config_fname)
    return _component(name, work_dir, binary, config, shard, runtime, cache, eventlog_path)


def _geosearch_component(name, work_dir, binary, config, shard, runtime, cache, eventlog_path):
    sys.path += [os.path.join(shard, 'conf'), os.path.join(shard, 'thesaurus')]
    try:
        from config import MakeConfig
        # from generate_geosearch_config import create_geosearch_config
    finally:
        del sys.path[-2:]

    buffer = io.BytesIO()
    MakeConfig(section='geo.yaml', port='8891', out=buffer, baseDir=os.path.join(shard, 'conf'), resources=None, shardPrefix="wizard/WIZARD_SHARD")
    cfg = buffer.getvalue()
    cfg = re.sub(r'RequestPopularityTrieFile .*', lambda x: x.group(0) + "\nUseTestData yes", cfg)
    with open('_geosearch.cfg', 'w') as o:
        o.write(cfg)
    return _component(name, work_dir, binary, '_geosearch.cfg', shard, runtime, cache, eventlog_path)


def _test(name, config, requests, rules=[], component=_component):
    params = printwzrd.extra_cgi_parameters(name)

    def impl(name, task, arcadia, work_dir, binary, shard, generated_configs, runtime, cache, eventlog_path, more_params, **args):
        timings = {}
        wizard = component(name, work_dir, binary, os.path.join(generated_configs, config), shard, runtime, cache, eventlog_path)
        if params and more_params:
            all_params = params + '&' + more_params
        else:
            all_params = params or more_params
        with open(os.path.join(arcadia, requests), 'rb') as reqs, wizard:
            start = _time()
            host = 'localhost:{}'.format(wizard.get_port())
            data = dict(printwzrd.printwzrd(host, reqs, extra_params=all_params, global_rules=rules, **args))
            total = _time() - start
        timings[name] = total
        if name == 'default':
            timings['startup'] = wizard.startup_time
            timings['shutdown'] = wizard.shutdown_time
        return timings, data
    return impl


_TESTS = {
    'geo': _test(
        'geo',
        'geo-printwizard-yaml.cfg',
        'tools/printwzrd/tests/test_geo/geo.txt',
        ['GeoAddr', 'GeoAddrRoute', 'GeoAddrUntranslit', 'GeoRelev', 'RelevLocale', 'Date', 'Transit', 'Qtree'],
        _geo_component,
    ),
    'geosearch': _test(
        'geosearch',
        'geo-yaml.cfg',
        'tools/printwzrd/tests/test_geosearch/geosearch.txt',
        [
            'Coord', 'GeoAddr', 'GeoAddrUntranslit', 'GeoAddrRoute', 'OrgNav', 'BusinessNav', 'CustomThesaurus/Geo',
            'TelOrg', 'GeoRelev', 'CommercialMx', 'GeosearchStopwords', 'PPO', 'RelevLocale', 'Transport', 'Rubrics',
        ],
        _geosearch_component,
    ),
}


class RunTests(parameters.SandboxBoolGroupParameter):
    name = 'run_tests'
    choices = [(t, t) for t in sorted(_TESTS)]
    required = True
    description = 'Which printwzrd tests to run'
    default_value = 'default'


class UnpackRichTree(parameters.ResourceSelector):
    name = 'unpackrichtree'
    description = 'tools/unpackrichtree'
    resource_type = resource_types.UNPACKRICHTREE


class UnpackReqBundle(parameters.ResourceSelector):
    name = 'unpackreqbundle'
    description = 'quality/relev_tools/lboost_ops/unpackreqbundle'
    resource_type = resource_types.UNPACKREQBUNDLE


class GetWizardPrintwzrdResponses(task.SandboxTask):
    '''Run Wizard with printwzrd configs and requests, saving responses as json.'''
    type = 'GET_WIZARD_PRINTWZRD_RESPONSES'
    required_ram = 48 * 1024
    execution_space = 60 * 1024
    input_parameters = (
        RunTests,
        bp.ArcadiaUrl,
        bp.ArcadiaPatch,
        wp.Binary,
        wp.PrintWizardConfig,
        wp.Shard,
        wp.RuntimeData,
        wp.ExtraQueryParams,
        wp.CacheEnable,
        UnpackReqBundle,
        UnpackRichTree,
    )
    client_tags = wizard_utils.ALL_SANDBOX_HOSTS_TAGS

    @property
    def _requested_tests(self):
        tests = self.ctx.get(RunTests.name, '').strip()
        tests = set(tests.split() if tests else [])
        if tests - set(_TESTS):
            raise errors.SandboxTaskFailureError('unknown tests: {}'.format(', '.join(tests - set(_TESTS))))
        if not tests:
            raise errors.SandboxTaskFailureError('no tests requested')
        return tests

    def on_enqueue(self):
        task.SandboxTask.on_enqueue(self)
        wizard_utils.setup_hosts_sdk1(self)
        resources = {}
        for test in self._requested_tests:
            resources[test] = self.create_resource(
                '{} | {}'.format(test, self.descr),
                '{}.json'.format(test),
                resource_types.PLAIN_TEXT_QUERIES,
                arch='any',
            ).id
        self.ctx['results'] = resources
        wizard_utils.on_enqueue(self)

    def on_execute(self):
        os.environ["MKL_CBWR"] = "COMPATIBLE"
        paths.make_folder('arcadia/tools/printwzrd', delete_content=True)
        arc = Arcadia.parse_url(utils.get_or_default(self.ctx, bp.ArcadiaUrl))
        for p in ['tools/printwzrd']:
            Arcadia.export(Arcadia.replace(utils.get_or_default(self.ctx, bp.ArcadiaUrl), path=os.path.join(arc.path, p)), os.path.join('arcadia', p))
        # TODO ignore patches that point outside tools/printwzrd.
        # Arcadia.apply_patch('arcadia', self.ctx.get(bp.ArcadiaPatch.name), self.abs_path())

        urt = utils.get_or_default(self.ctx, UnpackRichTree)
        if urt is not None:
            urt = self.sync_resource(urt)
            urt = [[urt, '-urrl'], [urt, '-us']]

        urb = utils.get_or_default(self.ctx, UnpackReqBundle)
        urb = [self.sync_resource(urb), '-dj'] if urb else None

        all_timings = {}
        pool = multiprocessing.Pool(processes=16)
        runtime_path = None
        if self.ctx.get(wp.RuntimeData.name):
            runtime_tar = tarfile.open(self.sync_resource(self.ctx[wp.RuntimeData.name]), "r")
            runtime_tar.extractall()
            if os.path.isdir("wizard.runtime"):
                runtime_path = "wizard.runtime"
        try:
            for i, test in enumerate(self._requested_tests, 1):
                paths.make_folder('tmp_' + test, delete_content=True)
                eventlog_path = os.path.join(str(self.log_path()), 'wizard_{}.evlog'.format(test))
                timings, responses = _TESTS[test](
                    test, self, self.abs_path('arcadia'), self.abs_path('tmp_' + test),
                    self.sync_resource(self.ctx[wp.Binary.name]),
                    self.sync_resource(self.ctx[wp.Shard.name]),
                    self.sync_resource(self.ctx[wp.PrintWizardConfig.name]),
                    runtime_path,
                    utils.get_or_default(self.ctx, wp.CacheEnable),
                    eventlog_path,
                    utils.get_or_default(self.ctx, wp.ExtraQueryParams),
                    pool=pool, unpackrichtree=urt, unpackreqbundle=urb
                )
                with open(test + '.json', 'wb') as out:
                    json.dump(responses, out)
                self.mark_resource_ready(self.ctx['results'][test])
                for k, v in timings.items():
                    all_timings.setdefault(k, (0, 0))
                    all_timings[k] = (all_timings[k][0] + v, all_timings[k][1] + 1)
        finally:
            pool.close()
            pool.join()
        self.ctx['timings'] = {k: round(v / c, 2) for k, (v, c) in all_timings.items()}


__Task__ = GetWizardPrintwzrdResponses
