# import sys
# import os
import shutil
import logging
import yaml
import os

import sandbox.sandboxsdk.task as sdk_task
import sandbox.sandboxsdk.process as sdk_process
import sandbox.sandboxsdk.parameters as sdk_parameters
import sandbox.sandboxsdk.sandboxapi as sdk_sandboxapi
from sandbox.projects import resource_types
from sandbox.projects.common import apihelpers, utils

from sandbox.sandboxsdk.channel import channel
# from sandbox.sandboxsdk.errors import SandboxTaskFailureError


# TODO: fix state names
SANDBOX_STATE_NAME = "NIRVANA_ONLINE_LEARNING_STATE"
SANDBOX_NEW_GRAPHS_STATE_NAME = "NIRVANA_ONLINE_LEARNING_NEW_GRAPHS"
MAX_NEW_GRAPHS_FILES_COUNT = 100500
FLOWKEEPER_CONFIG_FILENAME = "online_flowkeeper.conf"


# mapping between online_flowkeeper and sandbox task config parameter names
# flowkeeper -> sandbox task
CONFIG_PARAMETERS_MAPPING = {
    "recovery_policy": {
        "max_fail_count": "max_fail_count",
        "wait_after_fail_min": "wait_after_fail_min",
        "max_run_time": "max_run_time",
        "max_leader_run_time": "max_leader_run_time",
        "future_trail_time_length": "future_trail_time_length"
    },
    "sandbox": {
       # "state_file": "state_file",
       # "new_graphs_file": "new_graphs_file"
    },
    "flow": {
        "pipeline_type": "pipeline_type"
    },
    "model_matcher": {
        "prev_model_table": "prev_model_table",
        "ml_task_id": "ml_task_id"
    },
    "model_storage": {
        "full_state_models": "full_state_models",
        "ml_task_id": "ml_task_id"
    }
}


def generate_flowkeeper_config(sandbox_task, state_file, new_graphs_file, outfile):
    """
    Extract flowkeeper config from sandbox task
    :param sandbox_task:
    :return:
    """

    logging.info("Started generating flowkeeper config")

    config = {"online_flowkeeper": {
        "recovery_policy": {},
        "nirvana": {},
        "sandbox": {},
        "flow": {},
        "model_matcher": {},
        "model_storage": {},
        "graphite": {},
        "non_leader_filters": {}
    }}

    recovery_config = config["online_flowkeeper"]["recovery_policy"]
    nirvana_config = config["online_flowkeeper"]["nirvana"]
    sandbox_config = config["online_flowkeeper"]["sandbox"]
    flow_config = config["online_flowkeeper"]["flow"]
    model_matcher = config["online_flowkeeper"]["model_matcher"]
    model_storage = config["online_flowkeeper"]["model_storage"]
    graphite = config["online_flowkeeper"]["graphite"]
    non_leader_filters_config = config['online_flowkeeper']["non_leader_filters"]

    logging.info("Started reading non leader filters config")
    non_leader_filters_config[RunNirvanaOnlineLearning.MaxNonLeaderDelay.name] = sandbox_task.ctx.get(
        RunNirvanaOnlineLearning.MaxNonLeaderDelay.name
    )

    non_leader_filters_config[RunNirvanaOnlineLearning.MaxNonLeaderTasksInQueue.name] = sandbox_task.ctx.get(
        RunNirvanaOnlineLearning.MaxNonLeaderTasksInQueue.name
    )
    for key, value in non_leader_filters_config.iteritems():
        if value == "None":
            non_leader_filters_config[key] = None

    logging.info("Started reading recovery policy")
    for conf_param, sandbox_param in CONFIG_PARAMETERS_MAPPING["recovery_policy"].iteritems():
        recovery_config[conf_param] = sandbox_task.ctx.get(sandbox_param)

    logging.info("Started reading sandbox")
    for conf_param, sandbox_param in CONFIG_PARAMETERS_MAPPING["sandbox"].iteritems():
        sandbox_config[conf_param] = sandbox_task.ctx.get(sandbox_param)

    sandbox_config["state_file"] = state_file
    sandbox_config["new_graphs_file"] = new_graphs_file

    logging.info("Started reading flow")
    for conf_param, sandbox_param in CONFIG_PARAMETERS_MAPPING["flow"].iteritems():
        flow_config[conf_param] = sandbox_task.ctx.get(sandbox_param)

    logging.info("Started reading nirvana")
    nirvana_config["tags"] = ['online']
    nirvana_config["url"] = "https://nirvana.yandex-team.ru"
    nirvana_config["oauth_token"] = str(sandbox_task.get_vault_data(
        sandbox_task.ctx[sandbox_task.NirvanaTokenOwner.name],
        sandbox_task.ctx[sandbox_task.NirvanaToken.name],
    ))
    nirvana_config["request_retries"] = 10
    nirvana_config["request_delay"] = 10
    nirvana_config["backoff"] = 1.1

    logging.info("Started getting model connector (matcher) params")
    for conf_param, sandbox_param in CONFIG_PARAMETERS_MAPPING["model_matcher"].iteritems():
        model_matcher[conf_param] = str(sandbox_task.ctx.get(sandbox_param))

    logging.info("Started getting model storage params")
    for conf_param, sandbox_param in CONFIG_PARAMETERS_MAPPING["model_storage"].iteritems():
        model_storage[conf_param] = str(sandbox_task.ctx.get(sandbox_param))

    logging.info("Started configuring graphite sender")
    graphite["graphite_prefix"] = "one_hour.online_learning.graph_runtime_info"
    graphite["attempts"] = 2
    graphite["delay"] = 10
    graphite["backoff"] = 1.0

    with open(outfile, 'wt') as f:
        yaml.dump(config, f)
    logging.info("{}".format(dir(sandbox_task)))
    logging.info("{}".format(config))


