import random
from dataclasses import dataclass, asdict
from datetime import timedelta
import aiopg

from .helpers import db_general_retry
from .task import TaskParams, HuskydbEngine
from .export_helper import randomize_start_time
from .freeze_helper import log_update_user_state, is_live_user, UserState
from .cursor_provider import create_cursor_provider, locked_transactional_cursor, ShardDbCursorProvider
from .transfer_users import get_open_shards
from mail.python.theatre.detail.tvm import TvmServiceTickets
from mail.shiva.stages.api.props.services.sharpei import get_shard_dsn
from mail.shiva.stages.api.props.logger import get_uid_logger
from mail.shiva.stages.api.props.services.archive import ArchiveStorage
from mail.shiva.stages.api.settings.s3api import S3ApiSettings
from mail.shiva.stages.api.settings.tvm import TvmSettings

log = get_uid_logger(__name__)


@dataclass
class User:
    uid: int = None
    archivation_state: str = None


async def get_next_user(conn, state_ttl, uid):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
                SELECT uid, ma.state as archivation_state
                  FROM mail.users mu LEFT JOIN mail.archives ma USING (uid)
                 WHERE is_here
                   AND NOT is_deleted
                   AND mu.state = 'frozen'
                   AND (ma.state is NULL or ma.state = 'archivation_in_progress')
                   AND mu.last_state_update < (now() - %(state_ttl)s)
                   AND uid > %(uid)s
                   AND NOT EXISTS (SELECT uid FROM mail.deleted_box mdb WHERE mdb.uid = mu.uid LIMIT 1)
                   AND NOT EXISTS (SELECT uid FROM mail.storage_delete_queue msdq WHERE msdq.uid = mu.uid LIMIT 1)
                 ORDER BY uid
                 LIMIT 1
            ''',
            dict(
                state_ttl=state_ttl,
                uid=uid,
            )
        )
        res = await cur.fetchone()
        return res and User(**res)


@dataclass
class Message:
    mid: int = None
    st_id: str = None
    folder_type: str = None
    received_date: int = None
    is_shared: bool = None

    def saved_data(self) -> dict:
        data = asdict(self)
        del data['mid']
        return data


async def get_user_messages_chunk(conn, uid, last_mid, chunk_size):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
                SELECT mid, st_id,
                       type as folder_type,
                       impl.is_shared_message(attributes) as is_shared,
                       TRUNC(EXTRACT(EPOCH FROM received_date))::bigint as received_date
                  FROM mail.box
                  JOIN mail.messages USING (uid, mid)
                  JOIN mail.folders USING (uid, fid)
                 WHERE uid = %(uid)s
                   AND mid > %(last_mid)s
                 ORDER BY mid
                 LIMIT %(chunk_size)s
            ''',
            dict(
                uid=uid,
                last_mid=last_mid,
                chunk_size=chunk_size,
            )
        )
        return [Message(**rec) async for rec in cur]


async def get_user_messages(conn, uid, last_mid, chunk_size):
    while True:
        chunk = await get_user_messages_chunk(conn, uid, last_mid, chunk_size)

        if chunk:
            last_mid = chunk[-1].mid
            yield chunk

        if len(chunk) < chunk_size:
            return


async def start_archivation(conn, uid):
    async with locked_transactional_cursor(conn, uid) as cur:
        await cur.execute(
            '''
                INSERT INTO mail.archives (uid, state)
                VALUES (%(uid)s, 'archivation_in_progress')
            ''',
            dict(
                uid=uid,
            )
        )


async def fail_archivation(conn, uid, reason):
    async with locked_transactional_cursor(conn, uid) as cur:
        await cur.execute(
            '''
                INSERT INTO mail.archives
                    (uid, state, notice)
                VALUES
                    (%(uid)s, 'archivation_error', %(reason)s)
                ON CONFLICT (uid) DO UPDATE
                    SET state = 'archivation_error',
                        notice = %(reason)s,
                        updated = now()
            ''',
            dict(
                uid=uid,
                reason=reason,
            )
        )
        log.info(f'user was skipped, because {reason}', uid=uid)


async def get_user_shard_id(sharddb_conn, uid):
    async with sharddb_conn.cursor() as cur:
        await cur.execute(
            '''
            SELECT shard_id
              FROM shards.users
             WHERE uid = %(uid)s
            ''',
            dict(
                uid=uid,
            )
        )
        res = await cur.fetchone()
        return res and res['shard_id']


@db_general_retry
async def update_shard_id(sharddb_conn, uid, old_shard_id, new_shard_id):
    async with sharddb_conn.cursor() as cur:
        await cur.execute(
            '''
            UPDATE shards.users
               SET shard_id = %(new_shard_id)s
             WHERE uid = %(uid)s
               AND shard_id = %(old_shard_id)s
            RETURNING *
            ''',
            dict(
                uid=uid,
                new_shard_id=new_shard_id,
                old_shard_id=old_shard_id,
            )
        )
        return await cur.fetchone() is not None


@dataclass
class UsersData(UserState):
    revision: int = None
    message_count: int = None


