Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, flat):
"""
Probability distributions from multivariate Gaussian input
:param flat: ([float]) the multivariate Gaussian input data
"""
self.flat = flat
mean, logstd = tf.split(axis=len(flat.shape) - 1, num_or_size_splits=2, value=flat)
self.mean = mean
self.logstd = logstd
self.std = tf.exp(logstd)
super(DiagGaussianProbabilityDistribution, self).__init__()
def kl(self, other):
assert isinstance(other, DiagGaussianProbabilityDistribution)
return tf.reduce_sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) /
(2.0 * tf.square(other.std)) - 0.5, axis=-1)
def probability_distribution_class(self):
return DiagGaussianProbabilityDistribution