Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
serialized_params = file_.read("parameters")
params = bytes_to_params(
serialized_params, parameter_list
)
except zipfile.BadZipFile:
# load_path wasn't a zip file. Possibly a cloudpickle
# file. Show a warning and fall back to loading cloudpickle.
warnings.warn("It appears you are loading from a file with old format. " +
"Older cloudpickle format has been replaced with zip-archived " +
"models. Consider saving the model with new format.",
DeprecationWarning)
# Attempt loading with the cloudpickle format.
# If load_path is file-like, seek back to beginning of file
if not isinstance(load_path, str):
load_path.seek(0)
data, params = BaseRLModel._load_from_file_cloudpickle(load_path)
return data, params
def _save_to_file(save_path, data=None, params=None, cloudpickle=False):
"""Save model to a zip archive or cloudpickle file.
:param save_path: (str or file-like) Where to store the model
:param data: (OrderedDict) Class parameters being stored
:param params: (OrderedDict) Model parameters being stored
:param cloudpickle: (bool) Use old cloudpickle format
(stable-baselines<=2.7.0) instead of a zip archive.
"""
if cloudpickle:
BaseRLModel._save_to_file_cloudpickle(save_path, data, params)
else:
BaseRLModel._save_to_file_zip(save_path, data, params)
"(n_env, {}) for the observation shape.".format(len(observation_space.nvec)))
elif isinstance(observation_space, gym.spaces.MultiBinary):
if observation.shape == (observation_space.n,):
return False
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
return True
else:
raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) +
"environment, please use ({},) or ".format(observation_space.n) +
"(n_env, {}) for the observation shape.".format(observation_space.n))
else:
raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}."
.format(observation_space))
class ActorCriticRLModel(BaseRLModel):
"""
The base class for Actor critic model
:param policy: (BasePolicy) Policy object
:param env: (Gym environment) The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
:param policy_base: (BasePolicy) the base policy used by this method (default=ActorCriticPolicy)
:param requires_vec_env: (bool) Does this model require a vectorized environment
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
:param seed: (int) Seed for the pseudo-random generators (python, numpy, tensorflow).
If None (default), use random seed. Note that if you want completely deterministic
results, you must set `n_cpu_tf_sess` to 1.
:param n_cpu_tf_sess: (int) The number of threads for TensorFlow operations
If None, the number of cpu of the current machine will be used.
"""
def _save_to_file(save_path, data=None, params=None, cloudpickle=False):
"""Save model to a zip archive or cloudpickle file.
:param save_path: (str or file-like) Where to store the model
:param data: (OrderedDict) Class parameters being stored
:param params: (OrderedDict) Model parameters being stored
:param cloudpickle: (bool) Use old cloudpickle format
(stable-baselines<=2.7.0) instead of a zip archive.
"""
if cloudpickle:
BaseRLModel._save_to_file_cloudpickle(save_path, data, params)
else:
BaseRLModel._save_to_file_zip(save_path, data, params)
raise ValueError("The specified policy kwargs do not equal the stored policy kwargs. "
"Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'],
kwargs['policy_kwargs']))
model = cls(policy=data["policy"], env=None, _init_setup_model=False)
model.__dict__.update(data)
model.__dict__.update(kwargs)
model.set_env(env)
model.setup_model()
model.load_parameters(params)
return model
class OffPolicyRLModel(BaseRLModel):
"""
The base class for off policy RL model
:param policy: (BasePolicy) Policy object
:param env: (Gym environment) The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param replay_buffer: (ReplayBuffer) the type of replay buffer
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
:param requires_vec_env: (bool) Does this model require a vectorized environment
:param policy_base: (BasePolicy) the base policy used by this method
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
:param seed: (int) Seed for the pseudo-random generators (python, numpy, tensorflow).
If None (default), use random seed. Note that if you want completely deterministic
results, you must set `n_cpu_tf_sess` to 1.
:param n_cpu_tf_sess: (int) The number of threads for TensorFlow operations
If None, the number of cpu of the current machine will be used.
warnings.warn("Loading model parameters from a list. This has been replaced " +
"with parameter dictionaries with variable names and parameters. " +
"If you are loading from a file, consider re-saving the file.",
DeprecationWarning)
# Assume `load_path_or_dict` is list of ndarrays.
# Create param dictionary assuming the parameters are in same order
# as `get_parameter_list` returns them.
params = dict()
for i, param_name in enumerate(self._param_load_ops.keys()):
params[param_name] = load_path_or_dict[i]
else:
# Assume a filepath or file-like.
# Use existing deserializer to load the parameters.
# We only need the parameters part of the file, so
# only load that part.
_, params = BaseRLModel._load_from_file(load_path_or_dict, load_data=False)
params = dict(params)
feed_dict = {}
param_update_ops = []
# Keep track of not-updated variables
not_updated_variables = set(self._param_load_ops.keys())
for param_name, param_value in params.items():
placeholder, assign_op = self._param_load_ops[param_name]
feed_dict[placeholder] = param_value
# Create list of tf.assign operations for sess.run
param_update_ops.append(assign_op)
# Keep track which variables are updated
not_updated_variables.remove(param_name)
# Check that we updated all parameters if exact_match=True
if exact_match and len(not_updated_variables) > 0: