Source code for rkstiff.solvercs

r"""
Base solver infrastructure for rkstiff constant step PDE solvers
================================================================

Constant-step solver infrastructure for stiff time-dependent systems.

This module defines :class:`BaseSolverCS`, a foundation for solvers that
advance semi-linear differential equations using *fixed* time steps:

.. math::

        \frac{\partial \mathbf{U}}{\partial t}
        = \mathcal{L}\mathbf{U}
        + \mathcal{N}(\mathbf{U}),

where :math:`\mathcal{L}` is a (possibly stiff) linear operator and
:math:`\mathcal{N}` is a nonlinear function. Subclasses such as ETD4 or ETD5
implement constant-step exponential Runge-Kutta methods based on this interface.
"""

from abc import abstractmethod
from typing import Callable, Union, Literal
import numpy as np
from .solver import BaseSolver
from .util.solver_type import SolverType


# ======================================================================
# Base Class
# ======================================================================
[docs] class BaseSolverCS(BaseSolver): r""" Abstract base class for constant-step stiff solvers. Provides the common structure for fixed-step exponential Runge-Kutta and related integrators. Derived classes implement stage updates and solver-specific initialization. The governing semi-linear equation is .. math:: \frac{\partial \mathbf{U}}{\partial t} = \mathcal{L}\mathbf{U} + \mathcal{N}(\mathbf{U}), where :math:`\mathcal{L}` represents the linear (stiff) operator and :math:`\mathcal{N}` the nonlinear component. Parameters ---------- lin_op : np.ndarray Linear operator :math:`\mathcal{L}`. Must be 1D (diagonal) or a 2D square matrix. nl_func : Callable[[np.ndarray], np.ndarray] Nonlinear function :math:`\mathcal{N}(\mathbf{U})`. loglevel : str or int, default='WARNING' Logging verbosity. Accepts string names (``'DEBUG'``, ``'INFO'``…) or numeric :mod:`logging` constants. Attributes ---------- lin_op : np.ndarray Linear operator passed at construction. nl_func : Callable[[np.ndarray], np.ndarray] Nonlinear function for the system. t : list[float] Time points recorded during evolution. u : list[np.ndarray] Solution vectors recorded during evolution. Raises ------ ValueError If ``lin_op`` is not 1D or a square 2D matrix. Notes ----- - Subclasses must implement :meth:`_reset` and :meth:`_update_stages`. - The step size remains constant throughout evolution. - For adaptive time-stepping, use :class:`rkstiff.solveras.BaseSolverAS`. References ---------- P. Whalen, M. Brio, and J. V. Moloney, *Exponential time-differencing with embedded Runge-Kutta adaptive step control*, J. Comput. Phys. 280 (2015), 579–601. """ # ------------------------------------------------------------------ # Initialization # ------------------------------------------------------------------
[docs] def __init__( self, lin_op: np.ndarray, nl_func: Callable[[np.ndarray], np.ndarray], loglevel: Union[Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], int] = "WARNING", ) -> None: """Initialize a constant-step solver and validate inputs.""" super().__init__(lin_op, nl_func, loglevel) self.__tf, self.__tc = 0, 0
@property def solver_type(self) -> SolverType: """ Return the solver type for constant-step solvers. Returns ------- SolverType Always returns ``SolverType.CONSTANT_STEP``. Examples -------- >>> from rkstiff.if4 import IF4 >>> solver = IF4(lin_op, nl_func) >>> solver.solver_type == SolverType.CONSTANT_STEP True """ return SolverType.CONSTANT_STEP # ------------------------------------------------------------------ # Reset logic # ------------------------------------------------------------------
[docs] def reset(self) -> None: """Reset solver state and clear stored time/solution data.""" self.logger.debug("Resetting constant-step solver state") self.t, self.u = [], [] self.__tf, self.__tc = 0, 0 self._reset()
@abstractmethod def _reset(self) -> None: """Clear subclass-specific caches or precomputed data.""" # ------------------------------------------------------------------ # Stage update interface # ------------------------------------------------------------------ @abstractmethod def _update_stages(self, u: np.ndarray, h: float) -> np.ndarray: r""" Advance one step of size :math:`h`. Subclasses must implement this method to perform the internal Runge-Kutta or ETD stage updates. Parameters ---------- u : np.ndarray Current solution vector :math:`\mathbf{u}_n`. h : float Constant time step size :math:`\Delta t`. Returns ------- np.ndarray Updated state :math:`\mathbf{u}_{n+1}`. """ # ------------------------------------------------------------------ # Step routine # ------------------------------------------------------------------
[docs] def step(self, u: np.ndarray, h: float) -> np.ndarray: r""" Perform a single constant-step propagation. Parameters ---------- u : np.ndarray Current solution vector. h : float Constant step size (must be non-negative). Returns ------- np.ndarray Updated solution after one full time step. Notes ----- This method simply wraps :meth:`_update_stages` and performs minimal validation and logging. """ assert h >= 0.0 self.logger.debug("Executing constant step with h=%s", h) return self._update_stages(u, h)
# ------------------------------------------------------------------ # Evolution loop # ------------------------------------------------------------------
[docs] def evolve( self, u: np.ndarray, t0: float, tf: float, h: float, store_data: bool = True, store_freq: int = 1, ) -> np.ndarray: r""" Integrate the system from :math:`t_0` to :math:`t_f` using fixed step size. Repeatedly applies :meth:`step` with constant :math:`h` until the final time is reached. Parameters ---------- u : np.ndarray Initial solution vector at :math:`t_0`. t0 : float Initial time. tf : float Final time (integration stops when ``t ≥ tf``). h : float Constant step size :math:`\Delta t`. store_data : bool, default=True Whether to store intermediate results in :attr:`t` and :attr:`u`. store_freq : int, default=1 Frequency of storing data; every ``store_freq`` steps. Returns ------- np.ndarray Final solution vector at :math:`t_f`. Raises ------ ValueError If ``h`` exceeds the total time span ``tf - t0``. Notes ----- - The time grid is uniformly spaced with spacing :math:`h`. - Stored data can be accessed through :attr:`t` and :attr:`u`. Example ------- >>> solver = ETD4(lin_op, nl_func) >>> u_final = solver.evolve(u0, t0=0.0, tf=10.0, h=0.05) >>> len(solver.t) 200 """ self.reset() self.logger.info("Starting constant-step evolution from t=%s to t=%s", t0, tf) self.__tf, self.__tc = tf, t0 if store_data: self.t.append(t0) self.u.append(u) # Ensure step size is valid if self.__tc + h > self.__tf: raise ValueError("Step size h must be <= (tf - t0); reduce h or extend tf.") self.logger.debug("Step size h=%s, store_freq=%s", h, store_freq) step_count = 0 while self.__tc < self.__tf: u = self.step(u, h) self.__tc += h step_count += 1 if step_count % 100 == 0: self.logger.info( "Progress: t=%.6f/%.6f (%.1f%%), steps=%d", self.__tc, self.__tf, 100 * self.__tc / self.__tf, step_count, ) if store_data and (step_count % store_freq == 0): self.t.append(self.__tc) self.u.append(u) self.logger.debug("Stored snapshot at t=%.6f (step %d)", self.__tc, step_count) self.logger.info("Evolution complete after %d steps", step_count) self.logger.info("Stored %d snapshots", len(self.u)) return u