import yt.wrapper as yt

import logging
import datetime
import aiopg
from psycopg2.extras import RealDictCursor
from dataclasses import dataclass

from mail.shiva.stages.api.props.services.sharpei import get_shard_dsn
from .task import TaskParams
from .export_helper import (
    create_yt_table,
    write_data_to_yt,
    get_start_id,
    randomize_start_time,
    ExportParams,
)

log = logging.getLogger(__name__)


@dataclass
class Table:
    name: str = None
    use_full: bool = None


async def get_estimation(conn, tables, main_table, total_table):
    get_count = "(SELECT reltuples FROM pg_class WHERE oid = '{table}'::regclass)"
    get_size = "(pg_total_relation_size('{table}'))::float"

    full_tables_size = '+'.join([get_size.format(table=t.name) for t in tables if t.use_full])
    not_full_tables_size = '+'.join([get_size.format(table=t.name) for t in tables if not t.use_full])
    async with conn.cursor() as cur:
        query = '''
            SELECT ({full_size}) as full_size,
                   ({not_full_size}) as not_full_size,
                   ({main_count}) as main_count,
                   ({total_count}) as total_count
        '''.format(
            full_size=full_tables_size,
            not_full_size=not_full_tables_size,
            main_count=get_count.format(table=main_table),
            total_count=get_count.format(table=total_table),
        )
        await cur.execute(query)
        async for r in cur:
            main_count = int(r['main_count'])
            total_count = int(r['total_count'])
            if main_count == 0 or total_count == 0:
                return 0
            return int((r['full_size'] / main_count) + (r['not_full_size'] / total_count))


async def get_deleted_estimation(conn):
    tables = [
        Table('mail.deleted_box', use_full=True),
        Table('mail.messages', use_full=False),
        Table('mail.message_references', use_full=False),
        Table('mail.windat_messages', use_full=False),
    ]
    return await get_estimation(conn, tables, main_table='mail.deleted_box', total_table='mail.messages')


async def get_message_estimation(conn):
    tables = [
        Table('mail.box', use_full=True),
        Table('mail.threads', use_full=True),
        Table('mail.threads_hashes', use_full=True),
        Table('mail.pop3_box', use_full=True),
        Table('mail.messages', use_full=False),
        Table('mail.message_references', use_full=False),
        Table('mail.windat_messages', use_full=False),
        Table('mailish.messages', use_full=True),
    ]
    return await get_estimation(conn, tables, main_table='mail.box', total_table='mail.messages')


@dataclass
class SizeEstimation:
    message: int = 0
    deleted: int = 0


async def get_size_estimations(conn):
    estimation = SizeEstimation()
    estimation.message = await get_message_estimation(conn)
    estimation.deleted = await get_deleted_estimation(conn)
    log.info(f'successfully got shard estimations {estimation}')
    return estimation


async def create_yt_deleted_table(yt_client, table):
    fields = {
        'from_shard_id': 'uint64',
        'message_count': 'uint64',
        'approx_bytes_in_db': 'uint64',
    }
    await create_yt_table(yt_client, table, fields=fields)


async def create_yt_mailbox_table(yt_client, table):
    fields = {
        # mailbox data
        'uid': 'uint64',
        'from_shard_id': 'uint64',
        'here_since': 'timestamp',
        'folders_count': 'uint64',
        'message_count': 'uint64',
        'message_size': 'uint64',
        'attach_count': 'uint64',
        'attach_size': 'uint64',
        'approx_bytes_in_db': 'uint64',
        # filter actions
        'move': 'uint64',
        'delete': 'uint64',
        'movel': 'uint64',
        'status': 'uint64',
        'forward': 'uint64',
        'forwardwithstore': 'uint64',
        'reply': 'uint64',
        'notify': 'uint64',
    }
    await create_yt_table(yt_client, table, fields=fields)


