Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
inputs : tf.Tensor
input tensor
name : str
scope name
Returns
-------
tf.Tensor
"""
kwargs = cls.fill_params('body', **kwargs)
layout, kernel_size = cls.pop(['layout', 'kernel_size'], kwargs)
x, inputs = inputs, None
with tf.variable_scope(name):
if downsample:
x = conv_block(x, layout='cna', kernel_size=2, strides=2, name='downsample', **kwargs)
x = ResNet.block(x, layout=layout, kernel_size=kernel_size, downsample=0, name='conv', **kwargs)
return x
Returns
-------
tf.Tensor
"""
kwargs = cls.fill_params('body', **kwargs)
layout, filters = cls.pop(['layout', 'filters'], kwargs)
mask = cls.pop('mask', kwargs)
trunk = cls.pop('trunk', kwargs)
mask = {**kwargs, **mask}
trunk = {**kwargs, **trunk}
x, inputs = inputs, None
with tf.variable_scope(name):
for i, b in enumerate(layout):
if b == 'r':
x = ResNet.block(x, filters=filters[i], name='resblock-%d' % i, **{**kwargs, 'downsample':True})
else:
x = cls.attention(x, level=int(b), filters=filters[i], name='attention-%d' % i, **kwargs)
return x
Parameters
----------
inputs : tf.Tensor
input tensor
level : int
nested mask level
name : str
scope name
Returns
-------
tf.Tensor
"""
with tf.variable_scope(name):
x, inputs = inputs, None
x = ResNet.block(x, name='initial', **kwargs)
t = cls.trunk(x, **kwargs)
m = cls.mask((x, t), level=level, **kwargs)
x = conv_block(m, layout='nac nac', kernel_size=1, name='scale',
**{**kwargs, 'filters': kwargs['filters']*4})
x = tf.sigmoid(x, name='attention_map')
x = (1 + x) * t
x = ResNet.block(x, name='last', **kwargs)
return x
""" An ordinary ResNet block
Parameters
----------
inputs : tf.Tensor
input tensor
name : str
scope name
Returns
-------
tf.Tensor
"""
kwargs = cls.fill_params('body/br', **kwargs)
kwargs['filters'] = cls.num_channels(inputs, data_format=kwargs['data_format'])
return ResNet.block(inputs, name=name, **kwargs)
nested mask level
name : str
scope name
Returns
-------
tf.Tensor
"""
kwargs = cls.fill_params('body/mask', **kwargs)
upsample_args = cls.pop('upsample', kwargs)
with tf.variable_scope(name):
x, skip = inputs
inputs = None
x = conv_block(x, layout='p', name='pool', **kwargs)
b = ResNet.block(x, name='resblock_1', **kwargs)
c = ResNet.block(b, name='resblock_2', **kwargs)
if level > 0:
i = cls.mask((b, b), level=level-1, name='submask-%d' % level, **kwargs)
c = ResNet.block(c + i, name='resblock_3', **kwargs)
x = cls.upsample(c, resize_to=skip, name='interpolation', data_format=kwargs['data_format'],
**upsample_args)
return x
name : str
scope name
Returns
-------
tf.Tensor
"""
kwargs = cls.fill_params('body/mask', **kwargs)
upsample_args = cls.pop('upsample', kwargs)
with tf.variable_scope(name):
x, skip = inputs
inputs = None
x = conv_block(x, layout='p', name='pool', **kwargs)
b = ResNet.block(x, name='resblock_1', **kwargs)
c = ResNet.block(b, name='resblock_2', **kwargs)
if level > 0:
i = cls.mask((b, b), level=level-1, name='submask-%d' % level, **kwargs)
c = ResNet.block(c + i, name='resblock_3', **kwargs)
x = cls.upsample(c, resize_to=skip, name='interpolation', data_format=kwargs['data_format'],
**upsample_args)
return x
-------
tf.Tensor
"""
kwargs = cls.fill_params('body/mask', **kwargs)
upsample_args = cls.pop('upsample', kwargs)
with tf.variable_scope(name):
x, skip = inputs
inputs = None
x = conv_block(x, layout='p', name='pool', **kwargs)
b = ResNet.block(x, name='resblock_1', **kwargs)
c = ResNet.block(b, name='resblock_2', **kwargs)
if level > 0:
i = cls.mask((b, b), level=level-1, name='submask-%d' % level, **kwargs)
c = ResNet.block(c + i, name='resblock_3', **kwargs)
x = cls.upsample(c, resize_to=skip, name='interpolation', data_format=kwargs['data_format'],
**upsample_args)
return x
tf.Tensor
"""
with tf.variable_scope(name):
x, inputs = inputs, None
x = ResNet.block(x, name='initial', **kwargs)
t = cls.trunk(x, **kwargs)
m = cls.mask((x, t), level=level, **kwargs)
x = conv_block(m, layout='nac nac', kernel_size=1, name='scale',
**{**kwargs, 'filters': kwargs['filters']*4})
x = tf.sigmoid(x, name='attention_map')
x = (1 + x) * t
x = ResNet.block(x, name='last', **kwargs)
return x
tf.Tensor
"""
with tf.variable_scope(name):
x, inputs = inputs, None
x = ResNet.block(x, name='initial', **kwargs)
t = cls.trunk(x, **kwargs)
m = cls.mask((x, t), level=level, **kwargs)
x = conv_block(m, layout='nac nac', kernel_size=1, name='scale',
**{**kwargs, 'filters': kwargs['filters']*4})
x = tf.sigmoid(x, name='attention_map')
x = (1 + x) * t
x = ResNet.block(x, name='last', **kwargs)
return x