
import torch
import torchvision

from bran.train.nets.vgg import VGG
from bran.train.nets.svhn import SVHNNet

def model_from_config(config, device=None):
    if not device:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if config['arch'] == 'resnet18_pretrained':
        net = torchvision.models.resnet18()
        num_features = net.fc.in_features
        net.fc = torch.nn.Linear(num_features, config['n_classes'])
    elif config['arch'] == 'vgg16_pretrained':
        net = torchvision.models.vgg16_bn()
        num_features = net.classifier[-2].in_features
        features = list(net.classifier.children())[:-1]
        features.append(torch.nn.Linear(num_features, config['n_classes']))
        net.classifier = torch.nn.Sequential(*features)
    elif config['arch'] == 'resnet18_svhn':
        net = SVHNNet()
    elif config['arch'] == 'vgg':
        net = VGG(config)

    net.to(device)
    net.eval()
    return net
