# coding: utf-8
import os
import logging
import tarfile
import shutil
import time
import threading
import psutil
from Queue import Queue

import sandbox.sandboxsdk.parameters as parameters
from sandbox.sandboxsdk.process import run_process, check_process_return_code
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.paths import make_folder

import sandbox.projects.common.wizard.parameters as wizard_parameters
from sandbox.projects import resource_types
from sandbox.projects.common.wizard.utils import check_wizard_build
from sandbox.projects.common.wizard.wizard_builder import WizardBuilder
from sandbox.projects.common.wizard.current_production import get_current_production_task_id


nanny_service = 'wizard_web_production_vla'


def _thread_wrap_func(func):
    def wrapped(result, *args, **kwargs):
        try:
            res = func(*args, **kwargs)
            result.put(res)
        except Exception as e:
            result.put(e)
    return wrapped


class _ThreadResult(Queue, object):
    def __init__(self):
        super(_ThreadResult, self).__init__()
        self.result = None
        self.ready = False

    def get(self):
        if not self.ready:
            self.result = super(_ThreadResult, self).get()
            self.ready = True
        if isinstance(self.result, Exception):
            raise self.result
        return self.result


class ResourceManager(object):
    def __init__(self):
        self.resources = {}

    def __getitem__(self, resource_id):
        try:
            return self.resources[resource_id]
        except KeyError:
            return self.__prepare_resource(resource_id)

    def __prepare_resource(self, resource_id):
        result = _ThreadResult()
        thread = threading.Thread(target=_thread_wrap_func(channel.task.sync_resource), args=(result, resource_id))
        thread.daemon = True
        thread.start()
        self.resources[resource_id] = result
        return result


class PrintWzrd(object):
    def __init__(self, resource_manager, build_wizard_id, wizard_runtime_id=None, work_with_runtime=True,
                 read_only_resources=False, no_cache=True):
        check_wizard_build(build_wizard_id)
        self.resource_manager = resource_manager
        self.work_with_runtime = work_with_runtime

        printwzrd_resource = WizardBuilder.printwzrd_from_task(build_wizard_id)
        self.binary_res = resource_manager[printwzrd_resource]

        shard_resource = WizardBuilder.wizard_shard_from_task(build_wizard_id)
        self.shard_res = resource_manager[shard_resource]

        if work_with_runtime:
            runtime_resource = WizardBuilder.runtime_package_from_task(wizard_runtime_id or build_wizard_id)
            self.runtime_res = resource_manager[runtime_resource]
        else:
            self.runtime_res = None

        self.timestamp = time.time()

        self.data_path = make_folder('wizard_data_{}'.format(self.timestamp))

        self.eventlog_name = 'wizard_eventlog_{}'.format(self.timestamp)
        self.process = None
        self.read_only_resources = read_only_resources
        self.no_cache = no_cache

    def prepare_workspace(self, additional_params):
        def copy_untar(src, dst, should_copy=False):
            logging.info('coping {} to {}'.format(src, dst))
            if os.path.isfile(src) and tarfile.is_tarfile(src):
                with tarfile.open(src) as f:
                    old_root = f.getnames()[0]
                    f.extractall(self.data_path)
                shutil.move(os.path.join(self.data_path, old_root), dst)
            else:
                if should_copy:
                    shutil.copytree(src, dst)
                else:
                    os.symlink(src, dst)

        copy_untar(
            src=self.shard_res.get(),
            dst=os.path.join(self.data_path, 'wizard'),
            should_copy=self.read_only_resources
        )
        if self.runtime_res is not None:
            copy_untar(
                src=self.runtime_res.get(),
                dst=os.path.join(self.data_path, 'wizard.runtime'),
                should_copy=self.read_only_resources
            )
        self.config_path = self.generate_config()
        self.cmd = self.get_run_cmd(additional_params)

    def get_memory(self):
        return psutil.Process(self.process.pid).get_memory_info().rss

    def generate_config(self):
        logging.info("Generating config")
        generated_config = 'wizard_{}.cfg'.format(self.timestamp)
        config_name = 'wizard'
        wizard_cfg_script = os.path.join(self.data_path, 'wizard/conf/config.py')
        no_cache_param = '--no-cache' if self.no_cache else ''
        run_process("{bin} {name} --testing {cache} >'{out}'".format(
            bin=wizard_cfg_script, name=config_name, cache=no_cache_param, out=generated_config
        ), shell=True, log_prefix='gen_cfg_%s' % self.timestamp)
        return generated_config

    def get_run_cmd(self, additional_params):
        additional_params = additional_params or ''
        return '{bin} -a {data} -s {config} -e {log} {additional_params}'.format(
            bin=self.binary_res.get(), data=self.data_path, config=self.config_path,
            log=self.eventlog_name, additional_params=additional_params
        ).strip()

    def alive(self):
        return self.process and self.process.poll() is None

    def run(self, path_to_requests, path_to_answers, error_output=None, additional_params=None):
        logging.info('preparing workspace')
        self.prepare_workspace(additional_params)
        logging.info('workspace prepared')

        logging.info('starting printwzrd')
        if self.alive():
            raise Exception('Still alive')
        if error_output is not None:
            error_output = open(error_output, 'w')
        self.process = run_process(self.cmd.split(), outputs_to_one_file=False, wait=False,
                                   stdin=open(path_to_requests), stdout=open(path_to_answers, 'w'),
                                   stderr=error_output)
        logging.info('printwzrd started')

    def wait(self):
        if self.process is None:
            raise Exception('PrintWzrd has not been started')
        logging.info('waiting for printwzrd')
        self.process.wait()
        check_process_return_code(self.process)
        self.process = None


