import os.path

from sandbox.sandboxsdk.parameters import (
    SandboxBoolParameter,
    SandboxStringParameter,
    SandboxFloatParameter,
    SandboxIntegerParameter,
    SandboxArcadiaUrlParameter,
    ResourceSelector
)

from sandbox.sandboxsdk.svn import Arcadia
from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.sandboxsdk.paths import make_folder, get_unique_file_name
import sandbox.sandboxsdk.util as sdk_util

from sandbox.projects import resource_types

from sandbox.projects.common import apihelpers
from sandbox.projects.common.ProcessPool import ProcessPool
from sandbox.projects.common.userdata import mr_base_task, util
from sandbox.projects.common.utils import get_or_default
from sandbox.projects.common.userdata.packages_installer import PackagesInstaller


class MrSrcPrefix(SandboxStringParameter):
    name = 'mr_src_prefix'
    description = 'MR source dir (leave empty to sample production):'
    required = False
    group = mr_base_task.INPUT_PARAMS_GROUP_NAME
    default_value = ""


class ScriptsArcadiaUrl(SandboxArcadiaUrlParameter):
    name = 'scripts_arcadia_url'
    description = 'Use sampling configs from this branch:'
    required = False
    default_value = Arcadia.trunk_url()
    group = mr_base_task.MISC_PARAMS_GROUP_NAME


class MaxValueSize(SandboxIntegerParameter):
    name = 'max_value_size'
    description = 'Max row value size:'
    required = False
    default_value = 0
    group = mr_base_task.MISC_PARAMS_GROUP_NAME


class SamplingSalt(SandboxStringParameter):
    name = 'sampling_salt'
    description = 'Salt for hashing during sampling:'
    required = False
    default_value = ""
    group = mr_base_task.MISC_PARAMS_GROUP_NAME


class StratScale(SandboxFloatParameter):
    name = 'strat_scale'
    description = 'Scale stratification config by this factor:'
    required = True
    default_value = 1.0
    group = mr_base_task.MISC_PARAMS_GROUP_NAME


class CleanupTmp(SandboxBoolParameter):
    name = 'cleanup_tmp'
    description = 'Drop sampled tables after saving them to resource'
    required = True
    default_value = True
    group = mr_base_task.MISC_PARAMS_GROUP_NAME


class QueueId(SandboxStringParameter):
    """
        Persistent id for scheduled tasks; will exit if resources with same queue id and period description exist
    """
    name = 'queue_id'
    description = 'Queue id (if specified, insures uniqueness of queue_id/period pair):'
    required = False
    group = mr_base_task.MISC_PARAMS_GROUP_NAME


class UserfeatSamplersPackage(ResourceSelector):
    name = 'userfeat_samplers_resid'
    description = 'Resource with USERFEAT_SAMPLERS_PACKAGE'
    required = False
    resource_type = resource_types.USERFEAT_SAMPLERS_PACKAGE
    group = "Packages"


class PreparedSamplesPrefix(SandboxStringParameter):
    name = "prepared_samples_prefix"
    description = "Copy all tables from this prefix to resulting resource"
    required = False
    group = mr_base_task.INPUT_PARAMS_GROUP_NAME


