Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
A torch.nn.Module encapsulating an arbitrary function
(If not specified a deepcopy of Gm is used as a Module)
adapter : torch.nn.Module class
An optional wrapper class A for Fm and Gm which must output
s, t = A(x) with shape(s) = shape(t) = shape(x)
s, t are respectively the scale and shift tensors for the affine coupling.
implementation_fwd : int
Switch between different Affine Operation implementations for forward pass. Default = 1
implementation_bwd : int
Switch between different Affine Operation implementations for inverse pass. Default = 1
"""
super(AffineBlock, self).__init__()
# mirror the passed module, without parameter sharing...
if Gm is None:
Gm = copy.deepcopy(Fm)
# apply the adapter class if it is given
self.Gm = adapter(Gm) if adapter is not None else Gm
self.Fm = adapter(Fm) if adapter is not None else Fm
self.implementation_fwd = implementation_fwd
self.implementation_bwd = implementation_bwd
Retain the input information, by default it can be discarded since it will be
reconstructed upon the backward pass.
implementation_fwd : int
Switch between different Operation implementations for forward training. Default = 1
implementation_bwd : int
Switch between different Operation implementations for backward training. Default = 1
"""
super(ReversibleBlock, self).__init__()
if coupling == 'additive':
self.rev_block = AdditiveBlock(Fm, Gm, keep_input, implementation_fwd, implementation_bwd)
elif coupling == 'affine':
self.rev_block = AffineBlock(Fm, Gm, keep_input, implementation_fwd, implementation_bwd)
else:
raise NotImplementedError('Unknown coupling method: %s' % coupling)
reconstructed upon the backward pass.
Raises
------
NotImplementedError
If an unknown coupling or implementation is given.
"""
super(ReversibleBlock, self).__init__()
self.keep_input = keep_input
self.keep_input_inverse = keep_input_inverse
if coupling == 'additive':
self.rev_block = AdditiveBlock(Fm, Gm,
implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd)
elif coupling == 'affine':
self.rev_block = AffineBlock(Fm, Gm, adapter=adapter,
implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd)
else:
raise NotImplementedError('Unknown coupling method: %s' % coupling)