async def read_deleted_count(conn):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
            SELECT (pg_stat_get_live_tuples('mail.deleted_box'::regclass::oid)) as count
            '''
        )
        async for r in cur:
            return int(r['count'])


@dataclass
class UserData:
    uid: int = None
    here_since: datetime.datetime = None
    folders_count: int = None
    message_count: int = None
    message_size: int = None
    attach_count: int = None
    attach_size: int = None
    actions: dict = None


async def read_users_mailbox(conn, from_uid, chunk_size):
    while True:
        async with conn.cursor() as cur:
            await cur.execute(
                '''
                SELECT u.uid,
                       u.here_since,
                       u.folders_count,
                       u.message_count,
                       u.message_size,
                       u.attach_count,
                       u.attach_size,
                       coalesce(o.actions, '{}'::json) as actions
                  FROM (
                      SELECT uid,
                             u.here_since,
                             count(*) as folders_count,
                             sum(message_count) as message_count,
                             sum(message_size) as message_size,
                             sum(attach_count) as attach_count,
                             sum(attach_size) as attach_size
                        FROM mail.users u JOIN
                             mail.folders f using(uid)
                       WHERE uid >= %(from_uid)s
                         AND u.is_here
                         AND NOT u.is_deleted
                         AND u.state NOT IN ('archived', 'deleted')
                       GROUP BY uid
                       ORDER BY uid
                       LIMIT %(chunk_size)s
                ) u LEFT JOIN (
                    SELECT uid, json_object(array_agg(array[oper::text, op_count::text])) as actions
                      FROM (
                          SELECT uid, oper, count(*) as op_count
                            FROM filters.actions
                           GROUP BY uid, oper
                      ) s
                     GROUP BY uid
                ) o using(uid)
                ORDER BY uid
                ''',
                dict(
                    from_uid=from_uid,
                    chunk_size=chunk_size,
                )
            )
            chunk = [UserData(**r) async for r in cur]
            if len(chunk) == 0:
                return

            from_uid = chunk[-1].uid + 1
            yield chunk

            if len(chunk) < chunk_size:
                return


def approximate_deleted(shard_id, deleted_count, estimation: SizeEstimation):
    return {
        'from_shard_id': int(shard_id),
        'message_count': int(deleted_count),
        'approx_bytes_in_db': int(deleted_count) * estimation.deleted,
    }


def approximate_mailbox(shard_id, user: UserData, estimation: SizeEstimation):
    user_data = {
        'uid': user.uid,
        'from_shard_id': shard_id,
        'here_since': round(user.here_since.timestamp() * 1000),
        'folders_count': user.folders_count,
        'message_count': user.message_count,
        'message_size': user.message_size,
        'attach_count': user.attach_count,
        'attach_size': user.attach_size,
        'approx_bytes_in_db': user.message_count * estimation.message,
    }
    user_data.update({action: int(count) for action, count in user.actions.items()})
    return user_data


async def export_deleted(conn, table, shard_id, estimation, yt_client):
    await create_yt_deleted_table(yt_client, table)

    deleted_count = await read_deleted_count(conn)
    await write_data_to_yt(yt_client, table, [approximate_deleted(shard_id, deleted_count, estimation)])
    log.info('successfully written deleted to YT')


async def export_interval(conn, chunk_size, table, shard_id, from_uid, estimation, yt_client):
    if from_uid == 0:
        from_uid = await get_start_id(yt_client, table) + 1
    log.info(f'Process will be resumed from uid = {from_uid}')

    await create_yt_mailbox_table(yt_client, table)

    async for users in read_users_mailbox(conn, from_uid, chunk_size):
        result = []
        for user in users:
            result.append(approximate_mailbox(shard_id, user, estimation))
        await write_data_to_yt(yt_client, table, result)
        log.info(f'successfully written {len(result)} lines to YT')


async def export(conn, chunk_size, table_prefix, shard_id, from_uid, estimation, yt_client):
    date_prefix = datetime.datetime.today().strftime('%Y-%m-%d')
    mailbox_table = f'{table_prefix}{date_prefix}/mailbox/{shard_id}'
    deleted_table = f'{table_prefix}{date_prefix}/deleted/{shard_id}'
    await export_deleted(
        conn=conn,
        table=deleted_table,
        shard_id=shard_id,
        estimation=estimation,
        yt_client=yt_client,
    )
    await export_interval(
        conn=conn,
        chunk_size=chunk_size,
        table=mailbox_table,
        shard_id=shard_id,
        from_uid=from_uid,
        estimation=estimation,
        yt_client=yt_client,
    )


@dataclass
class PnlEstimationExportParams(TaskParams, ExportParams):
    task_name: str = 'pnl_estimation_export'
    table_prefix: str = '//home/mail-logs/core/mdb/pnl-data/'


async def shard_pnl_estimation_export(params: PnlEstimationExportParams, stats):
    await randomize_start_time(max_delay=params.max_delay)
    maildb_dsn = await get_shard_dsn(params.sharpei, params.db_user, params.shard_id, stats)
    async with aiopg.connect(maildb_dsn, cursor_factory=RealDictCursor) as conn:
        log.info(f'Start mailbox_export with chunk size {params.chunk_size}')
        yt_client = yt.YtClient(**params.yt_config)
        estimation = await get_size_estimations(conn)
        await export(
            conn=conn,
            chunk_size=params.chunk_size,
            table_prefix=params.table_prefix,
            shard_id=params.shard_id,
            from_uid=params.from_uid,
            estimation=estimation,
            yt_client=yt_client,
        )
