Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torch.nn.modules.utils import _single, _pair, _triple
from torchmeta.modules.module import MetaModule
class MetaConv1d(nn.Conv1d, MetaModule):
__doc__ = nn.Conv1d.__doc__
def forward(self, input, params=None):
if params is None:
params = OrderedDict(self.named_parameters())
bias = params.get('bias', None)
if self.padding_mode == 'circular':
expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2)
return F.conv1d(F.pad(input, expanded_padding, mode='circular'),
params['weight'], bias, self.stride,
_single(0), self.dilation, self.groups)
return F.conv1d(input, params['weight'], bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input, params=None):
if params is None:
params = OrderedDict(self.named_parameters())
bias = params.get('bias', None)
if self.padding_mode == 'circular':
expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
(self.padding[0] + 1) // 2, self.padding[0] // 2)
return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
params['weight'], bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, params['weight'], bias, self.stride,
self.padding, self.dilation, self.groups)
class MetaConv3d(nn.Conv3d, MetaModule):
__doc__ = nn.Conv3d.__doc__
def forward(self, input, params=None):
if params is None:
params = OrderedDict(self.named_parameters())
bias = params.get('bias', None)
if self.padding_mode == 'circular':
expanded_padding = ((self.padding[2] + 1) // 2, self.padding[2] // 2,
(self.padding[1] + 1) // 2, self.padding[1] // 2,
(self.padding[0] + 1) // 2, self.padding[0] // 2)
return F.conv3d(F.pad(input, expanded_padding, mode='circular'),
params['weight'], bias, self.stride,
_triple(0), self.dilation, self.groups)
return F.conv3d(input, params['weight'], bias, self.stride,
def forward(self, input, params=None):
if params is None:
params = OrderedDict(self.named_parameters())
bias = params.get('bias', None)
if self.padding_mode == 'circular':
expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2)
return F.conv1d(F.pad(input, expanded_padding, mode='circular'),
params['weight'], bias, self.stride,
_single(0), self.dilation, self.groups)
return F.conv1d(input, params['weight'], bias, self.stride,
self.padding, self.dilation, self.groups)
class MetaConv2d(nn.Conv2d, MetaModule):
__doc__ = nn.Conv2d.__doc__
def forward(self, input, params=None):
if params is None:
params = OrderedDict(self.named_parameters())
bias = params.get('bias', None)
if self.padding_mode == 'circular':
expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
(self.padding[0] + 1) // 2, self.padding[0] // 2)
return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
params['weight'], bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, params['weight'], bias, self.stride,
self.padding, self.dilation, self.groups)
import torch.nn as nn
from torchmeta.modules.module import MetaModule
from torchmeta.modules.utils import get_subdict
class MetaSequential(nn.Sequential, MetaModule):
__doc__ = nn.Sequential.__doc__
def forward(self, input, params=None):
for name, module in self._modules.items():
if isinstance(module, MetaModule):
input = module(input, params=get_subdict(params, name))
elif isinstance(module, nn.Module):
input = module(input)
else:
raise TypeError('The module must be either a torch module '
'(inheriting from `nn.Module`), or a `MetaModule`. '
'Got type: `{0}`'.format(type(module)))
return input
def forward(self, input, params=None):
for name, module in self._modules.items():
if isinstance(module, MetaModule):
input = module(input, params=get_subdict(params, name))
elif isinstance(module, nn.Module):
input = module(input)
else:
raise TypeError('The module must be either a torch module '
'(inheriting from `nn.Module`), or a `MetaModule`. '
'Got type: `{0}`'.format(type(module)))
return input
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torchmeta.modules.module import MetaModule
class MetaLayerNorm(nn.LayerNorm, MetaModule):
__doc__ = nn.LayerNorm.__doc__
def forward(self, input, params=None):
if params is None:
params = OrderedDict(self.named_parameters())
weight = params.get('weight', None)
bias = params.get('bias', None)
return F.layer_norm(
input, self.normalized_shape, weight, bias, self.eps)
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torch.nn.modules.batchnorm import _BatchNorm
from torchmeta.modules.module import MetaModule
class _MetaBatchNorm(_BatchNorm, MetaModule):
def forward(self, input, params=None):
self._check_input_dim(input)
if params is None:
params = OrderedDict(self.named_parameters())
# exponential_average_factor is self.momentum set to
# (when it is available) only so that if gets updated
# in ONNX graph when this node is exported to ONNX.
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
lambda module: module._parameters.items()
if isinstance(module, MetaModule) else [],
prefix=prefix, recurse=recurse)
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torchmeta.modules.module import MetaModule
class MetaLinear(nn.Linear, MetaModule):
__doc__ = nn.Linear.__doc__
def forward(self, input, params=None):
if params is None:
params = OrderedDict(self.named_parameters())
bias = params.get('bias', None)
return F.linear(input, params['weight'], bias)
class MetaBilinear(nn.Bilinear, MetaModule):
__doc__ = nn.Bilinear.__doc__
def forward(self, input1, input2, params=None):
if params is None:
params = OrderedDict(self.named_parameters())
bias = params.get('bias', None)
return F.bilinear(input1, input2, params['weight'], bias)
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torchmeta.modules.module import MetaModule
class MetaLinear(nn.Linear, MetaModule):
__doc__ = nn.Linear.__doc__
def forward(self, input, params=None):
if params is None:
params = OrderedDict(self.named_parameters())
bias = params.get('bias', None)
return F.linear(input, params['weight'], bias)
class MetaBilinear(nn.Bilinear, MetaModule):
__doc__ = nn.Bilinear.__doc__
def forward(self, input1, input2, params=None):
if params is None:
params = OrderedDict(self.named_parameters())
bias = params.get('bias', None)
return F.bilinear(input1, input2, params['weight'], bias)