Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
print(" non-finite elements exist, try to fix")
y0_[y0_ != y0_] = 0.
y0_[y0_ == float("Inf")] = 0.
y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU)
########################################################
# Error Ratio #
########################################################
mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1)
accept_step = (torch.tensor(mean_sq_error_ratio) <= 1).all()
########################################################
# Update RK State #
########################################################
dt_next = _optimal_step_size(
dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=5)
if not (dt_next<0.02): #not (dt_next<0.02 or dt_next>0.1):
y_next = y1 if accept_step else y0
f_next = f1 if accept_step else f0
t_next = t0 + dt if accept_step else t0
interp_coeff = _interp_fit_dopri5(y0, y1, k, dt) if accept_step else interp_coeff
else:
if dt_next<0.02:
print("warning the step of dopri5 {} is too small, set to 0.01".format(dt_next))
dt_next = _convert_to_tensor(0.01, dtype=torch.float64, device=y0[0].device)
if dt_next>0.1:
print("warning the step of dopri5 {} is too big, set to 0.1".format(dt_next))
dt_next = _convert_to_tensor(0.1, dtype=torch.float64, device=y0[0].device)
y_next = y1
f_next = f1
t_next = t0 + dt
y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_ADAPTIVE_HEUN_TABLEAU)
########################################################
# Error Ratio #
########################################################
mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1)
accept_step = (torch.tensor(mean_sq_error_ratio) <= 1).all()
########################################################
# Update RK State #
########################################################
y_next = y1 if accept_step else y0
f_next = f1 if accept_step else f0
t_next = t0 + dt if accept_step else t0
interp_coeff = _interp_fit_adaptive_heun(y0, y1, k, dt) if accept_step else interp_coeff
dt_next = _optimal_step_size(
dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=5
)
rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff)
return rk_state