Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def sample(self):
return tf.stack([p.sample() for p in self.categoricals], axis=-1)
@classmethod
def fromflat(cls, flat):
"""
Create an instance of this from new logits values
:param flat: ([float]) the multi categorical logits input
:return: (ProbabilityDistribution) the instance from the given multi categorical input
"""
raise NotImplementedError
class DiagGaussianProbabilityDistribution(ProbabilityDistribution):
def __init__(self, flat):
"""
Probability distributions from multivariate Gaussian input
:param flat: ([float]) the multivariate Gaussian input data
"""
self.flat = flat
mean, logstd = tf.split(axis=len(flat.shape) - 1, num_or_size_splits=2, value=flat)
self.mean = mean
self.logstd = logstd
self.std = tf.exp(logstd)
super(DiagGaussianProbabilityDistribution, self).__init__()
def flatparam(self):
return self.flat
# a categorical distribution (see http://amid.fish/humble-gumbel)
uniform = tf.random_uniform(tf.shape(self.logits), dtype=self.logits.dtype)
return tf.argmax(self.logits - tf.log(-tf.log(uniform)), axis=-1)
@classmethod
def fromflat(cls, flat):
"""
Create an instance of this from new logits values
:param flat: ([float]) the categorical logits input
:return: (ProbabilityDistribution) the instance from the given categorical input
"""
return cls(flat)
class MultiCategoricalProbabilityDistribution(ProbabilityDistribution):
def __init__(self, nvec, flat):
"""
Probability distributions from multicategorical input
:param nvec: ([int]) the sizes of the different categorical inputs
:param flat: ([float]) the categorical logits input
"""
self.flat = flat
self.categoricals = list(map(CategoricalProbabilityDistribution, tf.split(flat, nvec, axis=-1)))
super(MultiCategoricalProbabilityDistribution, self).__init__()
def flatparam(self):
return self.flat
def mode(self):
return tf.stack([p.mode() for p in self.categoricals], axis=-1)
def __init__(self, nvec, flat):
"""
Probability distributions from multicategorical input
:param nvec: ([int]) the sizes of the different categorical inputs
:param flat: ([float]) the categorical logits input
"""
self.flat = flat
self.categoricals = list(map(CategoricalProbabilityDistribution, tf.split(flat, nvec, axis=-1)))
super(MultiCategoricalProbabilityDistribution, self).__init__()
def __init__(self, nvec, flat):
"""
Probability distributions from multicategorical input
:param nvec: ([int]) the sizes of the different categorical inputs
:param flat: ([float]) the categorical logits input
"""
self.flat = flat
self.categoricals = list(map(CategoricalProbabilityDistribution, tf.split(flat, nvec, axis=-1)))
super(MultiCategoricalProbabilityDistribution, self).__init__()
def __init__(self, flat):
"""
Probability distributions from multivariate Gaussian input
:param flat: ([float]) the multivariate Gaussian input data
"""
self.flat = flat
mean, logstd = tf.split(axis=len(flat.shape) - 1, num_or_size_splits=2, value=flat)
self.mean = mean
self.logstd = logstd
self.std = tf.exp(logstd)
super(DiagGaussianProbabilityDistribution, self).__init__()
serialized_params = file_.read("parameters")
params = bytes_to_params(
serialized_params, parameter_list
)
except zipfile.BadZipFile:
# load_path wasn't a zip file. Possibly a cloudpickle
# file. Show a warning and fall back to loading cloudpickle.
warnings.warn("It appears you are loading from a file with old format. " +
"Older cloudpickle format has been replaced with zip-archived " +
"models. Consider saving the model with new format.",
DeprecationWarning)
# Attempt loading with the cloudpickle format.
# If load_path is file-like, seek back to beginning of file
if not isinstance(load_path, str):
load_path.seek(0)
data, params = BaseRLModel._load_from_file_cloudpickle(load_path)
return data, params
def _save_to_file(save_path, data=None, params=None, cloudpickle=False):
"""Save model to a zip archive or cloudpickle file.
:param save_path: (str or file-like) Where to store the model
:param data: (OrderedDict) Class parameters being stored
:param params: (OrderedDict) Model parameters being stored
:param cloudpickle: (bool) Use old cloudpickle format
(stable-baselines<=2.7.0) instead of a zip archive.
"""
if cloudpickle:
BaseRLModel._save_to_file_cloudpickle(save_path, data, params)
else:
BaseRLModel._save_to_file_zip(save_path, data, params)
"(n_env, {}) for the observation shape.".format(len(observation_space.nvec)))
elif isinstance(observation_space, gym.spaces.MultiBinary):
if observation.shape == (observation_space.n,):
return False
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
return True
else:
raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) +
"environment, please use ({},) or ".format(observation_space.n) +
"(n_env, {}) for the observation shape.".format(observation_space.n))
else:
raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}."
.format(observation_space))
class ActorCriticRLModel(BaseRLModel):
"""
The base class for Actor critic model
:param policy: (BasePolicy) Policy object
:param env: (Gym environment) The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
:param policy_base: (BasePolicy) the base policy used by this method (default=ActorCriticPolicy)
:param requires_vec_env: (bool) Does this model require a vectorized environment
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
:param seed: (int) Seed for the pseudo-random generators (python, numpy, tensorflow).
If None (default), use random seed. Note that if you want completely deterministic
results, you must set `n_cpu_tf_sess` to 1.
:param n_cpu_tf_sess: (int) The number of threads for TensorFlow operations
If None, the number of cpu of the current machine will be used.
"""
def _save_to_file(save_path, data=None, params=None, cloudpickle=False):
"""Save model to a zip archive or cloudpickle file.
:param save_path: (str or file-like) Where to store the model
:param data: (OrderedDict) Class parameters being stored
:param params: (OrderedDict) Model parameters being stored
:param cloudpickle: (bool) Use old cloudpickle format
(stable-baselines<=2.7.0) instead of a zip archive.
"""
if cloudpickle:
BaseRLModel._save_to_file_cloudpickle(save_path, data, params)
else:
BaseRLModel._save_to_file_zip(save_path, data, params)
raise ValueError("The specified policy kwargs do not equal the stored policy kwargs. "
"Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'],
kwargs['policy_kwargs']))
model = cls(policy=data["policy"], env=None, _init_setup_model=False)
model.__dict__.update(data)
model.__dict__.update(kwargs)
model.set_env(env)
model.setup_model()
model.load_parameters(params)
return model
class OffPolicyRLModel(BaseRLModel):
"""
The base class for off policy RL model
:param policy: (BasePolicy) Policy object
:param env: (Gym environment) The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param replay_buffer: (ReplayBuffer) the type of replay buffer
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
:param requires_vec_env: (bool) Does this model require a vectorized environment
:param policy_base: (BasePolicy) the base policy used by this method
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
:param seed: (int) Seed for the pseudo-random generators (python, numpy, tensorflow).
If None (default), use random seed. Note that if you want completely deterministic
results, you must set `n_cpu_tf_sess` to 1.
:param n_cpu_tf_sess: (int) The number of threads for TensorFlow operations
If None, the number of cpu of the current machine will be used.