class Task(mr_base_task.Task):
    input_parameters = util.smart_join_params(
        mr_base_task.Task.input_parameters,
        MrSrcPrefix,
        QueueId,
        SamplingSalt,
        StratScale,
        MaxValueSize,
        CleanupTmp,
        UserfeatSamplersPackage,
        ScriptsArcadiaUrl,
        PreparedSamplesPrefix
    )

    need_rem = False
    store_resulting_tables = True

    # USERFEAT-991
    # we can't do it because YT limits key length for production
    sort_resulting_tables = False

    required_ram = 23 << 10

    def on_enqueue(self):
        mr_base_task.Task.on_enqueue(self)
        if not self.ctx.get(UserfeatSamplersPackage.name):
            resource = apihelpers.get_last_released_resource(resource_type=resource_types.USERFEAT_SAMPLERS_PACKAGE)
            if resource and resource.is_ok():
                self.ctx[UserfeatSamplersPackage.name] = resource.id
            else:
                raise RuntimeError("There are no released USERFEAT_SAMPLERS_PACKAGE")

    def init_files(self):
        mr_base_task.Task.init_files(self)
        self._checkout_strat_config()

        p = PackagesInstaller()
        self.ctx["root"] = get_unique_file_name(self.abs_path(""), "ROOT")
        p.install(self.ctx["root"], self, [UserfeatSamplersPackage])
        self.ctx["bin_dir"] = os.path.join(self.ctx["root"], "Berkanavt/userfeat-samplers/bin")

    def get_tables_prefix(self):
        return os.path.join(
            "userfeat/tmp",
            self.type,
            str(self.id)
        ) + "/"

    def check_params(self):
        period_id = self.ctx["period_id"] = self.get_dates_context()["descr"]
        qid = self.ctx.get(QueueId.name)
        if qid:
            resource = apihelpers.get_last_resource_with_attrs(
                resource_type=resource_types.USERDATA_TABLES_ARCHIVE,
                attrs={
                    'queue_id': qid,
                    'period_id': period_id
                },
                all_attrs=True
            )
            if resource and resource.is_ok():
                raise SandboxTaskFailureError(
                    "Not creating duplicate resource with queue id '{}' and period '{}', already have {}".format(
                        qid, period_id, resource.id
                    )
                )

    def finalize_mr(self):
        mr_base_task.Task.finalize_mr(self)
        if get_or_default(self.ctx, CleanupTmp):
            tables = self.mr_client.get_tables_list(self.get_tables_prefix())
            p = ProcessPool(max(1, sdk_util.system_info()["ncpu"] / 2))
            p.map(lambda table: self.mr_client.run("-drop {}".format(table), log_prefix="drop"), tables)

    def process_mr_data(self):
        prepared_samples_prefix = self.ctx.get(PreparedSamplesPrefix.name)
        if prepared_samples_prefix:
            pr = util.ProcessRunner()
            for prefixed in self.mr_client.get_tables_list(prepared_samples_prefix):
                assert prefixed.startswith(prepared_samples_prefix)
                table = prefixed[len(prepared_samples_prefix):]
                pr.add(
                    "copy." + table.replace("/", ":"),
                    self.mr_client.command(
                        "-copy -src {src} -dst {dst_prefix}{dst}".format(
                            src=prefixed,
                            dst_prefix=self.get_tables_prefix(),
                            dst=table
                        )
                    )
                )
            pr.run()
        self.do_mr_sample()

    def get_output_resource_descr(self):
        return "Sample by {} for {}/{}".format(
            self.type,
            self.ctx.get(QueueId.name) or "**no queue**",
            self.get_dates_context()["descr"]
        )

    def get_mr_src_prefix(self):
        return get_or_default(self.ctx, MrSrcPrefix)

    def _checkout_strat_config(self):
        path = self.get_strat_data_dir()
        if path is None:
            return
        data_dir = self.abs_path('DATA')
        make_folder(data_dir)
        self.ctx['strat_data_dir'] = os.path.join(data_dir, 'strat_data_dir')
        svn_path = self.ctx[ScriptsArcadiaUrl.name]
        if "@" in svn_path:
            svn_path, rev = svn_path.split("@", 1)
        else:
            rev = "HEAD"
        svn_path = os.path.join(svn_path, path) + "@" + rev
        Arcadia.checkout(svn_path, self.ctx['strat_data_dir'])

    def updated_result_attrs(self, attrs):
        attrs = mr_base_task.Task.updated_result_attrs(self, attrs)
        attrs["ttl"] = 60
        attrs["queue_id"] = get_or_default(self.ctx, QueueId)
        attrs["period_id"] = self.ctx["period_id"]
        return attrs

    def do_mr_sample(self):
        raise NotImplementedError()

    def get_strat_data_dir(self):
        return None

    def do_get_dates_context(self):
        raise NotImplementedError()

    def get_dates_context(self):
        if not hasattr(self, "_dates_context"):
            self._dates_context = self.do_get_dates_context()
        return self._dates_context

    def get_client_environ(self):
        res = mr_base_task.Task.get_client_environ(self)
        res["YT_USE_YAMR_DEFAULTS"] = "1"
        res["MR_OPT"] = "failonemptysrctable=1"
        return res

    def get_sample_by_uid_command(self, table, frac, strat_config=None, key_prefix=""):
        cmd = (
            "{env} {bin_dir}/sample_by_uid "
            "-server {server} -frac {frac} -salt '{salt}' -scale {scale} -prefix '{key_prefix}' "
            "-src {src_prefix}{table} -dst {dst_prefix}{table} "
        )
        if strat_config:
            cmd += "-strat {}/{} ".format(self.ctx["strat_data_dir"], strat_config)
        cmd = cmd.format(
            env=self.get_client_environ_str(),
            bin_dir=self.ctx["bin_dir"],
            server=self.ctx["mr_server"],
            frac=frac,
            salt=self.ctx[SamplingSalt.name],
            scale=self.ctx[StratScale.name],
            src_prefix=get_or_default(self.ctx, MrSrcPrefix),
            dst_prefix=self.get_tables_prefix(),
            table=table,
            key_prefix=key_prefix
        )
        cmd += self.get_max_value_filter_extra_cmd(self.get_tables_prefix() + table)
        return cmd

    def get_max_value_filter_extra_cmd(self, table):
        if not self.ctx.get(MaxValueSize.name):
            return ""

        return (
            " && {env} {bin_dir}/userdata_filter_heavy_recs "
            "--cluster {server} --by-keys "
            "--max-value-size {max_value_size} "
            "--source {table} --dest {table}"
        ).format(
            env=self.get_client_environ_str(),
            bin_dir=self.ctx["bin_dir"],
            server=self.ctx["mr_server"],
            max_value_size=self.ctx[MaxValueSize.name],
            table=table,
        )

    def get_mr_sample_command(self, table, count, src_infix="", dst_table=None, by_keys=False):
        dst_table = dst_table or table
        cmd = (
            "{env} {bin_dir}/mr_sample "
            "-s {server} -sub -f -n {count} {by_keys} "
            "{src_prefix}{src_infix}{table} {dst_prefix}{dst_table} "
        ).format(
            env=self.get_client_environ_str(),
            bin_dir=self.ctx["bin_dir"],
            server=self.ctx["mr_server"],
            count=count,
            by_keys="-k" if by_keys else "",
            src_infix=src_infix,
            src_prefix=self.get_mr_src_prefix(),
            dst_prefix=self.get_tables_prefix(),
            table=table,
            dst_table=dst_table
        )
        cmd += self.get_max_value_filter_extra_cmd(self.get_tables_prefix() + dst_table)
        return cmd