def get_latest_resource(resource_type, all_attrs=None):

    logging.info("Trying to get latest resource")
    all_attrs = all_attrs or {}
    resources = channel.sandbox.list_resources(
        order_by="-id", limit=1, status="READY", resource_type=resource_type, all_attrs=all_attrs)
    if resources:
        return resources[0]
    logging.warning("Can't find latest resource: %s", resource_type)


def get_resource_logging_description(resource_type):
    if resource_type == resource_types.NIRVANA_ONLINE_LEARNING_STATE:
        return "Nirvana online learning state file for ml_task_id {}"
    elif resource_type == resource_types.NIRVANA_ONLINE_LEARNING_NEW_GRAPHS:
        return "Nirvana online learning new graphs file for ml_task_id {}"
    else:
        raise ValueError("Unknown resource")


def dump_previous_state_to_file(sandbox_task, filename, explicit_state_parameter, attrs=None):
    logging.info("Trying to dump prevous state to file")
    state_resource_id = sandbox_task.ctx.get(explicit_state_parameter)
    if state_resource_id:
        logging.info("State file is explicitly set in sandbox task. Taking it")
    if not state_resource_id:
        logging.info("State file is not explicitly set in sandbox task. "
                     "Will try to fetch latest state file from sandbox")
        resource = get_latest_resource(resource_type=SANDBOX_STATE_NAME, all_attrs=attrs)
        if resource:
            state_resource_id = resource.id
    if not state_resource_id:
        raise ValueError("There are no state files in sandbox storage, something is definetely wrong")
    shutil.copyfile(sandbox_task.sync_resource(state_resource_id), filename)


def dump_new_graphs_resources_to_files(sandbox_task, filename_regex, explicit_new_graphs_parameter, attrs=None):
    def save_to_file(resource_id, x=""):
        fff = filename_regex.replace("*", x)
        shutil.copyfile(sandbox_task.sync_resource(resource_id), fff)
        logging.info("Saved new graph file id = {} to file {}".format(resource_id, fff))

    logging.info("Trying to dump new graphs files")
    new_graphs_resource_id = sandbox_task.ctx.get(explicit_new_graphs_parameter)
    if new_graphs_resource_id:
        logging.info("New graphs file is explicitly set in sandbox task. Taking it")
        save_to_file(new_graphs_resource_id, "")
        return [new_graphs_resource_id]
    else:
        logging.info("New graphs file is not explicitly set in sandbox task. "
                     "Trying to fetch all yet not processed files from sandbox")
        all_attrs = attrs or {}
        all_attrs["is_processed"] = "False"
        resources = channel.sandbox.list_resources(
            resource_type=SANDBOX_NEW_GRAPHS_STATE_NAME,
            status="READY",
            all_attrs=all_attrs)

        if resources:
            logging.info("Successfully fetched {} yet not processed files from sandbox".format(len(resources)))
            for i, r in enumerate(resources):
                save_to_file(r.id, str(i))
            return [r.id for r in resources]
        else:
            logging.info("There are no new new graphs files on sandbox; "
                         "will create empty file and process only active graphs")
            new_graphs_file = filename_regex.replace("*", "")
            with open(new_graphs_file, "wt") as f:
                f.write("[]")
                return []


