# -*- coding: UTF-8 -*-

import sandbox.common.errors as ce
from sandbox.projects.common import binary_task
from sandbox.projects.metrika.utils import CommonParameters
from sandbox.projects.metrika.utils.base_metrika_task import BaseMetrikaTask, with_parents

from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox import sdk2
from sandbox.projects.kikimr.resources import YdbCliBinary
from sandbox.sdk2.helpers import subprocess as sp, ProcessLog
from sandbox.projects.metrika.utils.mixins.juggler_reporter import JugglerReporterMixin
import os
import logging


@with_parents
class MetrikaVisitsYdbRestoreFromYt(BaseMetrikaTask, JugglerReporterMixin):
    PROGRESS_DONE = "PROGRESS_DONE"
    SUCCESS = "SUCCESS"
    ydb_cli = None
    yt_token = None
    ydb_token = None

    class Requirements(BaseMetrikaTask.Requirements):
        environments = (PipEnvironment('yandex-yt'), PipEnvironment('yandex-yt-yson-bindings-skynet'))

    class Parameters(CommonParameters):
        with sdk2.parameters.Group("Step 1. YT Prepare") as yt_prepare:
            yt_token = sdk2.parameters.Vault(
                "YT token from Vault",
                description='"name" or "owner:name"',
                required=True,
            )

            proxy = sdk2.parameters.String(
                "YT proxy (cluster)",
                default="hahn",
                required=True,
            )

            yt_folder = sdk2.parameters.String(
                "Path to the dir with stored backup",
                description="There must not be a slash in the end of destination!",
                default_value="//home/metrika/disaev/ydb_backup",
                required=True,
            )

        with sdk2.parameters.Group("Step 2. YDB Prepare") as ydb_prepare:
            ydb_secret = sdk2.parameters.YavSecret(
                "YAV secret",
                description="YAV secret with ydb token",
                required=True
            )

            ydb_secret_key = sdk2.parameters.String(
                "YDB secret key",
                description="ydb token name",
                default_value="token",
                required=True
            )

            ydb_token = sdk2.parameters.Vault(
                "YDB token from Vault",
                description='"name" or "owner:name"',
                required=True,
            )

            ydb_endpoint = sdk2.parameters.String(
                "YDB endpoint",
                description="host:port",
                default_value="ydb-ru-prestable.yandex.net:2135",
                required=True,
            )

            ydb_database = sdk2.parameters.String(
                "YDB database name",
                default_value="/ru-prestable/metrika/testing/visits",
                required=True,
            )

            ydb_tables_path = sdk2.parameters.String(
                "Path to a tables in ydb to be restored up",
                description="Path to output tables in ydb",
                required=True,
            )

            check_interval_time = sdk2.parameters.Integer(
                "Check progress status interval time (sec.)",
                default=10,
                required=False,
            )

        with sdk2.parameters.Group("Step 3. Juggler Notifications") as juggler_prepare:
            juggler_host_name = sdk2.parameters.String(
                "juggler-host (for schedulers)",
                default="metrika-sandbox"
            )
            juggler_service_name = sdk2.parameters.String(
                "juggler-service (for schedulers)",
                default="metrika-visits4d-ydb-to-yt-backup"
            )

        _binary = binary_task.binary_release_parameters_list(stable=True)

    def on_prepare(self):
        if len(self.Parameters.juggler_host_name) > 0:
            self.juggler_host = self.Parameters.juggler_host_name
        if len(self.Parameters.juggler_service_name) > 0:
            self.juggler_service = self.Parameters.juggler_service_name

    def get_common_cmd_part(self):
        args = [self.ydb_cli]
        args += ["--endpoint={}".format(self.Parameters.ydb_endpoint)]
        args += ["--database={}".format(self.Parameters.ydb_database)]
        return args

    def get_ydb_cli(self):
        ydb_cli_resource = YdbCliBinary.find(
            attrs=dict(released="stable", platform="linux")
        ).first()

        if ydb_cli_resource is None:
            raise ce.TaskError("Cannot find {} resource".format(YdbCliBinary.name))
        self.ydb_cli = str(sdk2.ResourceData(ydb_cli_resource).path)

        postfix = ".tgz"
        if self.ydb_cli.endswith(postfix):
            with ProcessLog(self, logger="ydb_cli_unpack") as pl:
                work_dir = os.getcwd()
                sp.check_call(["tar", "-zxf", self.ydb_cli], shell=False, stdout=pl.stdout, stderr=pl.stderr, cwd=work_dir)
                pl.logger.info("Data after extraction : {}".format(os.listdir(work_dir)))
                self.ydb_cli = os.path.join(work_dir, os.path.basename(self.ydb_cli)[:-len(postfix)])

    def get_yt_tables(self):
        logging.info("YT target directory preparing started.")

        from yt import wrapper as yt
        yt.config.set_proxy(self.Parameters.proxy)
        yt.config["token"] = self.yt_token

        logging.info("Check data for restore in YT directory {}".format(self.Parameters.yt_folder))

        all_files_list = yt.list(self.Parameters.yt_folder)
        logging.info("List of all files in YT directory {}: {}".format(self.Parameters.yt_folder, all_files_list))

        self.yt_tables = []
        for node in all_files_list:
            node_path = os.path.join(self.Parameters.yt_folder, node)
            if yt.get_attribute(node_path, "type") == "table":
                self.yt_tables.append(node_path)
        self.yt_tables.sort()
        logging.info("List of backup dirs in YT directory {}: {}".format(self.Parameters.yt_folder, self.yt_tables))

        if len(self.yt_tables) == 0:
            raise ce.TaskError("there aro no tables in backup folder {}".format(self.Parameters.yt_folder))

    def check_operation_status(self, parsed_output):
        status = parsed_output["status"]
        if status != self.SUCCESS:
            raise ce.TaskError("Output status on backup starting isn't SUCCESS. Current status: {}. Output: {}".format(status, parsed_output))

    def start_restore(self):
        logging.info("YDB restore initializing.")

        from yql.api.v1.client import YqlClient

        yql_client = YqlClient(db=self.Parameters.proxy, token=self.yt_token)
        for yt_table in self.yt_tables:
            logging.info("Restore table: {}".format(yt_table))
            self.restore_table(yql_client, yt_table)

        logging.info("Restore is funished")

    def restore_table(self, client, backup_table_name):

        input_sql = "SELECT * WITHOUT _other FROM `%(backup_table_path)s`"

        sql = '''
                PRAGMA yt.InferSchema;
                PRAGMA yt.InferSchema = '1';
                PRAGMA yt.QueryCacheMode = "disable";
                PRAGMA yt.DataSizePerJob = "500485760";
                PRAGMA yt.UserSlots = '200';
                PRAGMA yt.DefaultMaxJobFails = "1";
                $input_data = "%(input_query)s";
                SELECT
                    SUM(Bytes) AS TotalBytes,
                    SUM(Rows) AS TotalRows,
                    SUM(Batches) AS TotalBatches,
                    SUM(Retries) AS TotalRetries
                FROM (
                    PROCESS
                        $input_data
                    USING YDB::BulkPushData(
                        TableRows(),
                        "%(endpoint)s",
                        "%(database)s",
                        "%(restore_table_path)s",
                        AsTuple("token", SecureParam("token:%(ydb_token_name)s"))
                    )
                );
                '''

        query_str = sql % {
            'endpoint': self.Parameters.ydb_endpoint,
            'database': self.Parameters.ydb_database,
            'backup_table_path': os.path.join(self.Parameters.yt_folder, backup_table_name),
            'restore_table_path': os.path.join(self.Parameters.ydb_database, self.Parameters.ydb_tables_path, backup_table_name),
            'ydb_token_name': self.Parameters.ydb_secret_key,
            'input_query': input_sql
        }

        logging.info("Restoring table: %s" % backup_table_name)
        logging.info(query_str)

        q = client.query(query_str, syntax_version=1)
        q.run()
        result = q.get_results()
        self.check_result(result)

        result_row = result.table.rows[0]
        logging.info('	Total bytes: %s, total rows: %s' % (result_row[0], result_row[1]))

    def check_result(self, result):
        if not result.is_success:
            msgs = [result.status]
            for e in result.errors:
                msgs.append(e.format_issue())

            raise ce.TaskError("\n".join(msgs))

    def on_execute(self):
        self.yt_token = self.Parameters.yt_token.data()
        self.ydb_token = self.Parameters.ydb_token.data()
        # self.prepare_yt_target_directory()
        self.get_ydb_cli()
        self.get_yt_tables()
        self.start_restore()
        logging.info("Congratulations! Restore is finished.")
