Source code for deer.fsolve_idae

from abc import abstractmethod
from typing import Any, Callable, List, Optional
import jax
import jax.numpy as jnp
import optimistix as optx
from deer.deer_iter import deer_iteration
from deer.maths import matmul_recursive
from deer.utils import get_method_meta, check_method, Result
from deer.froot import root, RootMethod


__all__ = ["solve_idae"]

[docs]def solve_idae(func: Callable[[jnp.ndarray, jnp.ndarray, Any, Any], jnp.ndarray], y0: jnp.ndarray, xinp: Any, params: Any, tpts: jnp.ndarray, method: Optional["SolveIDAEMethod"] = None, ) -> Result: r""" Solve the implicit differential algebraic equations (IDAE) systems. .. math:: f(\dot{y}, y, x; \theta) = 0 where :math:`\dot{y}` is the time-derivative of the output signal :math:`y`, :math:`x` is the input signal at given sampling time :math:`t`, and :math:`\theta` are the parameters of the function. The tentative initial condition is given by :math:`y(0) = y_0`. Arguments --------- func: Callable[[jnp.ndarray, jnp.ndarray, Any, Any], jnp.ndarray] Function to evaluate the residual of the IDAE system. The arguments are: (1) time-derivative of the output signal :math:`\dot{y}` ``(ny,)``, (2) output signal :math:`y` ``(ny,)``, (3) input signal :math:`x` ``(*nx,)`` in a pytree, and (4) parameters :math:`\theta` in a pytree. The return value is the residual of the IDAE system ``(ny,)``. y0: jnp.ndarray Tentative initial condition on :math:`y` ``(ny,)``. If the IDAE system has algebraic variables, then the initial values of the algebraic variables might be different to what is supplied. xinp: Any The external input signal of shape ``(nsamples, *nx)`` in a pytree. params: Any The parameters of the function ``func``. tpts: jnp.ndarray The time points to evaluate the solution ``(nsamples,)``. method: Optional[SolveIDAEMethod] The method to solve the implicit DAE. If None, then use the ``BwdEulerDEER()`` method. Returns ------- res: Result The ``Result`` object where ``.value`` is the solution of the IDAE system at the given time with shape ``(nsamples, ny)`` and ``.success`` is the boolean array indicating the convergence of the solver. Examples -------- >>> import jax.numpy as jnp >>> def idae_func(dy, y, x, params): ... return dy + y - x - params >>> y0 = jnp.array([1.0]) >>> xinp = jnp.array([[0.0], [1.0], [2.0], [3.0]]) >>> params = jnp.array([0.5]) >>> tpts = jnp.array([0.0, 1.0, 2.0, 3.0]) >>> solve_idae(idae_func, y0, xinp, params, tpts).value Array([[1. ], [1.25 ], [1.875 ], [2.6875]], dtype=float64) """ if method is None: method = BwdEulerDEER() check_method(method, solve_idae) return method.compute(func, y0, xinp, params, tpts)
class SolveIDAEMethod(metaclass=get_method_meta(solve_idae)): @abstractmethod def compute(self, func: Callable[[jnp.ndarray, Any, Any], jnp.ndarray], y0: jnp.ndarray, xinp: Any, params: Any, tpts: jnp.ndarray) -> Result: pass class BwdEuler(SolveIDAEMethod): """ Solve the implicit DAE method using backward Euler's method. Arguments --------- solver: Optional[RootMethod] The root finder solver. If None, then use the Newton's method. """ def __init__(self, solver: Optional[RootMethod] = None): if solver is None: solver = root.Newton(max_iter=200, atol=1e-6, rtol=1e-3) self.solver = solver def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any, Any], jnp.ndarray], y0: jnp.ndarray, xinp: Any, params: Any, tpts: jnp.ndarray) -> Result: # y0: (ny,) the initial states (it's not checked for correctness) # xinp: pytree, each has `(nsamples, *nx)` # tpts: (nsamples,) the time points # returns: (nsamples, ny), including the initial states def fn(yi, args): yim1, xi, dti, params = args return func((yi - yim1) / dti, yi, xi, params) def scan_fn(carry, x): _, success = carry def success_fn(carry, x): yprev, success = carry xi, dti = x sol = root(fn, yprev, (yprev, xi, dti, params), method=self.solver) yi = sol.value success = sol.success return yi, success def fail_fn(carry, x): yprev, _ = carry return yprev, jnp.full_like(yprev, False, dtype=jnp.bool) res = jax.lax.cond(jnp.all(success), success_fn, fail_fn, carry, x) return res, res dti = tpts[1:] - tpts[:-1] # (nsamples - 1,) xi = jax.tree_util.tree_map(lambda x: x[1:], xinp) # (nsamples - 1, *nx) carry = (y0, jnp.full_like(y0, True, dtype=jnp.bool)) _, (y, success) = jax.lax.scan(scan_fn, carry, (xi, dti)) # (nsamples - 1, ny) y = jnp.concatenate((y0[None], y), axis=0) # (nsamples, ny) # (nsamples, ny) success = jnp.concatenate((jnp.full_like(success[:1], True, dtype=jnp.bool), success), axis=0) # TODO: turn off the throw error in Newton, and check the convergence to be put in the Result here return Result(y, success) class BwdEulerDEER(SolveIDAEMethod): """ Solve the implicit DAE method using DEER method for backward Euler's method. Arguments --------- yinit_guess: Optional[jnp.ndarray] The initial guess of the output signal ``(nsamples, ny)``. If None, it will be initialized as all ``y0``. max_iter: int The maximum number of DEER iterations to perform. atol: Optional[float] The absolute tolerance of the DEER iteration convergence. rtol: Optional[float] The relative tolerance of the DEER iteration convergence. """ def __init__(self, yinit_guess: Optional[jnp.ndarray] = None, max_iter: int = 200, atol: Optional[float] = None, rtol: Optional[float] = None): self.yinit_guess = yinit_guess self.max_iter = max_iter self.atol = atol self.rtol = rtol def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any, Any], jnp.ndarray], y0: jnp.ndarray, xinp: Any, params: Any, tpts: jnp.ndarray) -> Result: # y0: (ny,) the initial states (it's not checked for correctness) # xinp: pytree, each has `(nsamples, *nx)` # tpts: (nsamples,) the time points # returns: (nsamples, ny), including the initial states # set the default initial guess yinit_guess = self.yinit_guess if yinit_guess is None: yinit_guess = jnp.zeros((tpts.shape[0], y0.shape[-1]), dtype=tpts.dtype) + y0 def func2(yshifts: List[jnp.ndarray], x: Any, params: Any) -> jnp.ndarray: # yshifts: [2] + (ny,) # x is dt y, ym1 = yshifts dt, xinp = x return func((y - ym1) / dt, y, xinp, params) def linfunc(y: jnp.ndarray, lin_params: Any) -> List[jnp.ndarray]: # y: (nsamples, ny) # we're using backward euler's method, so we need to shift the values by one ym1 = jnp.concatenate((y[:1], y[:-1]), axis=0) # (nsamples, ny) return [y, ym1] # dt[i] = t[i] - t[i - 1] dt_partial = tpts[1:] - tpts[:-1] # (nsamples - 1,) dt = jnp.concatenate((dt_partial[:1], dt_partial), axis=0) # (nsamples,) xinput = (dt, xinp) inv_lin_params = (y0,) result = deer_iteration( inv_lin=self.solve_idae_inv_lin, func=func2, shifter_func=linfunc, p_num=2, params=params, xinput=xinput, inv_lin_params=inv_lin_params, shifter_func_params=None, yinit_guess=yinit_guess, max_iter=self.max_iter, clip_ytnext=True, atol=self.atol, rtol=self.rtol, ) return result def solve_idae_inv_lin(self, jacs: List[jnp.ndarray], z: jnp.ndarray, inv_lin_params: Any) -> jnp.ndarray: # solving the equation: M0_i @ y_i + M1_i @ y_{i-1} = z_i # M: (nsamples, ny, ny) # G: (nsamples, ny, ny) # rhs: (nsamples, ny) # inv_lin_params: (y0,) where tpts: (nsamples,), y0: (ny,) M0, M1 = jacs y0, = inv_lin_params # tpts: (nsamples,), y0: (ny,) # using index [1:] because we don't need to compute y_0 again (it's already available from y0) M01 = M0[1:] M0invM1 = -jax.vmap(jnp.linalg.solve)(M01, M1[1:]) M0invz = jax.vmap(jnp.linalg.solve)(M01, z[1:]) y = matmul_recursive(M0invM1, M0invz, y0) # (nsamples, ny) return y