How to use the trains.binding.frameworks.WeightsFileHandler.restore_weights_file function in trains

To help you get started, weโ€™ve selected a few trains examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github allegroai / trains / trains / binding / joblib_bind.py View on Github external
filename = f
        elif hasattr(f, 'name'):
            filename = f.name
        else:
            filename = None

        if not PatchedJoblib._current_task:
            return original_fn(f, *args, **kwargs)

        # register input model
        empty = _Empty()
        # Hack: disabled
        if False and running_remotely():
            # we assume scikit-learn, for the time being
            current_framework = Framework.scikitlearn
            filename = WeightsFileHandler.restore_weights_file(empty, filename, current_framework,
                                                               PatchedJoblib._current_task)
            model = original_fn(filename or f, *args, **kwargs)
        else:
            # try to load model before registering, in case we fail
            model = original_fn(f, *args, **kwargs)
            current_framework = PatchedJoblib.get_model_framework(model)
            WeightsFileHandler.restore_weights_file(empty, filename, current_framework,
                                                    PatchedJoblib._current_task)

        if empty.trains_in_model:
            # noinspection PyBroadException
            try:
                model.trains_in_model = empty.trains_in_model
            except Exception:
                pass
        return model
github allegroai / trains / trains / binding / frameworks / xgboost_bind.py View on Github external
filename = None

        if not PatchXGBoostModelIO.__main_task:
            return original_fn(f, *args, **kwargs)

        # register input model
        empty = _Empty()
        # Hack: disabled
        if False and running_remotely():
            filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
                                                               PatchXGBoostModelIO.__main_task)
            model = original_fn(filename or f, *args, **kwargs)
        else:
            # try to load model before registering, in case we fail
            model = original_fn(f, *args, **kwargs)
            WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
                                                    PatchXGBoostModelIO.__main_task)

        if empty.trains_in_model:
            # noinspection PyBroadException
            try:
                model.trains_in_model = empty.trains_in_model
            except Exception:
                pass
        return model
github allegroai / trains / trains / binding / frameworks / tensorflow_bind.py View on Github external
# Hack: disabled
        if False and running_remotely():
            # register/load model weights
            filepath = WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras,
                                                               PatchKerasModelIO.__main_task)
            if 'filepath' in kwargs:
                kwargs['filepath'] = filepath
            else:
                args = (filepath,) + args[1:]
            # load model
            return original_fn(self, *args, **kwargs)

        # try to load the files, if something happened exception will be raised before we register the file
        model = original_fn(self, *args, **kwargs)
        # register/load model weights
        WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras, PatchKerasModelIO.__main_task)
        return model
github allegroai / trains / trains / binding / frameworks / tensorflow_bind.py View on Github external
def _restore(original_fn, self, save_path, *args, **kwargs):
        if PatchTensorflow2ModelIO.__main_task is None:
            return original_fn(self, save_path, *args, **kwargs)

        # Hack: disabled
        if False and running_remotely():
            # register/load model weights
            try:
                save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
                                                                    PatchTensorflow2ModelIO.__main_task)
            except Exception:
                pass
            # load model
            return original_fn(self, save_path, *args, **kwargs)

        # load model, if something is wrong, exception will be raised before we register the input model
        model = original_fn(self, save_path, *args, **kwargs)
        # register/load model weights
        try:
            WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
                                                    PatchTensorflow2ModelIO.__main_task)
        except Exception:
            pass
        return model
github allegroai / trains / trains / binding / frameworks / tensorflow_bind.py View on Github external
def _restore(original_fn, self, sess, save_path, *args, **kwargs):
        if PatchTensorflowModelIO.__main_task is None:
            return original_fn(self, sess, save_path, *args, **kwargs)

        # Hack: disabled
        if False and running_remotely():
            # register/load model weights
            save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
                                                                PatchTensorflowModelIO.__main_task)
            # load model
            return original_fn(self, sess, save_path, *args, **kwargs)

        # load model, if something is wrong, exception will be raised before we register the input model
        model = original_fn(self, sess, save_path, *args, **kwargs)
        # register/load model weights
        WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
                                                PatchTensorflowModelIO.__main_task)
        return model
