How to use the csbdeep.utils.tf.keras_import function in csbdeep

To help you get started, we’ve selected a few csbdeep examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github CSBDeep / CSBDeep / tests / test_models.py View on Github external
from __future__ import print_function, unicode_literals, absolute_import, division
from six.moves import range, zip, map, reduce, filter

from itertools import product

# import warnings
import numpy as np
import pytest
from csbdeep.data import NoNormalizer, NoResizer
from csbdeep.internals.predict import tile_overlap
from csbdeep.utils.tf import keras_import
K = keras_import('backend')

from csbdeep.internals.nets import receptive_field_unet
from csbdeep.models import Config, CARE, UpsamplingCARE, IsotropicCARE
from csbdeep.models import ProjectionConfig, ProjectionCARE
from csbdeep.utils import axes_dict
from csbdeep.utils.six import FileNotFoundError



def config_generator(cls=Config, **kwargs):
    assert 'axes' in kwargs
    keys, values = kwargs.keys(), kwargs.values()
    values = [v if isinstance(v,(list,tuple)) else [v] for v in values]
    for p in product(*values):
        yield cls(**dict(zip(keys,p)))
github CSBDeep / CSBDeep / csbdeep / models / config.py View on Github external
from __future__ import print_function, unicode_literals, absolute_import, division
from six.moves import range, zip, map, reduce, filter
from six import string_types

import numpy as np
import argparse
import warnings

from distutils.version import LooseVersion

from ..utils.tf import keras_import
keras = keras_import()
K = keras_import('backend')

from ..utils import _raise, axes_check_and_normalize, axes_dict, backend_channels_last


