Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
cuda = torch.cuda.is_available()
batch_size = torch.cuda.device_count() * 3
max_iter = cfg['max_iteration'] // batch_size
torch.manual_seed(1)
if cuda:
torch.cuda.manual_seed(1)
# 1. dataset
cfg['dataset'] = cfg.get('dataset', 'v2')
if cfg['dataset'] == 'v2':
dataset_class = torchfcn.datasets.APC2016V2
elif cfg['dataset'] == 'v3':
dataset_class = torchfcn.datasets.APC2016V3
else:
raise ValueError('Unsupported dataset: %s' % cfg['dataset'])
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
dataset_class(split='train', transform=True),
batch_size=batch_size, shuffle=True, **kwargs)
valid_loader = torch.utils.data.DataLoader(
dataset_class(split='valid', transform=True),
batch_size=batch_size, shuffle=False, **kwargs)
# 2. model
n_class = len(train_loader.dataset.class_names)
model = torchfcn.models.FCN32s(n_class=n_class, nodeconv=cfg['nodeconv'])
start_epoch = 0
cfg, out = load_config_file(config_file)
cuda = torch.cuda.is_available()
batch_size = torch.cuda.device_count() * 3
max_iter = cfg['max_iteration'] // batch_size
torch.manual_seed(1)
if cuda:
torch.cuda.manual_seed(1)
# 1. dataset
cfg['dataset'] = cfg.get('dataset', 'v2')
if cfg['dataset'] == 'v2':
dataset_class = torchfcn.datasets.APC2016V2
elif cfg['dataset'] == 'v3':
dataset_class = torchfcn.datasets.APC2016V3
else:
raise ValueError('Unsupported dataset: %s' % cfg['dataset'])
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
dataset_class(split='train', transform=True),
batch_size=batch_size, shuffle=True, **kwargs)
valid_loader = torch.utils.data.DataLoader(
dataset_class(split='valid', transform=True),
batch_size=batch_size, shuffle=False, **kwargs)
# 2. model
n_class = len(train_loader.dataset.class_names)
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
cuda = torch.cuda.is_available()
torch.manual_seed(1337)
if cuda:
torch.cuda.manual_seed(1337)
# 1. dataset
root = osp.expanduser('~/data/datasets')
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
torchfcn.datasets.SBDClassSeg(root, split='train', transform=True),
batch_size=1, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN32s(n_class=21)
start_epoch = 0
start_iteration = 0
if args.resume:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
vgg16 = torchfcn.models.VGG16(pretrained=True)
model.copy_params_from_vgg16(vgg16)