from .models import squeezenet
from .models import resnet18
from .models import resnet50


def get_network_fn(train_config):
    name = train_config["net"]
    if "squeezenet" == name:
        return squeezenet.get_model
    elif "resnet18" == name:
        return resnet18.get_model
    elif "resnet50" == name:
        return resnet50.get_model
    else:
        raise RuntimeError("Unknown network name: {}".format(name))
