import json
from typing import List
from dataclasses import dataclass
from datetime import timedelta

from .helpers import db_general_retry
from .task import TaskParams
from .cursor_provider import create_cursor_provider, locked_transactional_cursor, DbConnectionError
from .freeze_helper import db_update_user_state_with_notifies_count, db_get_user_state, is_live_user
from .notify_helper import User, Templates, Metrics, NotifyParams, blackbox, sendr, is_valid_default_address
from .export_helper import randomize_start_time
from mail.shiva.stages.api.props.logger import get_uid_logger
from mail.python.theatre.profiling.http import ProfiledClientSession
from mail.shiva.stages.api.settings.log import http_logger
from mail.shiva.stages.api.settings.surveillance import SurveillanceSettings
from mail.shiva.stages.api.props.services.surveillance import get_surveillance_users
from mail.shiva.stages.api.props.services.sharpei import get_shard_id_by_uid

log = get_uid_logger(__name__)
FAKE_NOTIFIES_COUNT_FOR_FOREIGN_ADDRESS = -1
FAKE_NOTIFIES_COUNT_FOR_BAD_BLACKBOX = -2
FAKE_NOTIFIES_COUNT_FOR_BAD_DEFAULT_ADDRESS = -3
FAKE_NOTIFIES_COUNT_FOR_DIRECT_USERS = -4
FAKE_NOTIFIES_COUNT_FOR_BAD_SHARD = -6


