from datetime import datetime
import json
import logging
import os.path
import sys
import traceback

from sandbox import common

from sandbox.common.types.client import Tag
from sandbox.common.types.misc import RamDriveType, Installation

from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.parameters import (
    SandboxStringParameter,
    SandboxBoolParameter,
    ResourceSelector,
    DictRepeater
)

from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.paths import get_unique_file_name
from sandbox.sandboxsdk.process import run_process

from sandbox.projects import resource_types
from sandbox.projects.common import apihelpers
from sandbox.projects.common.yt_runner import YTRunner, REQUIRED_TMPFS_SIZE, YTPackage
from sandbox.projects.common.utils import get_or_default
from sandbox.projects.common.rem import REMRunner
from sandbox.projects.common.mapreduce_stored_tables import MapreduceClient, MapreduceTablesIO, SortMode
from sandbox.projects.common.userdata import util, rem_client


INPUT_PARAMS_GROUP_NAME = "Input params"
OUTPUT_PARAMS_GROUP_NAME = "Output params"
MISC_PARAMS_GROUP_NAME = "Mics params"


class MrServer(SandboxStringParameter):
    """
        Mapreduce cluster to use
    """
    name = 'mr_server'
    description = 'Mapreduce server ("yt_local" to locally set up MR):'
    default_value = None
    required = True
    group = INPUT_PARAMS_GROUP_NAME


class XDebugString(SandboxStringParameter):
    """
        Just debug string to run tests with custom behaviour
    """
    name = 'x_debug_string'
    description = 'String to debug testing system:'
    default_value = None
    required = False
    group = MISC_PARAMS_GROUP_NAME


class YtLocal(ResourceSelector):
    """
        YT package
    """
    name = 'yt_local_resource_id'
    description = "YT package:"
    required = False
    resource_type = resource_types.YT_LOCAL
    group = MISC_PARAMS_GROUP_NAME


class YtEnableDebugLogs(SandboxBoolParameter):
    """
        Enable debug logs for local YT instance
    """
    name = 'yt_enable_debug_logs'
    description = "Debug logs for local YT instance:"
    default_value = False
    required = False
    group = MISC_PARAMS_GROUP_NAME


class SaveFailureTables(SandboxBoolParameter):
    """
        Save all intermediate tables in case of failure
    """
    name = 'save_failure_tables'
    description = "Save tables in case of failure:"
    default_value = False
    required = False
    group = MISC_PARAMS_GROUP_NAME


class ExtraAttrs(DictRepeater, SandboxStringParameter):
    name = 'extra_attrs'
    description = "Add attributes to resulting resource (will be formatted via task context):"
    required = False
    group = OUTPUT_PARAMS_GROUP_NAME


class YtTokenVaultName(SandboxStringParameter):
    """
        Name of item at sandbox vault containing yt token (for non-local mode only)
    """
    name = 'yt_token_vault_name'
    description = 'Name of vault item with yt token:'
    default_value = "yt-token"
    required = False
    group = MISC_PARAMS_GROUP_NAME


class YtPool(SandboxStringParameter):
    """
        YT pool to run process in (for non-local mode only)
    """
    name = 'yt_pool'
    description = 'Use specified yt pool (remote only):'
    default_value = ""
    required = False
    group = MISC_PARAMS_GROUP_NAME


class YtSpec(SandboxStringParameter):
    """
        Pass extra YT_SPEC
    """
    name = 'yt_spec'
    description = 'Use specified YT_SPEC:'
    default_value = ""
    multiline = True
    required = False
    group = MISC_PARAMS_GROUP_NAME


