Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def do_mvm(rhs):
@jit
def compute_element(i):
return np.dot(rhs, row(i))
return _chunk_vmap(compute_element, np.arange(rhs.shape[-1]), rhs.shape[-1] // dilation)
return do_mvm
def model(T=10, q=1, r=1, phi=0., beta=0.):
def transition(state, i):
x0, mu0 = state
x1 = numpyro.sample('x', dist.Normal(phi * x0, q))
mu1 = beta * mu0 + x1
y1 = numpyro.sample('y', dist.Normal(mu1, r))
numpyro.deterministic('y2', y1 * 2)
return (x1, mu1), (x1, y1)
mu0 = x0 = numpyro.sample('x_0', dist.Normal(0, q))
y0 = numpyro.sample('y_0', dist.Normal(mu0, r))
_, xy = scan(transition, (x0, mu0), jnp.arange(T))
x, y = xy
return jnp.append(x0, x), jnp.append(y0, y)
output_multiplier = sum(param_dims)
all_ones = (np.array(param_dims) == 1).all()
# Calculate the indices on the output corresponding to each parameter
ends = np.cumsum(np.array(param_dims), axis=0)
starts = np.concatenate((np.zeros(1), ends[:-1]))
param_slices = [slice(int(s), int(e)) for s, e in zip(starts, ends)]
# Hidden dimension must be not less than the input otherwise it isn't
# possible to connect to the outputs correctly
for h in hidden_dims:
if h < input_dim:
raise ValueError('Hidden dimension must not be less than input dimension.')
if permutation is None:
permutation = np.arange(input_dim)
# Create masks
masks, mask_skip = create_mask(input_dim=input_dim, hidden_dims=hidden_dims,
permutation=permutation,
output_dim_multiplier=output_multiplier)
main_layers = []
# Create masked layers
for i, mask in enumerate(masks):
main_layers.append(MaskedDense(mask))
if i < len(masks) - 1:
main_layers.append(nonlinearity)
if skip_connections:
net_init, net = stax.serial(stax.FanOut(2),
stax.parallel(stax.serial(*main_layers),
def model(N, y=None):
"""
:param int N: number of measurement times
:param numpy.ndarray y: measured populations with shape (N, 2)
"""
# initial population
z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1), sample_shape=(2,))
# measurement times
ts = jnp.arange(float(N))
# parameters alpha, beta, gamma, delta of dz_dt
theta = numpyro.sample(
"theta",
dist.TruncatedNormal(low=0., loc=jnp.array([0.5, 0.05, 1.5, 0.05]),
scale=jnp.array([0.5, 0.05, 0.5, 0.05])))
# integrate dz/dt, the result will have shape N x 2
z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
# measurement errors, we expect that measured hare has larger error than measured lynx
sigma = numpyro.sample("sigma", dist.Exponential(jnp.array([1, 2])))
# measured populations (in log scale)
numpyro.sample("y", dist.Normal(jnp.log(z), sigma), obs=y)
def do_mvm(rhs):
M = vmap(row)(np.arange(N))
return np.matmul(M, rhs)
return do_mvm
def fun(omega):
_fun = lambda dim: process_singleton_pcg(dim, P, kappa, kX, omega, Y, eta1, eta2, c, rank1, rank2,
cg_tol=cg_tol, max_iters=max_iters)
return chunk_vmap(_fun, np.arange(P), chunk_size=probe_chunk_size)
def chosen_probabs(probab_actions, actions):
"""Picks out the probabilities of the actions along batch and time-steps.
Args:
probab_actions: ndarray of shape `[B, AT, A]`, where
probab_actions[b, t, i] contains the log-probability of action = i at
the t^th time-step in the b^th trajectory.
actions: ndarray of shape `[B, AT]`, with each entry in [0, A) denoting
which action was chosen in the b^th trajectory's t^th time-step.
Returns:
`[B, AT, A]` ndarray with the log-probabilities of the chosen actions.
"""
B, AT = actions.shape # pylint: disable=invalid-name
assert (B, AT) == probab_actions.shape[:2]
return probab_actions[np.arange(B)[:, None], np.arange(AT), actions]
def range(self, start, limit=None, delta=1):
return np.arange(start, limit, step=delta)
def lowrank_presolve(kX, D, eta1, eta2, c, kappa, rank1, rank2):
N, P = kX.shape
all_ones = np.ones((N, 1))
kappa_indices = np.argsort(kappa)
top_features = dynamic_slice_in_dim(kappa_indices, P - rank1, rank1)
kX_top = np.take(kX, top_features, -1)
if rank2 > 0:
top_features2 = dynamic_slice_in_dim(kappa_indices, P - rank2, rank2)
kX_top2 = np.take(kX, top_features2, -1) # N rank2
kX_top2 = kX_top2[:, None, :] * kX_top2[:, :, None] # N rank2 rank2
lower_diag = np.ravel(np.arange(rank2) < np.arange(rank2)[:, None])
kX_top2 = np.compress(lower_diag, kX_top2.reshape((N, -1)), axis=-1)
Z = np.concatenate([eta2 * kX_top2, eta1 * kX_top, c * all_ones], axis=1)
else:
Z = np.concatenate([eta1 * kX_top, c * all_ones], axis=1)
ZD = Z / D[:, None]
ZDZ = np.eye(ZD.shape[-1]) + np.matmul(np.transpose(Z), ZD)
L = cho_factor(ZDZ, lower=True)[0]
return lambda b: b / D - np.matmul(ZD, cho_solve((L, True), np.matmul(np.transpose(ZD), b)))