How to use the trains.binding.frameworks._patched_call function in trains

To help you get started, weโ€™ve selected a few trains examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github allegroai / trains / trains / binding / frameworks / pytorch_bind.py View on Github external
def _patch_model_io():
        if PatchPyTorchModelIO.__patched:
            return

        if 'torch' not in sys.modules:
            return

        PatchPyTorchModelIO.__patched = True

        # noinspection PyBroadException
        try:
            # hack: make sure tensorflow.__init__ is called
            import torch
            torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
            torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
        except ImportError:
            pass
        except Exception:
            pass  # print('Failed patching pytorch')
github allegroai / trains / trains / binding / frameworks / tensorflow_bind.py View on Github external
if Sequential is not None:
                Sequential._updated_config = _patched_call(Sequential._updated_config,
                                                           PatchKerasModelIO._updated_config)
                if hasattr(Sequential.from_config, '__func__'):
                    Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__,
                                                         PatchKerasModelIO._from_config))
                else:
                    Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)

            if Network is not None:
                Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
                if hasattr(Sequential.from_config, '__func__'):
                    Network.from_config = classmethod(_patched_call(Network.from_config.__func__,
                                                                    PatchKerasModelIO._from_config))
                else:
                    Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
                Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
                Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
                Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)

            if keras_saving is not None:
                keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
                keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
        except Exception as ex:
            LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
github allegroai / trains / trains / binding / frameworks / tensorflow_bind.py View on Github external
try:
                # make sure we import the correct version of save
                import tensorflow
                from tensorflow.saved_model.experimental import save
                # actual import
                import tensorflow.saved_model.experimental as saved_model
            except ImportError:
                saved_model = None
            except Exception:
                saved_model = None

        except Exception:
            saved_model = None

        if saved_model is not None:
            saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model)

        # noinspection PyBroadException
        try:
            # make sure we import the correct version of save
            import tensorflow
            # actual import
            from tensorflow.saved_model import load
            import tensorflow.saved_model as saved_model_load
            saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load)
        except ImportError:
            pass
        except Exception:
            LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')

        # noinspection PyBroadException
        try:
github allegroai / trains / trains / binding / frameworks / tensorflow_bind.py View on Github external
PatchKerasModelIO._updated_config)
                if hasattr(Sequential.from_config, '__func__'):
                    Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__,
                                                         PatchKerasModelIO._from_config))
                else:
                    Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)

            if Network is not None:
                Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
                if hasattr(Sequential.from_config, '__func__'):
                    Network.from_config = classmethod(_patched_call(Network.from_config.__func__,
                                                                    PatchKerasModelIO._from_config))
                else:
                    Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
                Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
                Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
                Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)

            if keras_saving is not None:
                keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
                keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
        except Exception as ex:
            LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
github allegroai / trains / trains / binding / frameworks / tensorflow_bind.py View on Github external
PatchTensorflow2ModelIO.__patched = True
        # noinspection PyBroadException
        try:
            # hack: make sure tensorflow.__init__ is called
            import tensorflow
            from tensorflow.python.training.tracking import util
            # noinspection PyBroadException
            try:
                util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,
                                                         PatchTensorflow2ModelIO._save)
            except Exception:
                pass
            # noinspection PyBroadException
            try:
                util.TrackableSaver.restore = _patched_call(util.TrackableSaver.restore,
                                                            PatchTensorflow2ModelIO._restore)
            except Exception:
                pass
        except ImportError:
            pass
        except Exception:
            LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow v2')
github allegroai / trains / trains / binding / frameworks / tensorflow_bind.py View on Github external
if hasattr(Sequential.from_config, '__func__'):
                    Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__,
                                                         PatchKerasModelIO._from_config))
                else:
                    Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)

            if Network is not None:
                Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
                if hasattr(Sequential.from_config, '__func__'):
                    Network.from_config = classmethod(_patched_call(Network.from_config.__func__,
                                                                    PatchKerasModelIO._from_config))
                else:
                    Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
                Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
                Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
                Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)

            if keras_saving is not None:
                keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
                keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
        except Exception as ex:
            LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
github allegroai / trains / trains / binding / frameworks / tensorflow_bind.py View on Github external
import tensorflow
            # actual import
            from tensorflow.saved_model import loader as loader1
            loader1.load = _patched_call(loader1.load, PatchTensorflowModelIO._load)
        except ImportError:
            pass
        except Exception:
            LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')

        # noinspection PyBroadException
        try:
            # make sure we import the correct version of save
            import tensorflow
            # actual import
            from tensorflow.compat.v1.saved_model import loader as loader2
            loader2.load = _patched_call(loader2.load, PatchTensorflowModelIO._load)
        except ImportError:
            pass
        except Exception:
            LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')

        # noinspection PyBroadException
        try:
            import tensorflow
            from tensorflow.train import Checkpoint
            # noinspection PyBroadException
            try:
                Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save)
            except Exception:
                pass
            # noinspection PyBroadException
            try:
github allegroai / trains / trains / binding / frameworks / tensorflow_bind.py View on Github external
return

        PatchTensorflowModelIO.__patched = True
        # noinspection PyBroadException
        try:
            # hack: make sure tensorflow.__init__ is called
            import tensorflow
            from tensorflow.python.training.saver import Saver
            # noinspection PyBroadException
            try:
                Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save)
            except Exception:
                pass
            # noinspection PyBroadException
            try:
                Saver.restore = _patched_call(Saver.restore, PatchTensorflowModelIO._restore)
            except Exception:
                pass
        except ImportError:
            pass
        except Exception:
            LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')

        # noinspection PyBroadException
        try:
            # make sure we import the correct version of save
            import tensorflow
            from tensorflow.saved_model import save
            # actual import
            from tensorflow.python.saved_model import save as saved_model
        except ImportError:
            # noinspection PyBroadException
github allegroai / trains / trains / binding / joblib_bind.py View on Github external
def _patch_joblib():
        # noinspection PyBroadException
        try:
            if not PatchedJoblib._patched_joblib and 'joblib' in sys.modules:
                PatchedJoblib._patched_joblib = True
                try:
                    import joblib
                except ImportError:
                    joblib = None

                if joblib:
                    joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
                    joblib.load = _patched_call(joblib.load, PatchedJoblib._load)

            if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules:
                PatchedJoblib._patched_sk_joblib = True
                try:
                    import sklearn
                    # avoid deprecation warning, we must import sklearn before, so we could catch it
                    with warnings.catch_warnings():
                        warnings.simplefilter("ignore")
                        from sklearn.externals import joblib as sk_joblib
                except ImportError:
                    sk_joblib = None

                if sk_joblib:
                    sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump)
                    sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load)