Source code for rkstiff.if34

r"""
Adaptive-Step Fourth Order (Third Order Embedding) Integrating Factor Integrator
================================================================================

Adaptive Integrating Factor 4(3) solver.

Implements the **IF(3,4)** exponential Runge–Kutta scheme with an embedded
third-order method for local error estimation and adaptive step control.
It solves stiff semi-linear systems of the form

.. math::

        \frac{\partial \mathbf{U}}{\partial t}
        = \mathcal{L}\mathbf{U}
        + \mathcal{N}(\mathbf{U}),

where :math:`\mathcal{L}` is the linear (stiff) operator and
:math:`\mathcal{N}` the nonlinear term.

The IF(3,4) method integrates this system in the exponential form

.. math::

    \mathbf{U}_{n+1} = e^{h\mathcal{L}}\mathbf{U}_n
        + h \sum_{i=1}^{s} b_i \,
        e^{(1-c_i)h\mathcal{L}} \, \mathcal{N}(\mathbf{U}_i),

where the intermediate stages :math:`\mathbf{U}_i` are computed using
exponential operators and the nonlinear evaluations.

The embedded third-order estimate is used to compute adaptive step sizes
according to local error tolerances.

References
----------
P. Whalen, M. Brio, and J. V. Moloney,
*Exponential time-differencing with embedded Runge-Kutta adaptive step control*,
J. Comput. Phys. **280**, 579-601 (2015).
"""

import logging
from typing import Callable, Union, Literal
import numpy as np
from scipy.linalg import expm
from .solveras import SolverConfig, BaseSolverAS


# ======================================================================
# Diagonal operator strategy
# ======================================================================
class _If34Diagonal:
    r"""
    IF(3,4) diagonal strategy.

    Optimized implementation for diagonal linear operators using
    element-wise exponentials. Solves

    .. math::
        \frac{d\mathbf{U}}{dt} = \Lambda \mathbf{U} + \mathcal{N}(\mathbf{U}),

    where :math:`\Lambda` is diagonal.
    """

    def __init__(
        self,
        lin_op: np.ndarray,
        nl_func: Callable[[np.ndarray], np.ndarray],
        logger: logging.Logger = logging.getLogger(__name__),
    ) -> None:
        """Initialize the diagonal IF(3,4) system strategy."""
        self.lin_op = lin_op
        self.nl_func = nl_func
        self.logger = logger

        n = lin_op.shape[0]
        self._EL, self._EL2 = [np.zeros(n, dtype=np.complex128) for _ in range(2)]
        self._NL1, self._NL2, self._NL3, self._NL4, self._NL5 = [np.zeros(n, dtype=np.complex128) for _ in range(5)]
        self._k = np.zeros(n, dtype=np.complex128)
        self._err = np.zeros(n, dtype=np.complex128)

    def update_coeffs(self, h: float) -> None:
        r"""
        Update exponential coefficients for the given step size.

        .. math::

            z = h \Lambda, \qquad
            E_L = e^{z}, \qquad
            E_{L/2} = e^{z/2}.
        """
        z = h * self.lin_op
        self._EL = np.exp(z)
        self._EL2 = np.exp(z / 2)

    def n1_init(self, u: np.ndarray) -> None:
        r"""Initialize the first nonlinear evaluation :math:`\mathcal{N}_1 = \mathcal{N}(\mathbf{u}_n)`."""
        self._NL1 = self.nl_func(u)

    def update_stages(self, u: np.ndarray, h: float, accept: bool) -> tuple[np.ndarray, np.ndarray]:
        r"""
        Compute the IF(3,4) Runge–Kutta stages and error estimate.

        Parameters
        ----------
        u : np.ndarray
            Current state :math:`\mathbf{u}_n`.
        h : float
            Step size.
        accept : bool
            Whether the previous step was accepted (FSAL reuse).

        Returns
        -------
        tuple[np.ndarray, np.ndarray]
            Next state and local error estimate.
        """
        if accept:
            self._NL1 = self._NL5.copy()

        self._k = self._EL2 * u + h * self._EL2 * self._NL1 / 2.0
        self._NL2 = self.nl_func(self._k)
        self._k = self._EL2 * u + h * self._NL2 / 2.0
        self._NL3 = self.nl_func(self._k)
        self._k = self._EL * u + h * self._EL2 * self._NL3
        self._NL4 = self.nl_func(self._k)
        self._k = self._EL * u + h * (
            self._EL * self._NL1 / 6.0 + self._EL2 * self._NL2 / 3.0 + self._EL2 * self._NL3 / 3.0 + self._NL4 / 6.0
        )
        self._NL5 = self.nl_func(self._k)
        self._err = h * (self._NL4 - self._NL5) / 6.0
        return self._k, self._err


