import re
import logging
import aiopg
from datetime import timedelta
from dataclasses import dataclass

from mail.shiva.stages.api.props.services.sharpei import get_shard_dsn, get_shard_id_by_uid
from .task import TaskParams

log = logging.getLogger(__name__)


async def user_is_in_shard(sharpei_cfg, uid, shard_id, stats):
    try:
        return int(shard_id) == await get_shard_id_by_uid(sharpei_cfg, uid, stats)
    except:
        return False


async def get_outdated_prepared_transactions(conn, transaction_ttl):
    async with conn.cursor() as cur:
        await cur.execute(
            'SELECT gid FROM pg_prepared_xacts WHERE prepared <= now() - %(delay)s',
            dict(delay=transaction_ttl)
        )
        return [row[0] async for row in cur]


async def commit_prepared_transaction(conn, transaction_id):
    async with conn.cursor() as cur:
        await cur.execute(
            "COMMIT PREPARED %(transaction_id)s",
            dict(transaction_id=transaction_id)
        )
    log.info(f"prepared transactions {transaction_id} was commited")


async def rollback_prepared_transaction(conn, transaction_id):
    async with conn.cursor() as cur:
        await cur.execute(
            "ROLLBACK PREPARED %(transaction_id)s",
            dict(transaction_id=transaction_id)
        )
    log.info(f"prepared transactions {transaction_id} was rollbacked")


async def process_outdated_prepared_transaction(conn, sharpei_cfg, shard_id, transaction_id, stats):
    try:
        uid = re.search(r'(^reg_mdb_u)(\d+)(_s\d+)', transaction_id).group(2)
        if await user_is_in_shard(sharpei_cfg, uid, shard_id, stats):
            try:
                await commit_prepared_transaction(conn, transaction_id)
            except:
                await rollback_prepared_transaction(conn, transaction_id)
        else:
            await rollback_prepared_transaction(conn, transaction_id)
    except:
        log.error(f"can't process outdated prepared transaction {transaction_id}")
        pass


@dataclass
class EndPreparedTransactionParams(TaskParams):
    task_name: str = 'end_prepared_transaction'
    db_user: str = 'sharpei'
    transaction_ttl: timedelta = timedelta(minutes=1)


async def shard_end_prepared_transaction(params: EndPreparedTransactionParams, stats):
    async with aiopg.connect(await get_shard_dsn(params.sharpei, params.db_user, params.shard_id, stats)) as conn:
        for transaction_id in await get_outdated_prepared_transactions(conn, params.transaction_ttl):
            await process_outdated_prepared_transaction(conn, params.sharpei, params.shard_id, transaction_id, stats)
