from dataclasses import dataclass
from datetime import timedelta

from .export_helper import randomize_start_time
from .cursor_provider import create_cursor_provider
from .task import TaskParams
from mail.shiva.stages.api.props.logger import get_uid_logger
from mail.shiva.stages.api.props.services.archive import ArchiveStorage
from mail.shiva.stages.api.settings.s3api import S3ApiSettings
from mail.shiva.stages.api.settings.tvm import TvmSettings
from mail.python.theatre.detail.tvm import TvmServiceTickets

log = get_uid_logger(__name__)
CLEAN_STIDS_TIMEOUT = 600


async def get_next_user_for_clean(conn, clean_ttl, uid):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
                SELECT uid
                  FROM mail.users mu LEFT JOIN mail.archives ma USING (uid)
                 WHERE is_here
                   AND mu.state in ('active', 'deleted')
                   AND ma.state  = 'cleaning_in_progress'
                   AND mu.last_state_update < (now() - %(clean_ttl)s)
                   AND ma.updated < (now() - %(clean_ttl)s)
                   AND uid > %(uid)s
                 ORDER BY uid
                 LIMIT 1
            ''',
            dict(
                clean_ttl=clean_ttl,
                uid=uid,
            )
        )
        res = await cur.fetchone()
        return res and res['uid']


async def delete_from_archive(conn, uid):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
            DELETE FROM mail.archives
            WHERE uid = %(uid)s
            ''',
            dict(
                uid=uid,
            )
        )


async def clean_stids(conn, uid, stids_to_clean):
    if not stids_to_clean:
        return
    async with conn.cursor(timeout=CLEAN_STIDS_TIMEOUT) as cur:
        await cur.execute(
            '''
            SELECT * FROM code.add_to_storage_delete_queue(%(uid)s, %(stids)s)
            ''',
            dict(
                uid=uid,
                stids=stids_to_clean,
            )
        )


def get_stids_to_clean(messages):
    return [mess['st_id'] for mess in messages if not mess['is_shared']]


async def clean_user_archive(conn, archive_storage, uid):
    try:
        user_keys = await archive_storage.list_user_objects(uid)
        for key in user_keys:
            messages = await archive_storage.get_messages(key)
            stids_to_clean = get_stids_to_clean(messages)
            await clean_stids(conn, uid, stids_to_clean)
            await archive_storage.delete_objects([key])

        await delete_from_archive(conn, uid)
        log.info('successfully clean user archive', uid=uid)
    except Exception as exc:
        log.exception(f'Got exception: {exc}', uid=uid)


@dataclass
class CleanArchivesParams(TaskParams):
    task_name: str = 'clean_archives'
    clean_ttl: timedelta = timedelta(days=10)
    tvm_ids: TvmSettings = None
    tvm_tickets: TvmServiceTickets = None
    s3api_settings: S3ApiSettings = None
    max_delay: int = 3600


async def shard_clean_archives(params: CleanArchivesParams, stats):
    await randomize_start_time(max_delay=params.max_delay)

    archive_storage = ArchiveStorage(
        s3api_settings=params.s3api_settings,
        s3_id=params.tvm_ids.s3_id,
        tvm=params.tvm_tickets,
        stats=stats,
    )

    async with create_cursor_provider(params, stats) as conn:
        uid = 0
        while True:
            uid = await get_next_user_for_clean(
                conn=conn,
                clean_ttl=params.clean_ttl,
                uid=uid,
            )
            if not uid:
                break
            await clean_user_archive(conn, archive_storage, uid)
