Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _predict(imdims,axes):
img = rng.uniform(size=imdims)
if config.probabilistic:
prob = model.predict_probabilistic(img, axes, factor, None, None)
mean, scale = prob.mean(), prob.scale()
assert mean.shape == scale.shape
else:
mean = model.predict(img, axes, factor, None, None)
a = axes_dict(axes)['Z']
assert imdims[a]*factor == mean.shape[a]
patch_size[axes_dict(img_axes if patch_axes is None else patch_axes)[a]] = (
None if red_none else img_size[axes_dict(img_axes)[a]]
)
X,Y,XYaxes = create_patches_reduced_target (
raw_data = raw_data,
patch_size = patch_size,
patch_axes = patch_axes,
n_patches_per_image = n_patches_per_image,
reduction_axes = red_axes,
target_axes = rng.choice((None,img_axes)) if keepdims else ''.join(a for a in img_axes if a not in red_axes),
#
normalization = lambda patches_x, patches_y, *args: (patches_x, patches_y),
verbose = False,
)
assert len(X) == n_images*n_patches_per_image
_X = np.mean(X,axis=tuple(axes_dict(XYaxes)[a] for a in red_axes),keepdims=True)
err = np.max(np.abs(_X-Y))
assert err < 1e-5
def _predict(imdims,axes):
img = rng.uniform(size=imdims)
if config.probabilistic:
prob = model.predict_probabilistic(img, axes, factor, None, None)
mean, scale = prob.mean(), prob.scale()
assert mean.shape == scale.shape
else:
mean = model.predict(img, axes, factor, None, None)
a = axes_dict(axes)['Z']
assert imdims[a]*factor == mean.shape[a]
See :func:`predict` for parameter explanations.
Returns
-------
tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray` or None)
If model is probabilistic, returns a tuple `(mean, scale)` that defines the parameters
of per-pixel Laplace distributions. Otherwise, returns the restored image via a tuple `(restored,None)`
"""
normalizer, resizer = self._check_normalizer_resizer(normalizer, resizer)
axes = axes_check_and_normalize(axes,img.ndim)
_permute_axes = self._make_permute_axes(axes, self.config.axes)
x = _permute_axes(img)
channel = axes_dict(self.config.axes)['C']
self.config.n_channel_in == x.shape[channel] or _raise(ValueError())
# normalize
x = normalizer.before(x,self.config.axes)
# resize: make divisible by power of 2 to allow downsampling steps in unet
div_n = 2 ** self.config.unet_n_depth
x = resizer.before(x,div_n,exclude=channel)
done = False
while not done:
try:
if n_tiles == 1:
x = predict_direct(self.keras_model,x,channel_in=channel,channel_out=channel)
else:
overlap = tile_overlap(self.config.unet_n_depth, self.config.unet_kern_size)
"""TODO."""
if all(t==1 for t in n_tiles):
pred = predict_direct(keras_model,x,axes_in,axes_out,**kwargs)
if pbar is not None:
pbar.update()
return pred
###
if axes_out is None:
axes_out = axes_in
axes_in, axes_out = axes_check_and_normalize(axes_in,x.ndim), axes_check_and_normalize(axes_out)
assert 'S' not in axes_in
assert 'C' in axes_in and 'C' in axes_out
ax_in, ax_out = axes_dict(axes_in), axes_dict(axes_out)
channel_in, channel_out = ax_in['C'], ax_out['C']
assert set(axes_out).issubset(set(axes_in))
axes_lost = set(axes_in).difference(set(axes_out))
def _to_axes_out(seq,elem):
# assumption: prediction size is same as input size along all axes, except for channel (and lost axes)
assert len(seq) == len(axes_in)
# 1. re-order 'seq' from axes_in to axes_out semantics
seq = [seq[ax_in[a]] for a in axes_out]
# 2. replace value at channel position with 'elem'
seq[ax_out['C']] = elem
return tuple(seq)
###
print(raw_data.description)
print('='*66)
print('Transformations:')
for t in transforms:
print('{t.size} x {t.name}'.format(t=t))
print('='*66)
sys.stdout.flush()
## sample patches from each pair of transformed raw images
X = np.empty((n_patches,)+tuple(patch_size),dtype=np.float32)
Y = np.empty_like(X)
for i, (x,y,_axes,mask) in tqdm(enumerate(image_pairs),total=n_images):
if i==0:
axes = axes_check_and_normalize(_axes,len(patch_size))
channel = axes_dict(axes)['C']
# checks
# len(axes) >= x.ndim or _raise(ValueError())
axes == axes_check_and_normalize(_axes) or _raise(ValueError('not all images have the same axes.'))
x.shape == y.shape or _raise(ValueError())
mask is None or mask.shape == x.shape or _raise(ValueError())
(channel is None or (isinstance(channel,int) and 0<=channel
def _make_permute_axes(self,axes_in,axes_out=None):
if axes_out is None:
axes_out = self.config.axes
channel_in = axes_dict(axes_in) ['C']
channel_out = axes_dict(axes_out)['C']
assert channel_out is not None
def _permute_axes(data,undo=False):
if data is None:
return None
if undo:
if channel_in is not None:
return move_image_axes(data, axes_out, axes_in, True)
else:
# input is single-channel and has no channel axis
data = move_image_axes(data, axes_out, axes_in+'C', True)
# output is single-channel -> remove channel axis
if data.shape[-1] == 1:
data = data[...,0]
return data
else:
X_t = move_channel_for_backend(X_t,channel=channel)
Y_t = move_channel_for_backend(Y_t,channel=channel)
X = move_channel_for_backend(X,channel=channel)
Y = move_channel_for_backend(Y,channel=channel)
axes = axes.replace('C','') # remove channel
if backend_channels_last():
axes = axes+'C'
else:
axes = axes[:1]+'C'+axes[1:]
data_val = (X_t,Y_t) if validation_split > 0 else None
if verbose:
ax = axes_dict(axes)
n_train, n_val = len(X), len(X_t) if validation_split>0 else 0
image_size = tuple( X.shape[ax[a]] for a in axes if a in 'TZYX' )
n_dim = len(image_size)
n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']]
print('number of training images:\t', n_train)
print('number of validation images:\t', n_val)
print('image size (%dD):\t\t'%n_dim, image_size)
print('axes:\t\t\t\t', axes)
print('channels in / out:\t\t', n_channel_in, '/', n_channel_out)
return (X,Y), data_val, axes
def _guess_n_tiles(self, img):
axes = self._normalize_axes(img, axes=None)
shape = list(img.shape)
if 'C' in axes:
del shape[axes_dict(axes)['C']]
b = self.config.train_batch_size**(1.0/self.config.n_dim)
n_tiles = [int(np.ceil(s/(p*b))) for s,p in zip(shape,self.config.train_patch_size)]
if 'C' in axes:
n_tiles.insert(axes_dict(axes)['C'],1)
return tuple(n_tiles)
def predict_direct(keras_model,x,axes_in,axes_out=None,**kwargs):
"""TODO."""
if axes_out is None:
axes_out = axes_in
ax_in, ax_out = axes_dict(axes_in), axes_dict(axes_out)
channel_in, channel_out = ax_in['C'], ax_out['C']
single_sample = ax_in['S'] is None
len(axes_in) == x.ndim or _raise(ValueError())
x = to_tensor(x,channel=channel_in,single_sample=single_sample)
pred = from_tensor(keras_model.predict(x,**kwargs),channel=channel_out,single_sample=single_sample)
len(axes_out) == pred.ndim or _raise(ValueError())
return pred