Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
all((s % 2 == 1 for s in kernel_size)) or _raise(ValueError('kernel size should be odd in all dimensions.'))
channel_axis = -1 if backend_channels_last() else 1
n_dim = len(kernel_size)
conv = Conv2D if n_dim==2 else Conv3D
input = Input(input_shape, name = "input")
unet = unet_block(n_depth, n_filter_base, kernel_size,
activation=activation, dropout=dropout, batch_norm=batch_norm,
n_conv_per_depth=n_conv_per_depth, pool=pool_size)(input)
final = conv(n_channel_out, (1,)*n_dim, activation='linear')(unet)
if residual:
if not (n_channel_out == input_shape[-1] if backend_channels_last() else n_channel_out == input_shape[0]):
raise ValueError("number of input and output channels must be the same for a residual net.")
final = Add()([final, input])
final = Activation(activation=last_activation)(final)
if prob_out:
scale = conv(n_channel_out, (1,)*n_dim, activation='softplus')(unet)
scale = Lambda(lambda x: x+np.float32(eps_scale))(scale)
final = Concatenate(axis=channel_axis)([final,scale])
return Model(inputs=input, outputs=final)
def resnet_block(n_filter, kernel_size=(3,3), pool=(1,1), n_conv_per_block=2,
batch_norm=False, kernel_initializer='he_normal', activation='relu'):
n_conv_per_block >= 2 or _raise(ValueError('required: n_conv_per_block >= 2'))
len(pool) == len(kernel_size) or _raise(ValueError('kernel and pool sizes must match.'))
n_dim = len(kernel_size)
n_dim in (2,3) or _raise(ValueError('resnet_block only 2d or 3d.'))
conv_layer = Conv2D if n_dim == 2 else Conv3D
conv_kwargs = dict (
padding = 'same',
use_bias = not batch_norm,
kernel_initializer = kernel_initializer,
)
channel_axis = -1 if backend_channels_last() else 1
def f(inp):
x = conv_layer(n_filter, kernel_size, strides=pool, **conv_kwargs)(inp)
if batch_norm:
x = BatchNormalization(axis=channel_axis)(x)
x = Activation(activation)(x)
for _ in range(n_conv_per_block-2):
x = conv_layer(n_filter, kernel_size, **conv_kwargs)(x)
if batch_norm:
x = BatchNormalization(axis=channel_axis)(x)
x = Activation(activation)(x)
x = conv_layer(n_filter, kernel_size, **conv_kwargs)(x)
if batch_norm:
x = BatchNormalization(axis=channel_axis)(x)
def from_tensor(x,channel=-1,single_sample=True):
return np.moveaxis((x[0] if single_sample else x), (-1 if backend_channels_last() else 1), channel)
def loss_mae(mean=True):
R = _mean_or_not(mean)
if backend_channels_last():
def mae(y_true, y_pred):
n = K.shape(y_true)[-1]
return R(K.abs(y_pred[...,:n] - y_true))
return mae
else:
def mae(y_true, y_pred):
n = K.shape(y_true)[1]
return R(K.abs(y_pred[:,:n,...] - y_true))
return mae
self.unet_n_depth = 3
self.unet_kernel_size = 3,3
self.unet_n_filter_base = 32
self.unet_n_conv_per_depth = 2
self.unet_pool = 2,2
self.unet_activation = 'relu'
self.unet_last_activation = 'relu'
self.unet_batch_norm = False
self.unet_dropout = 0.0
self.unet_prefix = ''
self.net_conv_after_unet = 128
else:
# TODO: resnet backbone for 2D model?
raise ValueError("backbone '%s' not supported." % self.backbone)
if backend_channels_last():
self.net_input_shape = None,None,self.n_channel_in
self.net_mask_shape = None,None,1
else:
self.net_input_shape = self.n_channel_in,None,None
self.net_mask_shape = 1,None,None
self.train_shape_completion = False
self.train_completion_crop = 32
self.train_patch_size = 256,256
self.train_background_reg = 1e-4
self.train_dist_loss = 'mae'
self.train_loss_weights = 1,0.2
self.train_epochs = 400
self.train_steps_per_epoch = 100
self.train_learning_rate = 0.0003
if validation_split > 0:
n_val = int(round(n_images * validation_split))
n_train = n_images - n_val
assert 0 < n_val and 0 < n_train
X_t, Y_t = X[-n_val:], Y[-n_val:]
X, Y = X[:n_train], Y[:n_train]
assert X.shape[0] == n_train and X_t.shape[0] == n_val
X_t = move_channel_for_backend(X_t,channel=channel)
Y_t = move_channel_for_backend(Y_t,channel=channel)
X = move_channel_for_backend(X,channel=channel)
Y = move_channel_for_backend(Y,channel=channel)
axes = axes.replace('C','') # remove channel
if backend_channels_last():
axes = axes+'C'
else:
axes = axes[:1]+'C'+axes[1:]
data_val = (X_t,Y_t) if validation_split > 0 else None
if verbose:
ax = axes_dict(axes)
n_train, n_val = len(X), len(X_t) if validation_split>0 else 0
image_size = tuple( X.shape[ax[a]] for a in axes if a in 'TZYX' )
n_dim = len(image_size)
n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']]
print('number of training images:\t', n_train)
print('number of validation images:\t', n_val)
print('image size (%dD):\t\t'%n_dim, image_size)
activation="relu",
batch_norm=False,
dropout=0.0,
pool_size=(2,2,2),
n_channel_out=1,
residual=False,
prob_out=False,
eps_scale=1e-3):
""" TODO """
if last_activation is None:
raise ValueError("last activation has to be given (e.g. 'sigmoid', 'relu')!")
all((s % 2 == 1 for s in kernel_size)) or _raise(ValueError('kernel size should be odd in all dimensions.'))
channel_axis = -1 if backend_channels_last() else 1
n_dim = len(kernel_size)
conv = Conv2D if n_dim==2 else Conv3D
input = Input(input_shape, name = "input")
unet = unet_block(n_depth, n_filter_base, kernel_size,
activation=activation, dropout=dropout, batch_norm=batch_norm,
n_conv_per_depth=n_conv_per_depth, pool=pool_size)(input)
final = conv(n_channel_out, (1,)*n_dim, activation='linear')(unet)
if residual:
if not (n_channel_out == input_shape[-1] if backend_channels_last() else n_channel_out == input_shape[0]):
raise ValueError("number of input and output channels must be the same for a residual net.")
final = Add()([final, input])
final = Activation(activation=last_activation)(final)
def _build_unet(self):
assert self.config.backbone == 'unet'
input_img = Input(self.config.net_input_shape, name='input')
if backend_channels_last():
grid_shape = tuple(n//g if n is not None else None for g,n in zip(self.config.grid, self.config.net_mask_shape[:-1])) + (1,)
else:
grid_shape = (1,) + tuple(n//g if n is not None else None for g,n in zip(self.config.grid, self.config.net_mask_shape[1:]))
input_mask = Input(grid_shape, name='dist_mask')
unet_kwargs = {k[len('unet_'):]:v for (k,v) in vars(self.config).items() if k.startswith('unet_')}
# maxpool input image to grid size
pooled = np.array([1,1,1])
pooled_img = input_img
while tuple(pooled) != tuple(self.config.grid):
pool = 1 + (np.asarray(self.config.grid) > pooled)
pooled *= pool
for _ in range(self.config.unet_n_conv_per_depth):
pooled_img = Conv3D(self.config.unet_n_filter_base, self.config.unet_kernel_size,
padding="same", activation=self.config.unet_activation)(pooled_img)
def __init__(self, axes='YX', n_channel_in=1, n_channel_out=1, probabilistic=False, allow_new_parameters=False, **kwargs):
"""See class docstring."""
super(Config, self).__init__(axes, n_channel_in, n_channel_out)
not ('Z' in self.axes and 'T' in self.axes) or _raise(ValueError('using Z and T axes together not supported.'))
self.probabilistic = bool(probabilistic)
# default config (can be overwritten by kwargs below)
self.unet_residual = self.n_channel_in == self.n_channel_out
self.unet_n_depth = 2
self.unet_kern_size = 5 if self.n_dim==2 else 3
self.unet_n_first = 32
self.unet_last_activation = 'linear'
if backend_channels_last():
self.unet_input_shape = self.n_dim*(None,) + (self.n_channel_in,)
else:
self.unet_input_shape = (self.n_channel_in,) + self.n_dim*(None,)
self.train_loss = 'laplace' if self.probabilistic else 'mae'
self.train_epochs = 100
self.train_steps_per_epoch = 400
self.train_learning_rate = 0.0004
self.train_batch_size = 16
self.train_tensorboard = True
# the parameter 'min_delta' was called 'epsilon' for keras<=2.1.5
min_delta_key = 'epsilon' if LooseVersion(keras.__version__)<=LooseVersion('2.1.5') else 'min_delta'
self.train_reduce_lr = {'factor': 0.5, 'patience': 10, min_delta_key: 0}
# disallow setting 'n_dim' manually
def loss_laplace(mean=True):
R = _mean_or_not(mean)
C = np.log(2.0)
if backend_channels_last():
def nll(y_true, y_pred):
n = K.shape(y_true)[-1]
mu = y_pred[...,:n]
sigma = y_pred[...,n:]
return R(K.abs((mu-y_true)/sigma) + K.log(sigma) + C)
return nll
else:
def nll(y_true, y_pred):
n = K.shape(y_true)[1]
mu = y_pred[:,:n,...]
sigma = y_pred[:,n:,...]
return R(K.abs((mu-y_true)/sigma) + K.log(sigma) + C)
return nll