Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
data, target = Variable(data, volatile=True), Variable(target)
score = model(data)
imgs = data.data.cpu()
lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
lbl_true = target.data.cpu()
for img, lt, lp in zip(imgs, lbl_true, lbl_pred):
img, lt = val_loader.dataset.untransform(img, lt)
label_trues.append(lt)
label_preds.append(lp)
if len(visualizations) < 9:
viz = fcn.utils.visualize_segmentation(
lbl_pred=lp, lbl_true=lt, img=img, n_class=n_class,
label_names=val_loader.dataset.class_names)
visualizations.append(viz)
metrics = torchfcn.utils.label_accuracy_score(
label_trues, label_preds, n_class=n_class)
metrics = np.array(metrics)
metrics *= 100
print('''\
Accuracy: {0}
Accuracy Class: {1}
Mean IU: {2}
FWAV Accuracy: {3}'''.format(*metrics))
viz = fcn.utils.get_tile_image(visualizations)
skimage.io.imsave('viz_evaluate.png', viz)
if np.isnan(loss_data):
raise ValueError('loss is nan while validating')
val_loss += loss_data / len(data)
imgs = data.data.cpu()
lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
lbl_true = target.data.cpu()
for img, lt, lp in zip(imgs, lbl_true, lbl_pred):
img, lt = self.val_loader.dataset.untransform(img, lt)
label_trues.append(lt)
label_preds.append(lp)
if len(visualizations) < 9:
viz = fcn.utils.visualize_segmentation(
lbl_pred=lp, lbl_true=lt, img=img, n_class=n_class)
visualizations.append(viz)
metrics = torchfcn.utils.label_accuracy_score(
label_trues, label_preds, n_class)
out = osp.join(self.out, 'visualization_viz')
if not osp.exists(out):
os.makedirs(out)
out_file = osp.join(out, 'iter%012d.jpg' % self.iteration)
skimage.io.imsave(out_file, fcn.utils.get_tile_image(visualizations))
val_loss /= len(self.val_loader)
with open(osp.join(self.out, 'log.csv'), 'a') as f:
elapsed_time = (
datetime.datetime.now(pytz.timezone('Asia/Tokyo')) -
self.timestamp_start).total_seconds()
log = [self.epoch, self.iteration] + [''] * 5 + \
[val_loss] + list(metrics) + [elapsed_time]
score = self.model(data)
loss = cross_entropy2d(score, target,
size_average=self.size_average)
loss /= len(data)
loss_data = loss.data.item()
if np.isnan(loss_data):
raise ValueError('loss is nan while training')
loss.backward()
self.optim.step()
metrics = []
lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
lbl_true = target.data.cpu().numpy()
acc, acc_cls, mean_iu, fwavacc = \
torchfcn.utils.label_accuracy_score(
lbl_true, lbl_pred, n_class=n_class)
metrics.append((acc, acc_cls, mean_iu, fwavacc))
metrics = np.mean(metrics, axis=0)
with open(osp.join(self.out, 'log.csv'), 'a') as f:
elapsed_time = (
datetime.datetime.now(pytz.timezone('Asia/Tokyo')) -
self.timestamp_start).total_seconds()
log = [self.epoch, self.iteration] + [loss_data] + \
metrics.tolist() + [''] * 5 + [elapsed_time]
log = map(str, log)
f.write(','.join(log) + '\n')
if self.iteration >= self.max_iter:
break