import random
from datetime import timedelta
from asyncpg.pool import Pool

from ..settings.db import MultihostDbSettings
from ..db_helpers.types import DbHost, Db
from .db_tier_poller import DbTierPoller


class DbMultihostPool(DbTierPoller):
    def __init__(
        self,
        db_settings: MultihostDbSettings,
        db_host_status_poll_time: timedelta
    ):
        self._db_settings = db_settings
        self._pools = None
        super().__init__(
            db=Db(
                hosts=[
                    DbHost(
                        dsn=db_settings.dsn_for_host(host),
                        geo=geo,
                        health_check_limit=db_settings.health_check_limit,
                    )
                    for geo, host in db_settings.hosts.items()
                ]
            ),
            ssl=db_settings.ssl,
            run_every=db_host_status_poll_time,
        )

    @property
    def poller(self):
        return self

    async def create_pool(self):
        self._pools = await self._db_settings.create_pools()
        return self

    async def close(self):
        if self._pools:
            for pool in self._pools.values():
                await pool.close()

    def _get_rw_host(self):
        for host in self._db.hosts:
            if host.primary:
                return host
        return None

    def _get_geo_host(self, geo: str):
        if not geo:
            return None
        for host in self._db.hosts:
            if host.geo.lower() == geo.lower() and not host.dead:
                return host
        return None

    def _get_ro_hosts(self):
        return [host for host in self._db.hosts if not host.primary and not host.dead]

    def _get_alive_hosts(self):
        return [host for host in self._db.hosts if not host.dead]

    def __call__(self, ro: bool = False, geo_hint: str = None) -> Pool:
        if not self._pools:
            raise RuntimeError('Uninitialized pools, should run create_pool() first')

        if not ro:
            rw_host = self._get_rw_host()
            if rw_host:
                return self._pools[rw_host.geo]

        geo_host = self._get_geo_host(geo_hint)
        if geo_host:
            return self._pools[geo_host.geo]

        ro_hosts = self._get_alive_hosts()
        if len(ro_hosts) > 0:
            return self._pools[random.choice(ro_hosts).geo]

        alive_hosts = self._get_alive_hosts()
        if len(alive_hosts) > 0:
            return self._pools[random.choice(alive_hosts).geo]

        return self._pools[random.choice(self._db.hosts).geo]
