# coding: utf-8
import logging
from collections import Counter, namedtuple
from typing import Union

import sandbox.sdk2 as sdk2
from sandbox.common import errors
from sandbox.projects.music.deployment.MusicRestoreMdb.YandexCloudDatabase.DnsApi import DNSApi
from sandbox.projects.music.deployment.MusicRestoreMdb.YandexCloudDatabase.MongoClient import MDBMongoClient
from sandbox.projects.music.deployment.MusicRestoreMdb.YandexCloudDatabase.MySQLClient import MDBMySQLClient
from sandbox.projects.music.deployment.helpers.MusicBaseTask import MusicBaseTask
from sandbox.projects.music.deployment.helpers.TaskHelper import TaskHelper

supported_databases = {
    'mysql': MDBMySQLClient,
    'mongo': MDBMongoClient
}


def resources_flavours():
    m2 = 'm2'
    s2 = 's2'

    types = (m2, s2)
    common_flavours = ('nano', 'micro', 'small', 'medium', 'large', 'xlarge', '2xlarge', '3xlarge', '4xlarge')
    m2_only_flavour = ('5xlarge', '6xlarge', '7xlarge')

    return tuple("{}.{}".format(x, y) for x in types for y in common_flavours) + \
           tuple("{}.{}".format(m2, x) for x in m2_only_flavour)


