Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _patch_model_io():
if PatchPyTorchModelIO.__patched:
return
if 'torch' not in sys.modules:
return
PatchPyTorchModelIO.__patched = True
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import torch
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
except ImportError:
pass
except Exception:
pass # print('Failed patching pytorch')
if Sequential is not None:
Sequential._updated_config = _patched_call(Sequential._updated_config,
PatchKerasModelIO._updated_config)
if hasattr(Sequential.from_config, '__func__'):
Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__,
PatchKerasModelIO._from_config))
else:
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
if Network is not None:
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
if hasattr(Sequential.from_config, '__func__'):
Network.from_config = classmethod(_patched_call(Network.from_config.__func__,
PatchKerasModelIO._from_config))
else:
Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
if keras_saving is not None:
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
try:
# make sure we import the correct version of save
import tensorflow
from tensorflow.saved_model.experimental import save
# actual import
import tensorflow.saved_model.experimental as saved_model
except ImportError:
saved_model = None
except Exception:
saved_model = None
except Exception:
saved_model = None
if saved_model is not None:
saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model)
# noinspection PyBroadException
try:
# make sure we import the correct version of save
import tensorflow
# actual import
from tensorflow.saved_model import load
import tensorflow.saved_model as saved_model_load
saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load)
except ImportError:
pass
except Exception:
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
# noinspection PyBroadException
try:
PatchKerasModelIO._updated_config)
if hasattr(Sequential.from_config, '__func__'):
Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__,
PatchKerasModelIO._from_config))
else:
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
if Network is not None:
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
if hasattr(Sequential.from_config, '__func__'):
Network.from_config = classmethod(_patched_call(Network.from_config.__func__,
PatchKerasModelIO._from_config))
else:
Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
if keras_saving is not None:
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
PatchTensorflow2ModelIO.__patched = True
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow
from tensorflow.python.training.tracking import util
# noinspection PyBroadException
try:
util.TrackableSaver.save = _patched_call(util.TrackableSaver.save,
PatchTensorflow2ModelIO._save)
except Exception:
pass
# noinspection PyBroadException
try:
util.TrackableSaver.restore = _patched_call(util.TrackableSaver.restore,
PatchTensorflow2ModelIO._restore)
except Exception:
pass
except ImportError:
pass
except Exception:
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow v2')
if hasattr(Sequential.from_config, '__func__'):
Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__,
PatchKerasModelIO._from_config))
else:
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
if Network is not None:
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
if hasattr(Sequential.from_config, '__func__'):
Network.from_config = classmethod(_patched_call(Network.from_config.__func__,
PatchKerasModelIO._from_config))
else:
Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
if keras_saving is not None:
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))
import tensorflow
# actual import
from tensorflow.saved_model import loader as loader1
loader1.load = _patched_call(loader1.load, PatchTensorflowModelIO._load)
except ImportError:
pass
except Exception:
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
# noinspection PyBroadException
try:
# make sure we import the correct version of save
import tensorflow
# actual import
from tensorflow.compat.v1.saved_model import loader as loader2
loader2.load = _patched_call(loader2.load, PatchTensorflowModelIO._load)
except ImportError:
pass
except Exception:
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
# noinspection PyBroadException
try:
import tensorflow
from tensorflow.train import Checkpoint
# noinspection PyBroadException
try:
Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save)
except Exception:
pass
# noinspection PyBroadException
try:
return
PatchTensorflowModelIO.__patched = True
# noinspection PyBroadException
try:
# hack: make sure tensorflow.__init__ is called
import tensorflow
from tensorflow.python.training.saver import Saver
# noinspection PyBroadException
try:
Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save)
except Exception:
pass
# noinspection PyBroadException
try:
Saver.restore = _patched_call(Saver.restore, PatchTensorflowModelIO._restore)
except Exception:
pass
except ImportError:
pass
except Exception:
LoggerRoot.get_base_logger(TensorflowBinding).debug('Failed patching tensorflow')
# noinspection PyBroadException
try:
# make sure we import the correct version of save
import tensorflow
from tensorflow.saved_model import save
# actual import
from tensorflow.python.saved_model import save as saved_model
except ImportError:
# noinspection PyBroadException
def _patch_joblib():
# noinspection PyBroadException
try:
if not PatchedJoblib._patched_joblib and 'joblib' in sys.modules:
PatchedJoblib._patched_joblib = True
try:
import joblib
except ImportError:
joblib = None
if joblib:
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules:
PatchedJoblib._patched_sk_joblib = True
try:
import sklearn
# avoid deprecation warning, we must import sklearn before, so we could catch it
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from sklearn.externals import joblib as sk_joblib
except ImportError:
sk_joblib = None
if sk_joblib:
sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump)
sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load)