# ======================================================================
# Diagonalized via eigen-decomposition
# ======================================================================
class _If34Diagonalized(_If34Diagonal):
    r"""
    IF(3,4) strategy for diagonalizable linear systems.

    Performs eigenvalue decomposition

    .. math::
        \mathcal{L} = S \Lambda S^{-1},

    then evolves the system in the diagonal eigenbasis for efficiency.
    """

    def __init__(
        self,
        lin_op: np.ndarray,
        nl_func: Callable[[np.ndarray], np.ndarray],
        logger: logging.Logger = logging.getLogger(__name__),
    ) -> None:
        """Initialize eigen-basis transformation for the linear operator."""
        super().__init__(lin_op, nl_func, logger)
        if len(lin_op.shape) == 1:
            raise ValueError("Cannot diagonalize a 1D system")
        cond = np.linalg.cond(lin_op)
        if cond > 1e16:
            raise ValueError("Linear operator is near-singular and cannot be diagonalized")
        if cond > 1000:
            self.logger.warning(f"High condition number ({cond:.2e}); diagonalization may be unstable")
        self._eig_vals, self._S = np.linalg.eig(lin_op)
        self._Sinv = np.linalg.inv(self._S)
        self._v = np.zeros(lin_op.shape[0])

    def update_coeffs(self, h: float) -> None:
        """Update exponentials based on eigenvalues."""
        z = h * self._eig_vals
        self._EL = np.exp(z)
        self._EL2 = np.exp(z / 2)

    def n1_init(self, u: np.ndarray) -> None:
        """Initialize in the diagonalized (eigen) basis."""
        self._NL1 = self._Sinv.dot(self.nl_func(u))
        self._v = self._Sinv.dot(u)

    def update_stages(self, u: np.ndarray, h: float, accept: bool) -> tuple[np.ndarray, np.ndarray]:
        """Compute stages and return next state + error in transformed basis."""
        if accept:
            self._NL1 = self._NL5.copy()
            self._v = self._Sinv.dot(u)

        self._k = self._EL2 * self._v + h * self._EL2 * self._NL1 / 2.0
        self._NL2 = self._Sinv.dot(self.nl_func(self._S.dot(self._k)))
        self._k = self._EL2 * self._v + h * self._NL2 / 2.0
        self._NL3 = self._Sinv.dot(self.nl_func(self._S.dot(self._k)))
        self._k = self._EL * self._v + h * self._EL2 * self._NL3
        self._NL4 = self._Sinv.dot(self.nl_func(self._S.dot(self._k)))
        self._k = self._EL * self._v + h * (
            self._EL * self._NL1 / 6.0 + self._EL2 * self._NL2 / 3.0 + self._EL2 * self._NL3 / 3.0 + self._NL4 / 6.0
        )
        self._NL5 = self._Sinv.dot(self.nl_func(self._S.dot(self._k)))
        self._err = h * (self._NL4 - self._NL5) / 6.0
        return self._S.dot(self._k), self._err


