from dataclasses import dataclass
from datetime import timedelta

from mail.python.theatre.detail.tvm import TvmServiceTickets
from mail.python.theatre.profiling.http import ProfiledClientSession
from .freeze_helper import db_update_user_state, db_update_user_state_with_notifies_count, db_get_user_state, is_live_user
from .notify_users import FAKE_NOTIFIES_COUNT_FOR_BAD_SHARD
from .helpers import db_general_retry
from .export_helper import randomize_start_time
from .task import TaskParams
from .cursor_provider import create_cursor_provider, locked_transactional_cursor
from mail.shiva.stages.api.props.logger import get_uid_logger
from mail.shiva.stages.api.settings.passport import PassportSettings
from mail.shiva.stages.api.settings.log import http_logger
from mail.shiva.stages.api.settings.tvm import TvmSettings
from mail.shiva.stages.api.settings.surveillance import SurveillanceSettings
from mail.shiva.stages.api.props.services.passport import passport_update_user_state
from mail.shiva.stages.api.props.services.surveillance import get_surveillance_users
from mail.shiva.stages.api.props.services.sharpei import get_shard_id_by_uid

log = get_uid_logger(__name__)

FREEZE_NOTIFIES_COUNT = 2


@db_general_retry
async def get_users_chunk(conn, state_ttl, last_uid, chunk_size):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
                SELECT uid
                  FROM mail.users
                 WHERE is_here
                   AND NOT is_deleted
                   AND state = 'notified'
                   AND notifies_count = %(notifies_count)s
                   AND last_state_update < (now() - %(state_ttl)s)
                   AND uid > %(last_uid)s
                 ORDER BY uid
                 LIMIT %(chunk_size)s
            ''',
            dict(
                state_ttl=state_ttl,
                notifies_count=FREEZE_NOTIFIES_COUNT,
                last_uid=last_uid,
                chunk_size=chunk_size,
            )
        )
        return [rec['uid'] async for rec in cur]


async def get_users(conn, state_ttl, chunk_size, users_count):
    last_uid = 0
    while True:
        if chunk_size > users_count:
            chunk_size = users_count
        users_count -= chunk_size

        chunk = await get_users_chunk(conn, state_ttl, last_uid, chunk_size)

        if chunk:
            last_uid = chunk[-1]
            yield chunk

        if len(chunk) < chunk_size or users_count == 0:
            return


async def freeze_user(passport_client, params, passport_tvm_ticket, conn, uid, stats, surveillance_users):
    try:
        async with locked_transactional_cursor(conn, uid) as cur:
            user_state = await db_get_user_state(cur, uid)
            if not is_live_user(user_state) or user_state.state != 'notified' or user_state.notifies_count != FREEZE_NOTIFIES_COUNT:
                log.info('user was skipped, because his state was changed', uid=uid)
                return
            if str(uid) in surveillance_users:
                log.info('user was skipped, because he is special', uid=uid)
                await db_update_user_state_with_notifies_count(cur, user_state, 'special', -1)
                return
            if params.shard_id != await get_shard_id_by_uid(params.sharpei, uid, stats):
                await db_update_user_state_with_notifies_count(cur, user_state, 'notified', FAKE_NOTIFIES_COUNT_FOR_BAD_SHARD)
                log.info('user was skipped, because he is in different shard really', uid=uid)
                return
            await db_update_user_state(cur, user_state, 'frozen')
            await passport_update_user_state(passport_client, params.passport_settings, uid, passport_tvm_ticket, stats)
            log.info('user was successfully frozen', uid=uid)
            stats.increase_task_meter('passport_update_user_state_success')
    except Exception as exc:
        log.exception(f'Got exception: {exc}', uid=uid)


@dataclass
class FreezeUsersParams(TaskParams):
    task_name: str = 'freeze_users'
    state_ttl: timedelta = timedelta(days=10)
    chunk_size: int = 1000
    max_users_count: int = 10000
    passport_settings: PassportSettings = None
    surveillance_settings: SurveillanceSettings = None

    tvm_ids: TvmSettings = None
    tvm_tickets: TvmServiceTickets = None
    max_delay: int = 5 * 3600


async def shard_freeze_users(params: FreezeUsersParams, stats):
    await randomize_start_time(max_delay=params.max_delay)
    passport_tvm_ticket = await params.tvm_tickets.get(params.tvm_ids.passport_id)
    surveillance_users = await get_surveillance_users(params.surveillance_settings, stats)
    async with create_cursor_provider(params, stats) as conn:
        async with ProfiledClientSession(metrics=stats, logger=http_logger.get_logger(),
                                         timeout=params.passport_settings.timeout) as client:
            async for users_chunk in get_users(
                conn=conn,
                state_ttl=params.state_ttl,
                chunk_size=params.chunk_size,
                users_count=params.max_users_count,
            ):
                for uid in users_chunk:
                    await freeze_user(client, params, passport_tvm_ticket, conn, uid, stats, surveillance_users)
