# -*- coding: utf-8 -*-
import os
import json
import logging
import re
import requests
import shutil
import traceback

from sandbox import common
from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.common.types.client import Tag
from sandbox.projects.common.arcadia import sdk as arcadiasdk
from sandbox.projects.common.search.components import get_begemot, DEFAULT_BEGEMOT_PORT, DEFAULT_START_TIMEOUT
from sandbox.projects.common.wizard import printwizard as pw
from sandbox.projects.common.wizard import utils as wizard_utils
from sandbox.projects.websearch.begemot import parameters as bp
from sandbox.projects.websearch.begemot import resources as br
from sandbox.projects.websearch.begemot.common.fast_build import ShardSyncHelper
from sandbox.sandboxsdk.svn import Arcadia, Svn
from sandbox.sdk2.helpers import subprocess
import sandbox.common.types.notification as ctn
import sandbox.common.types.resource as ctr


class GetBegemotResponses(sdk2.Task):
    """
    Get begemot responses
    """
    class Requirements(sdk2.Requirements):
        client_tags = wizard_utils.ALL_SANDBOX_HOSTS_TAGS & ~wizard_utils.BEGEMOT_INVALID_HARDWARE
        cores = 6

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Task.Parameters):
        begemot_binary = bp.BegemotExecutableResource()

        fast_build_config = bp.FastBuildConfigResource(required=True)
        begemot_fresh = bp.FreshResource()
        requests_plan = bp.BegemotQueriesResource(required=False)
        requests_attrs = sdk2.parameters.Dict("Attributes to find requests_plan resource", default=None)
        arcadia_requests = sdk2.parameters.String("Arcadia path to additional requests", description='e.g. tools/printwzrd/tests/wizardreqs.txt')
        arcadia_url = sdk2.parameters.String("Arcadia url to checkout", description='Required if arcadia_requests param is set, useless otherwise')
        cache_size = sdk2.parameters.Integer('Begemot cache size in bytes', default=0)
        jobs = sdk2.parameters.Integer('Number of worker threads', default=None)
        additional_jobs = sdk2.parameters.Integer('Number of worker additiona threads', default=None)
        test_jobs = sdk2.parameters.Integer('Number of --test requests to process simultaneously (less requests = more threads for each one)', default=None)
        mlock = sdk2.parameters.Bool('Memory lock', default=False)
        fail_on_any_error = False
        verbose = sdk2.parameters.Bool('Check the box only if you want to debug', default=False)
        no_cache = sdk2.parameters.Bool('Disable all caches, add &nocache=da to requests', default=False)
        max_stderr_size_kb = sdk2.parameters.Integer('Fail task if stderr size in KB is larger than this number (0 to skip check)', default=200)

        with sdk2.parameters.Output():
            output_requests_plan = sdk2.parameters.Integer("Shard requests bundle resource id")

    def verbose(self, *args, **kwargs):
        if self.Parameters.verbose:
            logging.debug(*args, **kwargs)
        else:
            pass

    def on_save(self):
        wizard_utils.setup_hosts(self, additional_restrictions=~wizard_utils.BEGEMOT_INVALID_HARDWARE)

    def _get_cache_size_bytes(self):
        return int(self.Parameters.cache_size)
    
    def _find_requests(self):
        if not self.Parameters.requests_plan:
            shard = self.Parameters.requests_attrs['shard']
            type = self.Parameters.requests_attrs['type']
            if not shard:
                return
            if type == 'cgi':
                self.Parameters.requests_plan = br.BEGEMOT_CGI_QUERIES.find(
                    state=ctr.State.READY,
                    attrs={'shard': '{}-cgi'.format(shard)}
                ).first().id
            elif type == 'apphost':
                try:
                    apphost_requests = br.BEGEMOT_APPHOST_QUERIES.find(
                        state=ctr.State.READY,
                        attrs={'shard': shard}
                    ).first().id
                except Exception:
                    apphost_requests = br.BEGEMOT_APPHOST_QUERIES.find(
                        state=ctr.State.READY,
                        attrs={'shard': "Bravo"}
                    ).first().id
                finally:
                    self.Parameters.requests_plan = apphost_requests

    def on_enqueue(self):
        fresh_size = self.Parameters.begemot_fresh.size if self.Parameters.begemot_fresh else 0

        shard_size = ShardSyncHelper(self.Parameters.fast_build_config).get_shard_size()

        data_and_binary_size_mb = (
            self.Parameters.begemot_binary.size + shard_size + fresh_size
        ) >> 20
        self.Requirements.disk_space = self.Requirements.ram = data_and_binary_size_mb + 5 * 1024
        self.Requirements.disk_space += shard_size >> 20 # Because of rules copying
        self.Context.out_resource_id = br.BEGEMOT_RESPONSES_RESULT(
            self,
            'GetBegemotResponses output',
            'output.txt',
            Shard=self.Context.ShardName,
        ).id
        self.Parameters.tags += [self.Context.ShardName]
        if self.Parameters.requests_attrs:
            self._find_requests()

    def _setup_environ(self):
        os.environ["MKL_CBWR"] = "COMPATIBLE"

    def _patch_requests_params(self, requests):
        with open("new_requests", "w") as new_requests, open(requests, "r") as old_requests:
            for r in old_requests:
                new_r = re.sub(r"^bg_timeout=\d+&", "", r)
                new_r = re.sub(r"&bg_timeout=\d+", "", new_r)
                if 'format=json' not in new_r:
                    new_r += '&format=json'
                new_requests.write(new_r)
        shutil.move("new_requests", requests)

    def _prepare_requests(self):
        if self.Parameters.requests_plan is None:
            if not self.Parameters.arcadia_requests or not self.Parameters.arcadia_url:
                raise TaskFailure("No requests. Please choose requests_plan or set arcadia_requests path and arcadia_url")
            requests = None
            self.is_cgi = True
            self.apphost = False
        else:
            requests = str(sdk2.ResourceData(self.Parameters.requests_plan).path)
            self.is_cgi = isinstance(self.Parameters.requests_plan, (br.BEGEMOT_CGI_QUERIES, br.BEGEMOT_CGI_PLAN))
            self.apphost = isinstance(self.Parameters.requests_plan, br.BEGEMOT_APPHOST_QUERIES)
            self.verbose("requests_plan type: {}".format(self.Parameters.requests_plan))
        self.verbose("is_cgi {}".format(self.is_cgi))
        self.verbose("apphost {}".format(self.apphost))

        if not self.is_cgi and not self.apphost:
            self.server.notification(
                subject="wrong type of resources for GetBegemotRespones",
                body="type for requests: {}".format(self.Parameters.requests_plan.type),
                recipients=["pipeknight@yandex-team.ru"],
                transport=ctn.Transport.EMAIL,
                urgent=True
            )
        if self.is_cgi:
            plan_res = None
            if self.Parameters.requests_plan is None or self.Parameters.requests_plan.type != br.BEGEMOT_CGI_PLAN:
                requests_old = requests
                plan_res = br.BEGEMOT_CGI_PLAN(
                    self,
                    'Prepared requests for begemot',
                    'begemot_plan.txt',
                )
                requests = str(plan_res.path)
                if self.Parameters.arcadia_requests is not None and self.Parameters.arcadia_url is None:
                    raise TaskFailure("arcadia_requests param is set. Please set arcadia_url")
                with open(requests, 'w') as out:
                    if requests_old is not None:
                        for i, line in enumerate(pw.read_wizard_requests(requests_old), 1):
                            out.write(pw.parse_query(line)[0].partition('?')[2] + '&reqid={}\n'.format(i))
                        reqs_cnt = i
                    else:
                        reqs_cnt = 0
                    if self.Parameters.arcadia_requests:
                        if self.Parameters.arcadia_url.startswith(Arcadia.ARCADIA_ARC_SCHEME):
                            with arcadiasdk.mount_arc_path(self.Parameters.arcadia_url) as arc_arcadia:
                                additional_requests_path = os.path.join(arc_arcadia, self.Parameters.arcadia_requests)
                        else:
                            arcadia_path = Arcadia.checkout(self.Parameters.arcadia_url, 'arcadia', depth=Svn.Depth.IMMEDIATES)
                            additional_requests_path = os.path.join(arcadia_path, self.Parameters.arcadia_requests)
                            Arcadia.update(additional_requests_path, depth=Svn.Depth.IMMEDIATES, parents=True)
                        for i, line in enumerate(pw.read_wizard_requests(additional_requests_path), reqs_cnt + 1):
                            out.write(pw.parse_query(line, maybe_full=False)[0].partition('?')[2] + '&reqid={}\n'.format(i))
            self._patch_requests_params(requests)
            if plan_res:
                sdk2.ResourceData(plan_res).ready()
        return requests

    def _prepare_shard(self, config, out_dir):
        shard_helper = ShardSyncHelper(config)
        data_path = str(self.path(out_dir))
        shard = shard_helper.sync_shard(data_path, add_relev=True)
        return shard

    def _prepare_fresh(self):
        if self.Parameters.begemot_fresh:
            if 'FAST_BUILD' in self.Parameters.begemot_fresh.type.name:
                return self._prepare_shard(self.Parameters.begemot_fresh, 'fresh')
            return str(sdk2.ResourceData(self.Parameters.begemot_fresh).path)
        return None

    def _get_additional_cmd_args(self):
        additional_cmd_args = []
        if self.Parameters.mlock:
            additional_cmd_args += ('--mlock', 'yes')
        if self.Parameters.additional_jobs:
            additional_cmd_args += ('--additional-jobs', str(self.Parameters.additional_jobs))
        if self.Parameters.test_jobs:
            additional_cmd_args += ('--network-threads', str(self.Parameters.test_jobs))
        if self.Parameters.no_cache:
            additional_cmd_args += ('--additional-cgi', 'nocache=da')

    def _prepare_log_file(self):
        log_file = os.path.join(str(self.log_path()), 'begemot_logger.out')
        open(log_file, 'a').close()
        symlink_name = 'current-profile-log-'
        try:
            os.remove(symlink_name)
        except:
            pass
        os.symlink(log_file, symlink_name)
        return

    def _get_grpc_client(self):
        grpc_client_res = sdk2.Resource["APP_HOST_GRPC_CLIENT_EXECUTABLE"].find(state='READY', attrs={'released': 'stable'}).order(-sdk2.Task.id).first()
        self.grpc_client_path = str(sdk2.ResourceData(grpc_client_res).path)

    def _get_servant_client(self):
        servant_client_res = sdk2.Resource["APP_HOST_SERVANT_CLIENT_EXECUTABLE"].find(state='READY', attrs={'released': 'stable'}).order(-sdk2.Task.id).first()
        self.servant_client_path = str(sdk2.ResourceData(servant_client_res).path)

    def _run_servant_client(self, request_plan_path, output_file):
        cmd = (self.servant_client_path, '--plan', request_plan_path, 'localhost:' + str(DEFAULT_BEGEMOT_PORT))
        try:
            subprocess.check_call(cmd, stdout=output_file)
        except Exception as e:
            raise TaskFailure("Begemot failed: {}".format(e))

    def _shoot_cgi(self, request_plan_path, output_file):
        with open(request_plan_path, 'r') as reqs:
            for r in reqs:
                response = requests.get('http://localhost:' + str(DEFAULT_BEGEMOT_PORT) + '/wizard?' + r)
                output_file.write(response.content)

    def _postprocess_responses(self, responses_file_path):
        if not wizard_utils.validate_utf8(responses_file_path):
            self.set_info(wizard_utils.INVALID_UNICODE_FAILURE_MESSAGE)

        processed_responses_file_path = "processed_responses"
        with open(processed_responses_file_path, "w") as processed_responses, open(responses_file_path, "r") as responses:
            for r in responses:
                r = json.loads(r)
                if isinstance(r,dict) and 'answers' in r:
                    r = r['answers']
                processed_responses.write(json.dumps(r, separators=(',',':')))
                processed_responses.write('\n')
        shutil.move(processed_responses_file_path, responses_file_path)

    def _prepare_config(self):
        self.begemot_config_path = 'begemot.cfg'
        open(self.begemot_config_path, 'w').close()

    def _prepate_eventlog(self):
        self.begemot_eventlog_path = 'begemot.evlog'
        br.BEGEMOT_EVENTLOG(self, 'Begemot eventlog', self.begemot_eventlog_path)

    def on_execute(self):
        if self.Parameters.requests_plan:
            self.Parameters.output_requests_plan = self.Parameters.requests_plan
        self._setup_environ()
        requests_path = self._prepare_requests()
        shard_path = self._prepare_shard(self.Parameters.fast_build_config, 'data')
        fresh_path = self._prepare_fresh()
        if self.Parameters.fast_build_config.type.name == 'BEGEMOT_FAST_BUILD_CONFIG_LOGGER':
            self._prepare_log_file()
        self._prepare_config()
        self._prepate_eventlog()

        begemot = get_begemot(
            binary_path=str(sdk2.ResourceData(self.Parameters.begemot_binary).path),
            config_path=self.begemot_config_path,
            worker_dir=shard_path,
            fresh_dir=fresh_path,
            eventlog_path=self.begemot_eventlog_path,
            cache_size=self._get_cache_size_bytes(),
            jobs_count=self.Parameters.jobs,
            additional_cmd_args=self._get_additional_cmd_args(),
            start_timeout=(DEFAULT_START_TIMEOUT * 3),
            max_stderr_size_kb=self.Parameters.max_stderr_size_kb,
        )

        sdk2.Resource[self.Context.out_resource_id].is_cgi = self.is_cgi
        output = str(sdk2.Resource[self.Context.out_resource_id].path)

        with open(output, 'w') as output_file, begemot:
            if self.is_cgi:
                self._shoot_cgi(requests_path, output_file)
            else:
                self._get_servant_client()
                self._run_servant_client(requests_path, output_file)
                self._postprocess_responses(output)
