Source code for deer.fsolve_ivp

from abc import abstractmethod
from typing import Any, Callable, List, Optional, Tuple
import jax.numpy as jnp
from deer.deer_iter import deer_iteration
from deer.maths import matmul_recursive
from deer.utils import get_method_meta, check_method, Result


__all__ = ["solve_ivp"]

[docs]def solve_ivp(func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray, method: Optional["SolveIVPMethod"] = None, ) -> Result: r""" Solve the initial value problem. .. math:: \frac{dy}{dt} = f(y, x; \theta) with given initial condition :math:`y(0) = y_0`, where :math:`y` is the output signal, :math:`x` is the input signal, and :math:`\theta` is the parameters of the function. This function will return the output signal :math:`y` at the time points :math:`t`. Arguments --------- func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray] Function to evaluate the derivative of :math:`y` with respect to :math:`t`. The arguments are: output signal :math:`y` ``(ny,)``, input signal :math:`x` ``(nx,)``, and parameters :math:`\theta` in a pytree. The return value is the derivative of :math:`y` with respect to :math:`t`, i.e., :math:`\frac{dy}{dt}` ``(ny,)``. y0: jnp.ndarray Initial condition on :math:`y` ``(ny,)``. xinp: jnp.ndarray The external input signal of shape ``(nsamples, nx)``. params: Any The parameters of the function ``func``. tpts: jnp.ndarray The time points to evaluate the solution ``(nsamples,)``. method: Optional[SolveIVPMethod] The method to solve the initial value problem. If None, then use the ``DEER()`` method. Returns ------- res: Result The ``Result`` object where ``.value`` is the solution of the IVP system at the given time with shape ``(nsamples, ny)`` and ``.success`` is the boolean array indicating the convergence of the solver. Examples -------- >>> import jax.numpy as jnp >>> from fsolve_ivp import solve_ivp >>> >>> def simple_harmonic_oscillator(y, x, params): ... k, m = params ... dydt = jnp.array([y[1], -k/m*y[0]]) ... return dydt >>> >>> y0 = jnp.array([1.0, 0.0]) >>> xinp = jnp.zeros((100, 0)) # no input signal >>> params = (1.0, 1.0) # k, m >>> tpts = jnp.linspace(0, 10, 100) >>> >>> y = solve_ivp(simple_harmonic_oscillator, y0, xinp, params, tpts).value >>> # The output y should be an array of shape (nsamples, ny) >>> y.shape (100, 2) >>> # Check the first and last values (should be close to [1.0, 0.0] and [cos(10), -sin(10)] respectively) >>> jnp.allclose(y[0], jnp.array([1.0, 0.0])) Array(True, dtype=bool) >>> jnp.allclose(y[-1], jnp.array([jnp.cos(10), -jnp.sin(10)]), atol=1e-2) Array(True, dtype=bool) """ if method is None: method = DEER() check_method(method, solve_ivp) return method.compute(func, y0, xinp, params, tpts)
class SolveIVPMethod(metaclass=get_method_meta(solve_ivp)): @abstractmethod def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray) -> Result: pass class DEER(SolveIVPMethod): """ Compute the solution of initial value problem with the DEER method. Arguments --------- yinit_guess: jnp.ndarray or None The initial guess of the output signal ``(nsamples, ny)``. If None, it will be initialized as 0s. max_iter: int The maximum number of iterations to perform. atol: Optional[float] The absolute tolerance for the convergence of the solver. rtol: Optional[float] The relative tolerance for the convergence of the solver. """ def __init__(self, yinit_guess: Optional[jnp.ndarray] = None, max_iter: int = 10000, atol: Optional[float] = None, rtol: Optional[float] = None): self.yinit_guess = yinit_guess self.max_iter = max_iter self.atol = atol self.rtol = rtol def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray) -> Result: # set the default initial guess yinit_guess = self.yinit_guess if yinit_guess is None: yinit_guess = jnp.zeros((tpts.shape[0], y0.shape[-1]), dtype=tpts.dtype) + y0 def func2(ylist: List[jnp.ndarray], x: jnp.ndarray, params: Any) -> jnp.ndarray: return func(ylist[0], x, params) def shifter_func(y: jnp.ndarray, params: Any) -> List[jnp.ndarray]: # y: (nsamples, ny) return [y] # perform the deer iteration inv_lin_params = (tpts, y0) result = deer_iteration( inv_lin=self.solve_ivp_inv_lin, p_num=1, func=func2, shifter_func=shifter_func, params=params, xinput=xinp, inv_lin_params=inv_lin_params, shifter_func_params=(), yinit_guess=yinit_guess, max_iter=self.max_iter, clip_ytnext=True, atol=self.atol, rtol=self.rtol) return result def solve_ivp_inv_lin(self, gmat: List[jnp.ndarray], rhs: jnp.ndarray, inv_lin_params: Tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray: """ Inverse of the linear operator for solving the initial value problem. dy/dt + G(t) y = rhs(t), y(0) = y0. Arguments --------- gmat: list of jnp.ndarray The list of 1 G-matrix of shape (nsamples, ny, ny). rhs: jnp.ndarray The right hand side of the equation of shape (nsamples, ny). inv_lin_params: Tuple[jnp.ndarray, jnp.ndarray] The parameters of the linear operator. The first element is the time points (nsamples,), and the second element is the initial condition (ny,). Returns ------- y: jnp.ndarray The solution of the linear equation of shape (nsamples, ny). """ # extract the parameters tpts, y0 = inv_lin_params gmat = gmat[0] # (nsamples, ny, ny) eye = jnp.eye(gmat.shape[-1], dtype=gmat.dtype) # (ny, ny) # taking the mid-point of gmat and rhs half_dt = 0.5 * (tpts[1:] - tpts[:-1]) # (nsamples - 1,) gtmid_dt = (gmat[1:] + gmat[:-1]) * half_dt[..., None, None] # (nsamples - 1, ny, ny) htmid_dt = (rhs[1:] + rhs[:-1]) * half_dt[..., None] # (nsamples - 1, ny) # get the matrices and vectors to be convolved gtmid_dt2 = gtmid_dt @ gtmid_dt # (nt - 1, ny, ny) gtmid_dt3 = gtmid_dt @ gtmid_dt2 # (nt - 1, ny, ny) htbar_helper = eye - gtmid_dt / 2 + gtmid_dt2 / 6 - gtmid_dt3 / 24 gtbar = (eye - htbar_helper @ gtmid_dt) # (nt - 1, ny, ny) # approximate expm(-gtmid_dt) htbar = jnp.einsum("...ij,...j->...i", htbar_helper, htmid_dt) # compute the recursive matrix multiplication yt = matmul_recursive(gtbar, htbar, y0) # (nt, ny) return yt