import os.path as path
import json

import sandbox.common.types.task as ctt

from sandbox import sdk2

from sandbox.projects.yql.RunYQL2 import RunYQL2

from sandbox.projects.ydo import YdoKikimrExecutable, YdoDatabaseConfig, get_last_released_resource, get_id_or_none
from sandbox.projects.ydo.config_operations import get_table_types
from sandbox.projects.ydo.backup.ToKikimr import YdoBackupToKikimr

MAIN_TABLES = ["workers", "service_cards", "categories", "category_synonyms", "photos"]
BACKUP_TABLES = ["workers", "service_cards", "categories", "category_synonyms", "photos", "puid_to_worker_id", "workers_to_moderation"]


class YdoBackupRecover(sdk2.Task):
    class Parameters(sdk2.Parameters):
        use_stable_resources = sdk2.parameters.Bool(
            "Use stable resources?",
            default=True,
        )
        with use_stable_resources.value[False]:
            kikimr_executable_resource = sdk2.parameters.Resource(
                "Kikimr",
                resource_type=YdoKikimrExecutable,
                required=True,
            )

            ydo_database_config = sdk2.parameters.Resource(
                "Database config",
                resource_type=YdoDatabaseConfig,
                required=True,
            )

        with sdk2.parameters.Group("YT parameters") as yt_block:
            yql_vault_token = sdk2.parameters.String("Your yt token name in vault", default="YQL_TOKEN", required=True)
            with sdk2.parameters.RadioGroup("Host") as yt_host:
                yt_host.values["hahn"] = yt_host.Value(value="Hahn", default=True)
                yt_host.values["banach"] = yt_host.Value(value="Banach")

        backup_home = sdk2.parameters.String("Backup Home", default="//home/ydo/backups/recovering", required=True)
        backup_tag = sdk2.parameters.String("Backup tag", required=True)
        with sdk2.parameters.Group("backup tables") as backup_block:
            backup_workers = sdk2.parameters.String("workers table", required=True)
            backup_service_cards = sdk2.parameters.String("service_cards table", required=True)
            backup_categories = sdk2.parameters.String("categories table", required=True)
            backup_category_synonyms = sdk2.parameters.String("category_synonyms table", required=True)
            backup_photos = sdk2.parameters.String("photos table", required=True)

        parallel_recovering = sdk2.parameters.Bool("Recover all tables parallel", default=False)

    class Requirements(sdk2.Requirements):
        cores = 1

        class Caches(sdk2.Requirements.Caches):
            pass

    def init_context(self):
        self.Context.next_index = 0
        config_data = get_last_released_resource(YdoDatabaseConfig, self.Parameters.use_stable_resources, self.Parameters.ydo_database_config)
        with open(str(config_data.path)) as i:
            self.Context.config = json.load(i)

    def prepare_tables(self):
        with open(path.join(path.dirname(path.abspath(__file__)), "prepare_backup.sql")) as query_file:
            query = query_file.read()

        placeholders = {
            "%CLUSTER%": self.Parameters.yt_host,
            "%BACKUP_HOME%": self.Parameters.backup_home,
            "%TAG%": self.Parameters.backup_tag,
        }

        def form_table_schema(table, schema):
            if table == "service_cards":
                format_template = "service_cards.{0} AS {0}"
            else:
                format_template = "{0}"
            return ", ".join(map(lambda column: format_template.format(column["name"]), schema))

        for table in MAIN_TABLES:
            placeholders["%BACKUP_{}%".format(table.upper())] = getattr(self.Parameters, "backup_{}".format(table))
            table_schema = get_table_types(self.Context.config, table)
            placeholders["%{}_COLUMNS%".format(table.upper())] = form_table_schema(table, table_schema)

        task = RunYQL2(
            self,
            description="Prepare backup for task {}".format(self.id),
            notifications=self.Parameters.notifications,
            create_sub_task=False,
            query=query,
            custom_placeholders=placeholders,
            trace_query=True,
            yql_token_vault_name=self.Parameters.yql_vault_token
        )
        task.enqueue()

        raise sdk2.WaitTask([task.id], ctt.Status.Group.SUCCEED, wait_all=True)

    def generate_recover_task(self, table):
        return YdoBackupToKikimr(
            self,
            description="Recover {} for {}".format(table, self.id),
            notifications=self.Parameters.notifications,
            create_sub_task=False,
            use_stable_resources=self.Parameters.use_stable_resources,
            kikimr_executable_resource=get_id_or_none(self.Parameters.kikimr_executable_resource),
            ydo_database_config=get_id_or_none(self.Parameters.ydo_database_config),
            table=table,
            yql_vault_token=self.Parameters.yql_vault_token,
            yt_host=self.Parameters.yt_host,
            yt_table=path.join(self.Parameters.backup_home, self.Parameters.backup_tag, table)
        )

    def recover_table(self, table):
        task = self.generate_recover_task(table)

        task.enqueue()

        self.Context.next_index += 1

        raise sdk2.WaitTask(task.id, ctt.Status.Group.SUCCEED, wait_all=True)

    def recover_all_tables(self):
        tasks = [self.generate_recover_task(table) for table in BACKUP_TABLES]
        for task in tasks:
            task.enqueue()

        raise sdk2.WaitTask([task.id for task in tasks], ctt.Status.Group.SUCCEED, wait_all=True)

    def on_execute(self):
        with self.memoize_stage.init_context:
            self.init_context()

        with self.memoize_stage.prepare_tables:
            self.prepare_tables()

        if not self.Parameters.parallel_recovering:
            if self.Context.next_index < len(BACKUP_TABLES):
                self.recover_table(BACKUP_TABLES[self.Context.next_index])
        else:
            with self.memoize_stage.recover_all_tables:
                self.recover_all_tables()
