Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def forward(self, in0, in1):
assert(in0.size()[0]==1) # currently only supports batchSize 1
if(self.colorspace=='RGB'):
(N,C,X,Y) = in0.size()
value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
return value
elif(self.colorspace=='Lab'):
value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
ret_var = Variable( torch.Tensor((value,) ) )
if(self.use_gpu):
ret_var = ret_var.cuda()
return ret_var