# coding: utf-8

import os
import json
import yaml
import errno
import tempfile
from logging import info, error
from sandbox.projects import resource_types
from subprocess import check_output
from datetime import datetime

import sandbox.common.types.client as ctc

from sandbox import sandboxsdk

from sandbox.sandboxsdk.errors import SandboxSubprocessError
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.parameters import (
    SandboxStringParameter,
    SandboxIntegerParameter,
    SandboxRadioParameter,
    LastReleasedResource,
    ResourceSelector
)
from sandbox.sandboxsdk.process import run_process
from sandbox.projects.common.yabs.graphite import Graphite, YABS_SERVERS
import time


class LinearModelId(SandboxIntegerParameter):
    name = 'linear_model_id'
    description = 'Linear model id'
    required = True


class LinearModelDump(ResourceSelector):
    name = "linear_model_dump_resource"
    description = "Linear model dump resource"
    resource_type = (resource_types.ML_ENGINE_DUMP, resource_types.ONLINE_LEARNING_DUMP_TXT)
    required = True


class TaskId(SandboxStringParameter):
    name = 'task_id'
    description = 'Task id (don`t fill this field, used for compatibility)'


class Location(SandboxStringParameter):
    name = 'location'
    description = 'Model location'
    choices = [('stat', 'stat'), ('meta', 'meta')]
    required = True


class ShardsNumber(SandboxIntegerParameter):
    name = 'total_shards'
    description = 'Number of shards'
    required = False


class TruncateOptions(SandboxStringParameter):
    name = 'truncate_options'
    description = 'Model truncating options'
    default_value = '{}'
    required = False


class CorrectionOptions(SandboxStringParameter):
    name = 'correction_options'
    description = 'Model correction options'
    default_value = '{}'
    required = False


class ValidationOptions(SandboxStringParameter):
    name = 'validation_options'
    description = 'Model validation options'
    default_value = '{}'
    required = False


class GraphitePathPrefix(SandboxStringParameter):
    name = 'graphite_path_prefix'
    description = 'Graphite path prefix'
    default_value = 'five_min.bsmr-server-hahn-02_haze_yandex_net.bin_dump_gen'
    required = False


class GenerateLmDumpsBinary(LastReleasedResource):
    name = 'dumps_generate_binary_resource_id'
    description = 'dumps_generate binary'
    resource_type = 'GENERATE_LM_DUMPS_BINARY'


class VwDumpBinary(LastReleasedResource):
    name = 'vwdump_binary_resource_id'
    description = 'vwdump binary'
    resource_type = 'VWDUMP_BINARY'


class UnpackDumpBinary(LastReleasedResource):
    name = 'unpack_dump_binary_resource_id'
    description = 'unpack_dump binary'
    resource_type = 'ARCADIA_PROJECT'
    default_value = 232647012


class DumpType(SandboxStringParameter):
    name = "dump_type"
    description = "type of dump"
    default_value = 'vw'
    required = True


class DumpName(SandboxStringParameter):
    name = "dump_name"
    description = "name of learn task"
    required = True


class DumpCreateTime(SandboxIntegerParameter):
    name = "dump_create_time"
    description = "create time of dump txt"
    required = True


class DumpLastLogDate(SandboxStringParameter):
    name = "dump_last_log_date"
    description = "last log date of dump txt"
    required = True


class UpdateReasonId(SandboxIntegerParameter):
    name = "update_reason_id"
    description = "update reason identifier"
    required = True


class UpdateReasonMessage(SandboxStringParameter):
    name = "update_reason_message"
    description = "update reason description"
    required = True


class Ttl(SandboxStringParameter):
    name = 'ttl'
    description = 'Ttl of binary dump'
    required = True
    default_value = '30'


class Released(SandboxRadioParameter):
    choices = [(_, _) for _ in ("none", "prestable", "stable", "testing", "unstable")]
    description = 'Release status'
    required = True
    default_value = 'none'
    name = 'released'


class CompressionLevel(SandboxStringParameter):
    name = 'compression_level'
    description = 'Compression level of output dump'
    required = False
    default_value = None


class Token(SandboxStringParameter):
    name = "yt_token"
    description = "Yt token name (in sandbox vault)"
    required = False
    default_value = "robot_ml_engine_hahn_yt_token"


class TokenOwner(SandboxStringParameter):
    name = "token_owner"
    description = "Yt token owner (in sandbox vault)"
    required = False
    default_value = "ML-ENGINE"


