import logging
from dataclasses import dataclass

import random
import aiopg
import ujson
from psycopg2.extras import RealDictCursor

from .helpers import chunks
from mail.shiva.stages.api.props.services.sharpei import get_shard_dsn
from .task import TaskParams, HuskydbEngine

log = logging.getLogger(__name__)

PG_TIMEOUT = 600
EMPTY_USER_MESSAGES_EQUIVALENT = 5


@dataclass
class User:
    uid: str = None
    messages_count: int = None


def get_task_args(from_db, to_db, extra_args):
    task_args = ujson.loads(extra_args) if extra_args else {}
    task_args.update({
        'from_db': f'{from_db}',
        'to_db': f'{to_db}',
    })
    return ujson.dumps(task_args)


async def get_open_shards(conn, excl_shard_id, load_type, shards_count):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
            SELECT shard_id from shiva.shards
             WHERE can_transfer_to
               AND load_type = %(load_type)s
               AND shard_id != %(shard_id)s
            ''',
            dict(
                load_type=load_type,
                shard_id=excl_shard_id,
            )
        )
        shard_ids = [rec['shard_id'] async for rec in cur]
        if shards_count > len(shard_ids):
            return shard_ids
        return random.sample(shard_ids, shards_count)


async def add_transfer_tasks(conn, from_db, load_type, shards_count, extra_task_args, priority, users_chunk_size, uids):
    shard_ids = await get_open_shards(conn, from_db, load_type, shards_count)
    if not shard_ids:
        raise RuntimeError(f'No aviable open shards to transfer with load_type {load_type}')

    shard_chunk_size = int(len(uids) / len(shard_ids)) + 1
    for uids_chunk, to_db in zip(chunks(uids, shard_chunk_size), shard_ids):
        task_args = get_task_args(from_db, to_db, extra_task_args)
        await add_transfer_tasks_to_shard(conn, to_db, task_args, priority, users_chunk_size, uids_chunk)


async def add_transfer_tasks_to_shard(conn, shard_id, task_args, priority, users_chunk_size, uids):
    for uids_chunk in chunks(uids, users_chunk_size):
        async with conn.cursor() as cur:
            await cur.execute(
                '''
                INSERT INTO transfer.users_in_dogsleds(uid, priority, task, task_args, shard_id)
                VALUES(UNNEST(%(uids)s), %(priority)s, 'transfer', %(task_args)s, %(shard_id)s)
                ON CONFLICT DO NOTHING
                ''',
                dict(
                    uids=uids_chunk,
                    priority=priority,
                    task_args=task_args,
                    shard_id=shard_id,
                )
            )
            log.info(f'Added {len(uids_chunk)} users to transfer to shard {shard_id}')


async def get_shard_users_by_messages_count(conn, start_uid, users_count, min_count, max_count):
    async with conn.cursor(timeout=PG_TIMEOUT) as cur:
        await cur.execute(
            '''
            SELECT uid, sum(message_count) as messages_count from mail.folders join mail.users using (uid)
             WHERE uid > %(start_uid)s
               AND is_here
             GROUP BY uid
            HAVING sum(message_count) between %(min_count)s AND %(max_count)s
             ORDER BY uid
             LIMIT %(users_count)s
            ''',
            dict(
                start_uid=start_uid,
                users_count=users_count,
                min_count=min_count,
                max_count=max_count,
            )
        )
        return [User(**rec) async for rec in cur]


async def get_shard_users_deleted_messages_count(conn, users):
    async with conn.cursor(timeout=PG_TIMEOUT) as cur:
        await cur.execute(
            '''
            SELECT uid, count(*) as messages_count from mail.deleted_box
             WHERE uid = ANY(%(uids)s::bigint[])
             GROUP BY uid
            ''',
            dict(
                uids=[user.uid for user in users],
            )
        )
        return {rec['uid']: rec['messages_count'] async for rec in cur}


def get_user_deleted_messages_count(deleted_messages, uid):
    return deleted_messages[uid] if uid in deleted_messages else 0


async def get_shard_users(conn, users_chunk_size, total_count, min_count, max_count, consider_deleted_messages):
    uids = []
    selected_users_messages_count = 0
    start_uid = 0
    has_users = True
    while selected_users_messages_count < total_count and has_users:
        has_users = False
        users = await get_shard_users_by_messages_count(conn, start_uid, users_chunk_size, min_count, max_count)
        deleted_messages = {}
        if consider_deleted_messages:
            deleted_messages = await get_shard_users_deleted_messages_count(conn, users)
        for user in users:
            has_users = True
            uids.append(user.uid)
            selected_users_messages_count += EMPTY_USER_MESSAGES_EQUIVALENT + user.messages_count + get_user_deleted_messages_count(deleted_messages, user.uid)
            if selected_users_messages_count > total_count:
                break
            start_uid = user.uid
    return uids, selected_users_messages_count


@dataclass
class TransferUsersParams(TaskParams):
    task_name: str = 'transfer_users'
    huskydb: HuskydbEngine = None
    messages_count: int = 1000000
    min_messages_per_user: int = 0
    max_messages_per_user: int = 1000
    users_chunk_size: int = 10000
    task_args: str = None
    transfer_priority: int = -11
    load_type: str = 'dbaas_hot'
    shards_count: int = 5


async def shard_transfer_users(params: TransferUsersParams, stats):
    async with aiopg.connect(await get_shard_dsn(params.sharpei, params.db_user, params.shard_id, stats), cursor_factory=RealDictCursor) as maildb_conn:
        uids, _ = await get_shard_users(
            conn=maildb_conn,
            users_chunk_size=params.users_chunk_size,
            total_count=params.messages_count,
            min_count=params.min_messages_per_user,
            max_count=params.max_messages_per_user,
            consider_deleted_messages=False,
        )
        async with params.huskydb.connection() as huskydb_conn:
            await add_transfer_tasks(
                conn=huskydb_conn,
                from_db=params.shard_id,
                load_type=params.load_type,
                shards_count=params.shards_count,
                extra_task_args=params.task_args,
                priority=params.transfer_priority,
                users_chunk_size=params.users_chunk_size,
                uids=uids,
            )