class MusicRestoreMdb(MusicBaseTask, TaskHelper):
    """Restore databases in MDB"""

    class Requirements(sdk2.Task.Requirements):
        environments = [TaskHelper.pydns_library_environment, TaskHelper.mysql_library_environment]

    class Context(sdk2.Task.Context):
        operation_ids = {}
        new_cluster_id = ''
        old_cluster_id = ''

    class Parameters(sdk2.Task.Parameters):
        description = "Restore mdb cluster"

        kill_timeout = 2 * 60

        with sdk2.parameters.Group("Cluster config") as cluster_config:
            source_cluster = sdk2.parameters.String("Source cluster id",
                                                    required=True,
                                                    description='mdb5uosi8os1jlkd65kn')

            destination_cluster_name = sdk2.parameters.String("Destination cluster name",
                                                              required=False,
                                                              description='Destination cluster human readable name')

            target_folder_id = sdk2.parameters.String("Target folder id",
                                                      required=True,
                                                      description="MDB folder to restore the database to")  # type: str

            mdb_token_yav_id = sdk2.parameters.YavSecret("YAV secret with mdb token",
                                                         required=True)

            mdb_token_yav_id_key = sdk2.parameters.String("Key with mdb token",
                                                          required=True,
                                                          default_value='mdb_token')

        database_type = sdk2.parameters.String("Database type", choices=[(x, x) for x in supported_databases.keys()],
                                               ui=sdk2.parameters.String.UI("select"),
                                               required=True,
                                               description="Supported restore databases")  # type: Union[str, sdk2.parameters.String]
        with database_type.value['mysql']:
            with sdk2.parameters.Group("SQL config") as sql_c:
                execute_sql_on_target = sdk2.parameters.Bool("Execute sql after restore", default=False)
                with execute_sql_on_target.value[True]:
                    raw_sql_statement = sdk2.parameters.String("Raw sql statement", required=False,
                                                               description="Execute sql in restore target")
                    target_yav_with_mysql_creds = sdk2.parameters.YavSecret(
                        "YAV secret with target mysql passwd and username",
                        required=True)
                    target_yav_passwd_key = sdk2.parameters.String("YAV password key", default="main.db.cloud.password")
                    target_yav_user_key = sdk2.parameters.String("YAV user key", default="main.db.username")

        with sdk2.parameters.Group("Backup selection") as bk:
            use_latest_backup = sdk2.parameters.Bool("Use latest backup",
                                                     default=True)
            with use_latest_backup.value[False]:
                backup_name = sdk2.parameters.String("Backup name",
                                                     required=False,
                                                     description='mdb5uosi8os1jlkd65kn:stream_20200816T224904Z')

        # TODO
        # Set correct buffer size for smaller flavours
        # use_custom_flavour = sdk2.parameters.Bool("Use custom size flavour", default=False)
        # with use_custom_flavour.value[True]:
        #     custom_flavour = sdk2.parameters.String("Flavour",
        #                                             required=False,
        #                                             choices=[(x, x) for x in resources_flavours()])  # type: str

        with sdk2.parameters.Group("Datacenter config") as datacenter_config:
            custom_dc_hosts_allocation = sdk2.parameters.Bool("Allocate hosts in other datacenters",
                                                              default=False)

            with custom_dc_hosts_allocation.value[True]:
                dc_sas = sdk2.parameters.Integer("SAS", default=0)  # type: int
                dc_vla = sdk2.parameters.Integer("VLA", default=0)  # type: int
                dc_man = sdk2.parameters.Integer("MAN", default=0)  # type: int

        with sdk2.parameters.Group("DNS config") as dns_config:
            create_dns_names = sdk2.parameters.Bool("Create dns CNAMEs",
                                                    default=False)

            with create_dns_names.value[True]:
                dns_api_username = sdk2.parameters.String("Username to access dns api",
                                                          description='robot-music-admin',
                                                          required=True)  # type: str

                dns_token_yav_id = sdk2.parameters.YavSecret("YAV secret with dns token",
                                                             required=True)

                dns_token_yav_id_key = sdk2.parameters.String("Key with dns token",
                                                              required=True,
                                                              default_value="dns_token")

                dns_mask = sdk2.parameters.String("Target hosts domain",
                                                  required=True,
                                                  description=".music.yandex.net")

    def restore(self, token):
        mdb = supported_databases[self.Parameters.database_type](token)  # type: Union[MDBMongoClient, MDBMySQLClient]

        source_cluster = mdb.cluster(self.Parameters.source_cluster)
        source_config = source_cluster.get()

        if self.Parameters.destination_cluster_name:
            with self.memoize_stage.rename:
                self.rename(mdb, self.Parameters.target_folder_id)

            if 'rename_id' in self.Context.operation_ids:
                self.check_operation_completed(mdb, self.Context.operation_ids['rename_id'])

        with self.memoize_stage.restore_backup:
            if self.Parameters.use_latest_backup:
                backups = source_cluster.list_backups()
                backup_id = backups[0]
            else:
                backup_id = self.Parameters.backup_name

            target_hosts = []
            if not self.Parameters.custom_dc_hosts_allocation:
                source_hosts = source_cluster.list_hosts()
                for host in source_hosts:
                    target_hosts.append(
                        {
                            'zoneId': host['zoneId'],
                        })
            else:
                for h in {"sas": self.Parameters.dc_sas,
                          "man": self.Parameters.dc_man,
                          "vla": self.Parameters.dc_vla}.items():
                    for i in range(h[1]):
                        target_hosts.append(
                            {
                                'zoneId': h[0],
                            })

            res = source_cluster.restore(backup_id, source_config, hosts_config=target_hosts,
                                         name=self.Parameters.destination_cluster_name,
                                         target_folder=self.Parameters.target_folder_id)

            self.Context.new_cluster_id = res['metadata']['clusterId']
            self.Context.operation_ids['restore_id'] = res['id']

        self.check_operation_completed(mdb, self.Context.operation_ids['restore_id'], 500)

    @staticmethod
    def check_operation_completed(mdb, op_id, sleep_time=100):
        op = mdb.operation(op_id)
        state = op.get()
        if not state['done']:
            logging.debug(state)
            logging.info("Operation {} is not done yet. Sleeping for {} seconds".format(op_id, sleep_time))
            raise sdk2.WaitTime(sleep_time)

    def rename(self, mdb, folder_id):
        databases = mdb.cluster(None).list(folder_id)
        target_name = self.Parameters.destination_cluster_name
        old_cluster_id = ''
        for c in databases['clusters']:
            if c['name'] == target_name:
                old_cluster_id = c['id']

        if not old_cluster_id:
            logging.info("Could not find cluster '{}'. Proceeding to create it".format(target_name))
            return

        self.Context.old_cluster_id = old_cluster_id
        old_cluster = mdb.cluster(old_cluster_id)
        operation = old_cluster.rename(target_name + "_old")
        self.Context.operation_ids['rename_id'] = operation['id']

    def create_cnames(self, mdb_token):
        import DNS
        mdb = supported_databases[self.Parameters.database_type](mdb_token)
        cluster = mdb.cluster(self.Context.new_cluster_id)

        hosts = cluster.list_hosts()

        cnames = self._cnames(hosts)

        dns_token = self.Parameters.dns_token_yav_id.data()[self.Parameters.dns_token_yav_id_key]
        dns = DNSApi(self.Parameters.dns_api_username, dns_token)
        for cname in cnames:
            try:
                ret = DNS.dnslookup(cname.left, 'CNAME')
            except DNS.ServerError as e:
                if e.rcode == DNS.Status.NXDOMAIN:
                    logging.info("{} is not pointing to any domain".format(cname.left))
                else:
                    logging.error("Unexpected dns error", exc_info=True)
            else:
                logging.info("{} is pointing to {}. Removing".format(cname.left, ret[0]))
                dns.delete_cname(cname.left, ret[0])
            finally:
                logging.info("Creating cname {} -> {}".format(cname.left, cname.right))
                dns.add_cname(cname.left, cname.right)

        self.Context.cname_map = cnames

    def _cnames(self, hosts):
        c = Counter()
        cnames = []
        cname_mapping = namedtuple('cname_mapping', ['left', 'right'])
        template = "{service_name}-{dc}-{num}.{domain}"
        domain = self.Parameters.dns_mask if self.Parameters.dns_mask[0] != '.' else self.Parameters.dns_mask[1:]
        for host in hosts:
            c[host['zoneId']] += 1
            target_cname = template.format(service_name=self.Parameters.destination_cluster_name,
                                           dc=host['zoneId'],
                                           num=c[host['zoneId']],
                                           domain=domain)
            cnames.append(cname_mapping(target_cname, host['name']))
        logging.info("CNAMEs {}".format(cnames))
        return cnames

    def remove_old_cluster(self, mdb_token):
        # make sure we don't delete production cluster by accident
        logging.info("Check that source cluster {} != new cluster {}".format(self.Context.old_cluster_id,
                                                                             self.Parameters.source_cluster))
        if self.Context.old_cluster_id == self.Parameters.source_cluster:
            raise errors.TaskError("Trying to remove source cluster {}".format(self.Parameters.source_cluster))

        mdb = supported_databases[self.Parameters.database_type](mdb_token)  # noqa
        cluster = mdb.cluster(self.Context.old_cluster_id)
        logging.info("Remove old cluster '{}'".format(self.Context.old_cluster_id))
        cluster.delete()

    def execute_sql(self):
        from MySQLdb import connect

        sql = self.Parameters.raw_sql_statement
        if not sql:
            raise errors.TaskError("Sql statement can not be empty")

        yav_data = self.Parameters.target_yav_with_mysql_creds.data()
        username = yav_data[self.Parameters.target_yav_user_key]
        password = yav_data[self.Parameters.target_yav_passwd_key]

        master_template = "c-{cluster_id}.rw.db.yandex.net"
        host = master_template.format(cluster_id=self.Context.new_cluster_id)

        con = connect(host=host, user=username, passwd=password)
        cursor = con.cursor()

        logging.info("Executing '{}' on '{}'".format(sql, host))
        cursor.execute(sql)
        con.commit()
        cursor.close()

    def on_execute(self):
        mdb_token = self.Parameters.mdb_token_yav_id.data()[self.Parameters.mdb_token_yav_id_key]
        self.restore(mdb_token)

        if self.Parameters.create_dns_names:
            self.create_cnames(mdb_token)

        if self.Context.old_cluster_id:
            self.remove_old_cluster(mdb_token)

        if self.Parameters.execute_sql_on_target:
            self.execute_sql()
