How to use the csbdeep.utils.axes_dict function in csbdeep

To help you get started, we’ve selected a few csbdeep examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github CSBDeep / CSBDeep / tests / test_models.py View on Github external
def _predict(imdims,axes):
        img = rng.uniform(size=imdims)
        if config.probabilistic:
            prob = model.predict_probabilistic(img, axes, factor, None, None)
            mean, scale = prob.mean(), prob.scale()
            assert mean.shape == scale.shape
        else:
            mean = model.predict(img, axes, factor, None, None)
        a = axes_dict(axes)['Z']
        assert imdims[a]*factor == mean.shape[a]
github CSBDeep / CSBDeep / tests / test_datagen.py View on Github external
patch_size[axes_dict(img_axes if patch_axes is None else patch_axes)[a]] = (
                None if red_none else img_size[axes_dict(img_axes)[a]]
            )
        X,Y,XYaxes = create_patches_reduced_target (
            raw_data            = raw_data,
            patch_size          = patch_size,
            patch_axes          = patch_axes,
            n_patches_per_image = n_patches_per_image,
            reduction_axes      = red_axes,
            target_axes         = rng.choice((None,img_axes)) if keepdims else ''.join(a for a in img_axes if a not in red_axes),
            #
            normalization       = lambda patches_x, patches_y, *args: (patches_x, patches_y),
            verbose             = False,
        )
        assert len(X) == n_images*n_patches_per_image
        _X = np.mean(X,axis=tuple(axes_dict(XYaxes)[a] for a in red_axes),keepdims=True)
        err = np.max(np.abs(_X-Y))
        assert err < 1e-5
github CSBDeep / CSBDeep / tests / test_models.py View on Github external
def _predict(imdims,axes):
        img = rng.uniform(size=imdims)
        if config.probabilistic:
            prob = model.predict_probabilistic(img, axes, factor, None, None)
            mean, scale = prob.mean(), prob.scale()
            assert mean.shape == scale.shape
        else:
            mean = model.predict(img, axes, factor, None, None)
        a = axes_dict(axes)['Z']
        assert imdims[a]*factor == mean.shape[a]
