Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
func = lambda y0, t_points: torchdiffeq.odeint(tuple_f, (y0, y0), t_points, method='dopri5')[i]
self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='adaptive_heun')
self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='dopri5')
self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
def test_adams(self):
for ode in problems.PROBLEMS.keys():
f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
y = torchdiffeq.odeint(f, y0, t_points, method='adams')
with self.subTest(ode=ode):
self.assertLess(rel_error(sol, y), error_tol)
def test_rk4(self):
f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
y = torchdiffeq.odeint(f, y0, t_points, method='rk4')
self.assertLess(rel_error(sol, y), error_tol)
def test_euler(self):
f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
y = torchdiffeq.odeint(f, y0, t_points, method='euler')
self.assertLess(rel_error(sol, y), error_tol)
def test_explicit_adams(self):
f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
y = torchdiffeq.odeint(f, y0, t_points[0:1], method='explicit_adams')
self.assertLess(max_abs(sol[0] - y), error_tol)
func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='rk4')
self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='midpoint')
self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
h = self.encoder(X[:,0])
qz0_m, qz0_logv = self.fc1(h), self.fc2(h) # N,2q & N,2q
q = qz0_m.shape[1]//2
# latent samples
eps = torch.randn_like(qz0_m) # N,2q
z0 = qz0_m + eps*torch.exp(qz0_logv) # N,2q
logp0 = self.mvn.log_prob(eps) # N
# ODE
t = dt * torch.arange(T,dtype=torch.float).to(z0.device)
ztL = []
logpL = []
# sample L trajectories
for l in range(L):
f = self.bnn.draw_f() # draw a differential function
oderhs = lambda t,vs: self.ode2vae_rhs(t,vs,f) # make the ODE forward function
zt,logp = odeint(oderhs,(z0,logp0),t,method=method) # T,N,2q & T,N
ztL.append(zt.permute([1,0,2]).unsqueeze(0)) # 1,N,T,2q
logpL.append(logp.permute([1,0]).unsqueeze(0)) # 1,N,T
ztL = torch.cat(ztL,0) # L,N,T,2q
logpL = torch.cat(logpL) # L,N,T
# decode
st_muL = ztL[:,:,:,q:] # L,N,T,q
s = self.fc3(st_muL.contiguous().view([L*N*T,q]) ) # L*N*T,h_dim
Xrec = self.decoder(s) # L*N*T,nc,d,d
Xrec = Xrec.view([L,N,T,nc,d,d]) # L,N,T,nc,d,d
# likelihood and elbo
if inst_enc:
h = self.encoder(X.contiguous().view([N*T,nc,d,d]))
qz_enc_m, qz_enc_logv = self.fc1(h), self.fc2(h) # N*T,2q & N*T,2q
lhood, kl_z, kl_w, inst_KL = self.elbo(qz0_m, qz0_logv, ztL, logpL, X, Xrec, L, qz_enc_m, qz_enc_logv)
elbo = Ndata*(lhood-kl_z-inst_KL) - kl_w
else: