from dataclasses import dataclass

from mail.shiva.stages.api.roles.shard_worker import TaskStats
from mail.shiva.stages.api.props.errors import BadRequest
from mail.shiva.stages.api.props.shard.cursor_provider import create_cursor_provider
from mail.shiva.stages.api.settings.freeze_settings import FreezeSettings
from mail.shiva.stages.api.props.shard.task import HuskydbEngine
from mail.shiva.stages.api.props.shard.purge_user import purge_user
from mail.shiva.stages.api.props.shard.task import TaskParams


@dataclass
class PurgeTransferredUserParams(TaskParams):
    uid: str = None
    freeze_settings: FreezeSettings = None
    load_type: str = 'dbaas_hot'
    huskydb: HuskydbEngine = None
    limit: int = 5

    def check_uid_allowed(self):
        return self.freeze_settings.allowed_uids is None or self.uid in self.freeze_settings.allowed_uids


async def is_user_in_shard(conn, uid, is_here=False):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
                SELECT uid
                  FROM mail.users
                 WHERE uid = %(uid)s
                   AND is_here = %(is_here)s
                 LIMIT 1
            ''',
            dict(
                uid=uid,
                is_here=is_here
            )
        )
        return await cur.fetchone() is not None


async def get_shard_ids(params):
    shard_ids = []
    async with params.huskydb.connection() as huskydb_conn:
        async with huskydb_conn.cursor() as cur:
            await cur.execute(
                '''
                SELECT shard_id
                  FROM shiva.shards
                 WHERE can_transfer_to
                   AND load_type = %(load_type)s
                 LIMIT %(limit)s
                ''',
                dict(
                    load_type=params.load_type,
                    limit=params.limit
                )
            )
            shard_ids = [rec['shard_id'] async for rec in cur]
    return shard_ids


async def purge_transferred_user(params: PurgeTransferredUserParams):
    if not params.check_uid_allowed():
        raise BadRequest('uid not allowed to be modified')

    stats = TaskStats('util_purge_transferred_user')
    shard_ids = await get_shard_ids(params)

    for shard_id in shard_ids:
        params.shard_id = shard_id
        async with create_cursor_provider(params, stats) as conn:
            if not await is_user_in_shard(conn, params.uid, is_here=False):
                continue
            await purge_user(conn=conn, uid=params.uid)
