Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, gdice_kwargs):
super(PenaltyGDiceLoss, self).__init__()
self.k = 2.5
self.gdc = GDiceLoss(apply_nonlin=softmax_helper, **gdice_kwargs)
def forward(self, net_output, target, bound):
"""
net_output: (batch_size, class, x,y,z)
target: ground truth, shape: (batch_size, 1, x,y,z)
bound: precomputed distance map, shape (batch_size, class, x,y,z)
"""
net_output = softmax_helper(net_output)
# print('net_output shape: ', net_output.shape)
pc = net_output[:, 1:, ...].type(torch.float32)
dc = bound[:,1:, ...].type(torch.float32)
multipled = torch.einsum("bcxyz,bcxyz->bcxyz", pc, dc)
bd_loss = multipled.mean()
return bd_loss
def __init__(self, soft_dice_kwargs, bd_kwargs, aggregate="sum"):
super(DC_and_BD_loss, self).__init__()
self.aggregate = aggregate
self.bd = BDLoss(**bd_kwargs)
self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
unpack_data=True, deterministic=True, fp16=False):
super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage,
unpack_data, deterministic, fp16)
self.apply_nonlin = softmax_helper
self.loss = GDiceLossV2(apply_nonlin=self.apply_nonlin, smooth=1e-5)
def forward(self, net_output, gt):
"""
net_output: (batch_size, 2, x,y,z)
target: ground truth, shape: (batch_size, 1, x,y,z)
"""
net_output = softmax_helper(net_output)
# one hot code for gt
with torch.no_grad():
if len(net_output.shape) != len(gt.shape):
gt = gt.view((gt.shape[0], 1, *gt.shape[1:]))
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(net_output.shape)
if net_output.device.type == "cuda":
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
gt_temp = gt[:,0, ...].type(torch.float32)