import logging
from collections import defaultdict
from functools import cached_property
from typing import Callable, TypeVar

import time
from pydantic import BaseModel

from travel.avia.ad_feed.ad_feed.entities import Direction
from travel.avia.ad_feed.ad_feed.profit_cpa import ProfitCPA
from travel.avia.ad_feed.ad_feed.redir_balance_log import RedirBalanceLog, RedirLogRow

log = logging.getLogger(__name__)


class ClickPriceData(BaseModel):
    price_cpc: float = 0.0
    price_cpa: float = 0.0
    clicks: int = 0
    markers: list[str] = []
    price_avg: float = 0.0


T = TypeVar('T')


class ClickPriceCounter:
    def __init__(self, redir_balance_log: RedirBalanceLog, profit_cpa: ProfitCPA):
        self.redir_balance_log: RedirBalanceLog = redir_balance_log
        self.profit_cpa: ProfitCPA = profit_cpa

    @cached_property
    def avg_click_price_by_direction(self) -> dict[Direction, ClickPriceData]:
        return self._collect_data(lambda row: Direction(row.from_id, row.to_id))

    @cached_property
    def avg_click_price_by_settlement(self) -> dict[int, ClickPriceData]:
        return self._collect_data(lambda row: row.to_id)

    def _collect_data(self, key: Callable[[RedirLogRow], T]) -> dict[T, ClickPriceData]:
        log.info('Getting redir balance log')
        redir_balance_log = self.redir_balance_log.get_redir_balance_log()
        log.info('Getting cpa profit')
        cpa_profit_by_marker = self.profit_cpa.get_profit()

        result: dict[T, ClickPriceData] = defaultdict(ClickPriceData)

        log.info('Calculating total click price')
        rblog_count = 0
        last_informed = time.time()
        for row in redir_balance_log:
            rblog_count += 1
            if time.time() - last_informed > 30:
                last_informed = time.time()
                log.info('Ran through %d redir-balance-log records for now', rblog_count)
            current = result[key(row)]
            current.markers.append(row.marker)
            current.price_cpc += row.price_avg
            current.price_cpa += cpa_profit_by_marker.get(row.marker, 0.0)
        log.info('Ran through %d redir-balance-log records in total', rblog_count)

        log.info('Calculating average click price')
        for value in result.values():
            value.price_avg = (value.price_cpc + value.price_cpa) / len(value.markers)

        log.info('Directions total: %d', len(result))
        return result