class YtProxy(SandboxStringParameter):
    name = "yt_proxy"
    description = "Yt proxy"
    required = False
    default_value = "hahn"


class YtFolder(SandboxStringParameter):
    name = "yt_folder"
    description = "Yt folder where to store models "
    required = False
    default_value = None


class ClusterType(SandboxStringParameter):
    name = "cluster_type"
    description = "yabs cluster where model will be transported. enum('yabs','bs','all')"
    required = False
    default_value = None


def get_sharding_base_resource_ids(ns_list):
    bases_resources = {}
    for ns in ns_list:
        res = get_latest_resource(
            resource_type="LINEAR_MODEL_SHARDING_BASE",
            all_attrs={"namespace": ns}
        )
        if res is not None:
            bases_resources[ns] = res.id
        else:
            error("No SHARDING_BASE resource for ns=%s", ns)
    return bases_resources


def get_latest_resource(*args, **kwargs):
    sort_by = kwargs.pop("sort_by", None)
    limit = 10 if sort_by else 1
    resources = channel.sandbox.list_resources(order_by="-id", limit=limit, status="READY", *args, **kwargs)
    if not resources:
        error("Can't find latest resource: %s", kwargs)
        return
    if sort_by is not None:
        resources = sorted(resources, key=lambda r: (r.attributes.get(sort_by), r.id), reverse=True)
    return resources[0]


def get_meta(task):
    resources = channel.sandbox.list_resources(
        resource_type=resource_types.LINEAR_MODEL_BINARY_DUMP,
        task_id=task.id,
        limit=100
    )
    datetime_last_log_date = last_log_date_to_datetime(task.ctx[DumpLastLogDate.name])
    last_log_date = datetime_last_log_date.isoformat(' ')

    meta = {
        "task_id": task.id,
        "task_created": datetime.fromtimestamp(task.timestamp).isoformat(' '),
        "task_started": datetime.fromtimestamp(task.timestamp_start).isoformat(' '),
        "task_finished": datetime.fromtimestamp(task.timestamp_finish).isoformat(' '),
        "lm_id": task.ctx[LinearModelId.name],
        "location": task.ctx[Location.name],
        "learn_task_name": task.ctx[DumpName.name],
        "last_log_date": last_log_date,
        "dump_txt_type": task.ctx[DumpType.name],
        "dump_txt_resource_id": task.ctx[LinearModelDump.name],
        "dump_txt_resource_created": datetime.fromtimestamp(task.ctx[DumpCreateTime.name]).isoformat(),
        "shards_num": task.ctx.get(ShardsNumber.name, 1),
        "status": task.new_status,
        "reason_id": task.ctx[UpdateReasonId.name],
        "reason_message": task.ctx[UpdateReasonMessage.name],
        "shards_stats": []
    }

    for res in resources:
        shards_stat = {
            "resource_id": res.id,
            "resource_created": datetime.fromtimestamp(res.timestamp).isoformat(' '),
            "resource_size": res.size,
            "shard": int(res.attributes.get('shard', 1)),
            "status": res.status,
            "released": res.attributes['released']
        }
        meta["shards_stats"].append(shards_stat)

    return meta


def last_log_date_to_datetime(dump_last_log_date):
    last_log_date = dump_last_log_date.replace('-', '')
    datetime_last_log_date = datetime.strptime(last_log_date.ljust(14, '0'), "%Y%m%d%H%M%S")
    return datetime_last_log_date


