# -*- coding: utf-8 -*-

import os
import json
import shutil
import logging
import tarfile
import requests
import itertools
import multiprocessing

from sandbox import sdk2
from sandbox.sandboxsdk import svn
from sandbox.sandboxsdk.errors import SandboxTaskFailureError

import sandbox.common.types.task as ctt
from sandbox.projects import resource_types
from sandbox.projects.common.search.components import get_begemot, get_wizard, DEFAULT_BEGEMOT_PORT
from sandbox.projects.common.search.components.apphost import AppHost, DEFAULT_APPHOST_PORT
from sandbox.projects.common.wizard import utils as wizard_utils
from sandbox.projects.websearch.begemot import parameters as bp
from sandbox.projects.websearch.begemot import resources as br
from sandbox.projects.websearch.begemot.common import ApphostTestGraph, Begemots
from sandbox.projects.websearch.begemot.common.fast_build import ShardSyncHelper


def _get_ah_response(it, session=requests.Session()):
    request, apphost_port = it
    response = None
    try:
        response = session.post('http://localhost:{}/_json/test'.format(apphost_port), json=request)
        response.raise_for_status()
        response.encoding = 'utf-8'
        if not response.text:
            return {'error': 'empty response'}
        return response.json()
    except Exception as e:
        if response is None:
            return {'Python exception': '{0.__class__.__module__}.{0.__class__.__name__}: {0}'.format(e)}
        return {'error': response.text, 'HTTP code': response.status_code, 'Python exception': '{0.__class__.__module__}.{0.__class__.__name__}: {0}'.format(e)}


class ApphostConfigurator(object):
    logger = logging.getLogger('ApphostConfigurator')

    def __init__(self, task, ah_resource, graph, port=DEFAULT_APPHOST_PORT, backends=None):
        self.port = port
        self.task = task
        self.apphost_path = ah_resource
        self.data_path = self._get_data_dir()
        self._write_graph(graph)
        self._write_backends(backends)
        self.config = self._write_config()

    def get_apphost(self):
        self.logger.info('Instantiating apphost. Binary path: %s. Config: %s', self.apphost_path, self.config)
        return AppHost(
            port=self.port,
            task=self.task,
            binary_path=self.apphost_path,
            config_path=self.config,
            no_memlock=True
        )

    @staticmethod
    def _safe_mkdir(path):
        try:
            os.mkdir(path)
        except OSError:
            pass

    def _get_data_dir(self):
        path = str(self.task.path('data'))
        self._safe_mkdir(path)
        return path

    def _write_graph(self, graph):
        graphs_dir = os.path.join(self.data_path, 'graphs')
        self._safe_mkdir(graphs_dir)
        graph_file = os.path.join(graphs_dir, 'test.json')
        self.logger.info('Writing begemot graph to %s', graph_file)
        with open(graph_file, 'w') as f:
            f.write(json.dumps(graph, indent=4))
        shutil.copyfile(graph_file, os.path.join(str(self.task.log_path()), 'apphost_graph.json'))

    def _write_backends(self, backends):
        if backends is None:
            return

        backends_dir = os.path.join(self.data_path, 'backends')
        self._safe_mkdir(backends_dir)
        backends_test_dir = os.path.join(backends_dir)
        self._safe_mkdir(backends_test_dir)

        for name in backends:
            backend_file = os.path.join(backends_test_dir, name)
            self.logger.info('Writing backend {} to {}'.format(name, backend_file))
            with open(backend_file, 'w') as f:
                f.write(json.dumps(backends[name], indent=4))

    def _get_config(self):
        conf_dir = os.path.join(self.data_path, 'graphs')
        backends_dir = os.path.join(self.data_path, 'backends')
        return {
            'port': self.port,
            'threads': 2,
            'total_quota': 1000,
            'group_quotas': {'': 1},
            'conf_dir': conf_dir,
            'fallback_conf_dir': conf_dir,
            'backends_path': backends_dir,
            'fallback_backends_path': backends_dir,
            'protocol_options': {'post/ConnectTimeout': '10000ms'},
            'log':  os.path.join(str(self.task.log_path()), 'apphost.evlog'),
            'update_from_fs_config': {}
        }

    def _write_config(self):
        conf_path = os.path.join(str(self.task.log_path()), 'apphost.cfg')
        with open(conf_path, 'w') as f:
            f.write(json.dumps(self._get_config(), indent=4))
        self.logger.info('Wrote apphost config to %s', conf_path)
        return conf_path


