From 0d79342388863567e8475540235c4646b330ae29 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 23 Jun 2026 15:39:56 -0400 Subject: [PATCH] Track gmres move to arraycontext --- examples/helmholtz-dirichlet.py | 2 +- examples/laplace-dirichlet-3d.py | 2 +- examples/laplace-dirichlet-simple.py | 2 +- experiments/cahn-hilliard.py | 2 +- experiments/find-photonic-mode-sk.py | 2 +- experiments/layerpot-coverage.py | 2 +- experiments/maxwell.py | 2 +- experiments/maxwell_sphere.py | 2 +- experiments/poisson.py | 2 +- experiments/stokes-2d-interior.py | 2 +- experiments/two-domain-helmholtz.py | 2 +- pytential/linalg/gmres.py | 410 +-------------------------- pytential/solve.py | 11 - pytential/symbolic/execution.py | 2 +- test/test_beltrami.py | 2 +- test/test_maxwell.py | 2 +- test/test_scalar_int_eq.py | 2 +- test/test_stokes.py | 2 +- test/test_tools.py | 27 -- 19 files changed, 24 insertions(+), 456 deletions(-) delete mode 100644 pytential/solve.py diff --git a/examples/helmholtz-dirichlet.py b/examples/helmholtz-dirichlet.py index f9283116f..8d8c7bd84 100644 --- a/examples/helmholtz-dirichlet.py +++ b/examples/helmholtz-dirichlet.py @@ -135,7 +135,7 @@ def u_incoming_func(x): bvp_rhs = bind(places, sqrt_w*sym.var("bc"))(actx, bc=bc) - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres gmres_result = gmres( bound_op.scipy_op(actx, sigma_sym.name, dtype=np.complex128, k=k), bvp_rhs, tol=1e-8, progress=True, diff --git a/examples/laplace-dirichlet-3d.py b/examples/laplace-dirichlet-3d.py index 3273b8cc8..a39cfc8e5 100644 --- a/examples/laplace-dirichlet-3d.py +++ b/examples/laplace-dirichlet-3d.py @@ -120,7 +120,7 @@ def u_incoming_func(x): bc = u_incoming_func(nodes) bvp_rhs = bind(places, sqrt_w*sym.var("bc"))(actx, bc=bc) - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres gmres_result = gmres( bound_op.scipy_op(actx, "sigma", dtype=np.float64), bvp_rhs, tol=1e-14, progress=True, diff --git a/examples/laplace-dirichlet-simple.py b/examples/laplace-dirichlet-simple.py index 572f6498d..a77ca5667 100644 --- a/examples/laplace-dirichlet-simple.py +++ b/examples/laplace-dirichlet-simple.py @@ -94,7 +94,7 @@ def main(mesh_name="starfish", visualize=False): nodes = actx.thaw(density_discr.nodes()) bvp_rhs = actx.np.sin(nodes[0]) - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres gmres_result = gmres( bound_op.scipy_op(actx, sigma_sym.name, dtype=np.float64), bvp_rhs, tol=1e-8, progress=True, diff --git a/experiments/cahn-hilliard.py b/experiments/cahn-hilliard.py index d1b697460..bc975c375 100644 --- a/experiments/cahn-hilliard.py +++ b/experiments/cahn-hilliard.py @@ -73,7 +73,7 @@ def g(xvec): -g(nodes), ]) - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres gmres_result = gmres( bound_op.scipy_op(queue, "sigma", dtype=np.complex128), bc, tol=1e-8, progress=True, diff --git a/experiments/find-photonic-mode-sk.py b/experiments/find-photonic-mode-sk.py index d85f8af81..b40442096 100644 --- a/experiments/find-photonic-mode-sk.py +++ b/experiments/find-photonic-mode-sk.py @@ -107,7 +107,7 @@ def find_mode(): bound_op = bind(qbx, op, auto_where="source") def muller_solve_func(ne): - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres gmres_result = gmres( bound_op.scipy_op(queue, "u", np.complex128, ne=ne, **base_context), diff --git a/experiments/layerpot-coverage.py b/experiments/layerpot-coverage.py index aad11e39d..2c7ad8eea 100644 --- a/experiments/layerpot-coverage.py +++ b/experiments/layerpot-coverage.py @@ -82,7 +82,7 @@ def reference_solu(rvec): op.operator(sym.var('sigma'))) rhs = bind(qbx.density_discr, op.prepare_rhs(sym.var("bc")))(queue, bc=bvals) - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres gmres_result = gmres( bound_op.scipy_op(queue, "sigma", dtype=np.float64), rhs, diff --git a/experiments/maxwell.py b/experiments/maxwell.py index 73cc9fe1b..8c4cb3496 100644 --- a/experiments/maxwell.py +++ b/experiments/maxwell.py @@ -210,7 +210,7 @@ def dipole3eall(x,y,z,sources,strengths,k): bound_op = bind(qbx, sym_operator) - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres if 0: gmres_result = gmres( bound_op.scipy_op(queue, "sigma", dtype=np.complex128, k=k), diff --git a/experiments/maxwell_sphere.py b/experiments/maxwell_sphere.py index 1c7220bd9..3668bfc49 100644 --- a/experiments/maxwell_sphere.py +++ b/experiments/maxwell_sphere.py @@ -208,7 +208,7 @@ def dipole3eall(x,y,z,sources,strengths,k): bound_op = bind(qbx, sym_operator) - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres if 1: gmres_result = gmres( bound_op.scipy_op(queue, "sigma", dtype=np.complex128, k=k), diff --git a/experiments/poisson.py b/experiments/poisson.py index ec216a144..363c6d244 100644 --- a/experiments/poisson.py +++ b/experiments/poisson.py @@ -248,7 +248,7 @@ def get_kernel(): bvp_rhs = bind(bdry_discr, op.prepare_rhs(sym.var("bc")))(queue, bc=bvp_bc) - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres gmres_result = gmres( bound_op.scipy_op(queue, "sigma", dtype=np.float64), bvp_rhs, tol=1e-14, progress=True, diff --git a/experiments/stokes-2d-interior.py b/experiments/stokes-2d-interior.py index f91a9bd88..74f55994c 100644 --- a/experiments/stokes-2d-interior.py +++ b/experiments/stokes-2d-interior.py @@ -126,7 +126,7 @@ def couette_soln(x, y, dp, h): # Get rhs vector bvp_rhs = bind(qbx, sqrt_w*sym.make_sym_vector("bc",dim))(queue, bc=bc) - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres gmres_result = gmres( bound_op.scipy_op(queue, "sigma", np.float64, mu=mu, normal=normal), bvp_rhs, tol=1e-9, progress=True, diff --git a/experiments/two-domain-helmholtz.py b/experiments/two-domain-helmholtz.py index 1562190e1..33fcb27b3 100644 --- a/experiments/two-domain-helmholtz.py +++ b/experiments/two-domain-helmholtz.py @@ -136,7 +136,7 @@ def main(): bvp_rhs[i_bc] *= sqrt_w - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres gmres_result = gmres( bound_pde_op.scipy_op(queue, "unknown", dtype=np.complex128, domains=[sym.DEFAULT_TARGET]*2, K0=K0, K1=K1), diff --git a/pytential/linalg/gmres.py b/pytential/linalg/gmres.py index 33b808dda..94aadcd08 100644 --- a/pytential/linalg/gmres.py +++ b/pytential/linalg/gmres.py @@ -22,412 +22,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from warnings import warn -__doc__ = """ -.. autofunction:: gmres +from arraycontext.linalg.solve import gmres -.. autoclass:: GMRESResult -.. autoexception:: GMRESError -.. autoclass:: ResidualPrinter -.. autoclass:: InnerProduct - :members: - :undoc-members: - :special-members: __call__ -.. autoclass:: CallableOperator - :members: - :undoc-members: - :special-members: __call__ -.. autoclass:: HasMatVec - :members: - :undoc-members: -""" - -from dataclasses import dataclass -from functools import partial -from typing import TYPE_CHECKING, Generic, Protocol - -import numpy as np - -from arraycontext import ArrayContext, ArrayOrContainerT -from pytools import T - - -if TYPE_CHECKING: - from collections.abc import Callable, Sequence - - -class InnerProduct(Protocol, Generic[T]): - """A :class:`~typing.Protocol` for the inner product used by :func:`gmres`.""" - - def __call__(self, a: T, b: T) -> T: ... - - -class CallableOperator(Protocol, Generic[T]): - """A :class:`~typing.Protocol` for the operator used by :func:`gmres`.""" - - @property - def shape(self) -> tuple[int, int]: ... - - def __call__(self, x: T) -> T: ... - - -class HasMatVec(Protocol, Generic[T]): - """A :class:`~typing.Protocol` for the operator used by :func:`gmres`.""" - - @property - def shape(self) -> tuple[int, int]: ... - - def matvec(self, x: T) -> T: ... - - -def structured_vdot(x: ArrayOrContainerT, y: ArrayOrContainerT, - array_context: ArrayContext | None = None) -> float: - """vdot() implementation that is aware of scalars and host or - PyOpenCL arrays. It also recurses down nested object arrays. - """ - - if type(x) is not type(y): - raise TypeError("'structured_vdot' entries have different types: " - f"{type(x).__name__} and {type(y).__name__}") - - from numbers import Number - if (isinstance(x, Number) - or (isinstance(x, np.ndarray) and x.dtype.char != "O")): - return np.vdot(x, y) - else: - if array_context is None: - raise ValueError("'array_context' is required for non-scalar inputs") - - # actx.np.vdot works on PyOpenCL arrays and arbitrarily nested - # array containers, so this should handle all remaining cases - r = array_context.to_numpy(array_context.np.vdot(x, y)) - if isinstance(r, np.ndarray) and r.shape == (): - r = r[()] - - return r - - -# {{{ gmres - -# Modified Python port of ./Apps/Acoustics/root/matlab/gmres_restart.m -# from hellskitchen. -# Necessary because SciPy gmres is not reentrant and thus does -# not allow recursive solves. - - -class GMRESError(RuntimeError): - pass - - -# {{{ main routine - -@dataclass(frozen=True) -class GMRESResult(Generic[T]): - """ - .. autoattribute:: solution - .. autoattribute:: residual_norms - .. autoattribute:: iteration_count - .. autoattribute:: success - .. autoattribute:: state - """ - - solution: T - residual_norms: Sequence[float] - iteration_count: int - success: bool - """A :class:`bool` indicating whether the iteration succeeded.""" - state: str - """A description of the outcome.""" - - -def _gmres( - A: CallableOperator[ArrayOrContainerT] | HasMatVec[ArrayOrContainerT], - b: ArrayOrContainerT, - restart: int | None = None, - tol: float | None = None, - x0: ArrayOrContainerT | None = None, - dot: InnerProduct[ArrayOrContainerT] | None = None, - maxiter: int | None = None, - hard_failure: bool | None = None, - require_monotonicity: bool = True, - no_progress_factor: float | None = None, - stall_iterations: int | None = None, - callback: Callable[[ArrayOrContainerT], None] | None = None - ) -> GMRESResult[ArrayOrContainerT]: - - # {{{ input processing - - n, _ = A.shape - if not callable(A): - a_call = A.matvec - else: - a_call = A - - if dot is None: - raise ValueError("'dot' not provided") - - if restart is None: - restart = min(n, 20) - - if tol is None: - tol = 1e-5 - - if maxiter is None: - maxiter = 2*n - - if hard_failure is None: - hard_failure = True - - if stall_iterations is None: - stall_iterations = 10 - - if no_progress_factor is None: - no_progress_factor = 1.25 - - # }}} - - def norm(x: ArrayOrContainerT) -> float: - return np.sqrt(abs(dot(x, x))) - - if x0 is None: - x: ArrayOrContainerT = 0*b - r = b - recalc_r = False - else: - x = x0 - del x0 - recalc_r = True - - Ae: list[ArrayOrContainerT] = [None]*restart - e: list[ArrayOrContainerT] = [None]*restart - - k = 0 - - norm_b = norm(b) - last_resid_norm = None - residual_norms: list[float] = [] - - iteration = 0 - for iteration in range(maxiter): - # restart if required - if k == restart: - k = 0 - orth_count = restart - else: - orth_count = k - - # recalculate residual every 10 steps - if recalc_r: - r = b - a_call(x) - - norm_r = norm(r) - residual_norms.append(norm_r) - - if callback is not None: - callback(r) - - if norm_r < tol*norm_b or norm_r == 0: - return GMRESResult( - solution=x, - residual_norms=residual_norms, - iteration_count=iteration, - success=True, - state="success") - if last_resid_norm is not None: - if norm_r > 1.25*last_resid_norm: - state = "non-monotonic residuals" - if require_monotonicity: - if hard_failure: - raise GMRESError(state) - else: - return GMRESResult( - solution=x, - residual_norms=residual_norms, - iteration_count=iteration, - success=False, - state=state) - else: - print("*** WARNING: non-monotonic residuals in GMRES") - - if (stall_iterations - and len(residual_norms) > stall_iterations - and norm_r > ( - residual_norms[-stall_iterations] - / no_progress_factor)): - - state = "stalled" - if hard_failure: - raise GMRESError(state) - else: - return GMRESResult( - solution=x, - residual_norms=residual_norms, - iteration_count=iteration, - success=False, - state=state) - - last_resid_norm = norm_r - - # initial new direction guess - w = a_call(r) - - # {{{ double-orthogonalize the new direction against preceding ones - - rp = r - - for _orth_trips in range(2): - for j in range(orth_count): - d = dot(Ae[j], w) - w = w - d * Ae[j] - rp = rp - d * e[j] - - # normalize - d = 1/norm(w) - w = d*w - rp = d*rp - - # }}} - - Ae[k] = w - e[k] = rp - - # update the residual and solution - d = dot(Ae[k], r) - - recalc_r = (iteration+1) % 10 == 0 - if not recalc_r: - r = r - d*Ae[k] - - x = x + d*e[k] - - k += 1 - - state = "max iterations" - if hard_failure: - raise GMRESError(state) - else: - return GMRESResult( - solution=x, - residual_norms=residual_norms, - iteration_count=iteration, - success=False, - state=state) - -# }}} - - -# {{{ progress reporting - -class ResidualPrinter(Generic[ArrayOrContainerT]): - count: int - inner_product: InnerProduct[ArrayOrContainerT] - - def __init__( - self, - inner_product: InnerProduct[ArrayOrContainerT] | None = None - ) -> None: - if inner_product is None: - inner_product = structured_vdot - - self.count = 0 - self.inner_product = inner_product - - def __call__(self, resid: ArrayOrContainerT | None) -> None: - import sys - if resid is not None: - norm = np.sqrt(self.inner_product(resid, resid)) - sys.stdout.write(f"IT {self.count:8d} {abs(norm):.8e}\n") - else: - sys.stdout.write(f"IT {self.count:8d}\n") - - self.count += 1 - sys.stdout.flush() - -# }}} - - -# {{{ entrypoint - -def gmres( - op: CallableOperator[ArrayOrContainerT] | HasMatVec[ArrayOrContainerT], - rhs: ArrayOrContainerT, - restart: int | None = None, - tol: float | None = None, - x0: ArrayOrContainerT | None = None, - inner_product: InnerProduct[ArrayOrContainerT] | None = None, - maxiter: int | None = None, - hard_failure: bool | None = None, - no_progress_factor: float | None = None, - stall_iterations: int | None = None, - callback: Callable[[ArrayOrContainerT], None] | None = None, - progress: bool = False, - require_monotonicity: bool = True) -> GMRESResult[ArrayOrContainerT]: - """Solve a linear system :math:`Ax = b` using GMRES with restarts. - - :arg op: a callable to evaluate :math:`A(x)`. - :arg rhs: the right hand side :math:`b`. - :arg restart: the maximum number of iteration after which GMRES algorithm - needs to be restarted - :arg tol: the required decrease in residual norm (relative to the *rhs*). - :arg x0: an initial guess for the iteration (a zero array is used by default). - :arg inner_product: a callable with an interface compatible with - :func:`numpy.vdot` that returns a host scalar. - :arg maxiter: the maximum number of iterations permitted. - :arg hard_failure: if *True*, raise :exc:`GMRESError` in case of failure. - :arg stall_iterations: number of iterations with residual decrease - below *no_progress_factor* indicates stall. Set to ``0`` to disable - stall detection. - """ - if inner_product is None: - from pytential.symbolic.execution import ( - _find_array_context_from_args_in_context, - ) - try: - actx = _find_array_context_from_args_in_context({ - "rhs": rhs, "x0": x0, - }, supplied_array_context=getattr(op, "array_context", None)) - except (ValueError, TypeError): - actx = None - - inner_product = partial(structured_vdot, array_context=actx) - - if callback is None: - if progress: - callback = ResidualPrinter(inner_product) - else: - callback = None - - return _gmres(op, rhs, restart=restart, tol=tol, x0=x0, - dot=inner_product, - maxiter=maxiter, hard_failure=hard_failure, - no_progress_factor=no_progress_factor, - stall_iterations=stall_iterations, callback=callback, - require_monotonicity=require_monotonicity) - - -# }}} - -# }}} - - -# {{{ direct solve - -def lu(op, rhs, show_spectrum=False): - import numpy.linalg as la - - from sumpy.tools import build_matrix - mat = build_matrix(op) - - print(f"condition number: {la.cond(mat)}") - if show_spectrum: - ev = la.eigvals(mat) - import matplotlib.pyplot as pt - pt.plot(ev.real, ev.imag, "o") - pt.show() +__all__ = [ + "gmres", +] - return la.solve(mat, rhs) -# }}} +warn("This module will go away in 2027. Use arraycontext.linalg.solve instead.", + DeprecationWarning, + stacklevel=2) # vim: fdm=marker diff --git a/pytential/solve.py b/pytential/solve.py deleted file mode 100644 index 146815a12..000000000 --- a/pytential/solve.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -from warnings import warn - -from pytential.linalg.gmres import * # noqa: F403 - - -warn( - "pytential.solve is deprecated and will be removed in 2023. Use " - "pytential.linalg.gmres instead.", DeprecationWarning, - stacklevel=1) diff --git a/pytential/symbolic/execution.py b/pytential/symbolic/execution.py index 25cdeb1f6..77763482b 100644 --- a/pytential/symbolic/execution.py +++ b/pytential/symbolic/execution.py @@ -289,7 +289,7 @@ def map_inverse(self, expr: pp.IterativeInverse): **{var_name: self.rec(var_expr) for var_name, var_expr in expr.extra_vars.items()}) - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres rhs = self.rec(expr.rhs) return gmres(scipy_op, rhs) diff --git a/test/test_beltrami.py b/test/test_beltrami.py index 7bfd9c8c0..41d613dbe 100644 --- a/test/test_beltrami.py +++ b/test/test_beltrami.py @@ -245,7 +245,7 @@ def test_beltrami_convergence( actx, b=solution.source(actx, density_discr), **solution.context) - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres result = gmres( scipy_op, rhs, x0=rhs, diff --git a/test/test_maxwell.py b/test/test_maxwell.py index a9f89d0ab..fc8998460 100644 --- a/test/test_maxwell.py +++ b/test/test_maxwell.py @@ -366,7 +366,7 @@ def eval_inc_field_at(places, source=None, target=None): "hard_failure": True, "stall_iterations": 50, "no_progress_factor": 1.05} - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres gmres_result = gmres( bound_j_op.scipy_op(actx, "jt", np.complex128, **knl_kwargs), j_rhs, **gmres_settings) diff --git a/test/test_scalar_int_eq.py b/test/test_scalar_int_eq.py index 74dabe2ee..645e93ec9 100644 --- a/test/test_scalar_int_eq.py +++ b/test/test_scalar_int_eq.py @@ -242,7 +242,7 @@ def run_int_eq_test(actx, from pytential.qbx import QBXTargetAssociationFailedError try: - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres gmres_result = gmres( bound_op.scipy_op(actx, "u", dtype, **case.knl_concrete_kwargs), rhs, diff --git a/test/test_stokes.py b/test/test_stokes.py index a5e7f2a89..03bd34772 100644 --- a/test/test_stokes.py +++ b/test/test_stokes.py @@ -214,7 +214,7 @@ def run_exterior_stokes(actx_factory, *, # {{{ solve - from pytential.linalg.gmres import gmres + from arraycontext.linalg.solve import gmres gmres_tol = 1.0e-9 result = gmres( bound_op.scipy_op(actx, "sigma", np.float64, **op_context), diff --git a/test/test_tools.py b/test/test_tools.py index b5c886ffe..1b5bf8e49 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -42,33 +42,6 @@ ]) -# {{{ test_gmres - -def test_gmres(): - rng = np.random.default_rng(seed=42) - - n = 200 - A = ( - n * (np.eye(n) + 2j * np.eye(n)) - + rng.normal(size=(n, n)) + 1j * rng.normal(size=(n, n))) - - true_sol = rng.normal(size=n) + 1j * rng.normal(size=n) - b = np.dot(A, true_sol) - - A_func = lambda x: np.dot(A, x) # noqa - A_func.shape = A.shape - A_func.dtype = A.dtype - - from pytential.linalg.gmres import ResidualPrinter, gmres - tol = 1e-6 - sol = gmres(A_func, b, callback=ResidualPrinter(), - maxiter=5*n, tol=tol).solution - - assert la.norm(true_sol - sol) / la.norm(sol) < tol - -# }}} - - # {{{ test_interpolatory_error_reporting def test_interpolatory_error_reporting(actx_factory: ArrayContextFactory):