From af1c9df8942c9e14c951d9f3909f7f202893b255 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Tue, 16 Jun 2026 04:32:06 -0400 Subject: [PATCH 1/5] Transition uxinterpolators to vectorized .isel() to support lazy loading --- src/parcels/interpolators/_uxinterpolators.py | 91 +++++++++++++++---- 1 file changed, 73 insertions(+), 18 deletions(-) diff --git a/src/parcels/interpolators/_uxinterpolators.py b/src/parcels/interpolators/_uxinterpolators.py index cb2c7aa24..b4c232ebe 100644 --- a/src/parcels/interpolators/_uxinterpolators.py +++ b/src/parcels/interpolators/_uxinterpolators.py @@ -5,6 +5,8 @@ from typing import TYPE_CHECKING import numpy as np +import xarray as xr +from dask import is_dask_collection if TYPE_CHECKING: from parcels._core.field import Field, VectorField @@ -17,9 +19,21 @@ def UxConstantFaceConstantZC( field: Field, ): """Piecewise constant interpolation kernel for face registered data that is vertically centered (on zc points)""" - return field.data.values[ + # Broadcast the per-axis indices to a common (npart,) shape (``ti`` may be scalar for time-constant fields) + ti, zi, fi = np.broadcast_arrays( grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"] - ] + ) + + # Index lazily (one pointwise ``isel`` per particle) instead of materializing the whole field + # with ``.values``, so dask-backed (out-of-core) fields are never fully read into memory. + tdim, zdim, fdim = field.data.dims + selection_dict = { + tdim: xr.DataArray(ti, dims="points"), + zdim: xr.DataArray(zi, dims="points"), + fdim: xr.DataArray(fi, dims="points"), + } + value = field.data.isel(selection_dict, ignore_grid=True).data + return value.compute() if is_dask_collection(value) else value def UxConstantFaceLinearZF( @@ -31,15 +45,28 @@ def UxConstantFaceLinearZF( Piecewise constant interpolation (lateral) with linear vertical interpolation kernel for face registered data that is located at vertical interface levels (on zf points) """ - ti = grid_positions["T"]["index"] - zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"] + ti, zi, fi = np.broadcast_arrays( + grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"] + ) z = particle_positions["z"] + tdim, zdim, fdim = field.data.dims + + def _zsample(z_index): + """Pointwise ``isel`` of the face values at a single vertical interface level.""" + selection_dict = { + tdim: xr.DataArray(ti, dims="points"), + zdim: xr.DataArray(z_index, dims="points"), + fdim: xr.DataArray(fi, dims="points"), + } + value = field.data.isel(selection_dict, ignore_grid=True).data + return value.compute() if is_dask_collection(value) else value + # The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels. # For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1. # First, do barycentric interpolation in the lateral direction for each interface level - fzk = field.data.values[ti, zi, fi] - fzkp1 = field.data.values[ti, zi + 1, fi] + fzk = _zsample(zi) + fzkp1 = _zsample(zi + 1) # Then, do piecewise linear interpolation in the vertical direction zk = field.grid.z.values[zi] @@ -57,13 +84,25 @@ def UxLinearNodeConstantZC( Effectively, it applies barycentric interpolation in the lateral direction and piecewise constant interpolation in the vertical direction. """ - ti = grid_positions["T"]["index"] - zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"] - bcoords = grid_positions["FACE"]["bcoord"] + ti, zi, fi = np.broadcast_arrays( + grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"] + ) + bcoords = xr.DataArray(grid_positions["FACE"]["bcoord"], dims=("points", "nodes")) node_ids = field.grid.uxgrid.face_node_connectivity[fi, :].values - return np.sum( - field.data.values[ti[:, None], zi[:, None], node_ids] * bcoords, axis=-1 - ) # Linear interpolation in the vertical direction + + # The node ids form a (npart, n_nodes_per_face) block; index it pointwise with ``isel`` so only the + # touched nodes are read (rather than materializing the whole field with ``.values``). + tdim, zdim, ndim = field.data.dims + selection_dict = { + tdim: xr.DataArray(ti, dims="points"), + zdim: xr.DataArray(zi, dims="points"), + ndim: xr.DataArray(node_ids, dims=("points", "nodes")), + } + # Reduce over the "nodes" dimension by name (rather than a positional axis) so the result does + # not depend on the dimension order that vectorized ``isel`` happens to return. + node_data = field.data.isel(selection_dict, ignore_grid=True) + value = (node_data * bcoords).sum("nodes").data # Barycentric interpolation in the lateral direction + return value.compute() if is_dask_collection(value) else value def UxLinearNodeLinearZF( @@ -76,21 +115,37 @@ def UxLinearNodeLinearZF( Effectively, it applies barycentric interpolation in the lateral direction and piecewise linear interpolation in the vertical direction. """ - ti = grid_positions["T"]["index"] - zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"] + ti, zi, fi = np.broadcast_arrays( + grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"] + ) z = particle_positions["z"] - bcoords = grid_positions["FACE"]["bcoord"] + bcoords = xr.DataArray(grid_positions["FACE"]["bcoord"], dims=("points", "nodes")) node_ids = field.grid.uxgrid.face_node_connectivity[fi, :].values + + tdim, zdim, ndim = field.data.dims + + def _zsample(z_index): + """Barycentric (lateral) interpolation of the node values at a single vertical interface level.""" + selection_dict = { + tdim: xr.DataArray(ti, dims="points"), + zdim: xr.DataArray(z_index, dims="points"), + ndim: xr.DataArray(node_ids, dims=("points", "nodes")), + } + # Reduce over the "nodes" dimension by name so the result is independent of ``isel`` dim order. + node_data = field.data.isel(selection_dict, ignore_grid=True) + return (node_data * bcoords).sum("nodes").data + # The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels. # For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1. # First, do barycentric interpolation in the lateral direction for each interface level - fzk = np.sum(field.data.values[ti[:, None], zi[:, None], node_ids] * bcoords, axis=-1) - fzkp1 = np.sum(field.data.values[ti[:, None], zi[:, None] + 1, node_ids] * bcoords, axis=-1) + fzk = _zsample(zi) + fzkp1 = _zsample(zi + 1) # Then, do piecewise linear interpolation in the vertical direction zk = field.grid.z.values[zi] zkp1 = field.grid.z.values[zi + 1] - return (fzk * (zkp1 - z) + fzkp1 * (z - zk)) / (zkp1 - zk) # Linear interpolation in the vertical direction + value = (fzk * (zkp1 - z) + fzkp1 * (z - zk)) / (zkp1 - zk) # Linear interpolation in the vertical direction + return value.compute() if is_dask_collection(value) else value def Ux_Velocity( From 4e8190d484764aabcf7ea47403871bcb86772ac2 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Tue, 16 Jun 2026 04:41:23 -0400 Subject: [PATCH 2/5] Remove uninformative developer commments --- src/parcels/interpolators/_uxinterpolators.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/parcels/interpolators/_uxinterpolators.py b/src/parcels/interpolators/_uxinterpolators.py index b4c232ebe..f434ac96e 100644 --- a/src/parcels/interpolators/_uxinterpolators.py +++ b/src/parcels/interpolators/_uxinterpolators.py @@ -24,8 +24,6 @@ def UxConstantFaceConstantZC( grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"] ) - # Index lazily (one pointwise ``isel`` per particle) instead of materializing the whole field - # with ``.values``, so dask-backed (out-of-core) fields are never fully read into memory. tdim, zdim, fdim = field.data.dims selection_dict = { tdim: xr.DataArray(ti, dims="points"), @@ -90,16 +88,13 @@ def UxLinearNodeConstantZC( bcoords = xr.DataArray(grid_positions["FACE"]["bcoord"], dims=("points", "nodes")) node_ids = field.grid.uxgrid.face_node_connectivity[fi, :].values - # The node ids form a (npart, n_nodes_per_face) block; index it pointwise with ``isel`` so only the - # touched nodes are read (rather than materializing the whole field with ``.values``). tdim, zdim, ndim = field.data.dims selection_dict = { tdim: xr.DataArray(ti, dims="points"), zdim: xr.DataArray(zi, dims="points"), ndim: xr.DataArray(node_ids, dims=("points", "nodes")), } - # Reduce over the "nodes" dimension by name (rather than a positional axis) so the result does - # not depend on the dimension order that vectorized ``isel`` happens to return. + node_data = field.data.isel(selection_dict, ignore_grid=True) value = (node_data * bcoords).sum("nodes").data # Barycentric interpolation in the lateral direction return value.compute() if is_dask_collection(value) else value From 113909179deb4b76c460b3e2b8d429aa2320740e Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Tue, 16 Jun 2026 08:58:06 -0400 Subject: [PATCH 3/5] Add WindowedArray class, fieldset.to_windowed_arrays, and basic tests --- src/parcels/_core/_windowed_array.py | 98 ++++++++++++++++++++++++++++ src/parcels/_core/fieldset.py | 35 ++++++++++ tests/test_windowed_array.py | 46 +++++++++++++ 3 files changed, 179 insertions(+) create mode 100644 src/parcels/_core/_windowed_array.py create mode 100644 tests/test_windowed_array.py diff --git a/src/parcels/_core/_windowed_array.py b/src/parcels/_core/_windowed_array.py new file mode 100644 index 000000000..a70b7c905 --- /dev/null +++ b/src/parcels/_core/_windowed_array.py @@ -0,0 +1,98 @@ +"""Transparent rolling time-window cache for lazy (dask-backed) field data. + +Assumptions / current limits: + * ``time`` is the leading dimension of the field (true for both the SGRID and + UGRID ingestion paths; the structured path transposes to ``(time, ...)``). + * Valid while the requested time indices stay within the resident window + (i.e. all particles share the clock). A sample that requests time indices + spanning more than the retained levels would force reloads. +""" + +from __future__ import annotations + +import numpy as np +import xarray as xr +from dask import is_dask_collection + +# xarray / uxarray ``isel`` keyword arguments that are NOT dimension indexers. +_NON_INDEXER_KWARGS = frozenset({"drop", "missing_dims", "ignore_grid"}) + + +class WindowedArray: + """Wrap a lazy DataArray so ``isel`` loads/caches/evicts time levels as NumPy.""" + + def __init__(self, data: xr.DataArray, time_dim: str = "time", max_levels: int | None = None): + if data.dims[0] != time_dim: + raise ValueError(f"WindowedArray expects {time_dim!r} as the leading dimension, got {data.dims}") + self._data = data + self._tdim = time_dim + self._cache: dict[int, np.ndarray] = {} # time index -> NumPy slab (remaining dims) + self._max = max_levels + # diagnostics + self.loads = 0 + self.bytes_read = 0 + self._slab_bytes = int(np.prod(data.isel({time_dim: 0}).shape)) * data.dtype.itemsize + + # -- transparency: forward everything we don't override ------------------- + def __getattr__(self, name): + # __getattr__ only fires for misses; reach _data without recursing. + return getattr(object.__getattribute__(self, "_data"), name) + + def __repr__(self): + return (f"WindowedArray(time_dim={self._tdim!r}, cached_levels={sorted(self._cache)}, " + f"loads={self.loads})\n{self._data!r}") + + # -- window management ---------------------------------------------------- + def _read_level(self, lvl: int) -> np.ndarray: + """Bulk, sequential read of one time level into NumPy (the dask->NumPy step).""" + return np.asarray(self._data.isel({self._tdim: int(lvl)}).values) + + def _ensure(self, levels: np.ndarray) -> None: + for lvl in levels: + lvl = int(lvl) + if lvl not in self._cache: + self._cache[lvl] = self._read_level(lvl) + self.loads += 1 + self.bytes_read += self._slab_bytes + # retire stale levels (the clock only moves forward across the window) + lo = int(np.min(levels)) + for old in [k for k in self._cache if k < lo]: + del self._cache[old] + if self._max is not None and len(self._cache) > self._max: + for old in sorted(self._cache)[: len(self._cache) - self._max]: + del self._cache[old] + + # -- intercepted indexing ------------------------------------------------- + def isel(self, indexers: dict | None = None, **kwargs): + sel = dict(indexers) if indexers is not None else {} + sel.update({k: v for k, v in kwargs.items() if k not in _NON_INDEXER_KWARGS}) + + if self._tdim not in sel: + # no time selection -> nothing to window; preserve control kwargs + return self._data.isel(indexers, **kwargs) + + t_ind = sel[self._tdim] + t_vals = np.asarray(t_ind.values if isinstance(t_ind, xr.DataArray) else t_ind) + levels = np.unique(t_vals) + self._ensure(levels) + + # stack the resident levels into one small NumPy block; remap to local indices + block = np.stack([self._cache[int(l)] for l in levels]) # (nlevels, *rest) + nda = xr.DataArray(block, dims=self._data.dims) # NumPy-backed, original dim order + local = np.searchsorted(levels, t_vals) + sel[self._tdim] = xr.DataArray(local, dims=getattr(t_ind, "dims", ())) + return nda.isel(sel) # plain vectorised gather in NumPy (no ignore_grid needed) + + +def maybe_windowed(data: xr.DataArray, max_levels: int | None = None): + """Wrap dask-backed, field data in a ``WindowedArray``; else pass through. + + NumPy-backed fields (already resident) and fields without a leading ``time`` + dimension are returned unchanged, so existing eager workflows are unaffected. + Already-wrapped data is returned unchanged. + """ + if isinstance(data, WindowedArray): + return data + if is_dask_collection(data.data): + return WindowedArray(data, max_levels=max_levels) + return data diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index 4844ff210..d4015a5e5 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -11,6 +11,7 @@ import xgcm import parcels._sgrid as sgrid +from parcels._core._windowed_array import maybe_windowed from parcels._core.field import Field, VectorField from parcels._core.utils.string import _assert_str_and_python_varname from parcels._core.utils.time import get_datetime_type_calendar @@ -126,6 +127,40 @@ def add_field(self, field: Field, name: str | None = None): self.fields[name] = field + def to_windowed_arrays(self, *, max_levels: int | None = None): + """Wrap dask-backed field data in a rolling time-window cache. + + Opt-in optimization for forward-marching simulations where all particles + share a single clock. For each dask-backed, time-leading field, ``isel`` + then samples a resident NumPy window (each time level loaded once and + evicted as the clock advances) instead of re-reading chunks and paying the + dask scheduling overhead on every kernel step. The wrapper is transparent: + it forwards all attributes, intercepts only ``isel``, and returns + NumPy-backed results so the interpolators' ``is_dask_collection()`` guards + skip ``.compute()``. + + NumPy-backed (eager) fields and fields without a leading ``time`` dimension + are left unchanged, and re-wrapping is idempotent, so this is safe to call + more than once. + + Parameters + ---------- + max_levels : int, optional + Cap on the number of time levels kept resident per field. ``None`` + (default) retains every level that the advancing clock still brackets. + + Returns + ------- + FieldSet + ``self``, to allow chaining. + """ + for field in self.fields.values(): + components = (field.U, field.V, field.W) if isinstance(field, VectorField) else (field,) + for component in components: + if component is not None: + component.data = maybe_windowed(component.data, max_levels=max_levels) + return self + def add_constant_field(self, name: str, value, mesh: Mesh = "spherical"): """Wrapper function to add a Field that is constant in space, useful e.g. when using constant horizontal diffusivity diff --git a/tests/test_windowed_array.py b/tests/test_windowed_array.py new file mode 100644 index 000000000..8c5f3f3c5 --- /dev/null +++ b/tests/test_windowed_array.py @@ -0,0 +1,46 @@ +"""Tests for the transparent rolling time-window cache (WindowedArray).""" + +import dask.array as da +import numpy as np +import pytest +import xarray as xr + +from parcels import FieldSet, ParticleSet +from parcels._core._windowed_array import WindowedArray, maybe_windowed +from parcels._datasets.structured.generated import simple_UV_dataset +from parcels.kernels import AdvectionRK2 + + +def test_to_windowed_arrays_is_idempotent_and_forwards_max_levels(): + ds = simple_UV_dataset(mesh="flat") + fs = FieldSet.from_sgrid_conventions(ds.chunk({"time": 1}), mesh="flat") + + fs.to_windowed_arrays(max_levels=3) + first = fs.U.data + assert isinstance(first, WindowedArray) + assert first._max == 3 + + # re-wrapping returns the same object (idempotent), not a nested wrapper + fs.to_windowed_arrays(max_levels=3) + assert fs.U.data is first + +@pytest.mark.parametrize("mesh", ["flat", "spherical"]) +def test_dask_advection_matches_numpy(mesh): + """An identical advection must give identical trajectories whether the field + is numpy-backed or dask-backed (windowed).""" + ds = simple_UV_dataset(mesh=mesh) + ds["U"].data[:] = 1.0 # steady zonal flow -> in-bounds, deterministic + + def run(windowed): + d = ds.chunk({"time": 1}) if chunked else ds + fs = FieldSet.from_sgrid_conventions(d, mesh=mesh) + if windowed: + fs.to_windowed_arrays() + pset = ParticleSet(fs, lon=np.zeros(10), lat=np.linspace(-10, 10, 10)) + pset.execute(AdvectionRK2, runtime=7200, dt=np.timedelta64(15, "m")) + return np.array(pset.lon), np.array(pset.lat) + + lon_np, lat_np = run(False) + lon_dk, lat_dk = run(True) + np.testing.assert_allclose(lon_dk, lon_np, atol=1e-9) + np.testing.assert_allclose(lat_dk, lat_np, atol=1e-9) From b4cde58123c2bcb8e93d300309785bb9ee493f39 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Tue, 16 Jun 2026 09:08:10 -0400 Subject: [PATCH 4/5] Copy edit verbose claude doc-strings --- src/parcels/_core/fieldset.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index d4015a5e5..f61429607 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -131,16 +131,11 @@ def to_windowed_arrays(self, *, max_levels: int | None = None): """Wrap dask-backed field data in a rolling time-window cache. Opt-in optimization for forward-marching simulations where all particles - share a single clock. For each dask-backed, time-leading field, ``isel`` - then samples a resident NumPy window (each time level loaded once and - evicted as the clock advances) instead of re-reading chunks and paying the - dask scheduling overhead on every kernel step. The wrapper is transparent: - it forwards all attributes, intercepts only ``isel``, and returns - NumPy-backed results so the interpolators' ``is_dask_collection()`` guards - skip ``.compute()``. - - NumPy-backed (eager) fields and fields without a leading ``time`` dimension - are left unchanged, and re-wrapping is idempotent, so this is safe to call + share a single clock. For each dask-backed field ``isel`` then samples + a resident NumPy window instead of re-reading chunks and paying the + dask scheduling overhead on every kernel step. + + NumPy-backed fields are left unchanged, so this is safe to call more than once. Parameters From 8136bf541e769968672a7c4942f1e61acadaaa8b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jun 2026 13:09:53 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/parcels/_core/_windowed_array.py | 8 +++++--- src/parcels/_core/fieldset.py | 2 +- tests/test_windowed_array.py | 8 ++++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/parcels/_core/_windowed_array.py b/src/parcels/_core/_windowed_array.py index a70b7c905..80d322000 100644 --- a/src/parcels/_core/_windowed_array.py +++ b/src/parcels/_core/_windowed_array.py @@ -39,8 +39,10 @@ def __getattr__(self, name): return getattr(object.__getattribute__(self, "_data"), name) def __repr__(self): - return (f"WindowedArray(time_dim={self._tdim!r}, cached_levels={sorted(self._cache)}, " - f"loads={self.loads})\n{self._data!r}") + return ( + f"WindowedArray(time_dim={self._tdim!r}, cached_levels={sorted(self._cache)}, " + f"loads={self.loads})\n{self._data!r}" + ) # -- window management ---------------------------------------------------- def _read_level(self, lvl: int) -> np.ndarray: @@ -78,7 +80,7 @@ def isel(self, indexers: dict | None = None, **kwargs): # stack the resident levels into one small NumPy block; remap to local indices block = np.stack([self._cache[int(l)] for l in levels]) # (nlevels, *rest) - nda = xr.DataArray(block, dims=self._data.dims) # NumPy-backed, original dim order + nda = xr.DataArray(block, dims=self._data.dims) # NumPy-backed, original dim order local = np.searchsorted(levels, t_vals) sel[self._tdim] = xr.DataArray(local, dims=getattr(t_ind, "dims", ())) return nda.isel(sel) # plain vectorised gather in NumPy (no ignore_grid needed) diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index f61429607..9dab20db5 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -133,7 +133,7 @@ def to_windowed_arrays(self, *, max_levels: int | None = None): Opt-in optimization for forward-marching simulations where all particles share a single clock. For each dask-backed field ``isel`` then samples a resident NumPy window instead of re-reading chunks and paying the - dask scheduling overhead on every kernel step. + dask scheduling overhead on every kernel step. NumPy-backed fields are left unchanged, so this is safe to call more than once. diff --git a/tests/test_windowed_array.py b/tests/test_windowed_array.py index 8c5f3f3c5..3310d6bec 100644 --- a/tests/test_windowed_array.py +++ b/tests/test_windowed_array.py @@ -1,12 +1,10 @@ """Tests for the transparent rolling time-window cache (WindowedArray).""" -import dask.array as da import numpy as np import pytest -import xarray as xr from parcels import FieldSet, ParticleSet -from parcels._core._windowed_array import WindowedArray, maybe_windowed +from parcels._core._windowed_array import WindowedArray from parcels._datasets.structured.generated import simple_UV_dataset from parcels.kernels import AdvectionRK2 @@ -24,10 +22,12 @@ def test_to_windowed_arrays_is_idempotent_and_forwards_max_levels(): fs.to_windowed_arrays(max_levels=3) assert fs.U.data is first + @pytest.mark.parametrize("mesh", ["flat", "spherical"]) def test_dask_advection_matches_numpy(mesh): """An identical advection must give identical trajectories whether the field - is numpy-backed or dask-backed (windowed).""" + is numpy-backed or dask-backed (windowed). + """ ds = simple_UV_dataset(mesh=mesh) ds["U"].data[:] = 1.0 # steady zonal flow -> in-bounds, deterministic