-
Notifications
You must be signed in to change notification settings - Fork 176
Issue 2656 windowed array #2671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
af1c9df
4e8190d
1139091
b4cde58
8136bf5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| """Transparent rolling time-window cache for lazy (dask-backed) field data. | ||
|
|
||
| Assumptions / current limits: | ||
| * ``time`` is the leading dimension of the field (true for both the SGRID and | ||
| UGRID ingestion paths; the structured path transposes to ``(time, ...)``). | ||
| * Valid while the requested time indices stay within the resident window | ||
| (i.e. all particles share the clock). A sample that requests time indices | ||
| spanning more than the retained levels would force reloads. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import numpy as np | ||
| import xarray as xr | ||
| from dask import is_dask_collection | ||
|
|
||
| # xarray / uxarray ``isel`` keyword arguments that are NOT dimension indexers. | ||
| _NON_INDEXER_KWARGS = frozenset({"drop", "missing_dims", "ignore_grid"}) | ||
|
|
||
|
|
||
| class WindowedArray: | ||
| """Wrap a lazy DataArray so ``isel`` loads/caches/evicts time levels as NumPy.""" | ||
|
|
||
| def __init__(self, data: xr.DataArray, time_dim: str = "time", max_levels: int | None = None): | ||
| if data.dims[0] != time_dim: | ||
| raise ValueError(f"WindowedArray expects {time_dim!r} as the leading dimension, got {data.dims}") | ||
| self._data = data | ||
| self._tdim = time_dim | ||
| self._cache: dict[int, np.ndarray] = {} # time index -> NumPy slab (remaining dims) | ||
| self._max = max_levels | ||
| # diagnostics | ||
| self.loads = 0 | ||
| self.bytes_read = 0 | ||
| self._slab_bytes = int(np.prod(data.isel({time_dim: 0}).shape)) * data.dtype.itemsize | ||
|
|
||
| # -- transparency: forward everything we don't override ------------------- | ||
| def __getattr__(self, name): | ||
| # __getattr__ only fires for misses; reach _data without recursing. | ||
| return getattr(object.__getattribute__(self, "_data"), name) | ||
|
|
||
| def __repr__(self): | ||
| return ( | ||
| f"WindowedArray(time_dim={self._tdim!r}, cached_levels={sorted(self._cache)}, " | ||
| f"loads={self.loads})\n{self._data!r}" | ||
| ) | ||
|
|
||
| # -- window management ---------------------------------------------------- | ||
| def _read_level(self, lvl: int) -> np.ndarray: | ||
| """Bulk, sequential read of one time level into NumPy (the dask->NumPy step).""" | ||
| return np.asarray(self._data.isel({self._tdim: int(lvl)}).values) | ||
|
|
||
| def _ensure(self, levels: np.ndarray) -> None: | ||
| for lvl in levels: | ||
| lvl = int(lvl) | ||
| if lvl not in self._cache: | ||
| self._cache[lvl] = self._read_level(lvl) | ||
| self.loads += 1 | ||
| self.bytes_read += self._slab_bytes | ||
| # retire stale levels (the clock only moves forward across the window) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have you checked whether this also works in time-backward (
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have not :) |
||
| lo = int(np.min(levels)) | ||
| for old in [k for k in self._cache if k < lo]: | ||
| del self._cache[old] | ||
| if self._max is not None and len(self._cache) > self._max: | ||
| for old in sorted(self._cache)[: len(self._cache) - self._max]: | ||
| del self._cache[old] | ||
|
|
||
| # -- intercepted indexing ------------------------------------------------- | ||
| def isel(self, indexers: dict | None = None, **kwargs): | ||
| sel = dict(indexers) if indexers is not None else {} | ||
| sel.update({k: v for k, v in kwargs.items() if k not in _NON_INDEXER_KWARGS}) | ||
|
|
||
| if self._tdim not in sel: | ||
| # no time selection -> nothing to window; preserve control kwargs | ||
| return self._data.isel(indexers, **kwargs) | ||
|
|
||
| t_ind = sel[self._tdim] | ||
| t_vals = np.asarray(t_ind.values if isinstance(t_ind, xr.DataArray) else t_ind) | ||
| levels = np.unique(t_vals) | ||
| self._ensure(levels) | ||
|
|
||
| # stack the resident levels into one small NumPy block; remap to local indices | ||
| block = np.stack([self._cache[int(l)] for l in levels]) # (nlevels, *rest) | ||
| nda = xr.DataArray(block, dims=self._data.dims) # NumPy-backed, original dim order | ||
| local = np.searchsorted(levels, t_vals) | ||
| sel[self._tdim] = xr.DataArray(local, dims=getattr(t_ind, "dims", ())) | ||
| return nda.isel(sel) # plain vectorised gather in NumPy (no ignore_grid needed) | ||
|
|
||
|
|
||
| def maybe_windowed(data: xr.DataArray, max_levels: int | None = None): | ||
| """Wrap dask-backed, field data in a ``WindowedArray``; else pass through. | ||
|
|
||
| NumPy-backed fields (already resident) and fields without a leading ``time`` | ||
| dimension are returned unchanged, so existing eager workflows are unaffected. | ||
| Already-wrapped data is returned unchanged. | ||
| """ | ||
| if isinstance(data, WindowedArray): | ||
| return data | ||
| if is_dask_collection(data.data): | ||
| return WindowedArray(data, max_levels=max_levels) | ||
| return data | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| import xgcm | ||
|
|
||
| import parcels._sgrid as sgrid | ||
| from parcels._core._windowed_array import maybe_windowed | ||
| from parcels._core.field import Field, VectorField | ||
| from parcels._core.utils.string import _assert_str_and_python_varname | ||
| from parcels._core.utils.time import get_datetime_type_calendar | ||
|
|
@@ -126,6 +127,35 @@ def add_field(self, field: Field, name: str | None = None): | |
|
|
||
| self.fields[name] = field | ||
|
|
||
| def to_windowed_arrays(self, *, max_levels: int | None = None): | ||
| """Wrap dask-backed field data in a rolling time-window cache. | ||
|
|
||
| Opt-in optimization for forward-marching simulations where all particles | ||
| share a single clock. For each dask-backed field ``isel`` then samples | ||
| a resident NumPy window instead of re-reading chunks and paying the | ||
| dask scheduling overhead on every kernel step. | ||
|
|
||
| NumPy-backed fields are left unchanged, so this is safe to call | ||
| more than once. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| max_levels : int, optional | ||
| Cap on the number of time levels kept resident per field. ``None`` | ||
| (default) retains every level that the advancing clock still brackets. | ||
|
|
||
| Returns | ||
| ------- | ||
| FieldSet | ||
| ``self``, to allow chaining. | ||
| """ | ||
| for field in self.fields.values(): | ||
| components = (field.U, field.V, field.W) if isinstance(field, VectorField) else (field,) | ||
| for component in components: | ||
| if component is not None: | ||
| component.data = maybe_windowed(component.data, max_levels=max_levels) | ||
| return self | ||
|
Comment on lines
+130
to
+157
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def add_constant_field(self, name: str, value, mesh: Mesh = "spherical"): | ||
| """Wrapper function to add a Field that is constant in space, | ||
| useful e.g. when using constant horizontal diffusivity | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This class would also work on
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The uxinterpolators changes here actually are from the #2666 base - in order for this PR to be possible, we need the unstructured bits to use vectorized |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,46 @@ | ||||||
| """Tests for the transparent rolling time-window cache (WindowedArray).""" | ||||||
|
|
||||||
| import numpy as np | ||||||
| import pytest | ||||||
|
|
||||||
| from parcels import FieldSet, ParticleSet | ||||||
| from parcels._core._windowed_array import WindowedArray | ||||||
| from parcels._datasets.structured.generated import simple_UV_dataset | ||||||
| from parcels.kernels import AdvectionRK2 | ||||||
|
|
||||||
|
|
||||||
| def test_to_windowed_arrays_is_idempotent_and_forwards_max_levels(): | ||||||
| ds = simple_UV_dataset(mesh="flat") | ||||||
| fs = FieldSet.from_sgrid_conventions(ds.chunk({"time": 1}), mesh="flat") | ||||||
|
|
||||||
| fs.to_windowed_arrays(max_levels=3) | ||||||
| first = fs.U.data | ||||||
| assert isinstance(first, WindowedArray) | ||||||
| assert first._max == 3 | ||||||
|
|
||||||
| # re-wrapping returns the same object (idempotent), not a nested wrapper | ||||||
| fs.to_windowed_arrays(max_levels=3) | ||||||
| assert fs.U.data is first | ||||||
|
|
||||||
|
|
||||||
| @pytest.mark.parametrize("mesh", ["flat", "spherical"]) | ||||||
| def test_dask_advection_matches_numpy(mesh): | ||||||
| """An identical advection must give identical trajectories whether the field | ||||||
| is numpy-backed or dask-backed (windowed). | ||||||
| """ | ||||||
| ds = simple_UV_dataset(mesh=mesh) | ||||||
| ds["U"].data[:] = 1.0 # steady zonal flow -> in-bounds, deterministic | ||||||
|
|
||||||
| def run(windowed): | ||||||
| d = ds.chunk({"time": 1}) if chunked else ds | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| fs = FieldSet.from_sgrid_conventions(d, mesh=mesh) | ||||||
| if windowed: | ||||||
| fs.to_windowed_arrays() | ||||||
| pset = ParticleSet(fs, lon=np.zeros(10), lat=np.linspace(-10, 10, 10)) | ||||||
| pset.execute(AdvectionRK2, runtime=7200, dt=np.timedelta64(15, "m")) | ||||||
| return np.array(pset.lon), np.array(pset.lat) | ||||||
|
|
||||||
| lon_np, lat_np = run(False) | ||||||
| lon_dk, lat_dk = run(True) | ||||||
| np.testing.assert_allclose(lon_dk, lon_np, atol=1e-9) | ||||||
| np.testing.assert_allclose(lat_dk, lat_np, atol=1e-9) | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.