import logging

import sqlalchemy.exc
from sqlalchemy.orm import Session
from typing import Callable, Generator

from travel.avia.shared_flights.lib.python.db_models.apm_imported_file import ApmImportedFile
from travel.avia.shared_flights.lib.python.db_models.blacklist import Blacklist
from travel.avia.shared_flights.lib.python.db_models.carrier import Carrier
from travel.avia.shared_flights.lib.python.db_models.db_lock import DbLock
from travel.avia.shared_flights.lib.python.db_models.designated_carrier import DesignatedCarrier
from travel.avia.shared_flights.lib.python.db_models.flight_base import FlightBase, SirenaFlightBase, ApmFlightBase
from travel.avia.shared_flights.lib.python.db_models.flight_merge_rule import FlightMergeRule
from travel.avia.shared_flights.lib.python.db_models.flight_pattern import FlightPattern, SirenaFlightPattern, \
    ApmFlightPattern
from travel.avia.shared_flights.lib.python.db_models.flight_status_source import FlightStatusSource
from travel.avia.shared_flights.lib.python.db_models.flying_carrier import FlyingCarrier
from travel.avia.shared_flights.lib.python.db_models.iata_correction_rule import IataCorrectionRule
from travel.avia.shared_flights.lib.python.db_models.last_imported_info import LastImportedInfo
from travel.avia.shared_flights.lib.python.db_models.station import Station
from travel.avia.shared_flights.lib.python.db_models.stop_point import StopPoint
from travel.avia.shared_flights.lib.python.db_models.timezone import Timezone
from travel.avia.shared_flights.lib.python.db_models.transport_model import TransportModel
from travel.avia.shared_flights.tasks.monitoring.db.metric import Metric

logger = logging.getLogger(__name__)


class TableSizeMonitor:
    def __init__(self, session_factory: Callable[[], Session]):
        self.session_factory = session_factory

    def get_metrics(self) -> Generator[Metric, None, None]:
        session = self.session_factory()

        tables = [
            ApmFlightBase,
            ApmFlightPattern,
            ApmImportedFile,
            Blacklist,
            Carrier,
            DbLock,
            DesignatedCarrier,
            FlightBase,
            FlightMergeRule,
            FlightPattern,
            FlightStatusSource,
            FlyingCarrier,
            IataCorrectionRule,
            LastImportedInfo,
            SirenaFlightBase,
            SirenaFlightPattern,
            Station,
            StopPoint,
            Timezone,
            TransportModel,
        ]
        try:
            for table in tables:
                name = table.__table__.name
                try:
                    count = session.query(table).count()
                    logger.debug("Table: %r size %d", name, count)
                    yield Metric(
                        sensor='db.table.size',
                        labels={
                            'table_name': name,
                        },
                        value=count,
                    )
                except sqlalchemy.exc.ProgrammingError:
                    logger.exception("Cannot get table %s size", name)
                    session.rollback()
        finally:
            session.close()
