import logging
from dataclasses import dataclass

import aiopg
from psycopg2.extras import RealDictCursor

from mail.shiva.stages.api.props.services.sharpei import get_shard_dsn, get_sharddb_dsn
from .task import TaskParams, HuskydbEngine
from .export_helper import randomize_start_time

log = logging.getLogger(__name__)


async def close_shard_for_registration(conn, shard_id):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
            UPDATE shards.shards
               SET reg_weight = 0
             WHERE shard_id = %(shard_id)s
               AND reg_weight > 0
            ''',
            dict(
                shard_id=shard_id,
            )
        )
        if cur.rowcount:
            log.info(f'Shard {shard_id} closed for registration')


async def close_shard_for_transfer(conn, shard_id):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
            UPDATE shiva.shards
               SET can_transfer_to = 'f'
             WHERE shard_id = %(shard_id)s
               AND can_transfer_to
            ''',
            dict(
                shard_id=shard_id,
            )
        )
        if cur.rowcount:
            log.info(f'Shard {shard_id} closed for transfer')


@dataclass
class ShivaShard:
    disk_size: int
    can_transfer_to: bool


async def update_shard_used_size(conn, shard_id, used_size):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
            UPDATE shiva.shards
               SET used_size = %(used_size)s
             WHERE shard_id = %(shard_id)s
             RETURNING disk_size, can_transfer_to
            ''',
            dict(
                shard_id=shard_id,
                used_size=used_size,
            )
        )
        res = await cur.fetchone()
        return res and ShivaShard(**res)


async def get_shard_used_size(conn):
    async with conn.cursor() as cur:
        await cur.execute("select (pg_database_size('maildb') + wal_size) as used_size from (select sum(size) as wal_size from pg_ls_waldir()) as wal")
        res = await cur.fetchone()
        return res['used_size'] if res is not None else None


@dataclass
class CloseForLoadParams(TaskParams):
    task_name: str = 'close_for_load'
    db_user: str = 'maildb'
    max_transfer_ratio: float = 0.75
    max_registration_ratio: float = 0.75
    huskydb: HuskydbEngine = None
    max_delay: int = 600


def is_valid_shiva_shard(shard):
    return shard is not None and shard.disk_size > 0


async def shard_close_for_load(params: CloseForLoadParams, stats):
    await randomize_start_time(max_delay=params.max_delay)

    async with aiopg.connect(await get_shard_dsn(params.sharpei, params.db_user, params.shard_id, stats), cursor_factory=RealDictCursor) as maildb_conn:
        used_size = await get_shard_used_size(maildb_conn)
    if used_size is None:
        log.info("Can't retrive shard size")
        return

    async with params.huskydb.connection() as huskydb_conn:
        shiva_shard = await update_shard_used_size(huskydb_conn, params.shard_id, used_size)
        if not is_valid_shiva_shard(shiva_shard):
            log.info(f'bad shard {params.shard_id} in shiva')
            return
        total_size = shiva_shard.disk_size
        log.info(f"DB used size {used_size}, total size {total_size}")

        if shiva_shard.can_transfer_to and used_size > params.max_transfer_ratio * total_size:
            await close_shard_for_transfer(huskydb_conn, params.shard_id)

        if used_size > params.max_registration_ratio * total_size:
            async with aiopg.connect(await get_sharddb_dsn(params.sharpei, 'sharpei', stats)) as sharddb_conn:
                await close_shard_for_registration(sharddb_conn, params.shard_id)
