How to use the delve.torchcallback.CheckLayerSat function in delve

To help you get started, we’ve selected a few delve examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github delve-team / delve / example_fc.py View on Github external
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1)

for h in [10, 100, 300]:

    # Create random Tensors to hold inputs and outputs
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)

    model = LayerCake(D_in, h, H2, H3, H4, H5, D_out)

    x, y, model = x.to(device), y.to(device), model.to(device)
    if not os.path.exists('regression'):
        os.mkdir('regression')
    stats = CheckLayerSat('regression/h{}'.format(h), 'csv', model, device=device, reset_covariance=True,)

    loss_fn = torch.nn.MSELoss(size_average=False)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    steps_iter = trange(2000, desc='steps', leave=True, position=0)
    steps_iter.write("{:^80}".format(
        "Regression - SixLayerNet - Hidden layer size {}".format(h)))
    for i in steps_iter:
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        steps_iter.set_description('loss=%g' % loss.data)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        stats.add_saturations()
        #stats.saturation()
    steps_iter.write('\n')