import yt.wrapper as yt

import asyncio
import tenacity

from dataclasses import dataclass
from mail.shiva.stages.api.props.services.archive import ArchiveStorage
from mail.shiva.stages.api.props.logger import get_uid_logger
from .export_helper import create_yt_table, write_data_to_yt
from .cursor_provider import create_cursor_provider

log = get_uid_logger(__name__)

YT_DELETED_STIDS_TABLE = '//home/mail-logs/core/mdb/mds/deleted_stids'
YT_ARCHIVED_STIDS_TABLE = '//home/mail-logs/core/mdb/mds/archived_stids'
STIDS_CHUNK_SIZE = 1000000
USERS_CHUNK_SIZE = 30000
PG_TIMEOUT = 600
SLEEP_ON_PG_FAILS = 120


@dataclass
class DeletedMessages:
    uid: int = None
    mid: int = None
    st_id: str = None


async def get_deleted_messages(conn):
    uid = 0
    mid = 0
    while True:
        async with conn.cursor(timeout=PG_TIMEOUT) as cur:
            await cur.execute(
                '''
                SELECT uid, mid, st_id
                  FROM mail.deleted_box m JOIN
                       mail.messages db USING (uid, mid) JOIN
                       mail.users u using(uid)
                 WHERE (uid, mid) > (%(uid)s, %(mid)s)
                   AND u.is_here
                   AND 'mulca-shared' <> ALL(attributes)
                 ORDER BY uid, mid
                 LIMIT %(chunk_size)s
                ''',
                dict(
                    uid=uid,
                    mid=mid,
                    chunk_size=STIDS_CHUNK_SIZE,
                )
            )
            chunk = [DeletedMessages(**rec) async for rec in cur]

            if chunk:
                uid = chunk[-1].uid
                mid = chunk[-1].mid
                yield chunk
            else:
                return


async def get_archived_users(conn, jobs_count, job_no):
    uid = 0
    while True:
        try:
            async with conn.cursor(timeout=PG_TIMEOUT) as cur:
                await cur.execute(
                    '''
                        SELECT uid
                          FROM mail.users mu JOIN mail.archives ma USING (uid)
                         WHERE is_here
                           AND mu.state = 'archived'
                           AND ma.state  = 'archivation_complete'
                           AND uid > %(uid)s
                           AND uid %% %(jobs_count)s = %(job_no)s
                         ORDER BY uid
                         LIMIT %(chunk_size)s
                    ''',
                    dict(
                        uid=uid,
                        chunk_size=USERS_CHUNK_SIZE,
                        jobs_count=jobs_count,
                        job_no=job_no,
                    )
                )
                chunk = [r['uid'] async for r in cur]

                if chunk:
                    uid = chunk[-1]
                    yield chunk
                else:
                    return
        except Exception as exc:
            log.exception(f'Got exception: {exc}', uid=uid)
            await asyncio.sleep(SLEEP_ON_PG_FAILS)


async def get_archived_stids(archive_storage, uid):
    try:
        user_keys = await archive_storage.list_user_objects(uid)
        for key in user_keys:
            messages = await archive_storage.get_messages(key)
            yield [mess['st_id'] for mess in messages if not mess['is_shared']]
    except Exception as exc:
        log.exception(f'Got exception: {exc}', uid=uid)


async def write_stids_to_yt_buffered(yt_client, yt_table_name, stids, buffer, force):
    data = [{'stid': st_id} for st_id in stids]
    buffer.extend(data)
    buffer_length = len(buffer)
    if buffer_length > 0 and (force or buffer_length >= STIDS_CHUNK_SIZE):
        await write_data_to_yt(yt_client, yt_table_name, buffer)
        buffer.clear()
        return True
    return False


@tenacity.retry(reraise=True, wait=tenacity.wait_fixed(1), stop=tenacity.stop_after_attempt(5))
async def refresh_tvm(archive_storage):
    await archive_storage.refresh_ticket()


async def export_deleted_stids(params, stats):
    yt_client = yt.YtClient(**params.settings.yt.yt_config)
    yt_table_name = f'{YT_DELETED_STIDS_TABLE}/{params.shard_id}'
    await create_yt_table(yt_client, yt_table_name, auto_fields=set(('stid',)))

    async with create_cursor_provider(params, stats) as conn:
        async for messages in get_deleted_messages(conn):
            data = [{'stid': msg.st_id} for msg in messages]
            await write_data_to_yt(yt_client, yt_table_name, data)


async def export_archive_stids(params, stats):
    yt_client = yt.YtClient(**params.settings.yt.yt_config)
    yt_table_name = f'{YT_ARCHIVED_STIDS_TABLE}/{params.shard_id}_{params.job_no}'
    await create_yt_table(yt_client, yt_table_name, auto_fields=set(('stid',)))
    archive_storage = ArchiveStorage(
        s3api_settings=params.settings.s3api,
        s3_id=params.settings.tvm.s3_id,
        tvm=params.tvm_tickets,
        stats=stats,
        auto_refresh_tvm_ticket=False,
    )

    buffer = []
    async with create_cursor_provider(params, stats) as conn:
        async for uids in get_archived_users(conn, jobs_count=params.jobs_count, job_no=params.job_no):
            await refresh_tvm(archive_storage)
            for uid in uids:
                async for stids in get_archived_stids(archive_storage, uid):
                    await write_stids_to_yt_buffered(yt_client, yt_table_name, stids, buffer, False)

        await write_stids_to_yt_buffered(yt_client, yt_table_name, [], buffer, True)
