# -*- coding: UTF-8 -*-

import json
import logging
import os

import jinja2

from sandbox import sdk2
from sandbox.projects.music.deployment.helpers.MusicBaseTask import MusicBaseTask
from sandbox.projects.music.deployment.helpers.YdbHelper import YdbHelper
from sandbox.sandboxsdk.environments import PipEnvironment


class MusicImportYdbFromYt(MusicBaseTask, sdk2.Task):
    """
    Imports ydb table from yt.
    Path to destination table is inferred from yt path, whereas schema is inferred from destination table.
    Keep in mind that table is not created, it is only truncated.
    """

    class Requirements(sdk2.Task.Requirements):
        environments = (
            PipEnvironment('yandex-yt', use_wheel=False),
            PipEnvironment('yandex-yt-yson-bindings-skynet', use_wheel=False),
            PipEnvironment('yql', use_wheel=False),
        )

    class Parameters(sdk2.Task.Parameters):

        proxy = sdk2.parameters.String(
            "YT proxy (cluster)",
            default="hahn",
            required=True,
        )

        source_path = sdk2.parameters.String(
            "Path to the yt backup dir",
            description="There must not be a slash in the end of destination!"
                        " In mentioned directory the new directory with name"
                        " represented by timestamp as name will be created",
            default_value="//home/music/noobgam/ydb_backup",
            required=True,
        )

        check_interval_time = sdk2.parameters.Integer(
            "Check progress status interval time (sec.)",
            default=300,
            required=False,
        )

        delete_tables_after_import = sdk2.parameters.Bool(
            "Delete tables after import",
            default=False,
            required=False,
        )

        ydb_endpoint = sdk2.parameters.String(
            "YDB endpoint",
            description="host:port",
            default_value="ydb-ru-prestable.yandex.net:2135",
            required=True,
        )

        ydb_database = sdk2.parameters.String(
            "YDB database name",
            default_value="/ru-prestable/musicbackend/qa/music/",
            required=True,
        )

        ydb_from_yt_token = sdk2.parameters.YavSecret(
            "YAV secret with ydb-yt migration tokens",
            default='sec-01ej69jtmv3t8g2675y81dc5bv',
            required=True,
            description="Yav secret with ydb and yt tokens key")

        yt_token_name = sdk2.parameters.String(
            "yt token name",
            default_value='yt_prod_token',
            description="Token name to extract from YAV, "
                        "necessary to access tables on yt directly"
        )

        yql_token_name = sdk2.parameters.String(
            "yql token name",
            default_value='yql_prod_token',
            description="Token name to extract from YAV, "
                        "necessary if you want to import to an environment different from QA."
                        "Keep in mind that there is a nested ydb token in YQL SecureParam"

        )

        ydb_token_name = sdk2.parameters.String(
            "ydb token name",
            default_value='ydb_qa_token',
            description="Token name to extract from YAV, "
                        "necessary if you want to extract to an environment different from QA"
        )

    @staticmethod
    def __extract_columns(data):
        json_data = json.loads(data)
        return list(map(
            lambda col: {'name': col['name'], 'type': col['type']['optional_type']['item']['type_id'] + '?'},
            json_data['columns']
        ))

    @staticmethod
    def __extract_primary_key(data):
        return json.loads(data)['primary_key']

    @staticmethod
    def __extract_tables_recursive(prefix, yt):
        level = yt.list(
            prefix,
            absolute=True,
            attributes=['type'],
        )
        res = []
        for node in level:
            node_type = node.attributes['type']
            node_value = str(node)
            if node_type == 'map_node':
                res += MusicImportYdbFromYt.__extract_tables_recursive(node_value, yt)
            elif node_type == 'table':
                res.append(node_value)
        return res

    def start_table_export(self, table, template, yh, yql_client):
        path_in_ydb = table[len(self.Parameters.source_path) + 1:]
        describe_raw = yh.describe(path_in_ydb)
        columns = MusicImportYdbFromYt.__extract_columns(
            describe_raw
        )
        primary_key = MusicImportYdbFromYt.__extract_primary_key(
            describe_raw
        )
        yql_request = template.render(
            # TODO: https://st.yandex-team.ru/KIKIMR-6875
            # something is very wrong with the way SecureParam works
            ydb_token_name=str(self.Parameters.ydb_token_name),
            ydb_table=path_in_ydb,
            ydb_database=str(self.Parameters.ydb_database),
            ydb_endpoint=str(self.Parameters.ydb_endpoint),
            columns=columns,
            yt_cluster=str(self.Parameters.proxy),
            yt_dump_path=os.path.join(str(self.Parameters.source_path), path_in_ydb)
        )
        logging.info("Recreating table before restoring")
        yh.drop_table(path_in_ydb)
        yh.create_table(path_in_ydb, columns, primary_key)
        logging.info("Executing:\n{}".format(yql_request))
        req = yql_client.query(yql_request, syntax_version=1)
        req.run()
        self.Context.current_table_index += 1
        self.Context.last_operation_id = req.operation_id
        return req

    def is_operation_done(self, operation_id):
        from yql.client.operation import YqlOperationResultsRequest
        request = YqlOperationResultsRequest(operation_id)
        request.run()
        if request.status == 'COMPLETED':
            return True
        elif request.status == 'RUNNING':
            return False
        else:
            raise Exception('Unexpected operation {} status "{}"'.format(operation_id, request.status))

    def on_execute(self):
        from yt import wrapper as yt
        from yql.api.v1.client import YqlClient

        yt_token = self.Parameters.ydb_from_yt_token.data()[self.Parameters.yt_token_name]
        yql_token = self.Parameters.ydb_from_yt_token.data()[self.Parameters.yql_token_name]
        ydb_token = self.Parameters.ydb_from_yt_token.data()[self.Parameters.ydb_token_name]
        # TODO: improve this
        template = jinja2.Template(
            open(os.path.dirname(__file__) +
                 '/../deployment/helpers/import_ydb_from_yt.jinja').read().decode('utf-8')
        )

        yql_client = YqlClient(
            db=str(self.Parameters.proxy),
            token=yql_token,
        )
        yt.config.set_proxy(self.Parameters.proxy)
        yt.config["token"] = yt_token

        database = str(self.Parameters.ydb_database)
        yh = YdbHelper(
            ydb_token,
            self.Parameters.ydb_endpoint,
            database,
            yt_token,
            self.Parameters.proxy,
            self
        )

        with self.memoize_stage.find_tables_in_dump:
            self.Context.tables_in_dump = MusicImportYdbFromYt.__extract_tables_recursive(self.Parameters.source_path,
                                                                                          yt)
            self.Context.current_table_index = -1
            logging.info('Found these tables to extract from: {}'.format(self.Context.tables_in_dump))

        current_index = self.Context.current_table_index
        if current_index < len(self.Context.tables_in_dump):
            if current_index != -1 and not self.is_operation_done(self.Context.last_operation_id):
                raise sdk2.WaitTime(self.Parameters.check_interval_time)
            self.start_table_export(self.Context.tables_in_dump[current_index], template, yh, yql_client)
            raise sdk2.WaitTime(self.Parameters.check_interval_time)
        logging.info("Congratulations! Import is finished.")
        if self.Parameters.delete_tables_after_import:
            logging.info("Deleting tables in {} now.".format(self.Parameters.source_path))
            yt.remove(self.Parameters.source_path, recursive=True)
