# coding: utf-8
import logging
import sandbox.sdk2 as sdk2
from sandbox.common import errors
import sandbox.common.types.task as ctt
import sandbox.common.types.misc as ctm
from sandbox.projects.Afisha.base import AfishaSandboxBaseTask


class AfishaTicketSystemMdbRecovery(AfishaSandboxBaseTask):

    BINARY_TASK_ATTR_TARGET = 'Afisha/infra/AfishaTicketSystemMdbRecovery'
    RESOURCE_BASE_PATH = "sandbox/projects/Afisha/infra/AfishaTicketSystemMdbRecovery/"
    DNS_MAPPING = {
        "ticketsystem-restored-c0": "ticketsystem-c0.afisha.tst.yandex.net",
        "ticketsystem-restored-c1": "ticketsystem-c1.afisha.tst.yandex.net"
    }
    MDB_RECOVERY_INFO = """
    ---------------------------------------
    Source cluster: {source_cluster}
    Source backup:  {restore_from_backup}
    Remove cluster: {remove_cluster}
    Restore to:     {restore_target_cluster}
    Dry run:        {dry_run_active}
    ---------------------------------------
    """

    class Parameters(AfishaSandboxBaseTask.Parameters):
        with sdk2.parameters.Group("Settings") as settings_block:
            src_cluster_id = sdk2.parameters.String("Source: cluster id", required=True)
            dst_cluster_folder = sdk2.parameters.String("Destination: cluster folder", required=True)
            dst_resource_preset_id = sdk2.parameters.String("Destination: resourcePresetId", required=True, default="s2.small")
            with sdk2.parameters.String("Destination: cluster name") as dst_cluster_name:
                dst_cluster_name.values.ticketsystem_restored_c0 = "ticketsystem-restored-c0"
                dst_cluster_name.values.ticketsystem_restored_c1 = "ticketsystem-restored-c1"
            db_username = sdk2.parameters.String("Username: ", required=True, default="ticketsystem")
            secret_id = sdk2.parameters.YavSecret("YAV: DNS/MDB tokens", required=True)
            dry_run = sdk2.parameters.Bool("Dry run", default=True)
            dst_excluded_databases = sdk2.parameters.List("Do not run queries on this databases", default=["postgres"])
            dst_cluster_whitelist = sdk2.parameters.List("Destination cluster whitelist", default=["ticketsystem-restored-c0", "ticketsystem-restored-c1"])

    def prepare(self):
        from afisha.infra.libs.mdb import MdbClientPostgres
        secret = self.Parameters.secret_id
        mdb_token = secret.data()["mdb-token"]
        mdb = MdbClientPostgres(token=mdb_token)

        dst_preset_id = self.Parameters.dst_resource_preset_id
        dst_cluster_name = self.Parameters.dst_cluster_name.replace("_", "-")
        src_cluster_config = mdb.cluster_get(cluster_id=self.Parameters.src_cluster_id)
        logging.info("Cluster configuration:\n%s", src_cluster_config)
        src_host_disk_size = src_cluster_config.resources["diskSize"]
        src_cluster_config.resources.update({"resourcePresetId": dst_preset_id,
                                             "diskSize": int(int(src_host_disk_size) * 1.5)
                                             })
        recreate_dst_cluster = mdb.cluster_list(folder_id=self.Parameters.dst_cluster_folder,
                                                filter=dst_cluster_name)
        for cluster in recreate_dst_cluster:
            if cluster.name not in self.Parameters.dst_cluster_whitelist:
                raise errors.TaskError("cluster %s not in whitelist", cluster.name)

        old_dst_cluster_id = None

        src_cluster_last_backup, src_backup_created_at = mdb.cluster_backups_list(cluster_id=self.Parameters.src_cluster_id,
                                                                                  return_creation_date=True)

        if recreate_dst_cluster:
            old_dst_cluster_id = recreate_dst_cluster[0].id

        logging.info(self.MDB_RECOVERY_INFO.format(source_cluster=self.Parameters.src_cluster_id,
                                                   remove_cluster=old_dst_cluster_id,
                                                   restore_from_backup=src_cluster_last_backup,
                                                   restore_target_cluster=dst_cluster_name,
                                                   dry_run_active=self.Parameters.dry_run))
        return old_dst_cluster_id, src_cluster_last_backup, src_backup_created_at, src_cluster_config, dst_cluster_name

    def run_queries(self, cluster_id, port=6432, user="ticketsystem"):
        from afisha.infra.libs.mdb import MdbClientPostgres
        from library.python import resource
        import psycopg2
        secret = self.Parameters.secret_id
        db_password = secret.data()['restored_password']
        mdb_token = secret.data()["mdb-token"]
        mdb = MdbClientPostgres(token=mdb_token)
        target_host = "c-{id}.rw.db.yandex.net".format(id=cluster_id)
        connection_params = {
            "host": target_host,
            "dbname": None,
            "user": user,
            "port": port,
            "password": db_password if not self.Context.dry_run else "XXXXXXXXX"
        }
        cluster_name = None
        if self.Parameters.dry_run:
            try:
                cluster_name = self.Parameters.dst_cluster_whitelist[0]
            except IndexError:
                pass
        else:
            cluster_name = mdb.cluster_get(cluster_id).name

        if not cluster_id:
            logging.info("Cannot run queries: target cluster doesn't exist.")
            return

        if cluster_name in self.Parameters.dst_cluster_whitelist:
            sql_file = self.Parameters.dst_cluster_name + ".sql"
            resource_key = self.RESOURCE_BASE_PATH + sql_file
            logging.info("resource key: %s", resource_key)
            queries_resource = resource.find(resource_key)

            if not queries_resource:
                raise errors.TaskError("query resource doesn't exist")

            databases = mdb.database_list(cluster_id)
            for db in databases:
                if db in self.Parameters.dst_excluded_databases:
                    continue
                connection_params["dbname"] = db
                keepalive_kwargs = {
                    "keepalives": 1,
                    "keepalives_idle": 900,
                    "keepalives_interval": 30,
                    "keepalives_count": 30,
                }
                connection_string = " ".join(["%s=%s" % (k, v) for k, v in connection_params.items()])
                logging.info("Will run resource file %s on cluster %s %s database: %s user: %s", resource_key, cluster_name, target_host, db, user)
                if self.Parameters.dry_run:
                    return
                for i in range(1, 6):
                    try:
                        conn = psycopg2.connect(connection_string, options="-c search_path=ticketsystem", **keepalive_kwargs)
                        conn.set_session(readonly=False, autocommit=True)
                        with conn.cursor() as cursor:
                            cursor.execute(queries_resource)
                        conn.close()
                        break
                    except Exception as err:
                        logging.warning("Couldnt run querry in %s try. Will try again..\nError:\n%s", i, err)
                        conn.close()
                        continue

    def recovery_mdb(self):
        from afisha.infra.libs.mdb import MdbClientPostgres
        secret = self.Parameters.secret_id
        mdb_token = secret.data()["mdb-token"]
        mdb = MdbClientPostgres(token=mdb_token)
        operation_id = "dry_run"
        old_dst_cluster_id, src_backup, src_backup_created_at, src_cluster_config, dst_cluster_name = self.prepare()
        logging.info("recreating target cluster")
        new_dst_cluster_id = None
        if not self.Parameters.dry_run:
            if old_dst_cluster_id:
                self.delete_old_cluster(old_dst_cluster_id)
            else:
                logging.info("cannot delete: target cluster doesn't exist")
            logging.info("restoring cluster from backup id: %s created_at: %s to target cluster: %s\n%s", src_backup, src_backup_created_at, dst_cluster_name, src_cluster_config)
            operation_id = mdb.restore_postgres_cluster(src_backup, src_backup_created_at, dst_cluster_name, src_cluster_config)
            logging.info("restore operation id: %s", operation_id)
            new_dst_cluster_id = mdb.cluster_list(folder_id=self.Parameters.dst_cluster_folder,
                                                  filter=dst_cluster_name)[0].id
            logging.info("new cluster id: %s", new_dst_cluster_id)
        return operation_id, old_dst_cluster_id, new_dst_cluster_id

    def update_resource_record(self):
        from afisha.infra.libs.dns import DnsClient
        logging.info("Change DNS name. Dry run: %s", self.Parameters.dry_run)

        secret = self.Parameters.secret_id
        dns_token = secret.data()["dns-token"]
        dns_username = secret.data()["dns-username"]

        dns = DnsClient(token=dns_token, username=dns_username)

        dst_cluster_name = self.Parameters.dst_cluster_name.replace("_", "-")
        cname_record = self.DNS_MAPPING.get(dst_cluster_name)

        # dns section
        old_cluster_cname = None
        new_cluster_cname = None

        if self.Context.OLD_CLUSTER_ID:
            old_cluster_cname = "c-{cluster_id}.rw.db.yandex.net".format(cluster_id=self.Context.OLD_CLUSTER_ID)

        if self.Context.NEW_CLUSTER_ID:
            new_cluster_cname = "c-{cluster_id}.rw.db.yandex.net".format(cluster_id=self.Context.NEW_CLUSTER_ID)

        logging.info("DNS: New CNAME: %s", new_cluster_cname)
        logging.info("DNS: Old CNAME: %s", old_cluster_cname)
        logging.info("CNAME: %s", cname_record)

        if not self.Parameters.dry_run:
            if old_cluster_cname:
                try:
                    dns.change_request(operation='delete', name=cname_record, type='CNAME', data=old_cluster_cname, ttl=60)
                except Exception as e:
                    logging.error("failed to remove CNAME record: %s: %s" % (old_cluster_cname, e))
            else:
                logging.info("cannot remove CNAME: old cluster_id is None")

            if new_cluster_cname:
                dns.change_request(operation='add', name=cname_record, type='CNAME', data=new_cluster_cname, ttl=60)
            else:
                logging.info("cannot remove CNAME: new cluster_id is None")

    def delete_old_cluster(self, cluster_id):
        from afisha.infra.libs.mdb import MdbClientPostgres
        secret = self.Parameters.secret_id
        mdb_token = secret.data()["mdb-token"]
        mdb = MdbClientPostgres(token=mdb_token)
        cluster_name = mdb.cluster_get(cluster_id).name

        if self.Parameters.dry_run:
            logging.info("Dry run: delete cluster: {}".format(cluster_id))
        else:
            if cluster_name in self.Parameters.dst_cluster_whitelist:
                return mdb.delete_postgres_cluster(cluster_id=cluster_id)
            else:
                raise errors.TaskError("Cluster {} not in whitelist".format(cluster_name))

    def update_password(self, cluster_id):
        from afisha.infra.libs.mdb import MdbClientPostgres

        secret = self.Parameters.secret_id
        mdb_token = secret.data()["mdb-token"]
        db_password = secret.data()['restored_password']
        mdb = MdbClientPostgres(token=mdb_token)

        if self.Parameters.dry_run:
            logging.info("Changing password: cluster_id %s  user %s", cluster_id, self.Parameters.db_username)
            return

        return mdb.user_update(cluster_id=cluster_id,
                               username=self.Parameters.db_username,
                               password=db_password)

    def on_enqueue(self):
        component_flow_lock = ctt.Semaphores.Acquire(
            name="{}_{}".format("afisha_ticketsystem_mdb_recovery", self.Parameters.dst_cluster_name),
            weight=1, capacity=1)
        release = (ctt.Status.Group.BREAK, ctt.Status.Group.FINISH)
        self.Requirements.semaphores = ctt.Semaphores(acquires=[component_flow_lock], release=release)

    def wait_deploy(self, operation_id):
        from afisha.infra.libs.mdb import MdbClientPostgres
        logging.info("Wait deploy Operation ID: {}".format(operation_id))
        if self.Parameters.dry_run:
            return
        secret = self.Parameters.secret_id
        mdb_token = secret.data()["mdb-token"]
        mdb = MdbClientPostgres(token=mdb_token)
        task_status = mdb.operation_status_get(id=operation_id)

        if task_status["done"]:
            if "error" in task_status:
                raise errors.TaskFailure("MDB task failed")
            else:
                return
        else:
            raise sdk2.WaitTime(600)

    def on_execute(self):
        with self.memoize_stage.first_step(commit_on_entrance=False):
            self.Context.OPERATION_ID, self.Context.OLD_CLUSTER_ID, self.Context.NEW_CLUSTER_ID = self.recovery_mdb()

        self.wait_deploy(self.Context.OPERATION_ID)

        with self.memoize_stage.second_step(commit_on_entrance=False):
            self.Context.OPERATION_ID = self.update_password(self.Context.NEW_CLUSTER_ID)

        with self.memoize_stage.third_step(commit_on_entrance=False):
            self.update_resource_record()

            if self.Context.dns_updated is ctm.NotExists:
                self.Context.dns_updated = True
                raise sdk2.WaitTime(70)

        self.run_queries(self.Context.NEW_CLUSTER_ID)