class GetBegemotWorkerAndMergerResponses(sdk2.Task):
    '''
    Get responses from worker+merger+apphost
    '''
    logger = logging.getLogger('GetBegemotWorkerAndMergerResponses')
    ARCADIA_GRAPH_PATH = 'arcadia:/arc/trunk/arcadia/apphost/conf/verticals/TEST/begemot_worker_and_merger.json'
    ARCADIA_SAMPLE_BACKEND_PATH = 'arcadia:/arc/trunk/arcadia/web/daemons/begemot/test/backend.json'

    class Requirements(sdk2.Requirements):
        client_tags = wizard_utils.ALL_SANDBOX_HOSTS_TAGS & ~wizard_utils.BEGEMOT_INVALID_HARDWARE

    class Parameters(sdk2.Parameters):
        begemot_binary = bp.BegemotExecutableResource()
        begemot_worker_binary = sdk2.parameters.Resource(
            'Begemot worker binary if different to merger binary (e.g. Bert executable)',
            required=False
        )
        begemot_config = bp.BegemotConfigResource()

        with sdk2.parameters.RadioGroup('Begemot worker') as worker:
            for shard in [name for name, s in Begemots if s.test_with_merger and s.apphost_source_name]:
                worker.values[shard] = worker.Value(shard)

        fast_build_config = bp.FastBuildConfigResource(required=True)
        begemot_fresh = bp.FreshResource()

        begemot_merger_fast_build_config = sdk2.parameters.Resource(
            'Begemot merger fast build config',
            resource_type=br.BEGEMOT_FAST_BUILD_CONFIG_MERGER,
            required=True
        )
        begemot_merger_fresh = sdk2.parameters.Resource(
            'Begemot merger fresh',
            resource_type=[br.BEGEMOT_FRESH_DATA_FOR_MERGER, br.BEGEMOT_FRESH_DATA_FOR_MERGER_FAST_BUILD]
        )
        requests_plan = sdk2.parameters.Resource(
            'Begemot apphost queries',
            resource_type=[resource_types.PLAIN_TEXT_QUERIES, br.BEGEMOT_APPHOST_QUERIES],
            required=True
        )
        debug_mode = sdk2.parameters.Bool(
            'Query begemot with debug mode (dbgwzr=2)',
            default=False,
        )

    def _checkout_graph(self):
        self.logger.info('Exporting graph from %s to %', self.ARCADIA_GRAPH_PATH, self.graph_path)
        svn.Arcadia.export(self.ARCADIA_GRAPH_PATH, self.graph_path)

    def _checkout_sample_backend(self):
        self.logger.info('Exporting backend from %s to %', self.ARCADIA_SAMPLE_BACKEND_PATH, self.backend_path)
        svn.Arcadia.export(self.ARCADIA_SAMPLE_BACKEND_PATH, self.backend_path)


    def _get_fresh_path(self, res, suffix=""):
        if "FAST_BUILD" not in res.type.name:
            return str(sdk2.ResourceData(res).path)
        shard_helper = ShardSyncHelper(res)
        data_path = str(self.path("fresh{}".format(suffix)))
        return shard_helper.sync_shard(data_path)

    def init_apphost(self):
        ah_resource = sdk2.Resource["APP_HOST_DAEMON_EXECUTABLE"].find(status=ctt.Status.RELEASED, attrs={'released': 'stable'}).order(-sdk2.Task.id).first()
        ah_resource_path = str(sdk2.ResourceData(ah_resource).path)
        self.graph_path = str(self.path('testing_graph.json'))
        self.backend_path = str(self.path('sample_backend'))
        self._checkout_graph()
        self._checkout_sample_backend()
        graph_generator = ApphostTestGraph(self.graph_path, [
            (self.backend_path, self.worker_port + 1, 'BEGEMOT_WORKER_TESTING.json'),
            (self.backend_path, self.merger_port + 1, 'BEGEMOT_MERGER_TESTING.json'),
        ])
        graph, backends = graph_generator.generate_graph()
        return ApphostConfigurator(task=self, ah_resource=ah_resource_path, graph=graph, backends=backends).get_apphost()

    def init_merger(self, port):
        eventlog_path = 'merger.evlog'
        self.Context.merger_evlog_id = br.BEGEMOT_EVENTLOG(
            self,
            'Begemot merger eventlog',
            eventlog_path,
        ).id
        merger_shard_helper = ShardSyncHelper(self.Parameters.begemot_merger_fast_build_config)
        merger_shard_path = merger_shard_helper.sync_shard(str(self.path('data_merger_' + str(port))))

        return get_begemot(
            port=port,
            binary_path=str(sdk2.ResourceData(self.Parameters.begemot_binary).path),
            config_path=self.begemot_config,
            worker_dir=merger_shard_path,
            fresh_dir=self._get_fresh_path(self.Parameters.begemot_merger_fresh, "_Merger") if self.Parameters.begemot_merger_fresh else None,
            eventlog_path=eventlog_path
        )

    def init_worker(self, port=DEFAULT_BEGEMOT_PORT):
        shard_helper = ShardSyncHelper(self.Parameters.fast_build_config)
        shard_path = shard_helper.sync_shard(str(self.path('data_worker_port_' + str(port))))

        binary = self.Parameters.begemot_worker_binary
        if binary is None:
            binary = self.Parameters.begemot_binary

        return get_begemot(
            port = port,
            binary_path=str(sdk2.ResourceData(binary).path),
            config_path=self.begemot_config,
            worker_dir=shard_path,
            fresh_dir=self._get_fresh_path(self.Parameters.begemot_fresh, "_Worker") if self.Parameters.begemot_fresh else None,
            eventlog_path=os.path.join(str(self.log_path()), 'worker.evlog')
        )

    @staticmethod
    def prepare_request(r, debug=False):
        request = []
        for item in r:
            if item['name'] in ('BEGEMOT_CONFIG', 'BEGEMOT_CONFIG_ORIGINAL'):
                continue

            for i, result in enumerate(item['results']):
                if result.get('binary'):
                    item['results'][i] = result['binary']

            if debug and isinstance(item['results'][i], dict):
                version = item['results'][i].get('version', '')
                if isinstance(version, str) and 'INIT.settings' in version:
                    item['results'][i]['internal']['view_relevance'] = True

            request.append(item)

        return request

    @classmethod
    def get_requests(cls, requests_file, debug=False):
        with open(requests_file, 'r') as f:
            while True:
                line = f.readline()
                if not line:
                    break
                yield cls.prepare_request(json.loads(line), debug)

    @staticmethod
    def process_response(r):
        if isinstance(r, list):
            # protobuf comparison gives unreadable diffs
            # If there is something important in proto, we must decode it to compare (TODO)
            return [
                [r for r in i['results'] if r.get('__content_type') != 'protobuf']
                for i in r if i['name'] == 'WIZARD'
            ]
        return r

    def shoot(self):
        responses_path = str(sdk2.Resource[self.Context.out_resource_id].path)
        requests = self.get_requests(str(sdk2.ResourceData(self.Parameters.requests_plan).path), self.Parameters.debug_mode)
        pool = multiprocessing.Pool(8, maxtasksperchild=1000)
        try:
            with open(str(responses_path), 'w') as out:
                it = zip(requests, itertools.repeat(self.apphost_port))
                for response in (self.process_response(r) for r in pool.map(_get_ah_response, it)):
                    if type(response) is dict and response.get('error', None):
                        raise SandboxTaskFailureError(json.dumps(response))
                    json.dump(response, out)
                    out.write('\n')
        finally:
            pool.close()
            pool.join()

    def _get_fresh_size(self, res):
        if res:
            if "FAST_BUILD" in res.type.name:
                return ShardSyncHelper(res).get_shard_size()
            else:
                return res.size
        return 0

    def on_enqueue(self):
        self.Context.out_resource_id = br.BEGEMOT_RESPONSES_RESULT(
            self,
            'Begemot responses output',
            'output.txt',
            Shard=self.Parameters.worker
        ).id
        if self.Parameters.worker not in self.Parameters.tags:
            self.Parameters.tags += [self.Parameters.worker]

        shard_size = ShardSyncHelper(self.Parameters.fast_build_config).get_shard_size()
        merger_shard_size = ShardSyncHelper(self.Parameters.begemot_merger_fast_build_config).get_shard_size()

        data_size = shards_size = shard_size + merger_shard_size

        data_size += self._get_fresh_size(self.Parameters.begemot_fresh)
        data_size += self._get_fresh_size(self.Parameters.begemot_merger_fresh)

        self.Requirements.disk_space = self.Requirements.ram = data_size // 1024 // 1024 + 10 * 1024  # 10 GiB
        self.Requirements.disk_space += shards_size >> 20 # Because of rules copying

    def on_save(self):
        wizard_utils.setup_hosts(self, additional_restrictions=~wizard_utils.BEGEMOT_INVALID_HARDWARE)

    def on_execute(self):
        config = self.Parameters.begemot_config
        if config:
            self.begemot_config = os.path.join(str(sdk2.ResourceData(config).path), 'worker.cfg')
        else:
            self.begemot_config = 'begemot.cfg'
            open(self.begemot_config, 'w').close()
        worker = self.init_worker()
        self.worker_port = int(worker.port)
        self.merger_port = self.worker_port + 10
        merger = self.init_merger(self.merger_port)
        apphost = self.init_apphost()
        self.apphost_port = int(apphost.port)
        self.set_info("Starting shooting")
        with worker, merger, apphost:
            self.shoot()