class QueriesParam(wizard_parameters.Queries):
    name = 'queries_param'
    description = 'Queries'
    required = True


class WizardBuildParam(wizard_parameters.WizardBuildParameter):
    name = 'wizard_build_param'
    description = 'WIZARD_BUILD id'
    required = False


class WizardRuntimeParam(wizard_parameters.WizardRuntimeParameter):
    name = 'wizard_runtime_param'
    description = 'WIZARD_RUNTIME id'
    required = False


class VaultUserParam(parameters.SandboxStringParameter):
    name = 'vault_user_param'
    description = "Sandbox vault user (must contain 'nanny_oauth')"
    required = True


class AdditionalPrintwzrdParams(parameters.SandboxStringParameter):
    name = 'additional_printwzrd_params'
    description = 'Additional params'
    required = False


class RunPrintwzrd(SandboxTask):
    type = 'RUN_PRINTWZRD'

    input_parameters = (
        QueriesParam,
        VaultUserParam,
        AdditionalPrintwzrdParams,
        WizardBuildParam,
        WizardRuntimeParam,
    )

    def on_execute(self):
        nanny_oauth = self.get_vault_data(self.ctx[VaultUserParam.name], 'nanny_oauth')
        if not nanny_oauth:
            raise Exception('Can not get nanny_oauth from Vault')

        build_wizard_id = self.ctx.get(WizardBuildParam.name)
        if not build_wizard_id:
            build_wizard_id = get_current_production_task_id(nanny_service, resource_types.WIZARD_SHARD, nanny_oauth)
        runtime_wizard_id = self.ctx.get(WizardRuntimeParam.name)
        if not runtime_wizard_id:
            runtime_wizard_id = get_current_production_task_id(
                nanny_service,
                [resource_types.WIZARD_RUNTIME_PACKAGE, resource_types.WIZARD_RUNTIME_PACKAGE_UNPACKED],
                nanny_oauth
            )

        queries = channel.task.sync_resource(self.ctx[QueriesParam.name])
        additional_params = self.ctx.get(AdditionalPrintwzrdParams.name)

        resource_manager = ResourceManager()
        printwzrd = PrintWzrd(resource_manager, build_wizard_id, runtime_wizard_id)

        out, err = 'out.log', 'err.log'
        printwzrd.run(queries, out, err, additional_params)
        printwzrd.wait()

        self.mark_resource_ready(
            self.create_resource('printwzrd output', out, resource_types.PRINTWZRD_OUTPUT)
        )

        if os.path.getsize(err):
            error_text = 'Printwzrd stderr: {}'.format(open(err).read())
            logging.error(error_text)
            raise Exception(error_text)


__Task__ = RunPrintwzrd
