solve_ivp
- deer.solve_ivp(func: Callable[[Array, Array, Any], Array], y0: Array, xinp: Array, params: Any, tpts: Array, method: SolveIVPMethod | None = None) Result[source]
Solve the initial value problem.
\[\frac{dy}{dt} = f(y, x; \theta)\]with given initial condition \(y(0) = y_0\), where \(y\) is the output signal, \(x\) is the input signal, and \(\theta\) is the parameters of the function. This function will return the output signal \(y\) at the time points \(t\).
- Parameters:
func (Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]) – Function to evaluate the derivative of \(y\) with respect to \(t\). The arguments are: output signal \(y\)
(ny,), input signal \(x\)(nx,), and parameters \(\theta\) in a pytree. The return value is the derivative of \(y\) with respect to \(t\), i.e., \(\frac{dy}{dt}\)(ny,).y0 (jnp.ndarray) – Initial condition on \(y\)
(ny,).xinp (jnp.ndarray) – The external input signal of shape
(nsamples, nx).params (Any) – The parameters of the function
func.tpts (jnp.ndarray) – The time points to evaluate the solution
(nsamples,).method (Optional[SolveIVPMethod]) – The method to solve the initial value problem. If None, then use the
DEER()method.
- Returns:
res – The
Resultobject where.valueis the solution of the IVP system at the given time with shape(nsamples, ny)and.successis the boolean array indicating the convergence of the solver.- Return type:
Result
Examples
>>> import jax.numpy as jnp >>> from fsolve_ivp import solve_ivp >>> >>> def simple_harmonic_oscillator(y, x, params): ... k, m = params ... dydt = jnp.array([y[1], -k/m*y[0]]) ... return dydt >>> >>> y0 = jnp.array([1.0, 0.0]) >>> xinp = jnp.zeros((100, 0)) # no input signal >>> params = (1.0, 1.0) # k, m >>> tpts = jnp.linspace(0, 10, 100) >>> >>> y = solve_ivp(simple_harmonic_oscillator, y0, xinp, params, tpts).value >>> # The output y should be an array of shape (nsamples, ny) >>> y.shape (100, 2) >>> # Check the first and last values (should be close to [1.0, 0.0] and [cos(10), -sin(10)] respectively) >>> jnp.allclose(y[0], jnp.array([1.0, 0.0])) Array(True, dtype=bool) >>> jnp.allclose(y[-1], jnp.array([jnp.cos(10), -jnp.sin(10)]), atol=1e-2) Array(True, dtype=bool)
- method=solve_ivp.DEER()
solve_ivp.DEER(yinit_guess: Optional[jax.Array] = None, max_iter: int = 10000, atol: Optional[float] = None, rtol: Optional[float] = None)
Compute the solution of initial value problem with the DEER method.
- Parameters:
yinit_guess (jnp.ndarray or None) – The initial guess of the output signal
(nsamples, ny). If None, it will be initialized as 0s.max_iter (int) – The maximum number of iterations to perform.
atol (Optional[float]) – The absolute tolerance for the convergence of the solver.
rtol (Optional[float]) – The relative tolerance for the convergence of the solver.