Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN8sAtOnce(n_class=21)
start_epoch = 0
start_iteration = 0
if args.resume:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
vgg16 = torchfcn.models.VGG16(pretrained=True)
model.copy_params_from_vgg16(vgg16)
if cuda:
model = model.cuda()
# 3. optimizer
optim = torch.optim.SGD(
[
{'params': get_parameters(model, bias=False)},
{'params': get_parameters(model, bias=True),
'lr': args.lr * 2, 'weight_decay': 0},
],
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
if args.resume:
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN32s(n_class=21)
start_epoch = 0
start_iteration = 0
if args.resume:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
vgg16 = torchfcn.models.VGG16(pretrained=True)
model.copy_params_from_vgg16(vgg16)
if cuda:
model = model.cuda()
# 3. optimizer
optim = torch.optim.SGD(
[
{'params': get_parameters(model, bias=False)},
{'params': get_parameters(model, bias=True),
'lr': args.lr * 2, 'weight_decay': 0},
],
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
if args.resume:
batch_size=batch_size, shuffle=True, **kwargs)
valid_loader = torch.utils.data.DataLoader(
dataset_class(split='valid', transform=True),
batch_size=batch_size, shuffle=False, **kwargs)
# 2. model
n_class = len(train_loader.dataset.class_names)
model = torchfcn.models.FCN32s(n_class=n_class, nodeconv=cfg['nodeconv'])
start_epoch = 0
if resume:
checkpoint = torch.load(resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
else:
vgg16 = torchfcn.models.VGG16(pretrained=True)
model.copy_params_from_vgg16(vgg16, copy_fc8=False, init_upscore=False)
if cuda:
if torch.cuda.device_count() == 1:
model = model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()
# 3. optimizer
optim = torch.optim.Adam(model.parameters(), lr=cfg['lr'],
weight_decay=cfg['weight_decay'])
if resume:
optim.load_state_dict(checkpoint['optim_state_dict'])
trainer = torchfcn.Trainer(
cuda=cuda,