class Task(SandboxTask, object):
    """
        Runs any mr-related program.
        Also provides rem if required
    """

    client_tags = (Tag.LINUX_PRECISE | Tag.LINUX_TRUSTY | Tag.LINUX_XENIAL) & ~Tag.AMD6176
    execution_space = 30000
    node_chunk_store_quota = None
    forbid_chunk_storage_in_tmpfs = False
    cores = 17

    need_rem = True
    store_resulting_tables = True
    sort_resulting_tables = True
    yt_testable = False
    force_enable_debug_logs = False

    @common.utils.singleton_classproperty
    def required_ram(self):
        return (
            1 << 10
            if common.config.Registry().common.installation == Installation.LOCAL else
            32 << 10
        )

    @common.utils.singleton_classproperty
    def use_ramdrive(self):
        return (
            False
            if common.config.Registry().common.installation == Installation.LOCAL else
            True
        )

    environment = (
        PipEnvironment('requests', use_user=True),
        PipEnvironment('remclient', use_wheel=True),
    )

    binaries = []
    binaries_provider = None

    input_parameters = [
        MrServer,
        ExtraAttrs,
        XDebugString,
        YtLocal,
        YtEnableDebugLogs,
        SaveFailureTables,
        YtPool,
        YtSpec,
    ]

    def is_yt_local(self):
        return self.ctx["mr_server"] == "yt_local"

    def on_enqueue(self):
        SandboxTask.on_enqueue(self)
        channel.task = self
        self.ramdrive = None

        if self.use_ramdrive and self.ctx["mr_server"] == "yt_local":
            self.ramdrive = self.RamDrive(
                RamDriveType.TMPFS,
                REQUIRED_TMPFS_SIZE / 1024 / 1024,
                None
            )

        if self.store_resulting_tables and not self.ctx.get("tables_resource_id"):
            tables_resource = util.get_or_create_resource(
                self,
                self.ctx.get('tables_resource_id'),
                description=self.get_output_resource_descr(),
                resource_path='XXX',
                resource_type=resource_types.USERDATA_TABLES_ARCHIVE,
                arch='any',
            )
            self.ctx['tables_resource_id'] = tables_resource.id
        if not self.ctx.get(YtLocal.name):
            resource = apihelpers.get_last_resource(resource_type=resource_types.YT_LOCAL)
            self.ctx[YtLocal.name] = resource.id
        if self.ctx.get(YtSpec.name):
            s = json.dumps(json.loads(self.ctx[YtSpec.name]))
            self.ctx["yt_spec_parsed"] = "'{}'".format(s)

    def on_prepare(self):
        SandboxTask.on_prepare(self)

        # you can't run another task in on_enqueue
        if self.binaries_provider:
            self.binaries_provider.set_subtask_description("ymake for -- " + self.descr)
            self.binaries_provider.run_build_task(self)

    def on_execute(self):
        self.rem_runner = None
        self.mr_runner = None
        self.mr_client = None
        self.yt_package = None

        self.set_info("tasks archive version is {}".format(self.ctx["tasks_version"]))
        logging.info("checking input params")
        self.check_params()

        if self.binaries_provider:
            logging.info("fetching prebuild binaries")
            self.binaries_provider.join_build_task(self)
            self.binaries_provider.fetch_prebuild_binaries(self)

        has_failure = False

        try:
            logging.info("initializing file system")
            self.set_info("initializing files")
            self._setup_yt_package()
            self.init_files()

            logging.info("setting up MR and REM servers")
            self._setup_mr_server()
            self._setup_rem_server()

            logging.info("preparing mr tables")
            self.set_info("preparing source tables")
            self.prepare_mr()
            self._mark_source_tables()

            logging.info(str(self.ctx))

            logging.info("processing data")
            self.set_info("processing mr data")
            self.process_mr_data()

            logging.info("finalizing data")
            self.set_info("finalizing")
            self.finalize_mr()

            self.set_info("done")
            logging.info("done")
        except:
            (t, e, trace) = sys.exc_info()
            logging.error('Caught %s exception: "%s" at %s' % (t, e, traceback.format_tb(trace)))
            if self.mr_runner is not None:
                reason = self.mr_runner.master_failure()
                if reason is not None:
                    self.mr_runner = None  # to avoid attempts to teardown it
                    raise common.errors.SandboxException(reason)

                self.check_for_yt_failure_heuristic()
                has_failure = True
            raise
        finally:
            logging.info("saving diagnostics")
            self._print_diagnostics()
            logging.info("shutting down REM server")
            self._teardown_rem()
            logging.info("shutting down MR server")
            self._teardown_mr(has_failure)
            logging.info("bye")

    def on_break(self):
        if self.ctx["mr_server"] == "yt_local":
            logging.info("Trying to save logs for yt")
            YTRunner.save_yt_logs(self.abs_path(), self.log_path("yt-local"))
            logging.info("Done")
        SandboxTask.on_break(self)

    def on_failure(self):
        refs = []
        base_url = self._log_resource.proxy_url
        for name in self.ctx.get("failure_log_files", []):
            refs.append("<a href='{base}/{name}' target='_blank'>{name}</a>".format(
                base=base_url,
                name=name
            ))
        if refs:
            self.set_info("See " + " / ".join(refs), do_escape=False)

    def get_tables_prefix(self):
        return "sandbox/"

    def check_params(self):
        pass

    def init_files(self):
        pass

    def prepare_mr(self):
        pass

    def process_mr_data(self):
        pass

    def updated_result_attrs(self, attrs):
        return attrs

    def check_for_yt_failure_heuristic(self):
        pass

    def finalize_mr(self):
        if self.store_resulting_tables:
            self._download_tables()

    def get_common_pythonpaths(self):
        res = []
        if self.rem_runner:
            res += self.rem_runner.get_pythonpaths()

        if self.yt_package:
            res += self.yt_package.get_pythonpaths()
        return res

    def _setup_yt_package(self):
        self.ctx['yt_local_resource'] = self.sync_resource(self.ctx['yt_local_resource_id'])
        self.yt_package = YTPackage(self.ctx['yt_local_resource'], self.abs_path())
        self.yt_package.install()
        if not self.is_yt_local():
            yt_token = self.get_vault_data(self.owner, get_or_default(self.ctx, YtTokenVaultName))
            self.ctx["yt_token_path"] = self.abs_path("./.yt-token")
            with open(self.ctx["yt_token_path"], "w") as f:
                f.write(yt_token)

    def _setup_mr_server(self):
        logging.info("Have server: '{}'".format(self.ctx['mr_server']))
        if self.ctx['mr_server'] == 'yt_local':
            logging.info("Setting up YT")
            yt_local_resource = channel.sandbox.get_resource(self.ctx['yt_local_resource_id'])
            yt_local_version = yt_local_resource.attributes["yt_local_version"]
            logging.info("Using local yt of version %s", yt_local_version)
            self.mr_runner = YTRunner(
                self.yt_package,
                self.client_info,
                log_dir=self.log_path('yt-local'),
                debug_logging=self.force_enable_debug_logs or get_or_default(self.ctx, YtEnableDebugLogs),
                yt_testable=self.yt_testable,
                node_chunk_store_quota=self.node_chunk_store_quota,
                forbid_chunk_storage_in_tmpfs=self.forbid_chunk_storage_in_tmpfs,
                tmpfs=self.ramdrive.path if self.ramdrive else None,
                yt_local_version=yt_local_version,
            )
            self.mr_runner.start()
            self.ctx['real_mr_server'] = self.mr_runner.get_proxy_string()
        else:
            logging.info("REMOTE SERVER")
            self.ctx['real_mr_server'] = self.ctx['mr_server']

        mr_cluster = self.ctx['real_mr_server'].split(':')[0].split('.')[0].rstrip('0123456789')
        self.ctx['mr_cluster'] = mr_cluster
        self.ctx['mr_cluster_info'] = '{mr_cluster},{real_mr_server},{real_mr_server},,{suffix},0,0,False'.format(
            mr_cluster=mr_cluster,
            real_mr_server=self.ctx['real_mr_server'],
            suffix="-yt",
        )
        self.mr_client = MapreduceClient(
            self.yt_package.path_mr_client(),
            self.ctx["real_mr_server"],
            " ".join(["{}={}".format(k, v) for k, v in self._get_environ_for_mr_client().items()]),
            log_dir=self.log_path()
        )
        self.mr_tables_io = MapreduceTablesIO(self.mr_client)

        if self.sort_resulting_tables:
            self.mr_tables_io.set_sort_output_tables_mode(SortMode.get_testable(self.yt_testable))

        if self.yt_testable and self.ctx["mr_server"] == "yt_local":
            self.mr_client.run("-sort //sys/empty_yamr_table -sortby key -sortby subkey -sortby value")

    def _get_environ_for_mr_client(self):
        res = {
            'MR_RUNTIME': 'YT',
            'YT_CONFIG_PATCHES': '{yamr_mode={create_tables_outside_of_transaction=%true}}',
            'YT_USE_YAMR_STYLE_PREFIX': '1',
            'YT_PREFIX': '//',
            'YT_STRICTLY_TESTABLE': "1" if self.yt_testable else "",
        }
        if "yt_token_path" in self.ctx:
            res['YT_TOKEN_PATH'] = self.ctx["yt_token_path"]
        if YtPool.name in self.ctx:
            res['YT_POOL'] = self.ctx[YtPool.name]
        if "yt_spec_parsed" in self.ctx:
            res["YT_SPEC"] = self.ctx["yt_spec_parsed"]
        return res

    def get_client_environ(self):
        return self._get_environ_for_mr_client()

    def get_client_environ_str(self):
        return " ".join(["{}={}".format(k, v) for k, v in self.get_client_environ().items()])

    def _setup_rem_server(self):
        if self.need_rem:
            env = self.get_client_environ()
            env["MR_CLUSTER_INFO"] = self.ctx['mr_cluster_info']
            rem_rev = "2972181"  # XXX see https://st.yandex-team.ru/USERFEAT-477
            self.rem_runner = REMRunner(
                arcadia_root='svn+ssh://arcadia.yandex.ru/arc/trunk/arcadia@{}'.format(rem_rev),
                root_dir=get_unique_file_name('', 'rem-runtime'),
                log_dir=get_unique_file_name(self.log_path(), 'rem-server'),
                environ=env,
            )
            self.rem_runner.setup_server()
            self.rem_client = rem_client.RemClient(self.rem_runner.rem_url, self.rem_runner.rem_tool_path())

    def _mark_source_tables(self):
        self.mr_client.save_tables_list(self.get_tables_prefix(), False, 'mr_ls.before.short')
        self.mr_client.save_tables_list(self.get_tables_prefix(), True, 'mr_ls.before')
        self.ctx['source_tables_list'] = self.mr_client.get_tables_list(self.get_tables_prefix(), json_format=True)

    def dump_rem_status(self):
        log_dir = os.path.join(self.log_path(), datetime.now().strftime("rem-status-%Y%m%d-%H%M%S"))
        if not os.path.exists(log_dir):
            os.mkdir(log_dir)
        logging.info("Dumping REM status to %s", log_dir)
        self.rem_client.dump_history(log_dir, self.abs_path())

    def _print_diagnostics(self):
        try:
            run_process("netstat", shell=True, check=True, log_prefix="diagnostics.netstat")
            run_process("lsof", shell=True, check=True, log_prefix="diagnostics.lsof")
        except:
            logging.exception("")

    def _teardown_rem(self):
        if self.rem_runner is not None:
            try:
                files = self.rem_client.dump_history(self.log_path(), self.abs_path())
                self.ctx.setdefault("failure_log_files", []).extend(files)
            except:
                logging.exception("")
            self.rem_runner.teardown_server()

    def _teardown_mr(self, has_failure):
        if self.mr_runner is not None:
            if has_failure and self.ctx.get(SaveFailureTables.name):
                try:
                    self._download_tables(True)
                except:
                    logging.exception("")

            try:
                self._dump_mr_history()
            except:
                logging.exception("")
            self.mr_runner.teardown_server()

    def get_output_resource_descr(self):
        return 'tables by {} run ({})'.format(self.type, self.descr)

    def _download_tables(self, for_failure=False):
        attrs = {}
        attrs['tables_prefix'] = self.get_tables_prefix()

        if self.ctx.get(ExtraAttrs.name):
            for k, v in self.ctx[ExtraAttrs.name].items():
                if isinstance(v, (str, unicode)):
                    attrs[k.format(**self.ctx)] = v.format(**self.ctx)
                else:
                    attrs[k.format(**self.ctx)] = v

        attrs = self.updated_result_attrs(attrs)

        skip_unchanged = self.ctx.get('source_tables_list')
        if skip_unchanged:
            # do not keep too much info in ctx - it is not readable
            self.ctx["source_tables_list"] = [table["name"] for table in skip_unchanged]

        table_to_file = {}
        if for_failure:
            resource_id = None
            descr = "FAILURE CONTEXT: "
            self.mr_tables_io.set_safe_mode()
        else:
            resource_id = self.ctx.get('tables_resource_id')
            descr = ""
        descr += self.get_output_resource_descr()
        resource = self.mr_tables_io.download_tables(
            self,
            descr,
            prefix=attrs['tables_prefix'],
            new_prefix="sandbox/",
            attrs=attrs,
            skip_unchanged=skip_unchanged,
            resource_id=resource_id,
            table_to_file=table_to_file
        )
        self.ctx["table_to_file"] = table_to_file
        self.ctx["tables_resource_url"] = resource.proxy_url

    def _dump_mr_history(self):
        if self.mr_client is not None:
            self.mr_client.save_tables_list(self.get_tables_prefix(), False, 'mr_ls.after.short')
            self.mr_client.save_tables_list(self.get_tables_prefix(), True, 'mr_ls.after')
