# coding: utf-8

import os
import errno
from logging import info
from subprocess import check_output, check_call

from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.paths import get_logs_folder
from sandbox.sandboxsdk.parameters import (
    LastReleasedResource,
    ResourceSelector,
    SandboxStringParameter,
    SandboxIntegerParameter
)


class MakeLmShardingBaseBinary(LastReleasedResource):
    name = 'make_lm_sharding_base_resource_id'
    description = 'make_lm_sharding_base binary'
    resource_type = 'MAKE_LM_SHARDING_BASE_BINARY'


class YtPool(SandboxStringParameter):
    name = 'yt_pool'
    description = 'YT pool'
    default_value = ''
    required = False


class YtTokenSecret(SandboxStringParameter):
    name = 'yt_token'
    description = 'YT token secret'
    default_value = 'robot_ml_engine_hahn_yt_token'
    required = True


class ResourceTTL(SandboxIntegerParameter):
    name = 'resource_ttl'
    description = 'resource ttl'
    default_value = 5
    required = True


class MapreduceBinaries(ResourceSelector):
    name = 'mapreduce_binaries_resource_id'
    description = 'mapreduce binaries'


base_file_namespace_map = {
    "order.pibf": "OrderID",
    "client.pibf": "ClientID",
    "banner.pibf": "BannerID",
    "groupbanner.pibf": "GroupBannerID",
    "targetdomain.pibf": "TargetDomainID",
}


class GenerateLmShardingBase(SandboxTask):
    type = 'GENERATE_LINEAR_MODEL_SHARDING_BASE'

    input_parameters = [MakeLmShardingBaseBinary, MapreduceBinaries, YtTokenSecret, YtPool, ResourceTTL]

    @staticmethod
    def mkdir_if_not_exist(dirname):
        try:
            os.mkdir(dirname)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

    def get_prev_launch_context(self):
        task = channel.sandbox.list_tasks(task_type=self.type, status="FINISHED", limit=1)
        if len(task) > 0:
            return task[0].ctx
        return {}

    def create_base_resources(self, result_path):
        for base_filename in os.listdir(result_path):
            pibf_file_path = os.path.join(result_path, base_filename)
            ns = base_file_namespace_map.get(base_filename, None)
            if ns is not None:
                self.create_resource(
                    "Sharding base for namespace %s" % ns,
                    pibf_file_path,
                    "LINEAR_MODEL_SHARDING_BASE",
                    owner="ML-ENGINE",
                    attributes={
                        "namespace": ns,
                        "ttl": self.ctx[ResourceTTL.name]
                    }
                )

    def on_execute(self):
        prev_ctx = self.get_prev_launch_context()

        resource_id = self.ctx[MakeLmShardingBaseBinary.name]
        binary_path = self.sync_resource(resource_id)

        mapreduce_binaries_resource_id = self.ctx[MapreduceBinaries.name]
        mapreduce_binaries_archive_path = self.sync_resource(mapreduce_binaries_resource_id)

        with open(os.path.join(get_logs_folder(), "debug.log"), "a") as debug_log:
            check_call("tar -xvf %s -C ./" % mapreduce_binaries_archive_path,
                        stderr=debug_log, stdout=debug_log, shell=True)

        os.environ["YT_DT"] = "1"
        os.environ["YT_TOKEN"] = self.get_vault_data("ML-ENGINE", self.ctx[YtTokenSecret.name])

        if self.ctx[YtPool.name]:
            os.environ["YT_POOL"] = self.ctx[YtPool.name]

        result_path = "./result_dir"
        self.mkdir_if_not_exist(result_path)

        cmd = [binary_path, "--dir", result_path, "--prev-context", prev_ctx.get("launch_context", "{}"), "--mr-executables-path", "./bin"]
        info("Exec cmd: %s", cmd)
        with open(os.path.join(get_logs_folder(), "debug.log"), "a") as debug_log:
            out = check_output(cmd, stderr=debug_log)
        self.ctx["launch_context"] = out
        info("new context: %s", out)
        self.create_base_resources(result_path)


__Task__ = GenerateLmShardingBase
