Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
----------
Fm : torch.nn.Module
A torch.nn.Module encapsulating an arbitrary function
Gm : torch.nn.Module
A torch.nn.Module encapsulating an arbitrary function
(If not specified a deepcopy of Fm is used as a Module)
implementation_fwd : int
Switch between different Additive Operation implementations for forward pass. Default = 1
implementation_bwd : int
Switch between different Additive Operation implementations for inverse pass. Default = 1
"""
super(AdditiveBlock, self).__init__()
# mirror the passed module, without parameter sharing...
if Gm is None:
Gm = copy.deepcopy(Fm)
self.Gm = Gm
self.Fm = Fm
self.implementation_fwd = implementation_fwd
self.implementation_bwd = implementation_bwd
keep_input : bool
Retain the input information, by default it can be discarded since it will be
reconstructed upon the backward pass.
implementation_fwd : int
Switch between different Operation implementations for forward training. Default = 1
implementation_bwd : int
Switch between different Operation implementations for backward training. Default = 1
"""
super(ReversibleBlock, self).__init__()
if coupling == 'additive':
self.rev_block = AdditiveBlock(Fm, Gm, keep_input, implementation_fwd, implementation_bwd)
elif coupling == 'affine':
self.rev_block = AffineBlock(Fm, Gm, keep_input, implementation_fwd, implementation_bwd)
else:
raise NotImplementedError('Unknown coupling method: %s' % coupling)
keep_input_inverse : :obj:`bool`, optional
Set to retain the input information on inverse, by default it can be discarded since it will be
reconstructed upon the backward pass.
Raises
------
NotImplementedError
If an unknown coupling or implementation is given.
"""
super(ReversibleBlock, self).__init__()
self.keep_input = keep_input
self.keep_input_inverse = keep_input_inverse
if coupling == 'additive':
self.rev_block = AdditiveBlock(Fm, Gm,
implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd)
elif coupling == 'affine':
self.rev_block = AffineBlock(Fm, Gm, adapter=adapter,
implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd)
else:
raise NotImplementedError('Unknown coupling method: %s' % coupling)