class BaseConfig(argparse.Namespace):

    def __init__(self, axes='YX', n_channel_in=1, n_channel_out=1, allow_new_parameters=False, **kwargs):

        # parse and check axes
        axes = axes_check_and_normalize(axes)
        ax = axes_dict(axes)
        ax = {a: (ax[a] is not None) for a in ax}

        (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.'))
        # not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.'))
github CSBDeep / CSBDeep / csbdeep / utils / utils.py View on Github external
def backend_channels_last():
    from .tf import keras_import
    K = keras_import('backend')
    assert K.image_data_format() in ('channels_first','channels_last')
    return K.image_data_format() == 'channels_last'
github CSBDeep / CSBDeep / csbdeep / internals / blocks.py View on Github external
from __future__ import print_function, unicode_literals, absolute_import, division
from six.moves import range, zip, map, reduce, filter

from ..utils import _raise, backend_channels_last

from ..utils.tf import keras_import
K = keras_import('backend')
Conv2D, MaxPooling2D, UpSampling2D, Conv3D, MaxPooling3D, UpSampling3D, Cropping2D, Cropping3D, Concatenate, Add, Dropout, Activation, BatchNormalization = \
    keras_import('layers', 'Conv2D', 'MaxPooling2D', 'UpSampling2D', 'Conv3D', 'MaxPooling3D', 'UpSampling3D', 'Cropping2D', 'Cropping3D', 'Concatenate', 'Add', 'Dropout', 'Activation', 'BatchNormalization')



def conv_block2(n_filter, n1, n2,
                activation="relu",
                border_mode="same",
                dropout=0.0,
                batch_norm=False,
                init="glorot_uniform",
                **kwargs):

    def _func(lay):
        if batch_norm:
            s = Conv2D(n_filter, (n1, n2), padding=border_mode, kernel_initializer=init, **kwargs)(lay)
github CSBDeep / CSBDeep / csbdeep / models / care_projection.py View on Github external
# -*- coding: utf-8 -*-
from __future__ import print_function, unicode_literals, absolute_import, division

import numpy as np
from collections import namedtuple

from ..utils.tf import keras_import
K = keras_import('backend')
Model = keras_import('models', 'Model')
Input, Conv3D, MaxPooling3D, UpSampling3D, Lambda, Multiply = keras_import('layers', 'Input', 'Conv3D', 'MaxPooling3D', 'UpSampling3D', 'Lambda', 'Multiply')
softmax = keras_import('activations', 'softmax')

from .care_standard import CARE
from .config import Config
from ..utils import _raise, axes_dict, axes_check_and_normalize
from ..internals import nets
from ..internals.predict import tile_overlap


class ProjectionConfig(Config):

    def __init__(self, axes='ZYX', n_channel_in=1, n_channel_out=1, probabilistic=False, allow_new_parameters=False, **kwargs):
        super(ProjectionConfig, self).__init__(axes, n_channel_in, n_channel_out, probabilistic)
        ax = axes_dict(self.axes)
        self.proj_axis              = kwargs.get('proj_axis', 'Z')
github CSBDeep / CSBDeep / csbdeep / models / care_projection.py View on Github external
# -*- coding: utf-8 -*-
from __future__ import print_function, unicode_literals, absolute_import, division

import numpy as np
from collections import namedtuple

from ..utils.tf import keras_import
K = keras_import('backend')
Model = keras_import('models', 'Model')
Input, Conv3D, MaxPooling3D, UpSampling3D, Lambda, Multiply = keras_import('layers', 'Input', 'Conv3D', 'MaxPooling3D', 'UpSampling3D', 'Lambda', 'Multiply')
softmax = keras_import('activations', 'softmax')

from .care_standard import CARE
from .config import Config
from ..utils import _raise, axes_dict, axes_check_and_normalize
from ..internals import nets
from ..internals.predict import tile_overlap


class ProjectionConfig(Config):

    def __init__(self, axes='ZYX', n_channel_in=1, n_channel_out=1, probabilistic=False, allow_new_parameters=False, **kwargs):
        super(ProjectionConfig, self).__init__(axes, n_channel_in, n_channel_out, probabilistic)
        ax = axes_dict(self.axes)
github CSBDeep / CSBDeep / csbdeep / internals / losses.py View on Github external
from __future__ import print_function, unicode_literals, absolute_import, division
from six.moves import range, zip, map, reduce, filter

from ..utils import _raise, backend_channels_last

import numpy as np
from ..utils.tf import keras_import
K = keras_import('backend')



def _mean_or_not(mean):
    # return (lambda x: K.mean(x,axis=(-1 if backend_channels_last() else 1))) if mean else (lambda x: x)
    # Keras also only averages over axis=-1, see https://github.com/keras-team/keras/blob/master/keras/losses.py
    return (lambda x: K.mean(x,axis=-1)) if mean else (lambda x: x)


def loss_laplace(mean=True):
    R = _mean_or_not(mean)
    C = np.log(2.0)
    if backend_channels_last():
        def nll(y_true, y_pred):
            n     = K.shape(y_true)[-1]
            mu    = y_pred[...,:n]
github CSBDeep / CSBDeep / csbdeep / models / care_projection.py View on Github external
# -*- coding: utf-8 -*-
from __future__ import print_function, unicode_literals, absolute_import, division

import numpy as np
from collections import namedtuple

from ..utils.tf import keras_import
K = keras_import('backend')
Model = keras_import('models', 'Model')
Input, Conv3D, MaxPooling3D, UpSampling3D, Lambda, Multiply = keras_import('layers', 'Input', 'Conv3D', 'MaxPooling3D', 'UpSampling3D', 'Lambda', 'Multiply')
softmax = keras_import('activations', 'softmax')

from .care_standard import CARE
from .config import Config
from ..utils import _raise, axes_dict, axes_check_and_normalize
from ..internals import nets
from ..internals.predict import tile_overlap


class ProjectionConfig(Config):

    def __init__(self, axes='ZYX', n_channel_in=1, n_channel_out=1, probabilistic=False, allow_new_parameters=False, **kwargs):
        super(ProjectionConfig, self).__init__(axes, n_channel_in, n_channel_out, probabilistic)
        ax = axes_dict(self.axes)
        self.proj_axis              = kwargs.get('proj_axis', 'Z')
        self.proj_n_depth           = 4
        self.proj_n_filt            = 8
github CSBDeep / CSBDeep / csbdeep / models / care_projection.py View on Github external
# -*- coding: utf-8 -*-
from __future__ import print_function, unicode_literals, absolute_import, division

import numpy as np
from collections import namedtuple

from ..utils.tf import keras_import
K = keras_import('backend')
Model = keras_import('models', 'Model')
Input, Conv3D, MaxPooling3D, UpSampling3D, Lambda, Multiply = keras_import('layers', 'Input', 'Conv3D', 'MaxPooling3D', 'UpSampling3D', 'Lambda', 'Multiply')
softmax = keras_import('activations', 'softmax')

from .care_standard import CARE
from .config import Config
from ..utils import _raise, axes_dict, axes_check_and_normalize
from ..internals import nets
from ..internals.predict import tile_overlap


class ProjectionConfig(Config):

    def __init__(self, axes='ZYX', n_channel_in=1, n_channel_out=1, probabilistic=False, allow_new_parameters=False, **kwargs):
        super(ProjectionConfig, self).__init__(axes, n_channel_in, n_channel_out, probabilistic)
        ax = axes_dict(self.axes)
        self.proj_axis              = kwargs.get('proj_axis', 'Z')
        self.proj_n_depth           = 4
github CSBDeep / CSBDeep / csbdeep / internals / blocks.py View on Github external
from __future__ import print_function, unicode_literals, absolute_import, division
from six.moves import range, zip, map, reduce, filter

from ..utils import _raise, backend_channels_last

from ..utils.tf import keras_import
K = keras_import('backend')
Conv2D, MaxPooling2D, UpSampling2D, Conv3D, MaxPooling3D, UpSampling3D, Cropping2D, Cropping3D, Concatenate, Add, Dropout, Activation, BatchNormalization = \
    keras_import('layers', 'Conv2D', 'MaxPooling2D', 'UpSampling2D', 'Conv3D', 'MaxPooling3D', 'UpSampling3D', 'Cropping2D', 'Cropping3D', 'Concatenate', 'Add', 'Dropout', 'Activation', 'BatchNormalization')



def conv_block2(n_filter, n1, n2,
                activation="relu",
                border_mode="same",
                dropout=0.0,
                batch_norm=False,
                init="glorot_uniform",
                **kwargs):

    def _func(lay):
        if batch_norm:
            s = Conv2D(n_filter, (n1, n2), padding=border_mode, kernel_initializer=init, **kwargs)(lay)
            s = BatchNormalization()(s)
            s = Activation(activation)(s)