Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
if mask is not None:
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
if square:
tp = tp ** 2
fp = fp ** 2
fn = fn ** 2
tp = sum_tensor(tp, axes, keepdim=False)
fp = sum_tensor(fp, axes, keepdim=False)
fn = sum_tensor(fn, axes, keepdim=False)
return tp, fp, fn
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
softmax_output = self.apply_nonlin(net_output)
# no object value
bg_onehot = 1 - y_onehot
squared_error = (y_onehot - softmax_output)**2
specificity_part = sum_tensor(squared_error*y_onehot, axes)/(sum_tensor(y_onehot, axes)+self.smooth)
sensitivity_part = sum_tensor(squared_error*bg_onehot, axes)/(sum_tensor(bg_onehot, axes)+self.smooth)
ss = self.r * specificity_part + (1-self.r) * sensitivity_part
if not self.do_bg:
if self.batch_dice:
ss = ss[1:]
else:
ss = ss[:, 1:]
ss = ss.mean()
return ss
tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
if mask is not None:
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
if square:
tp = tp ** 2
fp = fp ** 2
fn = fn ** 2
tp = sum_tensor(tp, axes, keepdim=False)
fp = sum_tensor(fp, axes, keepdim=False)
fn = sum_tensor(fn, axes, keepdim=False)
return tp, fp, fn
tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
if mask is not None:
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
if square:
tp = tp ** 2
fp = fp ** 2
fn = fn ** 2
tp = sum_tensor(tp, axes, keepdim=False)
fp = sum_tensor(fp, axes, keepdim=False)
fn = sum_tensor(fn, axes, keepdim=False)
return tp, fp, fn
tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
if mask is not None:
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
if square:
tp = tp ** 2
fp = fp ** 2
fn = fn ** 2
tp = sum_tensor(tp, axes, keepdim=False)
fp = sum_tensor(fp, axes, keepdim=False)
fn = sum_tensor(fn, axes, keepdim=False)
return tp, fp, fn
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
if mask is not None:
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
if square:
tp = tp ** 2
fp = fp ** 2
fn = fn ** 2
tp = sum_tensor(tp, axes, keepdim=False)
fp = sum_tensor(fp, axes, keepdim=False)
fn = sum_tensor(fn, axes, keepdim=False)
return tp, fp, fn
def __init__(self, gdice_kwargs):
super(PenaltyGDiceLoss, self).__init__()
self.k = 2.5
self.gdc = GDiceLoss(apply_nonlin=softmax_helper, **gdice_kwargs)
def forward(self, net_output, target, bound):
"""
net_output: (batch_size, class, x,y,z)
target: ground truth, shape: (batch_size, 1, x,y,z)
bound: precomputed distance map, shape (batch_size, class, x,y,z)
"""
net_output = softmax_helper(net_output)
# print('net_output shape: ', net_output.shape)
pc = net_output[:, 1:, ...].type(torch.float32)
dc = bound[:,1:, ...].type(torch.float32)
multipled = torch.einsum("bcxyz,bcxyz->bcxyz", pc, dc)
bd_loss = multipled.mean()
return bd_loss
def __init__(self, soft_dice_kwargs, bd_kwargs, aggregate="sum"):
super(DC_and_BD_loss, self).__init__()
self.aggregate = aggregate
self.bd = BDLoss(**bd_kwargs)
self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
unpack_data=True, deterministic=True, fp16=False):
super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage,
unpack_data, deterministic, fp16)
self.apply_nonlin = softmax_helper
self.loss = GDiceLossV2(apply_nonlin=self.apply_nonlin, smooth=1e-5)