Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
cuda = torch.cuda.is_available()
torch.manual_seed(1337)
if cuda:
torch.cuda.manual_seed(1337)
# 1. dataset
root = osp.expanduser('~/data/datasets')
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
torchfcn.datasets.SBDClassSeg(root, split='train', transform=True),
batch_size=1, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN8s(n_class=21)
start_epoch = 0
start_iteration = 0
if args.resume:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
fcn16s = torchfcn.models.FCN16s()
state_dict = torch.load(args.pretrained_model)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('model_file', help='Model path')
parser.add_argument('-g', '--gpu', type=int, default=0)
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
model_file = args.model_file
root = osp.expanduser('~/data/datasets')
val_loader = torch.utils.data.DataLoader(
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False,
num_workers=4, pin_memory=True)
n_class = len(val_loader.dataset.class_names)
if osp.basename(model_file).startswith('fcn32s'):
model = torchfcn.models.FCN32s(n_class=21)
elif osp.basename(model_file).startswith('fcn16s'):
model = torchfcn.models.FCN16s(n_class=21)
elif osp.basename(model_file).startswith('fcn8s'):
if osp.basename(model_file).startswith('fcn8s-atonce'):
model = torchfcn.models.FCN8sAtOnce(n_class=21)
else:
model = torchfcn.models.FCN8s(n_class=21)
else:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
cuda = torch.cuda.is_available()
torch.manual_seed(1337)
if cuda:
torch.cuda.manual_seed(1337)
# 1. dataset
root = osp.expanduser('~/data/datasets')
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
torchfcn.datasets.SBDClassSeg(root, split='train', transform=True),
batch_size=1, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN16s(n_class=21)
start_epoch = 0
start_iteration = 0
if args.resume:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
fcn32s = torchfcn.models.FCN32s()
state_dict = torch.load(args.pretrained_model)
cuda = torch.cuda.is_available()
torch.manual_seed(1337)
if cuda:
torch.cuda.manual_seed(1337)
# 1. dataset
root = osp.expanduser('~/data/datasets')
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
torchfcn.datasets.SBDClassSeg(root, split='train', transform=True),
batch_size=1, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN32s(n_class=21)
start_epoch = 0
start_iteration = 0
if resume:
checkpoint = torch.load(resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
vgg16_fcn32s = torchfcn.models.FCN32s(n_class=21)
vgg16_fcn32s.load_state_dict(torch.load(osp.expanduser('~/data/models/torch/vgg16-fcn32s.pth')))
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
cuda = torch.cuda.is_available()
torch.manual_seed(1337)
if cuda:
torch.cuda.manual_seed(1337)
# 1. dataset
root = osp.expanduser('~/data/datasets')
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
torchfcn.datasets.SBDClassSeg(root, split='train', transform=True),
batch_size=1, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN8sAtOnce(n_class=21)
start_epoch = 0
start_iteration = 0
if args.resume:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
vgg16 = torchfcn.models.VGG16(pretrained=True)
model.copy_params_from_vgg16(vgg16)