# -*- coding: UTF-8 -*-

from concurrent.futures import TimeoutError
import logging
import os

from sandbox import sdk2
from sandbox.sdk2.service_resources import SandboxTasksBinary
import sandbox.common.errors as ce


class YdbDropTables(sdk2.Task):
    """Drop tables from YDB"""

    class Parameters(sdk2.Task.Parameters):

        with sdk2.parameters.RadioGroup("YdbDropTables binary type") as release_type:
            release_type.values.stable = release_type.Value("stable", default=True)
            release_type.values.test = release_type.Value("test")

        ydb_token = sdk2.parameters.Vault(
            "YDB token from Vault",
            description='"name" or "owner:name"',
            required=True,
        )
        endpoint = sdk2.parameters.String(
            "YDB endpoint",
            description="host:port",
            default_value="ydb-ru-prestable.yandex.net:2135",
            required=True,
        )
        database = sdk2.parameters.String(
            "YDB database name",
            required=True,
        )
        path = sdk2.parameters.String(
            "Path from the root of the database",
            default_value="",
            required=False,
        )
        tables_to_drop = sdk2.parameters.List(
            "List of tables to be dropped.",
            default=[],
            description="Empty means all tables will be dropped. Unknown paths are ignored.",
        )

    def on_save(self):
        attrs = {
            "target": "sandbox/projects/iot/YdbDropTables",
            "release": self.Parameters.release_type or "stable"
        }
        res = SandboxTasksBinary.find(attrs=attrs).first()
        if res is not None:
            self.Requirements.tasks_resource = res.id
        else:
            raise ce.ResourceNotFound("Can't find binary for %(type)s task (%(res)s with attrs: %(attrs)s)" % {
                "type": self.type.name,
                "res": SandboxTasksBinary.name,
                "attrs": attrs
            })

    def on_execute(self):
        import ydb

        driver_config = ydb.DriverConfig(
            self.Parameters.endpoint, self.Parameters.database, auth_token=bytes(self.Parameters.ydb_token.data()),
        )
        with ydb.Driver(driver_config) as driver:
            try:
                driver.wait(timeout=15)
            except TimeoutError:
                raise ce.TaskError("Failed to connect to YDB: %s" % driver.discovery_debug_details())

            scheme_client = driver.scheme_client
            session = driver.table_client.session().create()
            tables_to_drop = set()
            for table in self.Parameters.tables_to_drop:
                tables_to_drop.add(os.path.join(self.Parameters.database, table))

            if tables_to_drop:
                logging.info("Will drop these tables: ", tables_to_drop)
            else:
                logging.info("Will drop all tables")

            def walk(i, prefix):
                directory = scheme_client.list_directory(prefix)
                logging.info("%sInspecting %s" % (" " * i * 2, prefix))
                for c in directory.children:
                    pt = os.path.join(prefix, c.name)
                    if c.type == ydb.scheme.SchemeEntryType.DIRECTORY:
                        if c.name.startswith(".sys"):
                            continue
                        walk(i + 1, pt)
                        if tables_to_drop and pt not in tables_to_drop:
                            continue
                        logging.info("%sRemoving directory %s" % (" " * i * 2, pt))
                        scheme_client.remove_directory(pt)
                    elif c.type == ydb.scheme.SchemeEntryType.TABLE:
                        if tables_to_drop and pt not in tables_to_drop:
                            continue
                        logging.info("%sRemoving table %s" % (" " * i * 2, pt))
                        session.drop_table(pt)
                    else:
                        logging.info("%sSkipping %s %s" % (" " * i * 2, c.type, pt))

            try:
                target_path = self.Parameters.database
                if self.Parameters.path:
                    target_path = os.path.join(self.Parameters.database, self.Parameters.path.lstrip("/"))
                walk(0, target_path)
            except Exception as e:
                raise ce.TaskError("Failed to drop tables: %s" % e.message)
