Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
self.n_buckets % factor == 0 and
factor % 2 == 0 and
(self.n_buckets // factor) % 2 == 0):
factor -= 1
if factor > 2: # Factor of 2 does not warrant the effort.
rot_size = factor + (self.n_buckets // factor)
factor_list = [factor, self.n_buckets // factor]
random_rotations_shape = (
vecs.shape[-1],
self.n_hashes if self._rehash_each_round else 1,
rot_size // 2)
rng = jax.lax.tie_in(vecs, rng)
rng, subrng = backend.random.split(rng)
random_rotations = jax.random.normal(
rng, random_rotations_shape).astype('float32')
# TODO(lukaszkaiser): the dropout mask will be used for all rounds of
# hashing, so it's shared between them. Check if that's what we want.
dropped_vecs = self.drop_for_hash(vecs, subrng)
rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)
if self._rehash_each_round:
if self._factorize_hash and len(factor_list) > 1:
# We factorized self.n_buckets as the product of factor_list.
# Get the buckets for them and combine.
buckets, cur_sum, cur_product = None, 0, 1
for factor in factor_list:
rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
cur_sum += factor // 2
rv = np.concatenate([rv, -rv], axis=-1)
if buckets is None:
def test_improper_normal():
true_coef = 0.9
def model(data):
alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
loc = numpyro.param('loc', 0., constraint=constraints.interval(0., alpha))
numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
data = true_coef + random.normal(random.PRNGKey(0), (1000,))
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), data)
samples = mcmc.get_samples()
assert_allclose(np.mean(samples['loc'], 0), true_coef, atol=0.05)
def init(rng, shape):
axis_size = lax.psum(1, axis_name)
fan_in, fan_out = shape[in_axis] * axis_size, shape[out_axis] * axis_size
size = onp.prod(onp.delete(shape, [in_axis, out_axis]))
std = scale / np.sqrt((fan_in + fan_out) / 2. * size)
return std * random.normal(rng, shape, dtype=np.float32)
return init
def get_batches(batches=100, sequence_length=1000, key=PRNGKey(0)):
for _ in range(batches):
key, batch_key = random.split(key)
yield random.normal(batch_key, (1, receptive_field + sequence_length, 1))
def _rvs(self, s):
return np.exp(s * random.normal(self._random_state, self._size))
def gaussian_sample(rng, mu, sigmasq):
"""Sample a diagonal Gaussian."""
return mu + np.sqrt(sigmasq) * random.normal(rng, mu.shape)
def diag_gaussian_sample(rng, mean, log_std):
# Take a single sample from a diagonal multivariate Gaussian.
return mean + np.exp(log_std) * random.normal(rng, mean.shape)
noise_ = lambda n: n + std_dev * random.normal(rng, n.shape)
normalize_ = lambda n: n / float(batch_size)
def gaussian_sample(rng, mu, sigmasq):
"""Sample a diagonal Gaussian."""
return mu + np.sqrt(sigmasq) * random.normal(rng, mu.shape)