github allegroai / trains / trains / binding / frameworks / tensorflow_bind.py View on Github external
def _load_model(original_fn, filepath, *args, **kwargs):
        if not PatchKerasModelIO.__main_task:
            return original_fn(filepath, *args, **kwargs)

        empty = _Empty()
        # Hack: disabled
        if False and running_remotely():
            # register/load model weights
            filepath = WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras,
                                                               PatchKerasModelIO.__main_task)
            model = original_fn(filepath, *args, **kwargs)
        else:
            model = original_fn(filepath, *args, **kwargs)
            # register/load model weights
            WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO.__main_task)
        # update the input model object
        if empty.trains_in_model:
            # noinspection PyBroadException
            try:
                model.trains_in_model = empty.trains_in_model
            except Exception:
                pass

        return model
github allegroai / trains / trains / binding / frameworks / pytorch_bind.py View on Github external
if isinstance(f, six.string_types):
                filename = f
            elif hasattr(f, 'as_posix'):
                filename = f.as_posix()
            elif hasattr(f, 'name'):
                filename = f.name
            else:
                filename = None
        except Exception:
            filename = None

        # register input model
        empty = _Empty()
        # Hack: disabled
        if False and running_remotely():
            filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
                                                               PatchPyTorchModelIO.__main_task)
            model = original_fn(filename or f, *args, **kwargs)
        else:
            # try to load model before registering, in case we fail
            model = original_fn(f, *args, **kwargs)
            WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
                                                    PatchPyTorchModelIO.__main_task)

        if empty.trains_in_model:
            # noinspection PyBroadException
            try:
                model.trains_in_model = empty.trains_in_model
            except Exception:
                pass
        return model
github allegroai / trains / trains / binding / frameworks / pytorch_bind.py View on Github external
else:
                filename = None
        except Exception:
            filename = None

        # register input model
        empty = _Empty()
        # Hack: disabled
        if False and running_remotely():
            filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
                                                               PatchPyTorchModelIO.__main_task)
            model = original_fn(filename or f, *args, **kwargs)
        else:
            # try to load model before registering, in case we fail
            model = original_fn(f, *args, **kwargs)
            WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
                                                    PatchPyTorchModelIO.__main_task)

        if empty.trains_in_model:
            # noinspection PyBroadException
            try:
                model.trains_in_model = empty.trains_in_model
            except Exception:
                pass
        return model
github allegroai / trains / trains / binding / frameworks / tensorflow_bind.py View on Github external
def _restore(original_fn, self, sess, save_path, *args, **kwargs):
        if PatchTensorflowModelIO.__main_task is None:
            return original_fn(self, sess, save_path, *args, **kwargs)

        # Hack: disabled
        if False and running_remotely():
            # register/load model weights
            save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
                                                                PatchTensorflowModelIO.__main_task)
            # load model
            return original_fn(self, sess, save_path, *args, **kwargs)

        # load model, if something is wrong, exception will be raised before we register the input model
        model = original_fn(self, sess, save_path, *args, **kwargs)
        # register/load model weights
        WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
                                                PatchTensorflowModelIO.__main_task)
        return model
github allegroai / trains / trains / binding / frameworks / xgboost_bind.py View on Github external
filename = f
        elif hasattr(f, 'name'):
            filename = f.name
        elif len(args) == 1 and isinstance(args[0], six.string_types):
            filename = args[0]
        else:
            filename = None

        if not PatchXGBoostModelIO.__main_task:
            return original_fn(f, *args, **kwargs)

        # register input model
        empty = _Empty()
        # Hack: disabled
        if False and running_remotely():
            filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
                                                               PatchXGBoostModelIO.__main_task)
            model = original_fn(filename or f, *args, **kwargs)
        else:
            # try to load model before registering, in case we fail
            model = original_fn(f, *args, **kwargs)
            WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
                                                    PatchXGBoostModelIO.__main_task)

        if empty.trains_in_model:
            # noinspection PyBroadException
            try:
                model.trains_in_model = empty.trains_in_model
            except Exception:
                pass
        return model