Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def forward(self, x):
args = [x, self.Fm, self.Gm] + [w for w in self.Fm.parameters()] + [w for w in self.Gm.parameters()]
if self.implementation_fwd == 0:
out = AdditiveBlockFunction.apply(*args)
elif self.implementation_fwd == 1:
out = AdditiveBlockFunction2.apply(*args)
elif self.implementation_fwd == -1:
warnings.warn('Using direct non-memory saving implementation.', NonMemorySavingWarning)
x1, x2 = torch.chunk(x, 2, dim=1)
x1, x2 = x1.contiguous(), x2.contiguous()
fmd = self.Fm.forward(x2)
y1 = x1 + fmd
gmd = self.Gm.forward(y1)
y2 = x2 + gmd
out = torch.cat([y1, y2], dim=1)
else:
raise NotImplementedError("Selected implementation ({}) not implemented..."
.format(self.implementation_fwd))
return out