Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def forward(ctx, *args):
assert len(args) >= 8, 'Internal error: all arguments required.'
y0, func, t, flat_params, rtol, atol, method, options = \
args[:-7], args[-7], args[-6], args[-5], args[-4], args[-3], args[-2], args[-1]
ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options = func, rtol, atol, method, options
with torch.no_grad():
ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options)
ctx.save_for_backward(t, flat_params, *ans)
return ans
grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output)
func_i = func(t[i], ans_i)
# Compute the effect of moving the current time measurement point.
dLd_cur_t = sum(
torch.dot(func_i_.reshape(-1), grad_output_i_.reshape(-1)).reshape(1)
for func_i_, grad_output_i_ in zip(func_i, grad_output_i)
)
adj_time = adj_time - dLd_cur_t
time_vjps.append(dLd_cur_t)
# Run the augmented system backwards in time.
if adj_params.numel() == 0:
adj_params = torch.tensor(0.).to(adj_y[0])
aug_y0 = (*ans_i, *adj_y, adj_time, adj_params)
aug_ans = odeint(
augmented_dynamics, aug_y0,
torch.tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options
)
# Unpack aug_ans.
adj_y = aug_ans[n_tensors:2 * n_tensors]
adj_time = aug_ans[2 * n_tensors]
adj_params = aug_ans[2 * n_tensors + 1]
adj_y = tuple(adj_y_[1] if len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y)
if len(adj_time) > 0: adj_time = adj_time[1]
if len(adj_params) > 0: adj_params = adj_params[1]
adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output))
del aug_y0, aug_ans