from contextlib import contextmanager

from ora2pg.tools.find_master_helpers import find_sharddb, get_sharddb_pooled_conn
from ora2pg.transfer_data import DbEndpoint
from ora2pg.sharpei import get_shard_by_uid, get_shard_by_deleted_uid, get_shard_id_by_name

from mail.pypg.pypg.query_conf import load_from_package


class ShardIdResolveException(RuntimeError):
    pass


class GetUserShardIdError(RuntimeError):
    pass


class GetDeletedUserShardIdError(RuntimeError):
    pass


class MoveUserToDeletedError(RuntimeError):
    pass


class SharddbAdaptor(object):
    queries = load_from_package(__package__, __file__)

    def __init__(self, app_args):
        self.sharpei = getattr(app_args, 'sharpei')
        self.sharddb = find_sharddb(app_args)
        self.__pools = dict()

    @contextmanager
    def _sharddb_query(self, query, **query_args):
        with get_sharddb_pooled_conn(self.sharddb, autocommit=True) as conn:
            cur = conn.cursor()
            cur.execute(
                query.query,
                query_args
            )
            yield cur

    def get_shard_id(self, shard_name):
        shard_id = get_shard_id_by_name(self.sharpei, shard_name)
        if shard_id is None:
            raise ShardIdResolveException(f'Cannot find shard by name {shard_name}')
        return shard_id

    def resolve_endpoint_to_id(self, input_db_endpoint):
        db_endpoint = input_db_endpoint
        if db_endpoint.postgre:
            if not db_endpoint.db.isdigit():
                new_endpoint = DbEndpoint.make_pg(self.get_shard_id(db_endpoint.db))
                return new_endpoint
        return db_endpoint

    def resolve_endpoint(self, input_db_endpoint):
        if not isinstance(input_db_endpoint, DbEndpoint):
            input_db_endpoint = DbEndpoint(input_db_endpoint)
        return self.resolve_endpoint_to_id(input_db_endpoint)

    def get_user_shard_id(self, uid):
        try:
            shard = get_shard_by_uid(self.sharpei, uid)
            return shard['id']
        except:
            raise GetUserShardIdError(f'Cannot find user shard, uid: {uid}')

    def get_deleted_user_shard_id(self, uid):
        try:
            shard = get_shard_by_deleted_uid(self.sharpei, uid)
            return shard['id']
        except:
            raise GetDeletedUserShardIdError(f'Cannot find deleted user shard, uid: {uid}')

    def delete_user(self, uid):
        with self._sharddb_query(
            self.queries.move_user_to_deleted,
            uid=uid,
        ) as cur:
            result = cur.fetchone()[0]
            if result != 'success':
                raise MoveUserToDeletedError('Cannot move user (uid:%r) to deleted: %s' % (uid, result))
            return cur.rowcount

    def _get_shard_users(self, shard_name, is_deleted, chunk_size):
        last_uid = 0
        query = self.queries.deleted_users_by_shard if is_deleted else self.queries.users_by_shard
        while True:
            with self._sharddb_query(
                query,
                shard_name=shard_name,
                last_uid=last_uid,
                chunk_size=chunk_size,
            ) as cur:
                uids = [row[0] for row in cur.fetchall()]
                if uids:
                    last_uid = uids[-1]
                    yield uids

                if len(uids) < chunk_size:
                    return

    def get_shard_users(self, shard_name, with_deleted=False, chunk_size=10000):
        for uids in self._get_shard_users(shard_name, False, chunk_size):
            yield uids
        if with_deleted:
            for uids in self._get_shard_users(shard_name, True, chunk_size):
                yield uids
