from walle.clients.racktables import shorten_switch_name
from walle.constants import HostType
from walle.errors import BadRequestError
from walle.hosts import Host, HostLocation
from walle.util.mongo import MongoDocument


def _get_location_filter(location):
    location_dict = HostLocation.split_physical_location_string_to_fields(location)
    if not location_dict.get("rack"):
        raise BadRequestError("'%s' is not a rack" % location)

    return {
        Host.location.db_field + "." + getattr(HostLocation, field_name).db_field: field_value
        for field_name, field_value in location_dict.items()
    }


def _fetch_hosts(hosts_query):
    hosts_query["type"] = {"$in": [HostType.SERVER, HostType.SHADOW_SERVER]}
    return list(MongoDocument.for_model(Host).find(hosts_query, {"inv": 1, "project": 1, "_id": 0}))


def get_hosts_invs(hosts):
    return [host.inv for host in hosts]


def _get_switch_filter(switch):
    return {f"{Host.location.db_field}.{HostLocation.switch.db_field}": shorten_switch_name(switch)}


def get_hosts_invs_by_location(location):
    hosts = _fetch_hosts(_get_location_filter(location))
    if not hosts:
        raise BadRequestError(f"There are no hosts associated with location '{location}'")

    return get_hosts_invs(hosts)


def get_hosts_invs_by_switch(switch):
    hosts = _fetch_hosts(_get_switch_filter(switch))
    if not hosts:
        raise BadRequestError(f"There are no hosts associated with switch '{switch}'")

    return get_hosts_invs(hosts)