def save_state(sandbox_task, filename, resource_type, owner):

    logging.info("Trying to dump save state to file")
    ml_task_id = sandbox_task.ctx[sandbox_task.MLTaskID.name]
    sandbox_task.create_resource(
        get_resource_logging_description(resource_type).format(ml_task_id),
        filename,
        resource_type,
        owner=owner,
        attributes={"ml_task_id": ml_task_id}
    )


class RunNirvanaOnlineLearning(sdk_task.SandboxTask):
    cores = 1

    type = "RUN_NIRVANA_ONLINE_LEARNING"

    class NirvanaOnlineLearningState(sdk_parameters.ResourceSelector):
        name = 'state_file'
        description = 'online learning state'
        resource_type = resource_types.NIRVANA_ONLINE_LEARNING_STATE
        required = False

    class NirvanaOnlineLearningNewGraphs(sdk_parameters.ResourceSelector):
        name = 'new_graphs_file'
        description = 'recently launched graphs, not processed yet'
        resource_type = resource_types.NIRVANA_ONLINE_LEARNING_NEW_GRAPHS
        required = False

    ########### Nirvana parameters ###########
    class NirvanaWorkflowID(sdk_parameters.SandboxStringParameter):
        name = "basic_workflow_id"
        description = "Basic workflow id"
        required = True
        default_value = "76f1d00e-a4e3-11e6-98ff-0025909427cc"

    class MLTaskID(sdk_parameters.SandboxStringParameter):
        name = "ml_task_id"
        description = "ml task, used to retrieve state"
        required = True

    class NirvanaRequestRetries(sdk_parameters.SandboxIntegerParameter):
        name = "request_retries"
        description = "Number of retries when doing nirvana request"
        required = True
        default_value = 30

    class NirvanaRequestDelay(sdk_parameters.SandboxIntegerParameter):
        name = "request_delay"
        description = "Delay between consequent nirvana requests"
        required = True
        default_value = 300

    class NirvanaToken(sdk_parameters.SandboxStringParameter):
        name = "nirvana_oauth_token"
        description = "nirvana_oauth_token: Nirvana OAuth token name (in sandbox vault)"
        required = True
        default_value = "robot_ml_engine_hahn_yt_token"

    class NirvanaTokenOwner(sdk_parameters.SandboxStringParameter):
        name = "nirvana_token_owner"
        description = "nirvana_token_owner: Owner of Nirvana OAuth token"
        required = True
        default_value = "ML-ENGINE"

    ########### MR parameters ###########

    class MRFullStateModelsPath(sdk_parameters.SandboxStringParameter):
        name = "mr_full_state_models_path"
        description = "MR Path with full state models"
        required = True

    class MRPreviousModelTable(sdk_parameters.SandboxStringParameter):
        name = "mr_prev_model_table"
        description = "MR table with reference to previous models"
        required = True

    class YTTokenOwner(sdk_parameters.SandboxStringParameter):
        name = "yt_token_owner"
        description = "YT token owner"
        required = True
        default_value = "ML-ENGINE"

    class YTTokenName(sdk_parameters.SandboxStringParameter):
        name = "yt_token_name"
        description = "YT token name"
        required = True
        default_value = "robot_ml_engine_hahn_yt_token"

    ########### Recovery policy parameters ###########
    class GraphMaxFailCount(sdk_parameters.SandboxIntegerParameter):
        name = "max_fail_count"
        description = "Maximum number of failures for nirvana graph"
        required = True
        default_value = 10

    class WaitTimeBeforGraphRestart(sdk_parameters.SandboxIntegerParameter):
        name = "wait_after_fail_min"
        description = "Wait time before restarting failed graph"
        required = True
        default_value = 60

    class GraphMaxRunTime(sdk_parameters.SandboxIntegerParameter):
        name = "max_run_time"
        description = "Maximum allowed run time for graph"
        required = True
        default_value = 432000

    class LeaderMaxRunTime(sdk_parameters.SandboxIntegerParameter):
        name = "max_leader_run_time"
        description = "Maximum allowed run time for leader"
        required = True
        default_value = 3600

    ########### Model connector (matcher) parameters ###########
    class PreviousModelTable(sdk_parameters.SandboxStringParameter):
        name = "prev_model_table"
        description = "Mapreduce table used to store meta information about tasks: " \
                      "previous models path and result model path"
        required = True
        default_value = "users/alxmopo3ov/online_learning/prevous_model_mapping/Task1"

    ########### Model storage parameters ###########
    class FullStateModelsMapreducePath(sdk_parameters.SandboxStringParameter):
        name = "full_state_models"
        description = "Mapreduce directory used to store learned models for all task graphs"
        required = True
        default_value = "users/fram/online_learning/models/Task1"

    ########### Flow parameters ###########
    class OnlinePipelineType(sdk_parameters.SandboxStringParameter):
        name = "pipeline_type"
        description = "Type of pipeline for processing graphs"
        required = True
        default_value = "SandboxedUserPipeline"

    ########### Trail length ##########
    class FutureTrailLength(sdk_parameters.SandboxIntegerParameter):
        name = "future_trail_time_length"
        description = "Number of graphs we keep running ahead of the leader last log date (for the offline speedup)"
        required = True
        default_value = 20

    class Stable(sdk_parameters.SandboxBoolParameter):
        name = "is_stable"
        description = "When true use stable binary from sandbox, when false use testing"
        default_value = True
        required = False

    class MaxNonLeaderTasksInQueue(sdk_parameters.SandboxIntegerParameter):
        name = "max_non_leader_tasks_in_queue"
        description = "Maximum number of non leader tasks in queue, when None filter doesn't apply"
        required = False
        default_value = None

    class MaxNonLeaderDelay(sdk_parameters.SandboxIntegerParameter):
        name = "max_non_leader_delay"
        description = "Maximum delay of non leader task, when None filter doesn't apply"
        required = False
        default_value = None

    ########### Parameters description ends here ###########

    input_parameters = [
        MLTaskID,
        NirvanaWorkflowID,  # workflow parameters
        NirvanaRequestDelay, NirvanaRequestRetries,  # nirvana
        NirvanaOnlineLearningState, NirvanaOnlineLearningNewGraphs,  # sandbox resources with state
        GraphMaxFailCount, WaitTimeBeforGraphRestart, GraphMaxRunTime, LeaderMaxRunTime, FutureTrailLength,  # recovery policy
        OnlinePipelineType,  # flow
        PreviousModelTable,  # model mapping
        FullStateModelsMapreducePath,   # full state models directory
        YTTokenOwner, YTTokenName,    # YT configuration
        Stable,
        MaxNonLeaderTasksInQueue,
        MaxNonLeaderDelay,
        NirvanaToken,
        NirvanaTokenOwner
    ]

    def on_execute(self):

        state_file = "run_nirvana_online_learning_state.json"
        new_graphs_file = "run_nirvana_online_learning_new_graphs_queue*.json"

        generate_flowkeeper_config(self, state_file, new_graphs_file, FLOWKEEPER_CONFIG_FILENAME)

        dump_previous_state_to_file(self, state_file, self.NirvanaOnlineLearningState.name,
                                    attrs={"ml_task_id": self.ctx[self.MLTaskID.name]})

        new_graphs_resource_ids = dump_new_graphs_resources_to_files(self, new_graphs_file,
                                                                     self.NirvanaOnlineLearningNewGraphs.name,
                                                                     attrs={"ml_task_id": self.ctx[self.MLTaskID.name]})
        is_stable = self.ctx[self.Stable.name]
        release_status = sdk_sandboxapi.RELEASE_STABLE if is_stable else sdk_sandboxapi.RELEASE_TESTING
        logging.info('Release status: {release_status}'.format(release_status=release_status))
        build_task_id = apihelpers.get_last_released_resource(resource_types.ONLINE_FLOWKEEPER_BINARY, release_status)

        flowkeeper_binary = self.sync_resource(build_task_id)

        yt_token = self.get_vault_data(
            self.ctx[self.YTTokenOwner.name],
            self.ctx[self.YTTokenName.name],
        )
        env = os.environ.copy()
        env['YT_TOKEN'] = yt_token

        sdk_process.run_process(
            [
                flowkeeper_binary,
                "--conf", FLOWKEEPER_CONFIG_FILENAME,
                "--state", state_file
            ],
            wait=True,
            log_prefix='run_flowkeeper',
            environment=env
        )

        save_state(self, state_file, SANDBOX_STATE_NAME, self.owner)
        for r in new_graphs_resource_ids:
            utils.set_resource_attributes(r, {"ml_task_id": self.ctx[self.MLTaskID.name], "is_processed": "True"})


__Task__ = RunNirvanaOnlineLearning
