import yaml

from inspect import getframeinfo, stack

from mail.template_master.load_testing.scripts.lib.db_adaptor import DBAdaptor, DBConnectionProvider, YcDBConnectionProvider
from mail.template_master.load_testing.scripts.lib.util import get_logger, prepare_db_auth_data


def _get_shard_instances(shard):
    return ((instance['host'], 6432, instance['dc'], ShardDbAdaptor.CONSTANTS['maildb_name']) for instance in shard['instances'])


def _get_shard_instances_str(shard):
    instances = _get_shard_instances(shard)
    return ','.join(str(instance) for instance in instances)


def _pack_user_values(get_uids_cursor, shard_id):
    batch_size = ShardDbAdaptor.CONSTANTS['batch_size']
    return ','.join(str((uid[0], shard_id)) for uid in get_uids_cursor.fetchmany(batch_size))


class ShardDbAdaptor(object):
    __maildb_connection_provider = None
    __sharddb_connection_provider = None

    CONSTANTS = {
        'batch_size': 1000,
        'load_type': 'dbaas_hot',
        'maildb_name': 'maildb',
        'reg_weight': 10,
        'scopes': [
            {
                'id': 0,
                'name': 'default',
            }
        ],
        'sharddb_name': 'sharddb',
    }

    def __init__(self, shards, sharddb_adaptor_params, logger):
        '''
        :param shards: shards description with the following structure:
            [
                {
                    'shard_name': str,
                    'instances': [
                        {
                            'host': str,
                            'dc': str,
                        },
                    ]
                }
            ]
        :param sharddb_adaptor_params:
        :param logger:
        '''
        self._shards = shards
        self._sharddb_adaptor_params = sharddb_adaptor_params
        self._log = logger

    def _get_sharddb_adaptor_params(self):
        return self._sharddb_adaptor_params

    def _get_maildb_adaptor_params(self, hosts):
        '''
        Assumes that user 'sharpei' with the same password exists in maildb
        '''
        params = dict(self._sharddb_adaptor_params)
        params.update({'dbname': ShardDbAdaptor.CONSTANTS['maildb_name']})
        if hosts:
            params['hosts'] = hosts
            del params['cluster']
            del params['vault_secret_version']
        return params

    def _get_cursor(self, query: str, conn_provider):
        db_adaptor = DBAdaptor(conn_provider)
        self._log.debug(query)
        return db_adaptor.get_cursor(query)

    def _sharddb_connection_provider(self, params: dict):
        if not self.__sharddb_connection_provider:
            self.__sharddb_connection_provider = YcDBConnectionProvider(**params)
        return self.__sharddb_connection_provider

    def _get_sharddb_cursor(self, query: str):
        return self._get_cursor(query, self._sharddb_connection_provider(self._get_sharddb_adaptor_params()))

    def _maildb_connection_provider(self, params: dict):
        if not self.__maildb_connection_provider:
            self.__maildb_connection_provider = DBConnectionProvider(**params)
        return self.__maildb_connection_provider

    def _get_maildb_cursor(self, query: str, hosts):
        return self._get_cursor(query,  self._maildb_connection_provider(self._get_maildb_adaptor_params(hosts)))

    def _commit(self, cursor):
        caller = getframeinfo(stack()[1][0])
        self._log.debug('COMMIT on line ' + str(caller.lineno))
        cursor.connection.commit()

    def _prepare_shard_for_registration(self, shard_id: int):
        for scope in ShardDbAdaptor.CONSTANTS['scopes']:
            ADD_SCOPE_Q = f'''
                INSERT INTO shards.scopes (scope_id, name)
                VALUES ({scope['id']}, '{scope['name']}')
                ON CONFLICT DO NOTHING;
            '''
            with self._get_sharddb_cursor(ADD_SCOPE_Q) as cursor:
                self._commit(cursor)

            reg_weight = ShardDbAdaptor.CONSTANTS['reg_weight']
            ADD_SCOPE_FOR_SHARD_Q = f'''
                INSERT INTO shards.scopes_by_shards (reg_weight, scope_id, shard_id)
                VALUES ({reg_weight}, {scope['id']}, {shard_id})
                ON CONFLICT DO NOTHING;
            '''
            with self._get_sharddb_cursor(ADD_SCOPE_FOR_SHARD_Q) as cursor:
                self._commit(cursor)

    def _get_shard_users(self, host: str):
        GET_USERS_Q = '''
            SELECT uid
            FROM mail.users;
        '''
        return self._get_maildb_cursor(GET_USERS_Q, host)

    def _populate_shard(self, shard_id: int, shard: dict):
        mdb_hosts = [instance['host'] for instance in shard['instances']]
        with self._get_shard_users(mdb_hosts) as get_uids_cursor:
            while True:
                values = _pack_user_values(get_uids_cursor, shard_id)
                if not values:
                    break
                INSERT_ARGUMENTS_Q = f'''
                    INSERT INTO shards.users (uid, shard_id)
                    VALUES {values}
                    ON CONFLICT DO NOTHING;
                '''
                with self._get_sharddb_cursor(INSERT_ARGUMENTS_Q) as insert_cursor:
                    self._commit(insert_cursor)

    def process_shards(self):
        '''If a shard with the given shard_name already exists function will have no effect.'''
        for shard in self._shards:
            shard_name = shard['shard_name']
            load_type = ShardDbAdaptor.CONSTANTS['load_type']
            shard_instances = _get_shard_instances_str(shard)
            ADD_CLUSTER_Q = f'''
                SELECT code.add_cluster(
                    i_shard_name := '{shard_name}',
                    i_load_type := '{load_type}',
                    i_instances := array[ {shard_instances} ]::code.instance[]
                );
            '''
            with self._get_sharddb_cursor(ADD_CLUSTER_Q) as cursor:
                self._commit(cursor)
                shard_id = cursor.fetchone()[0]
                self._log.debug('assigned shard_id: %d', shard_id)
                self._prepare_shard_for_registration(shard_id)
                self._populate_shard(shard_id, shard)


sharddb_adaptor_params = {
    'cluster': 'mdb89eo15641i2ooopcf',
    'dbname': ShardDbAdaptor.CONSTANTS['sharddb_name'],
    'password': None,                   # is taken from the vault secret
    'sslrootcert_path': None,           # is taken from the vault secret
    'user': 'sharpei',
    'vault_secret_version': 'ver-01e734b46mxkw40wgepaxr1kp3',
}


def main():
    logger = get_logger()
    sharddb_adaptor_params['password'], sharddb_adaptor_params['sslrootcert_path'] = \
            prepare_db_auth_data(sharddb_adaptor_params['vault_secret_version'])

    cfg = yaml.safe_load(open('config.yml', 'r').read())
    sharddb = ShardDbAdaptor(cfg['shards'], sharddb_adaptor_params, logger)
    sharddb.process_shards()
