import asyncio
from dataclasses import dataclass
from datetime import timedelta
from psycopg2.extras import RealDictCursor

import aiopg
import psycopg2

from mail.shiva.stages.api.props.services.sharpei import get_shard_dsn

from .task import TaskParams
from .purge_deleted_box import purge as purge_deleted_box
from .purge_backups import purge_backup, Backup
from .cursor_provider import ShardDbCursorProvider
from mail.shiva.stages.api.props.logger import get_uid_logger

log = get_uid_logger(__name__)
PURGE_PROC_TIMEOUT = 3000


@dataclass
class User:
    uid: str = None
    shard_id: int = None
    is_here: bool = None
    is_deleted: bool = None


QUERY_GET_TRANSFERRED_USERS = '''
    SELECT uid, is_here, is_deleted
      FROM mail.users u
     WHERE NOT is_here
       AND purge_date < (now() - %(purge_ttl)s)
       AND uid %% %(jobs_count)s = %(job_no)s
       AND uid > %(last_uid)s
     ORDER BY uid
     LIMIT %(chunk_size)s
'''


QUERY_GET_DELETED_USERS = '''
    SELECT uid, is_here, is_deleted
      FROM mail.users u
     WHERE is_here
       AND is_deleted
       AND purge_date < (now() - %(purge_ttl)s)
       AND uid %% %(jobs_count)s = %(job_no)s
       AND uid > %(last_uid)s
     ORDER BY uid
     LIMIT %(chunk_size)s
'''


async def get_users_for_purge(conn, purge_ttl, chunk_size, jobs_count, job_no, query):
    last_uid = 0
    while True:
        async with conn.cursor() as cur:
            await cur.execute(
                query,
                dict(
                    purge_ttl=purge_ttl,
                    jobs_count=jobs_count,
                    job_no=job_no,
                    last_uid=last_uid,
                    chunk_size=chunk_size,
                )
            )
            chunk = [User(**r) async for r in cur]

            if chunk:
                last_uid = chunk[-1].uid
                yield chunk

            if len(chunk) < chunk_size:
                return


async def apply_purge_proc(conn, uid, purge_proc):
    proc_is_finished = False
    while not proc_is_finished:
        log.info(f'Purge user, proc is: {purge_proc}', uid=uid)
        async with conn.cursor(timeout=PURGE_PROC_TIMEOUT) as cur:
            await cur.execute(f'SELECT {purge_proc}(%(uid)s) as res', dict(uid=uid))
            proc_is_finished = (await cur.fetchone())['res']


async def purge_user(conn, uid):
    async with conn.cursor() as cur:
        await cur.execute('SELECT code.purge_user_steps()::text[] as res')
        purge_procedures = (await cur.fetchone())['res']
    for purge_proc in purge_procedures:
        try:
            await apply_purge_proc(conn, uid, purge_proc)
        except (psycopg2.IntegrityError, psycopg2.InternalError, asyncio.TimeoutError) as exc:
            log.exception(f'Got db exception while purging user, proc {purge_proc}: {exc}', uid=uid)
            return
    log.info('purge_user: successfully purge user', uid=uid)


async def purge_sharddb_deleted_user(conn, uid, shard_id):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
            DELETE FROM shards.deleted_users
             WHERE uid = %(uid)s
               AND shard_id = %(shard_id)s
            ''',
            dict(
                uid=uid,
                shard_id=shard_id,
            )
        )
    log.info('purge_user: successfully purge user from sharddb', uid=uid)


async def get_user_deleted_box(conn, uid, chunk_size):
    while True:
        async with conn.cursor() as cur:
            await cur.execute(
                '''
                SELECT mid
                  FROM mail.deleted_box
                 WHERE uid = %(uid)s
                 LIMIT %(chunk_size)s
                ''',
                dict(
                    uid=uid,
                    chunk_size=chunk_size,
                )
            )
            deleted = [row['mid'] async for row in cur]
            yield deleted
            if len(deleted) < chunk_size:
                break


async def purge_user_deleted_box(conn, uid, deleted_box_chunk_size):
    async for mids in get_user_deleted_box(conn, uid, deleted_box_chunk_size):
        if len(mids) > 0:
            await purge_deleted_box(conn, uid, mids)


async def get_user_backups(conn, uid):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
            SELECT uid, backup_id
              FROM backup.backups
             WHERE uid = %(uid)s
            ''',
            dict(
                uid=uid,
            )
        )
        return [Backup(**r) async for r in cur]


async def purge_user_backups(conn, uid):
    for backup in await get_user_backups(conn, uid):
        await purge_backup(conn, backup)


@dataclass
class UserInSharddb:
    uid: str = None
    shard_id: int = None
    is_deleted: bool = None


