Source code for deer.froot

from typing import Callable, Any, Optional
from functools import partial
from abc import abstractmethod
import jax
import jax.numpy as jnp
from deer.utils import get_method_meta, check_method
from deer.utils import Result


[docs]def root(func: Callable[[jnp.ndarray, Any], jnp.ndarray], y0: jnp.ndarray, params: Any, method: Optional["RootMethod"] = None) -> Result: r""" Solve the root of the function, .. math:: f(y; \theta) = 0 Arguments --------- func: Callable[[jnp.ndarray, Any], jnp.ndarray] The function to find the root. The function that takes the current value of the root and the parameters. y0: jnp.ndarray The initial guess of the root. params: Any The parameters of the function. method: Optional[RootMethod] The method to solve the root. If None, then use the ``Newton()`` method. Returns ------- res: Result The result of the root finding. """ if method is None: method = Newton() check_method(method, root) return method.compute(func, y0, params)
class RootMethod(metaclass=get_method_meta(root)): @abstractmethod def compute(self, func: Callable[[jnp.ndarray, Any], jnp.ndarray], y0: jnp.ndarray, params: Any): pass class Newton(RootMethod): """ Compute the root-finding method using Newton's method """ def __init__(self, max_iter: int = 100, atol: float = 1e-6, rtol: float = 1e-3): self.max_iter = max_iter self.atol = atol self.rtol = rtol def compute(self, func: Callable[[jnp.ndarray, Any], jnp.ndarray], y0: jnp.ndarray, params: Any): # y0: (ny,) # func: (ny,) -> (ny,) return newton_iter(func, y0, params, max_iter=self.max_iter, atol=self.atol, rtol=self.rtol) @partial(jax.custom_jvp, nondiff_argnums=(0, 3, 4, 5)) def newton_iter(func: Callable[[jnp.ndarray, Any], jnp.ndarray], y0: jnp.ndarray, params: Any, max_iter: int = 100, atol: float = 1e-6, rtol: float = 1e-3) -> Result: y, is_converged, jac = newton_iter_helper( func, y0, params, max_iter=max_iter, atol=atol, rtol=rtol) return Result(y, is_converged) def newton_iter_helper(func: Callable[[jnp.ndarray, Any], jnp.ndarray], y0: jnp.ndarray, # gradable as 0 params: Any, # gradable max_iter: int = 100, atol: float = 1e-6, rtol: float = 1e-3) -> Result: def iter_func(carry): y, err, tol, iiter, jac0 = carry jac = jax.jacfwd(func)(y, params) fy = func(y, params) jacinvfy = jnp.linalg.solve(jac, fy) # doing lstsq to handle singular matrix jacinvfy = jax.lax.cond(jnp.all(jnp.isfinite(jacinvfy)), lambda : jacinvfy, lambda : jnp.linalg.lstsq(jac, fy)[0]) ynext = y - jacinvfy # ynext = y - jnp.linalg.lstsq(jac, fy)[0] # clip nans and infs clip = 1e8 ynext = jnp.clip(ynext, min=-clip, max=clip) ynext = jnp.where(jnp.isnan(ynext), 0.0, ynext) err = jnp.abs(ynext - y) tol = atol + rtol * jnp.abs(ynext) iiter += 1 return ynext, err, tol, iiter, jac def cond_func(carry): y, err, tol, iiter, jac0 = carry return jnp.logical_and(jnp.any(err > tol), iiter < max_iter) err = jnp.full_like(y0, jnp.inf) tol = jnp.zeros_like(y0) iiter = jnp.array(0, dtype=jnp.int32) jac0 = jnp.zeros((y0.size, y0.size)) y, err, tol, iiter, jac = jax.lax.while_loop(cond_func, iter_func, (y0, err, tol, iiter, jac0)) is_converged = iiter < max_iter return y, is_converged, jac @newton_iter.defjvp def newton_iter_jvp( # collect non-gradable input first func: Callable[[jnp.ndarray, Any], jnp.ndarray], max_iter: int, atol: float, rtol: float, # meaningful arguments primals, tangents): y0, params = primals _, grad_params = tangents # compute the iterations yt, is_converged, jac = newton_iter_helper( func, y0, params, max_iter=max_iter, atol=atol, rtol=rtol) # compute grad of f func_partial_y = partial(func, yt) # grad_func: (ny,) _, grad_func = jax.jvp(func_partial_y, (params,), (grad_params,)) grad_y = jnp.linalg.solve(jac, -grad_func) # (ny,) is_converged_tangent = jnp.zeros_like(is_converged, dtype=jax.dtypes.float0) result = Result(yt, is_converged) grad_result = Result(grad_y, success=is_converged_tangent) return result, grad_result