# -*- coding: UTF-8 -*-

import logging
import os

import sandbox.common.errors as ce
from sandbox import sdk2
from sandbox.sandboxsdk.environments import PipEnvironment


class MediaYdbFromYtRestore(sdk2.Task):
    """
    Imports ydb table from yt.
    Path to destination table is inferred from yt path, whereas schema is inferred from destination table.
    Keep in mind that table is not created, it is only truncated.
    """
    class Requirements(sdk2.Task.Requirements):
        environments = (
            PipEnvironment('yandex-yt', use_wheel=False),
            PipEnvironment('yandex-yt-yson-bindings-skynet', use_wheel=False),
            PipEnvironment('yql', use_wheel=False),
        )

    class Parameters(sdk2.Task.Parameters):

        with sdk2.parameters.Group("Source settings", collapse=True) as source_settings:
            cluster = sdk2.parameters.String(
                "cluster",
                default="hahn",
                description="YT cluster (without domain) to work with",
                required=True,
            )

            yt_folder = sdk2.parameters.String(
                "YT folder",
                description="Base YT folder where backup tables placed",
                required=True
            )

            yt_tables = sdk2.parameters.List(
                "tables",
                description="List of YT static tables to restore"
            )

            exclude_tables = sdk2.parameters.List(
                "exclude tables",
                description="Do not restore this tables"
            )

            restore_all = sdk2.parameters.Bool("Restore all tables", default=False)

            strict_schema = sdk2.parameters.Bool("YT strict schema_mode", default=True)

        with sdk2.parameters.Group("Destination settings", collapse=True) as destination_settings:

            ydb_endpoint = sdk2.parameters.String(
                "YDB endpoint",
                required=True
            )

            ydb_database = sdk2.parameters.String(
                "YDB database",
                required=True
            )

            ydb_folder = sdk2.parameters.String(
                "YDB folder",
                description="Subolder where to restore the tables (default root database path)",
                default="",
                required=False
            )

        with sdk2.parameters.Group("Other settings", collapse=True) as other_settings:
            secret = sdk2.parameters.YavSecret("YAV secret", required=True)
            yt_secret_key = sdk2.parameters.String("YT secret key", required=True)
            ydb_secret_key = sdk2.parameters.String("YDB secret key", required=True)
            yql_token_name = sdk2.parameters.String("YQL token name", required=True, default="yt_default")
            dry_run = sdk2.parameters.Bool("Dry run", default=True)
            debug_yql = sdk2.parameters.Bool("Debug YQL", default=False)

    def restore_table(self, client, backup_table_name):

        input_sql = "SELECT * WITHOUT _other FROM `%(backup_table_path)s`"

        if not self.Parameters.strict_schema:
            input_sql = "SELECT * FROM `%(backup_table_path)s`"

        sql = '''
                PRAGMA yt.InferSchema;
                PRAGMA yt.InferSchema = '1';
                PRAGMA yt.QueryCacheMode = "disable";
                PRAGMA yt.DataSizePerJob = "500485760";
                PRAGMA yt.UserSlots = '200';
                PRAGMA yt.DefaultMaxJobFails = "1";
                $input_data = "%(input_query)s";
                SELECT
                    SUM(Bytes) AS TotalBytes,
                    SUM(Rows) AS TotalRows,
                    SUM(Batches) AS TotalBatches,
                    SUM(Retries) AS TotalRetries
                FROM (
                    PROCESS
                        $input_data
                    USING YDB::BulkPushData(
                        TableRows(),
                        "%(endpoint)s",
                        "%(database)s",
                        "%(restore_table_path)s",
                        AsTuple("token", SecureParam("token:%(yql_token_name)s"))
                    )
                );
                '''

        query_str = sql % {
            'endpoint': self.Parameters.ydb_endpoint,
            'database': self.Parameters.ydb_database,
            'backup_table_path': os.path.join(self.Parameters.yt_folder, backup_table_name),
            'restore_table_path': os.path.join(self.Parameters.ydb_database, self.Parameters.ydb_folder, backup_table_name),
            'yql_token_name': self.Parameters.yql_token_name,
            'input_query': input_sql
        }

        logging.info("Restoring table: %s" % backup_table_name)
        if self.Parameters.dry_run:
            logging.info(query_str)
            return

        q = client.query(query_str, syntax_version=1)
        q.run()
        result = q.get_results()
        self.check_result(result)

        result_row = result.table.rows[0]
        logging.info('	Total bytes: %s, total rows: %s' % (result_row[0], result_row[1]))

    def check_result(self, result):
        if not result.is_success:
            msgs = [result.status]
            for e in result.errors:
                msgs.append(e.format_issue())

            raise ce.TaskError("\n".join(msgs))

    def read_table_list(self, client, folder):
        query_str = '''
            SELECT
                AGGREGATE_LIST(TableName(Path))
            FROM FOLDER("%s", "schema")
                WHERE
            Type = "table";
            ''' % folder

        q = client.query(query_str, syntax_version=1)
        q.run()
        result = q.get_results()

        self.check_result(result)

        return result.table.rows[0][0]

    def on_execute(self):
        from yql.api.v1.client import YqlClient

        yt_token = self.Parameters.secret.data()[self.Parameters.yt_secret_key]
        yql_client = YqlClient(db=self.Parameters.cluster, token=yt_token)

        console_handler = logging.StreamHandler()
        yql_logger = logging.getLogger('yql.client.request')
        yql_logger.setLevel(logging.DEBUG if self.Parameters.debug_yql else logging.FATAL)
        yql_logger.addHandler(console_handler)

        table_list = self.get_tables_list(yql_client)

        if not table_list:
            raise ce.TaskError("Nothing to do - empty table list.")

        if self.Parameters.dry_run:
            logging.info("Dry run enabled")

        for backup_table in table_list:
            self.restore_table(yql_client, backup_table)
