Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
if filters > self.input_num_channels:
growth_rate = (filters - self.input_num_channels) // num_layers
else:
growth_rate = filters // num_layers
filters = growth_rate
if bottleneck:
bottleneck = 4 if bottleneck is True else bottleneck
layout = 'cna' + layout
kernel_size = [1, kernel_size]
strides = [1, strides]
filters = [growth_rate * bottleneck, filters]
layout = 'R' + layout + '.'
self.layer = ConvBlock(layout=layout, kernel_size=kernel_size, strides=strides, dropout_rate=dropout_rate,
filters=filters, n_repeats=num_layers, inputs=inputs, **kwargs)
factor = [factor] * num_stages
elif not isinstance(factor, list):
raise TypeError('factor should be int or list of int, but %s was given' % type(factor))
block_args = kwargs.pop('blocks')
upsample_args = kwargs.pop('upsample')
combine_args = kwargs.pop('combine')
self.decoder_b, self.decoder_u, self.decoder_c = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
for i in range(num_stages):
for letter in decoder_layout:
if letter in ['b']:
args = {**kwargs, **block_args, **unpack_args(block_args, i, num_stages)}
layer = ConvBlock(inputs=x, **args)
x = layer(x)
self.decoder_b.append(layer)
elif letter in ['u']:
args = {'factor': factor[i],
**kwargs, **upsample_args, **unpack_args(upsample_args, i, num_stages)}
layer = Upsample(inputs=x, **args)
x = layer(x)
self.decoder_u.append(layer)
elif letter in ['c']:
args = {**kwargs, **combine_args, **unpack_args(combine_args, i, num_stages)}
if skip and (i < len(inputs) - 2):
layer = Combine(inputs=[x, inputs[-i - 3]])
x = layer([inputs[-i - 3], x])
self.decoder_c.append(layer)
number of output filters
downsample : dict
parameters for downsampling blocks
encoder : dict
parameters for encoder blocks
Returns
-------
nn.Module
"""
if downsample:
downsample = cls.get_defaults('body/downsample', downsample)
down_block = ConvBlock(inputs, filters=filters, **{**kwargs, **downsample})
inputs = down_block
encoder = cls.get_defaults('body/encoder', encoder)
enc_block = ConvBlock(inputs, filters=filters, **{**kwargs, **encoder})
return nn.Sequential(down_block, enc_block) if downsample else enc_block
def initial_block(cls, inputs, **kwargs):
""" Transform inputs. Usually used for initial preprocessing, e.g. reshaping, downsampling etc.
Notes
-----
For parameters see :class:`~.torch.layers.ConvBlock`.
Returns
-------
torch.nn.Module or None
"""
kwargs = cls.get_defaults('initial_block', kwargs)
if kwargs.get('layout') or kwargs.get('base_block'):
return ConvBlock(inputs=inputs, **kwargs)
return None
def head(cls, inputs, num_classes, **kwargs):
""" 1x1 convolution
Parameters
----------
inputs
input tensor
num_classes : int
number of classes (and number of filters in the last 1x1 convolution)
Returns
-------
nn.Module
"""
kwargs = cls.get_defaults('head', kwargs)
x = ConvBlock(inputs, filters=num_classes, **kwargs)
return x
def __init__(self, **kwargs):
super().__init__()
attrs = {name.lower(): value for name, value in vars(type(self)).items()
if name.isupper()}
kwargs = {**attrs, **kwargs}
self.layer = ConvBlock(**kwargs)
if bottleneck:
bottleneck = 4 if bottleneck is True else bottleneck
layout = 'cna' + layout + 'cna'
kernel_size = [1] + kernel_size + [1]
strides = [1] + strides + [1]
strides_downsample = [1] + strides_downsample + [1]
groups = [1] + groups + [1]
filters = [filters[0]] + filters + [filters[0] * bottleneck]
if se:
layout += 'S*'
layout = 'B' + layout + op
layer_params = [{'strides': strides_downsample, 'side_branch/strides': side_branch_stride_downsample}]
layer_params += [{}]*(n_reps-1)
self.layer = ConvBlock(*layer_params, inputs=inputs, layout=layout, filters=filters,
kernel_size=kernel_size, strides=strides, groups=groups,
side_branch={'layout': 'c', 'filters': filters[-1], 'strides': side_branch_stride},
**kwargs)
def head(cls, inputs, num_classes, **kwargs):
""" Conv block with 1x1 convolution
Parameters
----------
num_classes : int
number of classes (and number of filters in the last 1x1 convolution)
Returns
-------
nn.Module
"""
kwargs = cls.get_defaults('head', kwargs)
return ConvBlock(inputs, filters=num_classes, **kwargs)
shortcut = ConvBlock(inputs,
**{**kwargs, **dict(layout='c', filters=inputs_channels,
kernel_size=1, strides=strides)})
if inputs_channels < x_channels:
padding = [(0, 0) for _ in range(get_num_dims(inputs))]
padding[-1] = (0, x_channels - inputs_channels)
padding = sum(tuple(padding), ())
elif inputs_channels != x_channels or downsample:
shortcut = ConvBlock(inputs,
**{**kwargs, **dict(layout='c', filters=x_channels,
kernel_size=1, strides=strides)})
x = ResBlock(x, shortcut, padding)
if post_activation:
p = ConvBlock(x, layout=post_activation, **kwargs)
x = nn.Sequential(x, p)
return x
def body(cls, inputs, **kwargs):
""" Base layers which produce a network embedding.
Notes
-----
For parameters see :class:`~.torch.layers.ConvBlock`.
Returns
-------
torch.nn.Module or None
"""
kwargs = cls.get_defaults('body', kwargs)
if kwargs.get('layout') or kwargs.get('base_block'):
return ConvBlock(inputs=inputs, **kwargs)
return None