How to use the baal.bayesian.dropout.patch_module function in baal

To help you get started, we’ve selected a few baal examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github ElementAI / baal / experiments / vgg_mcdropout_cifar10.py View on Github external
print("warning, the experiments would take ages to run on cpu")

    hyperparams = vars(args)

    active_set, test_set = get_datasets(hyperparams['initial_pool'])

    heuristic = get_heuristic(hyperparams['heuristic'],
                              hyperparams['shuffle_prop'])
    criterion = CrossEntropyLoss()
    model = vgg16(pretrained=False, num_classes=10)
    weights = load_state_dict_from_url('https://download.pytorch.org/models/vgg16-397923af.pth')
    weights = {k: v for k, v in weights.items() if 'classifier.6' not in k}
    model.load_state_dict(weights, strict=False)

    # change dropout layer to MCDropout
    model = patch_module(model)

    if use_cuda:
        model.cuda()
    optimizer = optim.SGD(model.parameters(), lr=hyperparams["lr"], momentum=0.9)

    # Wraps the model into a usable API.
    model = ModelWrapper(model, criterion)

    logs = {}
    logs['epoch'] = 0

    # for prediction we use a smaller batchsize
    # since it is slower
    active_loop = ActiveLearningLoop(active_set,
                                     model.predict_on_dataset,
                                     heuristic,