class GenerateLinearModelBinaryDump(SandboxTask):
    type = 'GENERATE_LINEAR_MODEL_BINARY_DUMP'

    cores = 1
    execution_space = 35 * 1024
    required_ram = 25 * 1024
    client_tags = ctc.Tag.LINUX_PRECISE

    input_parameters = [
            LinearModelId,
            LinearModelDump,
            TaskId,
            Location,
            ShardsNumber,
            TruncateOptions,
            CorrectionOptions,
            ValidationOptions,
            GraphitePathPrefix,
            GenerateLmDumpsBinary,
            VwDumpBinary,
            DumpType,
            UnpackDumpBinary,
            Ttl,
            Released,
            CompressionLevel,
            Token,
            TokenOwner,
            YtProxy,
            YtFolder,
            ClusterType
    ]

    ns_for_sharding = {
        "OrderID",
        "ClientID",
        "BannerID",
        "GroupBannerID",
        "TargetDomainID"
    }

    result_path = "./binary_dumps"

    @property
    def binary_name(self):
        if not hasattr(self, "__binary_name"):
            resource_id = self.ctx[GenerateLmDumpsBinary.name]
            self.__binary_name = self.sync_resource(resource_id)
        return self.__binary_name

    @property
    def vwdump_binary_name(self):
        if not hasattr(self, "__vw_dump_binary_name"):
            resource_id = self.ctx[VwDumpBinary.name]
            self.__vw_dump_binary_name = self.sync_resource(resource_id)
        return self.__vw_dump_binary_name

    @property
    def unpack_dump_binary_name(self):
        if not hasattr(self, "__unpack_dump_binary_name"):
            resource_id = self.ctx[UnpackDumpBinary.name]
            self.__unpack_dump_binary_name = self.sync_resource(resource_id)
        return self.__unpack_dump_binary_name

    @property
    def stat_model_shard_file_pattern(self):
        return "lm_dump_%d_shard_%d.bin"

    @property
    def meta_model_file_pattern(self):
        return "lm_dump_%d.bin"

    def get_ns_list_for_sharding(self, model_path):
        cmd = "%s --action namespaces --model-path %s" \
                   % (self.binary_name, model_path)
        info(cmd)
        with open(os.path.join(sandboxsdk.paths.get_logs_folder(), 'get_ns_list_for_sharding.err'), 'a') as err:
            namespaces_str = check_output(cmd, shell=True, stderr=err)

        return set(namespaces_str.strip().split()) & self.ns_for_sharding

    def sync_sharding_bases(self, ns_base_resource_ids):
        bases_pathes = {}
        for ns, base_resource_id in ns_base_resource_ids.iteritems():
            info("Syncing sharding base for ns=%s", ns)
            bases_pathes[ns] = self.sync_resource(base_resource_id)
        return bases_pathes

    @staticmethod
    def shards_range(shards_number):
        return xrange(1, shards_number + 1)

    def get_shard_file_name(self, linear_model_id, shard_no):
        return os.path.join(self.result_path, self.stat_model_shard_file_pattern % (linear_model_id, shard_no))

    def assert_shards_exists(self, shards_number, linear_model_id):
        shard_files = {}
        for shard_no in self.shards_range(shards_number):
            shard_filename = self.get_shard_file_name(linear_model_id, shard_no)
            if not os.path.exists(shard_filename):
                raise Exception("Shard %d for model %d is not created" % (shard_no, linear_model_id))
            shard_files[shard_no] = shard_filename
        return shard_files

    @staticmethod
    def mkdir_if_not_exist(dirname):
        try:
            os.mkdir(dirname)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

    def get_generate_config(self, model_files_path, lm_dump, ns_sharding_bases, last_log_date):
        task_id = lm_dump.attributes.get("task_id") or lm_dump.attributes.get("ml_task_id")
        assert task_id is not None, "Linear model dump resource has no task_id attribute"

        external_attributes = {
                "yabs_data_time": time.mktime(last_log_date_to_datetime(last_log_date).timetuple()),
            }

        cluster_type = self.ctx[ClusterType.name]

        if cluster_type not in (None, "None", ""):
            external_attributes["cluster_type"] = cluster_type
        info(
            "Parsed cluster_type is {cluster_type}".format(
                cluster_type=external_attributes.get("cluster_type", "cluster type not setted")
            )
        )

        conf = {
            "dump_type": self.ctx[DumpType.name],
            "linear_model_path": model_files_path,
            "sharding_bases_map": ns_sharding_bases,
            "linear_model_id": self.ctx[LinearModelId.name],
            "task_id": task_id,
            "destination_path": self.result_path,
            "location": self.ctx[Location.name],
            "total_shards": self.ctx[ShardsNumber.name],
            "stat_model_shard_file_pattern": self.stat_model_shard_file_pattern,
            "meta_model_file_pattern": self.meta_model_file_pattern,
            "truncate_options": yaml.load(self.ctx[TruncateOptions.name]),
            "validation_options": yaml.load(self.ctx[ValidationOptions.name]),
            "correction_options": yaml.load(self.ctx[CorrectionOptions.name]),
            "last_log_date": last_log_date,
            "external_attributes": external_attributes
        }

        compression_level = self.ctx[CompressionLevel.name]
        if compression_level not in (None, "None", ""):
            conf["compression_level"] = compression_level
        info(
            "Parsed compression level is {compression_level}".format(
                compression_level=conf.get("compression_level", "compression level not setted")
            )
        )


        return conf

    def on_execute(self):
        linear_model_id = self.ctx[LinearModelId.name]
        graphite_path_prefix = self.ctx[GraphitePathPrefix.name]
        location = self.ctx[Location.name]
        shards_number = self.ctx[ShardsNumber.name]
        dump_type = self.ctx[DumpType.name]
        ttl = self.ctx[Ttl.name]
        released = self.ctx[Released.name]


        text_dump_resource = channel.sandbox.get_resource(self.ctx[LinearModelDump.name])

        info("text_dump_resource.last_log_date=%s", text_dump_resource.attributes["last_log_date"])

        last_log_date = text_dump_resource.attributes["last_log_date"]

        resource_path = self.sync_resource(text_dump_resource.id)
        info("Text dump archive = %s", resource_path)

        model_files_path = tempfile.mktemp(dir="./")

        if dump_type == "online":
            model_files_path = os.path.dirname(resource_path)
        else:
            info("Unpack dump")
            run_process(
                [
                    self.unpack_dump_binary_name,
                    resource_path,
                    model_files_path,
                ],
                wait=True,
                log_prefix='unpack_dump'
            )
            info("Text dump files = %s", os.listdir(model_files_path))

        ###
        # load sharding bases
        ###
        ns_sharding_bases = {}
        bases_resource_ids = {}
        if location == "stat":
            ns_list = self.get_ns_list_for_sharding(model_files_path)
            info("Namespaces for sharding: %s", ns_list)
            bases_resource_ids = get_sharding_base_resource_ids(ns_list)
            info("Sharding base resources: %s", bases_resource_ids)
            ns_sharding_bases = self.sync_sharding_bases(bases_resource_ids)

        self.mkdir_if_not_exist(self.result_path)

        info("Copy vwdump to cwd")
        run_process("cp %s ./vwdump" % self.vwdump_binary_name, shell=True, log_prefix='cp_vwdump')
        os.environ["PATH"] = "./:%s" % os.environ["PATH"]
        os.environ['YT_TOKEN'] = self.get_vault_data(self.ctx[TokenOwner.name], self.ctx[Token.name])
        os.environ['YT_PROXY'] = self.ctx[YtProxy.name]

        ###
        # call dumps_generate with config
        ###
        gen_conf = self.get_generate_config(model_files_path, text_dump_resource, ns_sharding_bases, last_log_date)
        with tempfile.NamedTemporaryFile() as conf_file:
            json.dump(gen_conf, conf_file)
            conf_file.flush()

            info("Start generating binary, conf=%s", gen_conf)
            cmd = [
                self.binary_name,
                "--action",
                "generate",
                "--conf",
                conf_file.name
            ]
            if self.ctx[YtFolder.name] is not None:
                cmd += [
                    "--yt_path",
                    self.ctx[YtFolder.name]
                ]
            p = run_process(
                cmd,
                wait=False,
                log_prefix='generate_dump'
            )

            retcode = p.wait()
            graphite_path = '.'.join([graphite_path_prefix, str(linear_model_id)])
            graphite_obj = Graphite(YABS_SERVERS)
            graphite_obj.send([(graphite_path, retcode, time.time())])

            if retcode > 0:
                raise SandboxSubprocessError("process {} died with exit code {}".format(" ".join(cmd), retcode))

        ###
        # save generated resources
        ###
        if location == "stat":
            shard_files = self.assert_shards_exists(shards_number, linear_model_id)
            for shard_no, filename in shard_files.iteritems():
                self.create_resource(
                    "LM binary dump #%s %s shard=%s/%s" % (linear_model_id, location, shard_no, shards_number),
                    filename,
                    "LINEAR_MODEL_BINARY_DUMP",
                    attributes={
                        "linear_model_id": linear_model_id,
                        "location": location,
                        "last_log_date": last_log_date,
                        "total_shards": shards_number,
                        "released": released,
                        "published": "no",
                        "shard": shard_no,
                        "ttl": ttl,
                        "sharding_bases": json.dumps(bases_resource_ids),
                    }
                )
        if location == "meta":
            filename = os.path.join(
                self.result_path,
                self.meta_model_file_pattern % linear_model_id
            )
            assert os.path.exists(filename), "File %s not generated" % filename
            self.create_resource(
                "LM binary dump #%s meta" % linear_model_id,
                filename,
                "LINEAR_MODEL_BINARY_DUMP",
                attributes={
                    "linear_model_id": linear_model_id,
                    "location": location,
                    "ttl": ttl,
                    "released": released,
                    "published": "no",
                    "last_log_date": last_log_date
                }
            )


__Task__ = GenerateLinearModelBinaryDump