async def get_user_data(cursor, uid):
    await cursor.execute(
        '''
            SELECT uid, state, is_here, is_deleted, notifies_count,
            (SELECT sum(message_count) from mail.folders mf where mu.uid = mf.uid) as message_count,
            code.acquire_current_revision(uid) as revision
              FROM mail.users mu
             WHERE uid = %(uid)s
        ''',
        dict(uid=uid)
    )
    res = await cursor.fetchone()
    return res and UsersData(**res)


async def mark_user_as_moved_from_here(cursor, uid):
    await cursor.execute(
        '''UPDATE mail.users
              SET is_here=false,
                  purge_date=current_timestamp
            WHERE uid=%(uid)s
              AND is_here
        ''',
        dict(uid=uid)
    )


async def is_user_in_shard(conn, uid):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
                SELECT uid
                  FROM mail.users
                 WHERE uid = %(uid)s
            ''',
            dict(
                uid=uid,
            )
        )
        return await cur.fetchone() is not None


async def get_open_shards_for_move(params):
    async with params.huskydb.connection() as huskydb_conn:
        shard_ids = await get_open_shards(huskydb_conn, params.shard_id, params.load_type, shards_count=100)
        if not shard_ids:
            raise RuntimeError(f'No available open shards for move archived user with load_type {params.load_type}')
        random.shuffle(shard_ids)
        return shard_ids


async def create_archived_user(conn, uid, revision, message_count):
    async with conn.cursor() as cur:
        await cur.execute(
            'SELECT * FROM code.create_archived_user(%(uid)s::code.uid, %(revision)s::bigint, %(message_count)s::integer)',
            dict(
                uid=uid,
                revision=revision,
                message_count=message_count,
            )
        )


async def move_archived_user(params, uid, revision, message_count, open_shard_ids, stats):
    for new_shard_id in open_shard_ids:
        async with aiopg.connect(await get_shard_dsn(params.sharpei, params.db_user, new_shard_id, stats)) as conn:
            if not await is_user_in_shard(conn, uid):
                await create_archived_user(conn, uid, revision, message_count)
                return new_shard_id
    raise RuntimeError(f'No appropriate shard for move archived user with load_type {params.load_type}')


async def archive_user(conn, sharddb_conn, params, user, archive_storage, messages_chunk_size, open_shard_ids, stats):
    try:
        last_archived_mid = 0
        if user.archivation_state is None:
            await start_archivation(conn, user.uid)
        else:
            last_archived_mid = await archive_storage.get_last_saved_mid(user.uid)

        user_shard_id = await get_user_shard_id(sharddb_conn, user.uid)
        if user_shard_id is None or user_shard_id != params.shard_id:
            await fail_archivation(conn, user.uid, 'user in different shard')
            return

        async for messages in get_user_messages(
            conn=conn,
            uid=user.uid,
            last_mid=last_archived_mid,
            chunk_size=messages_chunk_size,
        ):
            if len(messages) > 0:
                await archive_storage.save_messages(user.uid, messages)

        async with locked_transactional_cursor(conn, user.uid) as cur:
            user_data = await get_user_data(cur, user.uid)
            if not is_live_user(user_data) or user_data.state != 'frozen':
                log.info('user was skipped, because his state was changed', uid=user.uid)
                return

            new_shard_id = await move_archived_user(params, user.uid, user_data.revision, user_data.message_count, open_shard_ids, stats)
            if not await update_shard_id(sharddb_conn, user.uid, params.shard_id, new_shard_id):
                await fail_archivation(conn, user.uid, 'user in different shard')
                return
            log_update_user_state(user_data, 'archived', 0)
            await mark_user_as_moved_from_here(cur, user.uid)
    except Exception as exc:
        log.exception(f'Got exception: {exc}', uid=user.uid)
        return
    log.info('user was successfully archived', uid=user.uid)


@dataclass
class ArchiveUsersParams(TaskParams):
    task_name: str = 'archive_users'
    db_user: str = 'sharpei'
    state_ttl: timedelta = timedelta(days=190)
    messages_chunk_size: int = 10000
    max_users_count: int = 10000
    load_type: str = 'dbaas_hot'
    max_delay: int = 5 * 3600
    tvm_ids: TvmSettings = None
    tvm_tickets: TvmServiceTickets = None
    s3api_settings: S3ApiSettings = None
    huskydb: HuskydbEngine = None


async def shard_archive_users(params: ArchiveUsersParams, stats):
    await randomize_start_time(max_delay=params.max_delay)
    archive_storage = ArchiveStorage(
        s3api_settings=params.s3api_settings,
        s3_id=params.tvm_ids.s3_id,
        tvm=params.tvm_tickets,
        stats=stats,
    )

    open_shard_ids = await get_open_shards_for_move(params)
    sharddb_conn = ShardDbCursorProvider(params.sharpei, stats)
    try:
        async with create_cursor_provider(params, stats) as conn:
            user = User(uid=0)
            for _ in range(params.max_users_count):
                user = await get_next_user(
                    conn=conn,
                    state_ttl=params.state_ttl,
                    uid=user.uid,
                )
                if not user:
                    break
                await archive_user(
                    conn=conn,
                    sharddb_conn=sharddb_conn,
                    params=params,
                    user=user,
                    archive_storage=archive_storage,
                    messages_chunk_size=params.messages_chunk_size,
                    open_shard_ids=open_shard_ids,
                    stats=stats,
                )
    finally:
        await sharddb_conn.close()
