import asyncio
import concurrent.futures
from decimal import Decimal
from typing import Dict

from yql.api.v1.client import YqlClient, YqlTableReadIterator

from maps_adv.adv_store.lib.config import config
from maps_adv.adv_store.lib.db import db
from maps_adv.adv_store.v2.lib.db import tables

from .campaign_change_log import (
    CampaignsChangeLogActionName,
    insert_campaign_change_log,
    refresh_campaigns_change_logs,
)


async def retrieve_campaigns_display_probability() -> Dict[int, Decimal]:
    def _yt_table_read() -> Dict[str, Decimal]:
        cluster_name = config.YQL_CLUSTER
        table_name = config.YQL_DISPLAY_PROBABILITY_TABLE

        with YqlClient(token=config.YQL_TOKEN):
            table_iterator = YqlTableReadIterator(
                table_name,
                cluster=cluster_name,
                column_names=["campaign_id", "new_probability"],
            )

            data = dict()
            for campaign_id, probability in table_iterator:
                data[campaign_id] = Decimal(probability)
            return data

    with concurrent.futures.ThreadPoolExecutor() as pool:
        result = await asyncio.get_running_loop().run_in_executor(pool, _yt_table_read)

    return result


async def update_campaign_display_probability(campaign_id: int, value: Decimal):
    async with db.rw.transaction():
        change_log_id = await insert_campaign_change_log(
            campaign_id=campaign_id,
            action_name=CampaignsChangeLogActionName.CAMPAIGN_REFRESH_DISPLAY_PROBABILITY,  # noqa: E501
        )
        await db.rw.execute(
            tables.campaign.update()
            .values({"display_probability": value})
            .where(tables.campaign.c.id == campaign_id)
        )
        await refresh_campaigns_change_logs(ids=[change_log_id])
