diff --git a/src/parcels/_core/_windowed_array.py b/src/parcels/_core/_windowed_array.py new file mode 100644 index 000000000..80d322000 --- /dev/null +++ b/src/parcels/_core/_windowed_array.py @@ -0,0 +1,100 @@ +"""Transparent rolling time-window cache for lazy (dask-backed) field data. + +Assumptions / current limits: + * ``time`` is the leading dimension of the field (true for both the SGRID and + UGRID ingestion paths; the structured path transposes to ``(time, ...)``). + * Valid while the requested time indices stay within the resident window + (i.e. all particles share the clock). A sample that requests time indices + spanning more than the retained levels would force reloads. +""" + +from __future__ import annotations + +import numpy as np +import xarray as xr +from dask import is_dask_collection + +# xarray / uxarray ``isel`` keyword arguments that are NOT dimension indexers. +_NON_INDEXER_KWARGS = frozenset({"drop", "missing_dims", "ignore_grid"}) + + +class WindowedArray: + """Wrap a lazy DataArray so ``isel`` loads/caches/evicts time levels as NumPy.""" + + def __init__(self, data: xr.DataArray, time_dim: str = "time", max_levels: int | None = None): + if data.dims[0] != time_dim: + raise ValueError(f"WindowedArray expects {time_dim!r} as the leading dimension, got {data.dims}") + self._data = data + self._tdim = time_dim + self._cache: dict[int, np.ndarray] = {} # time index -> NumPy slab (remaining dims) + self._max = max_levels + # diagnostics + self.loads = 0 + self.bytes_read = 0 + self._slab_bytes = int(np.prod(data.isel({time_dim: 0}).shape)) * data.dtype.itemsize + + # -- transparency: forward everything we don't override ------------------- + def __getattr__(self, name): + # __getattr__ only fires for misses; reach _data without recursing. + return getattr(object.__getattribute__(self, "_data"), name) + + def __repr__(self): + return ( + f"WindowedArray(time_dim={self._tdim!r}, cached_levels={sorted(self._cache)}, " + f"loads={self.loads})\n{self._data!r}" + ) + + # -- window management ---------------------------------------------------- + def _read_level(self, lvl: int) -> np.ndarray: + """Bulk, sequential read of one time level into NumPy (the dask->NumPy step).""" + return np.asarray(self._data.isel({self._tdim: int(lvl)}).values) + + def _ensure(self, levels: np.ndarray) -> None: + for lvl in levels: + lvl = int(lvl) + if lvl not in self._cache: + self._cache[lvl] = self._read_level(lvl) + self.loads += 1 + self.bytes_read += self._slab_bytes + # retire stale levels (the clock only moves forward across the window) + lo = int(np.min(levels)) + for old in [k for k in self._cache if k < lo]: + del self._cache[old] + if self._max is not None and len(self._cache) > self._max: + for old in sorted(self._cache)[: len(self._cache) - self._max]: + del self._cache[old] + + # -- intercepted indexing ------------------------------------------------- + def isel(self, indexers: dict | None = None, **kwargs): + sel = dict(indexers) if indexers is not None else {} + sel.update({k: v for k, v in kwargs.items() if k not in _NON_INDEXER_KWARGS}) + + if self._tdim not in sel: + # no time selection -> nothing to window; preserve control kwargs + return self._data.isel(indexers, **kwargs) + + t_ind = sel[self._tdim] + t_vals = np.asarray(t_ind.values if isinstance(t_ind, xr.DataArray) else t_ind) + levels = np.unique(t_vals) + self._ensure(levels) + + # stack the resident levels into one small NumPy block; remap to local indices + block = np.stack([self._cache[int(l)] for l in levels]) # (nlevels, *rest) + nda = xr.DataArray(block, dims=self._data.dims) # NumPy-backed, original dim order + local = np.searchsorted(levels, t_vals) + sel[self._tdim] = xr.DataArray(local, dims=getattr(t_ind, "dims", ())) + return nda.isel(sel) # plain vectorised gather in NumPy (no ignore_grid needed) + + +def maybe_windowed(data: xr.DataArray, max_levels: int | None = None): + """Wrap dask-backed, field data in a ``WindowedArray``; else pass through. + + NumPy-backed fields (already resident) and fields without a leading ``time`` + dimension are returned unchanged, so existing eager workflows are unaffected. + Already-wrapped data is returned unchanged. + """ + if isinstance(data, WindowedArray): + return data + if is_dask_collection(data.data): + return WindowedArray(data, max_levels=max_levels) + return data diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index 4844ff210..9dab20db5 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -11,6 +11,7 @@ import xgcm import parcels._sgrid as sgrid +from parcels._core._windowed_array import maybe_windowed from parcels._core.field import Field, VectorField from parcels._core.utils.string import _assert_str_and_python_varname from parcels._core.utils.time import get_datetime_type_calendar @@ -126,6 +127,35 @@ def add_field(self, field: Field, name: str | None = None): self.fields[name] = field + def to_windowed_arrays(self, *, max_levels: int | None = None): + """Wrap dask-backed field data in a rolling time-window cache. + + Opt-in optimization for forward-marching simulations where all particles + share a single clock. For each dask-backed field ``isel`` then samples + a resident NumPy window instead of re-reading chunks and paying the + dask scheduling overhead on every kernel step. + + NumPy-backed fields are left unchanged, so this is safe to call + more than once. + + Parameters + ---------- + max_levels : int, optional + Cap on the number of time levels kept resident per field. ``None`` + (default) retains every level that the advancing clock still brackets. + + Returns + ------- + FieldSet + ``self``, to allow chaining. + """ + for field in self.fields.values(): + components = (field.U, field.V, field.W) if isinstance(field, VectorField) else (field,) + for component in components: + if component is not None: + component.data = maybe_windowed(component.data, max_levels=max_levels) + return self + def add_constant_field(self, name: str, value, mesh: Mesh = "spherical"): """Wrapper function to add a Field that is constant in space, useful e.g. when using constant horizontal diffusivity diff --git a/src/parcels/interpolators/_uxinterpolators.py b/src/parcels/interpolators/_uxinterpolators.py index cb2c7aa24..f434ac96e 100644 --- a/src/parcels/interpolators/_uxinterpolators.py +++ b/src/parcels/interpolators/_uxinterpolators.py @@ -5,6 +5,8 @@ from typing import TYPE_CHECKING import numpy as np +import xarray as xr +from dask import is_dask_collection if TYPE_CHECKING: from parcels._core.field import Field, VectorField @@ -17,9 +19,19 @@ def UxConstantFaceConstantZC( field: Field, ): """Piecewise constant interpolation kernel for face registered data that is vertically centered (on zc points)""" - return field.data.values[ + # Broadcast the per-axis indices to a common (npart,) shape (``ti`` may be scalar for time-constant fields) + ti, zi, fi = np.broadcast_arrays( grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"] - ] + ) + + tdim, zdim, fdim = field.data.dims + selection_dict = { + tdim: xr.DataArray(ti, dims="points"), + zdim: xr.DataArray(zi, dims="points"), + fdim: xr.DataArray(fi, dims="points"), + } + value = field.data.isel(selection_dict, ignore_grid=True).data + return value.compute() if is_dask_collection(value) else value def UxConstantFaceLinearZF( @@ -31,15 +43,28 @@ def UxConstantFaceLinearZF( Piecewise constant interpolation (lateral) with linear vertical interpolation kernel for face registered data that is located at vertical interface levels (on zf points) """ - ti = grid_positions["T"]["index"] - zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"] + ti, zi, fi = np.broadcast_arrays( + grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"] + ) z = particle_positions["z"] + tdim, zdim, fdim = field.data.dims + + def _zsample(z_index): + """Pointwise ``isel`` of the face values at a single vertical interface level.""" + selection_dict = { + tdim: xr.DataArray(ti, dims="points"), + zdim: xr.DataArray(z_index, dims="points"), + fdim: xr.DataArray(fi, dims="points"), + } + value = field.data.isel(selection_dict, ignore_grid=True).data + return value.compute() if is_dask_collection(value) else value + # The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels. # For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1. # First, do barycentric interpolation in the lateral direction for each interface level - fzk = field.data.values[ti, zi, fi] - fzkp1 = field.data.values[ti, zi + 1, fi] + fzk = _zsample(zi) + fzkp1 = _zsample(zi + 1) # Then, do piecewise linear interpolation in the vertical direction zk = field.grid.z.values[zi] @@ -57,13 +82,22 @@ def UxLinearNodeConstantZC( Effectively, it applies barycentric interpolation in the lateral direction and piecewise constant interpolation in the vertical direction. """ - ti = grid_positions["T"]["index"] - zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"] - bcoords = grid_positions["FACE"]["bcoord"] + ti, zi, fi = np.broadcast_arrays( + grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"] + ) + bcoords = xr.DataArray(grid_positions["FACE"]["bcoord"], dims=("points", "nodes")) node_ids = field.grid.uxgrid.face_node_connectivity[fi, :].values - return np.sum( - field.data.values[ti[:, None], zi[:, None], node_ids] * bcoords, axis=-1 - ) # Linear interpolation in the vertical direction + + tdim, zdim, ndim = field.data.dims + selection_dict = { + tdim: xr.DataArray(ti, dims="points"), + zdim: xr.DataArray(zi, dims="points"), + ndim: xr.DataArray(node_ids, dims=("points", "nodes")), + } + + node_data = field.data.isel(selection_dict, ignore_grid=True) + value = (node_data * bcoords).sum("nodes").data # Barycentric interpolation in the lateral direction + return value.compute() if is_dask_collection(value) else value def UxLinearNodeLinearZF( @@ -76,21 +110,37 @@ def UxLinearNodeLinearZF( Effectively, it applies barycentric interpolation in the lateral direction and piecewise linear interpolation in the vertical direction. """ - ti = grid_positions["T"]["index"] - zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"] + ti, zi, fi = np.broadcast_arrays( + grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"] + ) z = particle_positions["z"] - bcoords = grid_positions["FACE"]["bcoord"] + bcoords = xr.DataArray(grid_positions["FACE"]["bcoord"], dims=("points", "nodes")) node_ids = field.grid.uxgrid.face_node_connectivity[fi, :].values + + tdim, zdim, ndim = field.data.dims + + def _zsample(z_index): + """Barycentric (lateral) interpolation of the node values at a single vertical interface level.""" + selection_dict = { + tdim: xr.DataArray(ti, dims="points"), + zdim: xr.DataArray(z_index, dims="points"), + ndim: xr.DataArray(node_ids, dims=("points", "nodes")), + } + # Reduce over the "nodes" dimension by name so the result is independent of ``isel`` dim order. + node_data = field.data.isel(selection_dict, ignore_grid=True) + return (node_data * bcoords).sum("nodes").data + # The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels. # For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1. # First, do barycentric interpolation in the lateral direction for each interface level - fzk = np.sum(field.data.values[ti[:, None], zi[:, None], node_ids] * bcoords, axis=-1) - fzkp1 = np.sum(field.data.values[ti[:, None], zi[:, None] + 1, node_ids] * bcoords, axis=-1) + fzk = _zsample(zi) + fzkp1 = _zsample(zi + 1) # Then, do piecewise linear interpolation in the vertical direction zk = field.grid.z.values[zi] zkp1 = field.grid.z.values[zi + 1] - return (fzk * (zkp1 - z) + fzkp1 * (z - zk)) / (zkp1 - zk) # Linear interpolation in the vertical direction + value = (fzk * (zkp1 - z) + fzkp1 * (z - zk)) / (zkp1 - zk) # Linear interpolation in the vertical direction + return value.compute() if is_dask_collection(value) else value def Ux_Velocity( diff --git a/tests/test_windowed_array.py b/tests/test_windowed_array.py new file mode 100644 index 000000000..3310d6bec --- /dev/null +++ b/tests/test_windowed_array.py @@ -0,0 +1,46 @@ +"""Tests for the transparent rolling time-window cache (WindowedArray).""" + +import numpy as np +import pytest + +from parcels import FieldSet, ParticleSet +from parcels._core._windowed_array import WindowedArray +from parcels._datasets.structured.generated import simple_UV_dataset +from parcels.kernels import AdvectionRK2 + + +def test_to_windowed_arrays_is_idempotent_and_forwards_max_levels(): + ds = simple_UV_dataset(mesh="flat") + fs = FieldSet.from_sgrid_conventions(ds.chunk({"time": 1}), mesh="flat") + + fs.to_windowed_arrays(max_levels=3) + first = fs.U.data + assert isinstance(first, WindowedArray) + assert first._max == 3 + + # re-wrapping returns the same object (idempotent), not a nested wrapper + fs.to_windowed_arrays(max_levels=3) + assert fs.U.data is first + + +@pytest.mark.parametrize("mesh", ["flat", "spherical"]) +def test_dask_advection_matches_numpy(mesh): + """An identical advection must give identical trajectories whether the field + is numpy-backed or dask-backed (windowed). + """ + ds = simple_UV_dataset(mesh=mesh) + ds["U"].data[:] = 1.0 # steady zonal flow -> in-bounds, deterministic + + def run(windowed): + d = ds.chunk({"time": 1}) if chunked else ds + fs = FieldSet.from_sgrid_conventions(d, mesh=mesh) + if windowed: + fs.to_windowed_arrays() + pset = ParticleSet(fs, lon=np.zeros(10), lat=np.linspace(-10, 10, 10)) + pset.execute(AdvectionRK2, runtime=7200, dt=np.timedelta64(15, "m")) + return np.array(pset.lon), np.array(pset.lat) + + lon_np, lat_np = run(False) + lon_dk, lat_dk = run(True) + np.testing.assert_allclose(lon_dk, lon_np, atol=1e-9) + np.testing.assert_allclose(lat_dk, lat_np, atol=1e-9)