# ======================================================================
# Full matrix exponential (non-diagonal)
# ======================================================================
class _If34NonDiagonal:
    r"""
    IF(3,4) strategy for full (non-diagonalizable) linear operators.

    Uses direct matrix exponentials:

    .. math::

        E_L = e^{h\mathcal{L}}, \qquad
        E_{L/2} = e^{h\mathcal{L}/2}.
    """

    def __init__(self, lin_op: np.ndarray, nl_func: Callable[[np.ndarray], np.ndarray]) -> None:
        """Initialize the general IF(3,4) solver strategy."""
        self.lin_op = lin_op
        self.nl_func = nl_func

        n = lin_op.shape[0]
        self._EL, self._EL2 = [np.zeros(shape=lin_op.shape, dtype=np.complex128) for _ in range(2)]
        self._NL1, self._NL2, self._NL3, self._NL4, self._NL5 = [np.zeros(n, dtype=np.complex128) for _ in range(5)]
        self._k = np.zeros(n, dtype=np.complex128)
        self._err = np.zeros(n, dtype=np.complex128)

    def update_coeffs(self, h: float) -> None:
        """Update matrix exponentials for current step size."""
        z = h * self.lin_op
        self._EL = expm(z)
        self._EL2 = expm(z / 2)

    def n1_init(self, u: np.ndarray) -> None:
        """Initialize nonlinear evaluation."""
        self._NL1 = self.nl_func(u)

    def update_stages(self, u: np.ndarray, h: float, accept: bool) -> tuple[np.ndarray, np.ndarray]:
        """Compute IF(3,4) stages and error using full matrix operations."""
        if accept:
            self._NL1 = self._NL5.copy()

        self._k = self._EL2.dot(u) + h * self._EL2.dot(self._NL1 / 2.0)
        self._NL2 = self.nl_func(self._k)
        self._k = self._EL2.dot(u) + h * self._NL2 / 2.0
        self._NL3 = self.nl_func(self._k)
        self._k = self._EL.dot(u) + h * self._EL2.dot(self._NL3)
        self._NL4 = self.nl_func(self._k)
        self._k = self._EL.dot(u) + h * (
            self._EL.dot(self._NL1 / 6.0)
            + self._EL2.dot(self._NL2 / 3.0)
            + self._EL2.dot(self._NL3 / 3.0)
            + self._NL4 / 6.0
        )
        self._NL5 = self.nl_func(self._k)
        self._err = h * (self._NL4 - self._NL5) / 6.0
        return self._k, self._err


# ======================================================================
# Public adaptive solver
# ======================================================================
[docs] class IF34(BaseSolverAS): r""" Adaptive Integrating-Factor 4(3) solver. Fourth-order integrating factor Runge–Kutta scheme with an embedded third-order pair for local error control. Parameters ---------- lin_op : np.ndarray Linear operator :math:`\mathcal{L}`. nl_func : Callable[[np.ndarray], np.ndarray] Nonlinear function :math:`\mathcal{N}(\mathbf{U})`. config : SolverConfig, optional Adaptive step configuration. diagonalize : bool, default=False Attempt eigenvalue diagonalization if linear operator is 2D. loglevel : str or int, default='WARNING' Logging verbosity. Notes ----- - Implements **FSAL** (First Same As Last) reuse for efficiency. - Coefficients are recomputed only when step size changes. - Supports diagonal, diagonalizable, and full-matrix systems. """ _method: Union[_If34Diagonal, _If34Diagonalized, _If34NonDiagonal]
[docs] def __init__( self, lin_op: np.ndarray, nl_func: Callable[[np.ndarray], np.ndarray], config: SolverConfig = SolverConfig(), diagonalize: bool = False, loglevel: Union[Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], int] = "WARNING", ) -> None: """Initialize the IF(3,4) adaptive solver.""" super().__init__(lin_op, nl_func, config=config, loglevel=loglevel) self._method = ( _If34Diagonal(lin_op, nl_func, self.logger) if self._diag else (_If34Diagonalized(lin_op, nl_func, self.logger) if diagonalize else _If34NonDiagonal(lin_op, nl_func)) ) self.__n1_init = False self._h_coeff = None self._accept = False
def _reset(self) -> None: """Reset solver state (reinitializes stage storage and flags).""" self.__n1_init = False self._h_coeff = None self._accept = False def _update_coeffs(self, h: float) -> None: """Recompute coefficients if step size changed.""" if h != self._h_coeff: self._h_coeff = h self._method.update_coeffs(h) self.logger.debug("IF34 coefficients updated for step size h=%s", h) def _update_stages(self, u: np.ndarray, h: float) -> tuple[np.ndarray, np.ndarray]: r""" Perform one adaptive IF(3,4) integration step. Returns ------- tuple[np.ndarray, np.ndarray] Next state and local error vector :math:`\mathbf{e}_n`. """ self._update_coeffs(h) if not self.__n1_init: self._method.n1_init(u) self.__n1_init = True return self._method.update_stages(u, h, self._accept) def _q(self) -> int: """Order of method used for adaptive control (4).""" return 4