# -*- coding: utf-8 -*-

import logging
import os
import random

from sandbox import common
from sandbox import sdk2
import sandbox.common.types.client as ctc
import sandbox.common.types.task as ctt
from sandbox.projects.rtmr.clusters import RTMR_CLUSTERS, RtmrClustersInfo
import sandbox.projects.rtmr.common as rtmr_common
from sandbox.projects.rtmr.resources import RtmrPushTool
from sandbox.sandboxsdk import environments
from sandbox.sdk2.helpers import subprocess as sp


class RtmrUploadFromYt(sdk2.Task):
    """Upload table from YT"""

    class Requirements(sdk2.Task.Requirements):
        client_tags = ctc.Tag.Group.LINUX
        environments = [environments.PipEnvironment("yandex-yt")]
        cores = 2
        disk_space = 2 * 1024  # 2Gb

    class Parameters(sdk2.Task.Parameters):
        description = "Upload table from YT"
        kill_timeout = 23 * 3600  # 23 hours

        with sdk2.parameters.Group("Source") as source_block:
            yt_cluster_name = sdk2.parameters.String("YT Cluster name", required=True)
            yt_path = sdk2.parameters.String("YT table path", required=True)
            yt_token_name = sdk2.parameters.String(
                "Vault secret name with YT token",
                default_value="robot-rtmr-build-yt",
                required=True
            )
            yt_row_from = sdk2.parameters.Integer("From row", default_value=None)
            yt_row_to = sdk2.parameters.Integer("To row", default_value=None)

        split_upload = sdk2.parameters.Bool("Split upload", default_value=False)
        with split_upload.value[True]:
            parallel_uploads = sdk2.parameters.Integer("Number of parallel uploads", required=True, default_value=1)
            rows_to_upload = sdk2.parameters.Integer("Number of rows to upload", required=True, default_value=1000000)

        with sdk2.parameters.Group("Destination") as destination_block:
            with sdk2.parameters.String("RTMR Cluster name", multiline=True, required=True) as cluster_name:
                _first = True
                for _name in RTMR_CLUSTERS:
                    if _first:
                        cluster_name.values[_name] = cluster_name.Value(default=True)
                        _first = False
                    else:
                        cluster_name.values[_name] = None
            rtmr_table = sdk2.parameters.String("RTMR Table", required=True)
            rtmr_push_resource = rtmr_common.LastResource(
                "RTMR Push Tool",
                resource_type=RtmrPushTool,
                required=False
            )

    class Context(sdk2.Task.Context):
        rtmr_push = None
        helper_path = None
        rows_count = None
        upload_tasks = dict()
        active_tasks = list()

    def get_cluster_hosts(self):
        from library.sky.hostresolver import Resolver
        return Resolver().resolveHosts(RtmrClustersInfo().clusters[self.Parameters.cluster_name].skynet)

    def load_rtmr_push(self):
        rtmr_push_resource = self.Parameters.rtmr_push_resource
        if rtmr_push_resource is None or rtmr_push_resource == '0':
            rtmr_push_resource = sdk2.Resource.find(resource_type=RtmrPushTool).order(-sdk2.Resource.id).first()
        self.Context.rtmr_push = str(sdk2.ResourceData(rtmr_push_resource).path)

    def make_helper(self):
        self.Context.helper_path = str(sdk2.Path("helper.sh").absolute())
        with open(self.Context.helper_path, "w") as fd:
            fd.write("""#!/bin/bash\nset -ex\nyt read-table --format "<lenval=%true;has_subkey=%true>yamr" "$1" | """
                     """ "$2" -F lenval -T "$3" -s "$4" """)
            os.fchmod(fd.fileno(), 0755)

    def get_yt_env(self):
        env = os.environ.copy()
        env["YT_PROXY"] = self.Parameters.yt_cluster_name
        env["YT_TOKEN"] = sdk2.Vault.data(self.Parameters.yt_token_name)
        return env

    def update_rows_count(self):
        cmd = ["yt", "get", self.Parameters.yt_path + "/@row_count"]
        proc = sp.Popen(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, env=self.get_yt_env())
        stdout, _ = proc.communicate()
        if proc.returncode != 0:
            self.set_info("YT-tool return error code: " + str(proc.returncode))
            raise common.errors.TaskError("yt-tool return code: " + str(proc.returncode))
        self.Context.rows_count = int(stdout)
        self.set_info("Number of rows in input table: " + str(self.Context.rows_count))

    def schedule_tasks(self):
        idx = 0 if self.Parameters.yt_row_from is None else self.Parameters.yt_row_from
        upto = self.Context.rows_count if self.Parameters.yt_row_to is None else self.Parameters.yt_row_to
        upto = min(self.Context.rows_count, upto)
        while idx < upto:
            torow = idx + self.Parameters.rows_to_upload
            self.Context.upload_tasks[idx] = min(torow, upto)
            idx = torow
        self.Context.save()

    def start_task(self, rfrom, rto):
        params = dict(
            yt_cluster_name=self.Parameters.yt_cluster_name,
            yt_path=self.Parameters.yt_path,
            yt_token_name=self.Parameters.yt_token_name,
            split_upload=False,
            cluster_name=self.Parameters.cluster_name,
            rtmr_table=self.Parameters.rtmr_table,
            rtmr_push_resource=self.Parameters.rtmr_push_resource,
        )
        if rfrom is not None:
            params["yt_row_from"] = rfrom
        if rto is not None:
            params["yt_row_to"] = rto
        task = RtmrUploadFromYt(
            self,
            description="Upload " + self.get_source_yt_path(rfrom, rto),
            **params
        )
        task.save().enqueue()
        self.Context.active_tasks.append(task.id)
        self.set_info(
            "Upload {rfrom} .. {rto} task <a href=\"https://sandbox.yandex-team.ru/task/{taskid}/view\">{taskid}</a>".format(
                rfrom=rfrom,
                rto=rto,
                taskid=task.id,
            ),
            do_escape=False,
        )

    def start_tasks(self, number_tasks):
        for rfrom, rto in self.Context.upload_tasks.items():
            if number_tasks == 0:
                return
            if rto is not None and rto <= 0:
                continue
            self.Context.upload_tasks[rfrom] = -1
            self.start_task(rfrom, rto)
            number_tasks -= 1
        self.Context.save()

    def check_active_tasks(self):
        tasks_to_wait = list()
        logging.info("Wait tasks %r", self.Context.active_tasks)
        tasks = rtmr_common.get_tasks_by_id(self.Context.active_tasks)
        for task in tasks:
            logging.info("Task %r status is %r", task.id, task.status)
            if rtmr_common.is_task_failed(task):
                rtmr_common.stop_tasks(tasks)
                raise common.errors.TaskError("Subtask id {} failed".format(task.id))
            if not rtmr_common.is_task_completed(task):
                tasks_to_wait.append(task)
        self.Context.active_tasks = [t.id for t in tasks_to_wait]
        self.Context.save()

    def launch_uploads(self):
        with self.memoize_stage.update_rows_couunt(commit_on_entrance=False):
            self.update_rows_count()
        with self.memoize_stage.schedule_tasks(commit_on_entrance=False):
            self.schedule_tasks()

        self.check_active_tasks()
        if len(self.Context.active_tasks) < self.Parameters.parallel_uploads:
            self.start_tasks(self.Parameters.parallel_uploads - len(self.Context.active_tasks))

        tasks_to_wait = rtmr_common.get_tasks_by_id(self.Context.active_tasks)
        if len(tasks_to_wait) > 0:
            raise sdk2.WaitTask(
                tasks_to_wait,
                list(ctt.Status.Group.FINISH + ctt.Status.Group.BREAK),
                wait_all=False
            )

    def get_source_yt_path(self, rfrom, rto):

        def yt_limit_format(param):
            if param is None:
                return ""
            return "#" + str(param)

        if rfrom is not None or rto is not None:
            return "{path}[{rfrom}:{rto}]".format(
                path=self.Parameters.yt_path,
                rfrom=yt_limit_format(rfrom),
                rto=yt_limit_format(rto),
            )
        return self.Parameters.yt_path

    def upload(self):
        with self.memoize_stage.load_rtmr_push(commit_on_entrance=False):
            self.load_rtmr_push()
        with self.memoize_stage.make_helper(commit_on_entrance=False):
            self.make_helper()

        with sdk2.helpers.ProcessLog(self, logger=logging.getLogger("helper")) as pl:
            cmd = [
                self.Context.helper_path,
                self.get_source_yt_path(self.Parameters.yt_row_from, self.Parameters.yt_row_to),
                self.Context.rtmr_push,
                self.Parameters.rtmr_table,
                random.choice(list(self.get_cluster_hosts()))
            ]
            proc = sp.Popen(cmd, stdout=pl.stdout, stderr=sp.STDOUT, env=self.get_yt_env())
            proc.wait()
            if proc.returncode != 0:
                raise common.errors.TaskError("Helper return code: " + str(proc.returncode))

    def on_execute(self):
        if self.Parameters.split_upload:
            self.launch_uploads()
        else:
            self.upload()
