Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
for h in [3, 32, 128]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, h, 10
# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
model = TwoLayerNet(D_in, H, D_out)
x, y, model = x.to(device), y.to(device), model.to(device)
layers = [model.linear1, model.linear2]
stats = CheckLayerSat('regression/h{}'.format(h), layers, device=device)
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
steps_iter = trange(2000, desc='steps', leave=True, position=0)
steps_iter.write("{:^80}".format(
"Regression - TwoLayerNet - Hidden layer size {}".format(h)))
for _ in steps_iter:
y_pred = model(x)
loss = loss_fn(y_pred, y)
steps_iter.set_description('loss=%g' % loss.data)
optimizer.zero_grad()
loss.backward()
optimizer.step()
stats.saturation()
steps_iter.write('\n')
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1)
if not os.path.exists('convNet'):
os.mkdir('convNet')
epochs = 5
for h2 in [8, 32, 128]: # compare various hidden layer sizes
net = resnet18(pretrained=False, num_classes=10)#Net(h2=h2) # instantiate network with hidden layer size `h2`
net.to(device)
logging_dir = 'convNet/simpson_h2-{}'.format(h2)
stats = CheckLayerSat(savefile=logging_dir, save_to=['plot', 'csv', 'npy'], modules=net, include_conv=True, stats=['dtrc', 'trc', 'cov', 'idim', 'lsat', 'det'], max_samples=1024,
verbose=True, writer_args={}, conv_method='channelwise', device='cpu', initial_epoch=5, interpolation_downsampling=4, interpolation_strategy='nearest')
#net = nn.DataParallel(net, device_ids=['cuda:0', 'cuda:1'])
print(net)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#stats.write( "CIFAR10 ConvNet - Changing fc2 - size {}".format(h2)) # optional
for epoch in range(epochs):
if epoch == 2:
stats.stop()
if epoch == 3:
stats.resume()
running_loss = 0.0
step = 0