Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, nvec, flat):
"""
Probability distributions from multicategorical input
:param nvec: ([int]) the sizes of the different categorical inputs
:param flat: ([float]) the categorical logits input
"""
self.flat = flat
self.categoricals = list(map(CategoricalProbabilityDistribution, tf.split(flat, nvec, axis=-1)))
super(MultiCategoricalProbabilityDistribution, self).__init__()
def __init__(self, logits):
"""
Probability distributions from categorical input
:param logits: ([float]) the categorical logits input
"""
self.logits = logits
super(CategoricalProbabilityDistribution, self).__init__()
def probability_distribution_class(self):
return CategoricalProbabilityDistribution