Source code for deer.fseq1d

from typing import Callable, Tuple, Optional, Any, List
from abc import abstractmethod
from deer.utils import get_method_meta, check_method
import jax
import jax.numpy as jnp
from deer.deer_iter import deer_iteration
from deer.maths import matmul_recursive
from deer.utils import Result


__all__ = ["seq1d"]

[docs]def seq1d(func: Callable[[jnp.ndarray, Any, Any], jnp.ndarray], y0: jnp.ndarray, xinp: Any, params: Any, method: Optional["Seq1DMethod"] = None, ) -> jnp.ndarray: r""" Solve the discrete sequential equation .. math:: y_{i + 1} = f(y_i, x_i; \theta) where :math:`f` is a non-linear function, :math:`y_i` is the output signal at time :math:`i`, :math:`x_i` is the input signal at time :math:`i`, and :math:`\theta` are the parameters of the function. Arguments --------- func: Callable[[jnp.ndarray, Any, Any], jnp.ndarray] Function to evaluate the next output signal :math:`y_{i+1}` from the current output signal :math:`y_i`. The arguments are: signal :math:`y` at the current time ``(ny,)``, input signal :math:`x` at the current time ``(*nx,)`` in a pytree, and parameters :math:`\theta` in a pytree. The return value is the next output signal :math:`y` at the next time ``(ny,)``. y0: jnp.ndarray Initial condition on :math:`y` ``(ny,)``. xinp: Any The external input signal in a pytree of shape ``(nsamples, *nx)`` params: Any The parameters of the function ``func``. method: Optional[Seq1DMethod] The method to solve the 1D sequence. If None, then use the ``DEER()`` method. Returns ------- res: Result The ``Result`` object where ``.value`` is the solution of the sequential model with shape ``(nsamples, ny)`` and ``.success`` is the boolean array indicating the convergence of the solver. Examples -------- >>> import jax >>> import jax.numpy as jnp >>> from fseq1d import seq1d >>> def func(y, x, params): ... return y ** 2 + x * params[0] >>> y0 = jnp.array([0.0]) >>> xinp = jnp.linspace(0, 1, 10).reshape(-1, 1) >>> params = jnp.array([0.5]) >>> y = seq1d(func, y0, xinp, params, method=seq1d.Sequential()).value >>> y Array([[0. ], [0.05555556], [0.11419753], [0.17970774], [0.2545171 ], [0.34255673], [0.45067845], [0.59199995], [0.79490839], [1.13187934]], dtype=float64) """ if method is None: method = DEER() check_method(method, seq1d) return method.compute(func, y0, xinp, params)
class Seq1DMethod(metaclass=get_method_meta(seq1d)): @abstractmethod def compute(self, func: Callable[[jnp.ndarray, Any, Any], jnp.ndarray], y0: jnp.ndarray, xinp: Any, params: Any): pass class Sequential(Seq1DMethod): """ Compute the 1D sequence with traditional sequential method. """ def compute(self, func: Callable[[jnp.ndarray, Any, Any], jnp.ndarray], y0: jnp.ndarray, xinp: Any, params: Any): # compute y[i] = f(y[i - 1], x[i]; params) # xinp: pytree, each has `(nsamples, *nx)` # y0: (ny,) the initial states # returns: (nsamples, ny), excluding the initial states def scan_fn(carry, x): yim1 = carry y = func(yim1, x, params) return y, y _, y = jax.lax.scan(scan_fn, y0, xinp) return Result(y) class DEER(Seq1DMethod): """ Compute the 1D sequential method using DEER method. Arguments --------- yinit_guess: Optional[jnp.ndarray] The initial guess of the output signal ``(nsamples, ny)``. If None, it will be initialized as all ``y0``. max_iter: int The maximum number of DEER iterations to perform. atol: Optional[float] The absolute tolerance of the DEER iteration convergence. rtol: Optional[float] The relative tolerance of the DEER iteration convergence. """ def __init__(self, yinit_guess: Optional[jnp.ndarray] = None, max_iter: int = 10000, atol: Optional[float] = None, rtol: Optional[float] = None): self.yinit_guess = yinit_guess self.max_iter = max_iter self.atol = atol self.rtol = rtol def compute(self, func: Callable[[jnp.ndarray, Any, Any], jnp.ndarray], y0: jnp.ndarray, xinp: Any, params: Any): # set the default initial guess xinp_flat = jax.tree_util.tree_flatten(xinp)[0][0] yinit_guess = self.yinit_guess if yinit_guess is None: yinit_guess = jnp.zeros((xinp_flat.shape[0], y0.shape[-1]), dtype=xinp_flat.dtype) + y0 # (nsamples, ny) def func2(yshifts: List[jnp.ndarray], x: Any, params: Any) -> jnp.ndarray: # yshifts: (ny,) return func(yshifts[0], x, params) def shifter_func(y: jnp.ndarray, shifter_params: Any) -> List[jnp.ndarray]: # y: (nsamples, ny) # shifter_params = (y0,) y0, = shifter_params y = jnp.concatenate((y0[None, :], y[:-1, :]), axis=0) # (nsamples, ny) return [y] # perform the deer iteration result = deer_iteration( inv_lin=self.seq1d_inv_lin, p_num=1, func=func2, shifter_func=shifter_func, params=params, xinput=xinp, inv_lin_params=(y0,), shifter_func_params=(y0,), yinit_guess=yinit_guess, max_iter=self.max_iter, clip_ytnext=True, atol=self.atol, rtol=self.rtol) return result def seq1d_inv_lin(self, gmat: List[jnp.ndarray], rhs: jnp.ndarray, inv_lin_params: Tuple[jnp.ndarray]) -> jnp.ndarray: """ Inverse of the linear operator for solving the discrete sequential equation. y[i + 1] + G[i] y[i] = rhs[i], y[0] = y0. Arguments --------- gmat: jnp.ndarray The list of 1 G-matrix of shape (nsamples, ny, ny). rhs: jnp.ndarray The right hand side of the equation of shape (nsamples, ny). inv_lin_params: Tuple[jnp.ndarray] The parameters of the linear operator. The first element is the initial condition (ny,). Returns ------- y: jnp.ndarray The solution of the linear equation of shape (nsamples, ny). """ # extract the parameters y0, = inv_lin_params gmat = gmat[0] # compute the recursive matrix multiplication and drop the first element yt = matmul_recursive(-gmat, rhs, y0)[1:] # (nsamples, ny) return yt