r"""
Grid construction utilities for spectral and transform methods
==============================================================
This module provides functions for constructing **spatial** and **spectral**
grids used in Fourier, Chebyshev, and Hankel spectral methods.
These grids form the basis of spatial discretizations for PDE solvers
such as those implemented in :mod:`rkstiff`.
Overview
--------
Each function constructs either:
- A **uniform grid** for FFT-based spectral differentiation.
- A **Chebyshev–Gauss–Lobatto** grid for non-periodic domains.
- A **radial grid** for axisymmetric problems using the Hankel transform.
All grids and wavenumber sets are compatible with NumPy and SciPy FFTs
and can be directly used with the derivative utilities in
:mod:`rkstiff.derivatives`.
Contents
--------
- :func:`construct_x_kx_rfft` — Uniform grid for rFFT (real-valued FFT)
- :func:`construct_x_kx_fft` — Uniform grid for FFT (complex-valued)
- :func:`construct_x_cheb` — Chebyshev spatial grid
- :func:`construct_x_dx_cheb` — Chebyshev grid with differentiation matrix
- :func:`construct_r_kr_hankel` — Radial grid for Hankel transforms
- :func:`mirror_grid` — Symmetric grid reflection utility
"""
from typing import Optional, Tuple
import numpy as np
import scipy.special as sp # type: ignore
[docs]
def construct_x_kx_rfft(n: int, a: float = 0.0, b: float = 2 * np.pi) -> Tuple[np.ndarray, np.ndarray]:
r"""
Construct a uniform 1D spatial grid and *rFFT-compatible* wavenumber grid.
Parameters
----------
n : int
Number of grid points. Must be an even integer greater than 2.
a : float, optional
Left endpoint of the spatial domain. Default is ``0.0``.
b : float, optional
Right endpoint of the spatial domain. Default is ``2π``.
Returns
-------
x : np.ndarray
Uniform spatial grid with ``n`` points in :math:`[a, b)`.
kx : np.ndarray
Spectral wavenumber grid with ``n/2 + 1`` points (for use with rFFT).
Notes
-----
The grid spacing and wavenumbers are given by:
.. math::
\Delta x = \frac{b - a}{n}, \qquad
k_x = 2\pi \, \mathrm{rfftfreq}(n, \Delta x).
Examples
--------
>>> x, kx = construct_x_kx_rfft(128, a=0, b=2*np.pi)
>>> x.shape
(128,)
>>> kx.shape
(65,)
"""
if not isinstance(n, int):
raise TypeError("n must be an integer.")
if n <= 2 or (n % 2) != 0:
raise ValueError("n must be an even integer greater than 2.")
dx = (b - a) / n
x = np.arange(a, b, dx)
kx = 2 * np.pi * np.fft.rfftfreq(n, d=dx)
return x, kx
[docs]
def construct_x_kx_fft(n: int, a: float = 0.0, b: float = 2 * np.pi) -> Tuple[np.ndarray, np.ndarray]:
r"""
Construct a uniform 1D spatial grid and *FFT-compatible* wavenumber grid.
Parameters
----------
n : int
Number of grid points. Must be an even integer greater than 2.
a : float, optional
Left endpoint of the spatial domain. Default is ``0.0``.
b : float, optional
Right endpoint of the spatial domain. Default is ``2π``.
Returns
-------
x : np.ndarray
Uniform spatial grid with ``n`` points in :math:`[a, b)`.
kx : np.ndarray
Spectral wavenumber grid with ``n`` points (for use with FFT).
Notes
-----
The uniform grid and Fourier frequencies are defined by:
.. math::
\Delta x = \frac{b - a}{n}, \qquad
k_x = 2\pi \, \mathrm{fftfreq}(n, \Delta x).
Examples
--------
>>> x, kx = construct_x_kx_fft(128)
>>> x.shape, kx.shape
((128,), (128,))
"""
if not isinstance(n, int):
raise TypeError("n must be an integer.")
if n <= 2 or (n % 2) != 0:
raise ValueError("n must be an even integer greater than 2.")
dx = (b - a) / n
x = np.arange(a, b, dx)
kx = 2 * np.pi * np.fft.fftfreq(n, d=dx)
return x, kx
[docs]
def construct_x_cheb(n: int, a: float = -1.0, b: float = 1.0) -> np.ndarray:
r"""
Construct a 1D grid with Chebyshev–Gauss–Lobatto points.
Parameters
----------
n : int
Polynomial order (number of subintervals). The grid has ``n + 1`` points.
a, b : float, optional
Interval endpoints. Defaults are ``a = -1``, ``b = 1``.
Returns
-------
np.ndarray
Grid of ``n + 1`` Chebyshev points mapped to the interval :math:`[a, b]`.
Notes
-----
The Chebyshev points on :math:`[-1, 1]` are given by:
.. math::
x_j = \cos\!\left(\frac{j\pi}{n}\right),
\quad j = 0, 1, \dots, n,
which are then linearly mapped to :math:`[a, b]`.
Examples
--------
>>> x = construct_x_cheb(10, a=-1, b=1)
>>> len(x)
11
>>> x[0], x[-1]
(1.0, -1.0)
"""
if not isinstance(n, int):
raise TypeError("n must be an integer.")
if n < 2:
raise ValueError("n must be ≥ 2.")
x = np.polynomial.chebyshev.chebpts2(n + 1)
x = a + (b - a) * (x + 1) / 2.0
return x
[docs]
def construct_x_dx_cheb(n: int, a: float = -1, b: float = 1) -> Tuple[np.ndarray, np.ndarray]:
r"""
Construct Chebyshev–Gauss–Lobatto grid and its differentiation matrix.
Parameters
----------
n : int
Polynomial order (number of subintervals). Produces ``n + 1`` points.
a, b : float, optional
Interval endpoints. Defaults are ``a = -1``, ``b = 1``.
Returns
-------
x : np.ndarray
Chebyshev grid points on :math:`[a, b]`.
d_cheb_matrix : np.ndarray, shape (n+1, n+1)
Differentiation matrix :math:`D` such that ``D @ f ≈ df/dx``.
Notes
-----
The entries of the differentiation matrix are:
.. math::
D_{ij} =
\begin{cases}
\dfrac{c_i}{c_j (x_i - x_j)}, & i \neq j, \\
-\sum_{k \neq i} D_{ik}, & i = j,
\end{cases}
where :math:`c_i = 2(-1)^i` for endpoints and :math:`c_i = (-1)^i` otherwise.
The matrix achieves spectral accuracy for smooth functions
and satisfies the property :math:`D\mathbf{1} = 0`.
"""
x = construct_x_cheb(n, a, b)
c = np.r_[2, np.ones(n - 1), 2] * np.power(-1, np.arange(0, n + 1))
X = np.tile(x.reshape(n + 1, 1), (1, n + 1))
dX = X - X.T
d_cheb_matrix = np.outer(c, 1.0 / c) / (dX + np.eye(n + 1))
d_cheb_matrix = d_cheb_matrix - np.diag(d_cheb_matrix.sum(axis=1))
return x, d_cheb_matrix
[docs]
def construct_r_kr_hankel(nr: int, rmax: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
r"""
Construct Hankel transform grids for axisymmetric domains.
Parameters
----------
nr : int
Number of radial grid points (≥ 4).
rmax : float
Maximum radius of the domain.
Returns
-------
r : np.ndarray
Radial grid points in :math:`(0, r_\max]`.
kr : np.ndarray
Spectral grid points in wavenumber space.
bessel_zeros : np.ndarray
First ``nr`` zeros of :math:`J_0`.
jN : float
The ``(nr+1)``-th zero of :math:`J_0` used for normalization.
Notes
-----
The grid is defined from the zeros of the Bessel function :math:`J_0`:
.. math::
r_i = \frac{j_i}{j_{N+1}} \, r_\max, \qquad
k_i = \frac{j_i}{r_\max},
where :math:`j_i` are the zeros of :math:`J_0`.
"""
if nr < 4:
raise ValueError("nr must be ≥ 4.")
if rmax <= 0:
raise ValueError("rmax must be positive.")
bessel_zeros = sp.jn_zeros(0, nr + 1)
bessel_zeros, jN = bessel_zeros[:-1], bessel_zeros[-1]
r = bessel_zeros * rmax / jN
kr = bessel_zeros / rmax
return r, kr, bessel_zeros, jN
[docs]
def mirror_grid(
r: np.ndarray, u: Optional[np.ndarray] = None, axis: int = -1
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
r"""
Mirror a radial grid (and optional function) to produce a symmetric domain.
Parameters
----------
r : np.ndarray
Radial grid on :math:`[0, r_\max]`.
u : np.ndarray, optional
Function values at ``r``. If provided, mirrored values are returned.
axis : int, optional
Axis along which to mirror ``u``. Default is ``-1``.
* ``-1`` — Stack horizontally (for 1D)
* ``0`` — Stack vertically (rows)
* ``1`` — Stack horizontally (columns)
Returns
-------
rnew : np.ndarray
Mirrored grid on :math:`[-r_\max, r_\max]`.
unew : np.ndarray, optional
Mirrored function values (if ``u`` was provided).
Notes
-----
Useful for visualizing radially symmetric solutions or enforcing
symmetric boundary conditions.
Examples
--------
>>> r = np.array([0, 1, 2, 3])
>>> u = np.array([1, 2, 3, 4])
>>> rnew, unew = mirror_grid(r, u)
>>> rnew
array([-3, -2, -1, 0, 0, 1, 2, 3])
>>> unew
array([4, 3, 2, 1, 1, 2, 3, 4])
"""
rnew = np.hstack([-np.flipud(r), r])
if u is None:
return rnew, None
if axis == -1:
unew = np.hstack([np.flipud(u), u])
elif axis == 0:
unew = np.vstack([np.flipud(u), u])
elif axis == 1:
unew = np.hstack([np.fliplr(u), u])
else:
raise ValueError("axis must be -1, 0, or 1.")
return rnew, unew