@db_general_retry
async def get_users_chunk(conn, state_ttl, notifies_count, last_uid, chunk_size):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
                SELECT uid
                  FROM mail.users
                 WHERE is_here
                   AND NOT is_deleted
                   AND state = 'notified'
                   AND notifies_count = %(notifies_count)s
                   AND last_state_update < (now() - %(state_ttl)s)
                   AND uid > %(last_uid)s
                 ORDER BY uid
                 LIMIT %(chunk_size)s
            ''',
            dict(
                state_ttl=state_ttl,
                notifies_count=notifies_count,
                last_uid=last_uid,
                chunk_size=chunk_size,
            )
        )
        return [rec['uid'] async for rec in cur]


async def get_users(conn, state_ttl, notifies_count, chunk_size, users_count):
    last_uid = 0
    while True:
        if chunk_size > users_count:
            chunk_size = users_count
        users_count -= chunk_size

        chunk = await get_users_chunk(conn, state_ttl, notifies_count, last_uid, chunk_size)

        if chunk:
            last_uid = chunk[-1]
            yield chunk

        if len(chunk) < chunk_size or users_count == 0:
            return


@dataclass
class NotifyRule:
    notifies_count: int
    state_ttl: timedelta
    mail_templates: Templates


@dataclass
class NotifyUsersParams(TaskParams, NotifyParams):
    task_name: str = 'notify_users'
    surveillance_settings: SurveillanceSettings = None
    notify_rules: List[NotifyRule] = (
        NotifyRule(
            notifies_count=0,
            state_ttl=timedelta(days=1),
            mail_templates=Templates(
                default='freeze-30d-def',
                ru='freeze-30d-ru',
                en='freeze-30d-en',
                tr='freeze-30d-tr')
        ),
        NotifyRule(
            notifies_count=1,
            state_ttl=timedelta(days=20),
            mail_templates=Templates(
                default='freeze-10d-def',
                ru='freeze-10d-ru',
                en='freeze-10d-en',
                tr='freeze-10d-tr')
        ),
    )
    chunk_size: int = 200
    max_users_count: int = 10000
    max_delay: int = 10 * 3600


def has_only_default_recovery_address(user: User):
    return user.default_address is not None and user.default_address in user.addresses and len(user.addresses) == 1


def is_appropriate_user_state(user_state, notifies_count):
    return is_live_user(user_state) and user_state.state == 'notified' and notifies_count == user_state.notifies_count + 1


@db_general_retry
async def update_user_state(conn, uid, notifies_count):
    async with locked_transactional_cursor(conn, uid) as cur:
        user_state = await db_get_user_state(cur, uid)
        if not is_appropriate_user_state(user_state, notifies_count):
            log.info('cannot update notified user, because his state was changed', uid=uid)
            return
        await db_update_user_state_with_notifies_count(cur, user_state, 'notified', notifies_count)


async def notify_user(conn, user: User, notifies_count, metrics, params: NotifyUsersParams, templates: Templates, surveillance_users):
    need_db_update_retries = False
    try:
        async with locked_transactional_cursor(conn, user.uid) as cur:
            user_state = await db_get_user_state(cur, user.uid)
            if not is_appropriate_user_state(user_state, notifies_count):
                log.info('user was skipped, because his state was changed', uid=user.uid)
                return
            if str(user.uid) in surveillance_users:
                log.info('user was skipped, because he is special', uid=user.uid)
                await db_update_user_state_with_notifies_count(cur, user_state, 'special', -1)
                return
            if params.shard_id != await get_shard_id_by_uid(params.sharpei, user.uid, metrics):
                await db_update_user_state_with_notifies_count(cur, user_state, 'notified', FAKE_NOTIFIES_COUNT_FOR_BAD_SHARD)
                log.info('user was skipped, because he is in different shard really', uid=user.uid)
                return
            if user.is_direct:
                await db_update_user_state_with_notifies_count(cur, user_state, 'notified', FAKE_NOTIFIES_COUNT_FOR_DIRECT_USERS)
                log.info('user was skipped, because he has Direct subscription', uid=user.uid)
                return
            if not has_only_default_recovery_address(user):
                await db_update_user_state_with_notifies_count(cur, user_state, 'notified', FAKE_NOTIFIES_COUNT_FOR_FOREIGN_ADDRESS)
                log.info('user was skipped, because he has non-default recovery address', uid=user.uid)
                return
            if not await is_valid_default_address(params, user, metrics):
                await db_update_user_state_with_notifies_count(cur, user_state, 'notified', FAKE_NOTIFIES_COUNT_FOR_BAD_DEFAULT_ADDRESS)
                log.info('user was skipped, because he has bad default recovery address', uid=user.uid)
                return
            async with ProfiledClientSession(metrics=metrics, logger=http_logger.get_logger(), timeout=params.sendr_settings.timeout) as client:
                if await sendr(cfg=params.sendr_settings, user=user, metrics=metrics, client=client, templates=templates):
                    try:
                        await db_update_user_state_with_notifies_count(cur, user_state, 'notified', notifies_count)
                    except DbConnectionError:
                        need_db_update_retries = True
        if need_db_update_retries:
            await update_user_state(conn, user.uid, notifies_count)
    except Exception as exc:
        log.exception(f'Got exception {exc}', uid=user.uid)
        return
    log.info('user was successfully notified', uid=user.uid)


async def update_bad_bb_user(conn, uid, notifies_count):
    try:
        async with locked_transactional_cursor(conn, uid) as cur:
            user_state = await db_get_user_state(cur, uid)
            if not is_appropriate_user_state(user_state, notifies_count):
                log.info('user was skipped, because his state was changed', uid=uid)
                return
            await db_update_user_state_with_notifies_count(cur, user_state, 'notified', FAKE_NOTIFIES_COUNT_FOR_BAD_BLACKBOX)
            log.info('user was skipped, because he has bad data in blackbox', uid=uid)
    except Exception as exc:
        log.exception(f'Got exception: {exc}', uid=uid)


async def shard_notify_users(params: NotifyUsersParams, stats):
    await randomize_start_time(max_delay=params.max_delay)
    metrics = Metrics(stats)
    surveillance_users = await get_surveillance_users(params.surveillance_settings, stats)
    async with create_cursor_provider(params, stats) as conn:
        for rule in params.notify_rules:
            async for users_chunk in get_users(
                conn=conn,
                state_ttl=rule.state_ttl,
                notifies_count=rule.notifies_count,
                chunk_size=params.chunk_size,
                users_count=params.max_users_count,
            ):
                bb_tvm_ticket = await params.tvm_tickets.get(params.tvm_ids.bb_id)
                templates = params.notify_rules[rule.notifies_count].mail_templates

                bb_uids = []
                async with ProfiledClientSession(metrics=metrics, logger=http_logger.get_logger(), timeout=params.bb_settings.timeout, json_serialize=json.dumps) as client:
                    for u in await blackbox(cfg=params.bb_settings, uids=users_chunk, bb_tvm_ticket=bb_tvm_ticket,
                                            metrics=metrics, client=client):
                        bb_uids.append(int(u.uid))
                        await notify_user(conn=conn, user=u, notifies_count=rule.notifies_count + 1,
                                          metrics=metrics, params=params, templates=templates, surveillance_users=surveillance_users)

                for db_uid in users_chunk:
                    if db_uid not in bb_uids:
                        await update_bad_bb_user(conn, db_uid, rule.notifies_count + 1)
