Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def compute_element(i):
return np.dot(rhs, row(i))
return _chunk_vmap(compute_element, np.arange(rhs.shape[-1]), rhs.shape[-1] // dilation)
grad_sum2 = np.dot(np.ones((1,N)), grads)
assert np.allclose(grad_sum, grad_sum2)
# Now make things go fast
from jax import jit
grad_fun = jit(grad(loss))
grads = vmap(partial(grad_fun, w))(X,y)
assert np.allclose(grads, grads2)
# Logistic regression
H1 = hessian(loss)(w, X, y)
mu = predict(w, X)
S = np.diag(mu * (1-mu))
H2 = np.dot(np.dot(X.T, S), X)
assert np.allclose(H1, H2)
def pcg_quad_form_log_det_jvp(primals, tangents):
A, b, probes, cg_tol, max_iters = primals
A_dot, b_dot, _, _, _, _ = tangents
D = b.shape[-1]
b_probes = np.concatenate([b[None, :], probes])
Ainv_b_probes, res_norm, iters = pcg_batch_b(b_probes, A, max_iters=max_iters)
Ainv_b, Ainv_probes = Ainv_b_probes[0], Ainv_b_probes[1:]
quad_form_dA = -np.dot(Ainv_b, np.matmul(A_dot, Ainv_b))
quad_form_db = 2.0 * np.dot(Ainv_b, b_dot)
log_det_dA = np.mean(np.einsum('...i,...i->...', np.matmul(probes, A_dot), Ainv_probes))
tangent_out = log_det_dA + quad_form_dA + quad_form_db
quad_form = np.dot(b, Ainv_b)
return quad_form, tangent_out
Args:
y0: function value at the start of the interval.
y1: function value at the end of the interval.
y_mid: function value at the mid-point of the interval.
dy0: derivative value at the start of the interval.
dy1: derivative value at the end of the interval.
dt: width of the interval.
Returns:
Coefficients `[a, b, c, d, e]` for the polynomial
p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e
"""
v = np.stack([dy0, dy1, y0, y1, y_mid])
a = np.dot(np.hstack([-2. * dt, 2. * dt, np.array([-8., -8., 16.])]), v)
b = np.dot(np.hstack([5. * dt, -3. * dt, np.array([18., 14., -32.])]), v)
c = np.dot(np.hstack([-4. * dt, dt, np.array([-11., -5., 16.])]), v)
d = dt * dy0
e = y0
return a, b, c, d, e
def matvec(A, b):
return np.dot(A, b)
nontrainable_params=self._nontrainable_params,
state=self._model_state,
rng=k1))
opt_step += 1
self._total_opt_step += 1
# Compute the approx KL for early stopping. Use the whole dataset - as we
# only do inference, it should fit in the memory.
(log_probab_actions_new, _) = (
self._policy_and_value_net_apply(
padded_observations,
weights=self._policy_and_value_net_weights,
state=self._model_state,
rng=k2))
action_mask = np.dot(
np.pad(reward_mask, ((0, 0), (0, 1))), self._rewards_to_actions
)
approx_kl = ppo.approximate_kl(log_probab_actions_new, log_probabs_traj,
action_mask)
early_stopping = approx_kl > 1.5 * self._target_kl
if early_stopping:
logging.vlog(
1, 'Early stopping policy and value optimization after %d steps, '
'with approx_kl: %0.2f', opt_step, approx_kl)
# We don't return right-away, we want the below to execute on the last
# iteration.
t2 = time.time()
if (opt_step % self._print_every_optimizer_steps == 0 or
opt_step == self._n_optimizer_steps or early_stopping):
value_prediction_old=value_prediction_old,
epsilon=epsilon)
(ppo_loss, ppo_summaries) = ppo_loss_given_predictions(
log_probab_actions_new,
log_probab_actions_old,
value_prediction_old,
padded_actions,
rewards_to_actions,
padded_rewards,
reward_mask,
gamma=gamma,
lambda_=lambda_,
epsilon=epsilon)
# Pad the reward mask to be compatible with rewards_to_actions.
padded_reward_mask = np.pad(reward_mask, ((0, 0), (0, 1)))
action_mask = np.dot(padded_reward_mask, rewards_to_actions)
entropy_bonus = masked_entropy(log_probab_actions_new, action_mask)
combined_loss_ = ppo_loss + (c1 * value_loss) - (c2 * entropy_bonus)
summaries = {
"combined_loss": combined_loss_,
"entropy_bonus": entropy_bonus,
}
for loss_summaries in (value_summaries, ppo_summaries):
summaries.update(loss_summaries)
return (combined_loss_, (ppo_loss, value_loss, entropy_bonus), summaries)
Args:
y0: function value at the start of the interval.
y1: function value at the end of the interval.
y_mid: function value at the mid-point of the interval.
dy0: derivative value at the start of the interval.
dy1: derivative value at the end of the interval.
dt: width of the interval.
Returns:
Coefficients `[a, b, c, d, e]` for the polynomial
p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e
"""
v = np.stack([dy0, dy1, y0, y1, y_mid])
a = np.dot(np.hstack([-2. * dt, 2. * dt, np.array([-8., -8., 16.])]), v)
b = np.dot(np.hstack([5. * dt, -3. * dt, np.array([18., 14., -32.])]), v)
c = np.dot(np.hstack([-4. * dt, dt, np.array([-11., -5., 16.])]), v)
d = dt * dy0
e = y0
return a, b, c, d, e
def cg_body_fun(state, mvm):
x, r, p, r_dot_r, iteration = state
Ap = mvm(p)
alpha = r_dot_r / np.dot(p, Ap)
x = x + alpha * p
r = r - alpha * Ap
beta_denom = r_dot_r
r_dot_r = np.dot(r, r)
beta = r_dot_r / beta_denom
p = r + beta * p
return CGState(x, r, p, r_dot_r, iteration + 1)
gamma=gamma)
# (B, RT)
advantages = gae_advantages(
td_deltas, reward_mask, lambda_=lambda_, gamma=gamma)
# Normalize the advantages.
advantage_mean = np.mean(advantages)
advantage_std = np.std(advantages)
advantages = (advantages - advantage_mean) / (advantage_std + 1e-8)
# Scatter advantages over padded_actions.
# rewards_to_actions is RT + 1 -> AT, so we pad the advantages and the reward
# mask by 1.
advantages = np.dot(np.pad(advantages, ((0, 0), (0, 1))), rewards_to_actions)
action_mask = np.dot(
np.pad(reward_mask, ((0, 0), (0, 1))), rewards_to_actions
)
# (B, AT)
ratios = compute_probab_ratios(log_probab_actions_new, log_probab_actions_old,
padded_actions, action_mask)
assert (B, AT) == ratios.shape
# (B, AT)
objective = clipped_objective(
ratios, advantages, action_mask, epsilon=epsilon)
assert (B, AT) == objective.shape
# ()
average_objective = np.sum(objective) / np.sum(action_mask)