github CSBDeep / CSBDeep / csbdeep / models.py View on Github external
See :func:`predict` for parameter explanations.

        Returns
        -------
        tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray` or None)
            If model is probabilistic, returns a tuple `(mean, scale)` that defines the parameters
            of per-pixel Laplace distributions. Otherwise, returns the restored image via a tuple `(restored,None)`

        """
        normalizer, resizer = self._check_normalizer_resizer(normalizer, resizer)
        axes = axes_check_and_normalize(axes,img.ndim)
        _permute_axes = self._make_permute_axes(axes, self.config.axes)

        x = _permute_axes(img)
        channel = axes_dict(self.config.axes)['C']

        self.config.n_channel_in == x.shape[channel] or _raise(ValueError())

        # normalize
        x = normalizer.before(x,self.config.axes)
        # resize: make divisible by power of 2 to allow downsampling steps in unet
        div_n = 2 ** self.config.unet_n_depth
        x = resizer.before(x,div_n,exclude=channel)

        done = False
        while not done:
            try:
                if n_tiles == 1:
                    x = predict_direct(self.keras_model,x,channel_in=channel,channel_out=channel)
                else:
                    overlap = tile_overlap(self.config.unet_n_depth, self.config.unet_kern_size)
github CSBDeep / CSBDeep / csbdeep / internals / predict.py View on Github external
"""TODO."""

    if all(t==1 for t in n_tiles):
        pred = predict_direct(keras_model,x,axes_in,axes_out,**kwargs)
        if pbar is not None:
            pbar.update()
        return pred

    ###

    if axes_out is None:
        axes_out = axes_in
    axes_in, axes_out = axes_check_and_normalize(axes_in,x.ndim), axes_check_and_normalize(axes_out)
    assert 'S' not in axes_in
    assert 'C' in axes_in and 'C' in axes_out
    ax_in, ax_out = axes_dict(axes_in), axes_dict(axes_out)
    channel_in, channel_out = ax_in['C'], ax_out['C']

    assert set(axes_out).issubset(set(axes_in))
    axes_lost = set(axes_in).difference(set(axes_out))

    def _to_axes_out(seq,elem):
        # assumption: prediction size is same as input size along all axes, except for channel (and lost axes)
        assert len(seq) == len(axes_in)
        # 1. re-order 'seq' from axes_in to axes_out semantics
        seq = [seq[ax_in[a]] for a in axes_out]
        # 2. replace value at channel position with 'elem'
        seq[ax_out['C']] = elem
        return tuple(seq)

    ###
github CSBDeep / CSBDeep / csbdeep / datagen.py View on Github external
print(raw_data.description)
        print('='*66)
        print('Transformations:')
        for t in transforms:
            print('{t.size} x {t.name}'.format(t=t))
        print('='*66)
        sys.stdout.flush()

    ## sample patches from each pair of transformed raw images
    X = np.empty((n_patches,)+tuple(patch_size),dtype=np.float32)
    Y = np.empty_like(X)

    for i, (x,y,_axes,mask) in tqdm(enumerate(image_pairs),total=n_images):
        if i==0:
            axes = axes_check_and_normalize(_axes,len(patch_size))
            channel = axes_dict(axes)['C']
        # checks
        # len(axes) >= x.ndim or _raise(ValueError())
        axes == axes_check_and_normalize(_axes) or _raise(ValueError('not all images have the same axes.'))
        x.shape == y.shape or _raise(ValueError())
        mask is None or mask.shape == x.shape or _raise(ValueError())
        (channel is None or (isinstance(channel,int) and 0<=channel
github CSBDeep / CSBDeep / csbdeep / models.py View on Github external
def _make_permute_axes(self,axes_in,axes_out=None):
        if axes_out is None:
            axes_out = self.config.axes
        channel_in  = axes_dict(axes_in) ['C']
        channel_out = axes_dict(axes_out)['C']
        assert channel_out is not None

        def _permute_axes(data,undo=False):
            if data is None:
                return None
            if undo:
                if channel_in is not None:
                    return move_image_axes(data, axes_out, axes_in, True)
                else:
                    # input is single-channel and has no channel axis
                    data = move_image_axes(data, axes_out, axes_in+'C', True)
                    # output is single-channel -> remove channel axis
                    if data.shape[-1] == 1:
                        data = data[...,0]
                    return data
            else:
github CSBDeep / CSBDeep / csbdeep / io / __init__.py View on Github external
X_t = move_channel_for_backend(X_t,channel=channel)
        Y_t = move_channel_for_backend(Y_t,channel=channel)

    X = move_channel_for_backend(X,channel=channel)
    Y = move_channel_for_backend(Y,channel=channel)

    axes = axes.replace('C','') # remove channel
    if backend_channels_last():
        axes = axes+'C'
    else:
        axes = axes[:1]+'C'+axes[1:]

    data_val = (X_t,Y_t) if validation_split > 0 else None

    if verbose:
        ax = axes_dict(axes)
        n_train, n_val = len(X), len(X_t) if validation_split>0 else 0
        image_size = tuple( X.shape[ax[a]] for a in axes if a in 'TZYX' )
        n_dim = len(image_size)
        n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']]

        print('number of training images:\t', n_train)
        print('number of validation images:\t', n_val)
        print('image size (%dD):\t\t'%n_dim, image_size)
        print('axes:\t\t\t\t', axes)
        print('channels in / out:\t\t', n_channel_in, '/', n_channel_out)

    return (X,Y), data_val, axes
github mpicbg-csbd / stardist / stardist / models / base.py View on Github external
def _guess_n_tiles(self, img):
        axes = self._normalize_axes(img, axes=None)
        shape = list(img.shape)
        if 'C' in axes:
            del shape[axes_dict(axes)['C']]
        b = self.config.train_batch_size**(1.0/self.config.n_dim)
        n_tiles = [int(np.ceil(s/(p*b))) for s,p in zip(shape,self.config.train_patch_size)]
        if 'C' in axes:
            n_tiles.insert(axes_dict(axes)['C'],1)
        return tuple(n_tiles)
github CSBDeep / CSBDeep / csbdeep / internals / predict.py View on Github external
def predict_direct(keras_model,x,axes_in,axes_out=None,**kwargs):
    """TODO."""
    if axes_out is None:
        axes_out = axes_in
    ax_in, ax_out = axes_dict(axes_in), axes_dict(axes_out)
    channel_in, channel_out = ax_in['C'], ax_out['C']
    single_sample = ax_in['S'] is None
    len(axes_in) == x.ndim or _raise(ValueError())
    x = to_tensor(x,channel=channel_in,single_sample=single_sample)
    pred = from_tensor(keras_model.predict(x,**kwargs),channel=channel_out,single_sample=single_sample)
    len(axes_out) == pred.ndim or _raise(ValueError())
    return pred