import logging
import aiopg
from psycopg2.extras import RealDictCursor
from dataclasses import dataclass

from mail.shiva.stages.api.props.services.sharpei import get_shard_dsn
from .task import TaskParams, HuskydbEngine
from .transfer_users import get_shard_users, add_transfer_tasks

log = logging.getLogger(__name__)

DB_AVG_MESSAGE_SIZE = 1800

TRANSFER_USERS_MESSAGE_COUNT = (
    (1000, 5000),
    (5000, 25000),
    (25000, 250000),
    (100, 1000),
    (10, 100),
    (250000, 1000000),
)


async def get_shard_users_for_transfer(conn, transfer_messages_count, users_chunk_size):
    uids = []
    for min_count, max_count in TRANSFER_USERS_MESSAGE_COUNT:
        uids_chunk, selected_users_messages_count = await get_shard_users(
            conn=conn,
            users_chunk_size=users_chunk_size,
            total_count=transfer_messages_count,
            min_count=min_count,
            max_count=max_count-1,
            consider_deleted_messages=True,
        )
        uids.extend(uids_chunk)
        if selected_users_messages_count >= transfer_messages_count:
            break
        transfer_messages_count -= selected_users_messages_count
    return uids


async def get_shard_disk_size(conn, shard_id):
    async with conn.cursor() as cur:
        await cur.execute(
            "SELECT disk_size FROM shiva.shards WHERE shard_id=%(shard_id)s",
            dict(shard_id=shard_id)
        )
        res = await cur.fetchone()
        return res['disk_size'] if res is not None else None


async def get_shard_used_size(conn):
    async with conn.cursor() as cur:
        await cur.execute("SELECT pg_database_size('maildb') as used_size")
        res = await cur.fetchone()
        return res['used_size'] if res is not None else None


async def get_shard_bloat_size(conn):
    async with conn.cursor() as cur:
        await cur.execute("SELECT sum(GREATEST(bloat_size_bytes, 0)) AS bloat_size FROM code.get_heap_bloat_info()")
        res = await cur.fetchone()
        return res['bloat_size'] if res is not None else None


@dataclass
class SpaceBalancerParams(TaskParams):
    task_name: str = 'space_balancer'
    huskydb: HuskydbEngine = None
    transfer_ratio: float = 0.08
    db_used_ratio: float = 0.87
    db_used_without_bloat_ratio: float = 0.80
    users_chunk_size: int = 1000
    task_args: str = None
    transfer_priority: int = -11
    load_type: str = 'dbaas_hot'
    shards_count: int = 5


async def shard_space_balancer(params: SpaceBalancerParams, stats):
    async with aiopg.connect(await get_shard_dsn(params.sharpei, params.db_user, params.shard_id, stats), cursor_factory=RealDictCursor) as maildb_conn:
        async with params.huskydb.connection() as shivadb_conn:
            used_size = await get_shard_used_size(maildb_conn)
            bloat_size = await get_shard_bloat_size(maildb_conn)
            used_without_bloat_size = used_size - bloat_size
            total_size = await get_shard_disk_size(shivadb_conn, params.shard_id)
            if not total_size or used_size is None:
                log.info("Can't retrive shard size")
                return
            log.info(f"DB used size {used_size}, bloat size {bloat_size}, total size {total_size}")
            transfer_messages_count = int(params.transfer_ratio * total_size / DB_AVG_MESSAGE_SIZE)
            if used_size > params.db_used_ratio * total_size or used_without_bloat_size > params.db_used_without_bloat_ratio * total_size:
                uids = await get_shard_users_for_transfer(maildb_conn, transfer_messages_count, params.users_chunk_size)
                await add_transfer_tasks(
                    conn=shivadb_conn,
                    from_db=params.shard_id,
                    load_type=params.load_type,
                    shards_count=params.shards_count,
                    extra_task_args=params.task_args,
                    priority=params.transfer_priority,
                    users_chunk_size=params.users_chunk_size,
                    uids=uids,
                )