async def get_user_from_sharddb(conn, uid):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
            SELECT uid, shard_id, false as is_deleted
              FROM shards.users
             WHERE uid = %(uid)s
            UNION
            SELECT uid, shard_id, true as is_deleted
              FROM shards.deleted_users
             WHERE uid = %(uid)s
            ''',
            dict(
                uid=uid,
            )
        )
        res = await cur.fetchone()
        return res and UserInSharddb(**res)


async def get_user_from_maildb(conn, uid):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
            SELECT uid, is_here, is_deleted
              FROM mail.users u
             WHERE uid = %(uid)s
            ''',
            dict(
                uid=uid,
            )
        )
        res = await cur.fetchone()
        return res and User(**res)


async def need_purge_user(sharddb_conn, sharpei, db_user, user, purge_shard_id, stats):
    user_in_sharddb = await get_user_from_sharddb(sharddb_conn, user.uid)
    user.shard_id = purge_shard_id
    if user_in_sharddb is None:
        log.warning(f'purge_user: skipped, user {user} absent in sharddb', uid=user.uid)
        return False
    if user.is_here and user.is_deleted:
        if purge_shard_id != user_in_sharddb.shard_id or user.is_deleted != user_in_sharddb.is_deleted:
            log.warning(f'purge_user: skipped, mismatch between sharddb {user_in_sharddb} and maildb {user}', uid=user.uid)
            return False
    if not user.is_here:
        if purge_shard_id == user_in_sharddb.shard_id:
            log.warning(f'purge_user: skipped, mismatch between sharddb {user_in_sharddb} and maildb {user}', uid=user.uid)
            return False
        async with aiopg.connect(await get_shard_dsn(sharpei, db_user, user_in_sharddb.shard_id, stats), cursor_factory=RealDictCursor) as conn:
            user_in_maildb = await get_user_from_maildb(conn, user.uid)
            if user_in_maildb is None:
                log.warning(f'purge_user: skipped, sharddb user {user_in_sharddb} absent in new maildb', uid=user.uid)
                return False
            if not user_in_maildb.is_here or user_in_maildb.is_deleted != user_in_sharddb.is_deleted:
                user_in_maildb.shard_id = user_in_sharddb.shard_id
                log.warning(f'purge_user: skipped, mismatch between sharddb {user_in_sharddb} and new maildb {user_in_maildb}', uid=user.uid)
                return False
    return True


@dataclass
class PurgeUserParams(TaskParams):
    purge_ttl: timedelta = timedelta(days=1)
    chunk_size: int = 10000
    deleted_box_chunk_size: int = 1000
    force: bool = False


@dataclass
class PurgeDeletedUserParams(PurgeUserParams):
    task_name: str = 'purge_deleted_user'


@dataclass
class PurgeTransferredUserParams(PurgeUserParams):
    task_name: str = 'purge_transferred_user'


async def shard_purge_deleted_user(params: PurgeDeletedUserParams, stats):
    async with aiopg.connect(await get_shard_dsn(params.sharpei, params.db_user, params.shard_id, stats), cursor_factory=RealDictCursor) as conn:
        sharddb_conn = ShardDbCursorProvider(params.sharpei, stats)
        try:
            async for users_chunk in get_users_for_purge(
                conn=conn,
                purge_ttl=params.purge_ttl,
                jobs_count=params.jobs_count,
                job_no=params.job_no,
                chunk_size=params.chunk_size,
                query=QUERY_GET_DELETED_USERS,
            ):
                for user in users_chunk:
                    if params.force or await need_purge_user(sharddb_conn, params.sharpei, params.db_user, user, params.shard_id, stats):
                        await purge_user_backups(conn, user.uid)
                        await purge_user_deleted_box(conn, user.uid, params.deleted_box_chunk_size)
                        await purge_user(conn, user.uid)
                        await purge_sharddb_deleted_user(sharddb_conn, user.uid, params.shard_id)
        finally:
            await sharddb_conn.close()


async def shard_purge_transferred_user(params: PurgeTransferredUserParams, stats):
    async with aiopg.connect(await get_shard_dsn(params.sharpei, params.db_user, params.shard_id, stats), cursor_factory=RealDictCursor) as conn:
        sharddb_conn = ShardDbCursorProvider(params.sharpei, stats)
        try:
            async for users_chunk in get_users_for_purge(
                conn=conn,
                purge_ttl=params.purge_ttl,
                jobs_count=params.jobs_count,
                job_no=params.job_no,
                chunk_size=params.chunk_size,
                query=QUERY_GET_TRANSFERRED_USERS,
            ):
                for user in users_chunk:
                    if params.force or await need_purge_user(sharddb_conn, params.sharpei, params.db_user, user, params.shard_id, stats):
                        await purge_user(conn, user.uid)
        finally:
            await sharddb_conn.close()
