Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions src/parcels/_core/_windowed_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Transparent rolling time-window cache for lazy (dask-backed) field data.

Assumptions / current limits:
* ``time`` is the leading dimension of the field (true for both the SGRID and
UGRID ingestion paths; the structured path transposes to ``(time, ...)``).
* Valid while the requested time indices stay within the resident window
(i.e. all particles share the clock). A sample that requests time indices
spanning more than the retained levels would force reloads.
"""

from __future__ import annotations

import numpy as np
import xarray as xr
from dask import is_dask_collection

# xarray / uxarray ``isel`` keyword arguments that are NOT dimension indexers.
_NON_INDEXER_KWARGS = frozenset({"drop", "missing_dims", "ignore_grid"})


class WindowedArray:
"""Wrap a lazy DataArray so ``isel`` loads/caches/evicts time levels as NumPy."""

def __init__(self, data: xr.DataArray, time_dim: str = "time", max_levels: int | None = None):
if data.dims[0] != time_dim:
raise ValueError(f"WindowedArray expects {time_dim!r} as the leading dimension, got {data.dims}")
self._data = data
self._tdim = time_dim
self._cache: dict[int, np.ndarray] = {} # time index -> NumPy slab (remaining dims)
self._max = max_levels
# diagnostics
self.loads = 0
self.bytes_read = 0
self._slab_bytes = int(np.prod(data.isel({time_dim: 0}).shape)) * data.dtype.itemsize

# -- transparency: forward everything we don't override -------------------
def __getattr__(self, name):
# __getattr__ only fires for misses; reach _data without recursing.
return getattr(object.__getattribute__(self, "_data"), name)

def __repr__(self):
return (
f"WindowedArray(time_dim={self._tdim!r}, cached_levels={sorted(self._cache)}, "
f"loads={self.loads})\n{self._data!r}"
)

# -- window management ----------------------------------------------------
def _read_level(self, lvl: int) -> np.ndarray:
"""Bulk, sequential read of one time level into NumPy (the dask->NumPy step)."""
return np.asarray(self._data.isel({self._tdim: int(lvl)}).values)

def _ensure(self, levels: np.ndarray) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _ensure(self, levels: np.ndarray) -> None:
def _ensure(self, levels: np.ndarray) -> None:
if self._max < len(levels):
raise ValueError("Trying to ensure more levels than we can hold.")

for lvl in levels:
lvl = int(lvl)
if lvl not in self._cache:
self._cache[lvl] = self._read_level(lvl)
self.loads += 1
self.bytes_read += self._slab_bytes
# retire stale levels (the clock only moves forward across the window)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you checked whether this also works in time-backward (dt<0) mode?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not :)

lo = int(np.min(levels))
for old in [k for k in self._cache if k < lo]:
del self._cache[old]
if self._max is not None and len(self._cache) > self._max:
for old in sorted(self._cache)[: len(self._cache) - self._max]:
del self._cache[old]

# -- intercepted indexing -------------------------------------------------
def isel(self, indexers: dict | None = None, **kwargs):
sel = dict(indexers) if indexers is not None else {}
sel.update({k: v for k, v in kwargs.items() if k not in _NON_INDEXER_KWARGS})

if self._tdim not in sel:
# no time selection -> nothing to window; preserve control kwargs
return self._data.isel(indexers, **kwargs)

t_ind = sel[self._tdim]
t_vals = np.asarray(t_ind.values if isinstance(t_ind, xr.DataArray) else t_ind)
levels = np.unique(t_vals)
self._ensure(levels)

# stack the resident levels into one small NumPy block; remap to local indices
block = np.stack([self._cache[int(l)] for l in levels]) # (nlevels, *rest)
nda = xr.DataArray(block, dims=self._data.dims) # NumPy-backed, original dim order
local = np.searchsorted(levels, t_vals)
sel[self._tdim] = xr.DataArray(local, dims=getattr(t_ind, "dims", ()))
return nda.isel(sel) # plain vectorised gather in NumPy (no ignore_grid needed)


def maybe_windowed(data: xr.DataArray, max_levels: int | None = None):
"""Wrap dask-backed, field data in a ``WindowedArray``; else pass through.

NumPy-backed fields (already resident) and fields without a leading ``time``
dimension are returned unchanged, so existing eager workflows are unaffected.
Already-wrapped data is returned unchanged.
"""
if isinstance(data, WindowedArray):
return data
if is_dask_collection(data.data):
return WindowedArray(data, max_levels=max_levels)
return data
30 changes: 30 additions & 0 deletions src/parcels/_core/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import xgcm

import parcels._sgrid as sgrid
from parcels._core._windowed_array import maybe_windowed
from parcels._core.field import Field, VectorField
from parcels._core.utils.string import _assert_str_and_python_varname
from parcels._core.utils.time import get_datetime_type_calendar
Expand Down Expand Up @@ -126,6 +127,35 @@ def add_field(self, field: Field, name: str | None = None):

self.fields[name] = field

def to_windowed_arrays(self, *, max_levels: int | None = None):
"""Wrap dask-backed field data in a rolling time-window cache.

Opt-in optimization for forward-marching simulations where all particles
share a single clock. For each dask-backed field ``isel`` then samples
a resident NumPy window instead of re-reading chunks and paying the
dask scheduling overhead on every kernel step.

NumPy-backed fields are left unchanged, so this is safe to call
more than once.

Parameters
----------
max_levels : int, optional
Cap on the number of time levels kept resident per field. ``None``
(default) retains every level that the advancing clock still brackets.

Returns
-------
FieldSet
``self``, to allow chaining.
"""
for field in self.fields.values():
components = (field.U, field.V, field.W) if isinstance(field, VectorField) else (field,)
for component in components:
if component is not None:
component.data = maybe_windowed(component.data, max_levels=max_levels)
return self
Comment on lines +130 to +157

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I think that this is actually incompatible with #2646 :'( since that refactors things so that there isn't a field.data object anymore.

Do you think that it would be possible to wrap this WindowArray approach similar to what is done in #2668 ?


def add_constant_field(self, name: str, value, mesh: Mesh = "spherical"):
"""Wrapper function to add a Field that is constant in space,
useful e.g. when using constant horizontal diffusivity
Expand Down
86 changes: 68 additions & 18 deletions src/parcels/interpolators/_uxinterpolators.py

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class would also work on xgrids, right? Test it there too?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The uxinterpolators changes here actually are from the #2666 base - in order for this PR to be possible, we need the unstructured bits to use vectorized isel

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import TYPE_CHECKING

import numpy as np
import xarray as xr
from dask import is_dask_collection

if TYPE_CHECKING:
from parcels._core.field import Field, VectorField
Expand All @@ -17,9 +19,19 @@ def UxConstantFaceConstantZC(
field: Field,
):
"""Piecewise constant interpolation kernel for face registered data that is vertically centered (on zc points)"""
return field.data.values[
# Broadcast the per-axis indices to a common (npart,) shape (``ti`` may be scalar for time-constant fields)
ti, zi, fi = np.broadcast_arrays(
grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"]
]
)

tdim, zdim, fdim = field.data.dims
selection_dict = {
tdim: xr.DataArray(ti, dims="points"),
zdim: xr.DataArray(zi, dims="points"),
fdim: xr.DataArray(fi, dims="points"),
}
value = field.data.isel(selection_dict, ignore_grid=True).data
return value.compute() if is_dask_collection(value) else value


def UxConstantFaceLinearZF(
Expand All @@ -31,15 +43,28 @@ def UxConstantFaceLinearZF(
Piecewise constant interpolation (lateral) with linear vertical interpolation kernel for face registered data
that is located at vertical interface levels (on zf points)
"""
ti = grid_positions["T"]["index"]
zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"]
ti, zi, fi = np.broadcast_arrays(
grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"]
)
z = particle_positions["z"]

tdim, zdim, fdim = field.data.dims

def _zsample(z_index):
"""Pointwise ``isel`` of the face values at a single vertical interface level."""
selection_dict = {
tdim: xr.DataArray(ti, dims="points"),
zdim: xr.DataArray(z_index, dims="points"),
fdim: xr.DataArray(fi, dims="points"),
}
value = field.data.isel(selection_dict, ignore_grid=True).data
return value.compute() if is_dask_collection(value) else value

# The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels.
# For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1.
# First, do barycentric interpolation in the lateral direction for each interface level
fzk = field.data.values[ti, zi, fi]
fzkp1 = field.data.values[ti, zi + 1, fi]
fzk = _zsample(zi)
fzkp1 = _zsample(zi + 1)

# Then, do piecewise linear interpolation in the vertical direction
zk = field.grid.z.values[zi]
Expand All @@ -57,13 +82,22 @@ def UxLinearNodeConstantZC(
Effectively, it applies barycentric interpolation in the lateral direction
and piecewise constant interpolation in the vertical direction.
"""
ti = grid_positions["T"]["index"]
zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"]
bcoords = grid_positions["FACE"]["bcoord"]
ti, zi, fi = np.broadcast_arrays(
grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"]
)
bcoords = xr.DataArray(grid_positions["FACE"]["bcoord"], dims=("points", "nodes"))
node_ids = field.grid.uxgrid.face_node_connectivity[fi, :].values
return np.sum(
field.data.values[ti[:, None], zi[:, None], node_ids] * bcoords, axis=-1
) # Linear interpolation in the vertical direction

tdim, zdim, ndim = field.data.dims
selection_dict = {
tdim: xr.DataArray(ti, dims="points"),
zdim: xr.DataArray(zi, dims="points"),
ndim: xr.DataArray(node_ids, dims=("points", "nodes")),
}

node_data = field.data.isel(selection_dict, ignore_grid=True)
value = (node_data * bcoords).sum("nodes").data # Barycentric interpolation in the lateral direction
return value.compute() if is_dask_collection(value) else value


def UxLinearNodeLinearZF(
Expand All @@ -76,21 +110,37 @@ def UxLinearNodeLinearZF(
Effectively, it applies barycentric interpolation in the lateral direction
and piecewise linear interpolation in the vertical direction.
"""
ti = grid_positions["T"]["index"]
zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"]
ti, zi, fi = np.broadcast_arrays(
grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"]
)
z = particle_positions["z"]
bcoords = grid_positions["FACE"]["bcoord"]
bcoords = xr.DataArray(grid_positions["FACE"]["bcoord"], dims=("points", "nodes"))
node_ids = field.grid.uxgrid.face_node_connectivity[fi, :].values

tdim, zdim, ndim = field.data.dims

def _zsample(z_index):
"""Barycentric (lateral) interpolation of the node values at a single vertical interface level."""
selection_dict = {
tdim: xr.DataArray(ti, dims="points"),
zdim: xr.DataArray(z_index, dims="points"),
ndim: xr.DataArray(node_ids, dims=("points", "nodes")),
}
# Reduce over the "nodes" dimension by name so the result is independent of ``isel`` dim order.
node_data = field.data.isel(selection_dict, ignore_grid=True)
return (node_data * bcoords).sum("nodes").data

# The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels.
# For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1.
# First, do barycentric interpolation in the lateral direction for each interface level
fzk = np.sum(field.data.values[ti[:, None], zi[:, None], node_ids] * bcoords, axis=-1)
fzkp1 = np.sum(field.data.values[ti[:, None], zi[:, None] + 1, node_ids] * bcoords, axis=-1)
fzk = _zsample(zi)
fzkp1 = _zsample(zi + 1)

# Then, do piecewise linear interpolation in the vertical direction
zk = field.grid.z.values[zi]
zkp1 = field.grid.z.values[zi + 1]
return (fzk * (zkp1 - z) + fzkp1 * (z - zk)) / (zkp1 - zk) # Linear interpolation in the vertical direction
value = (fzk * (zkp1 - z) + fzkp1 * (z - zk)) / (zkp1 - zk) # Linear interpolation in the vertical direction
return value.compute() if is_dask_collection(value) else value


def Ux_Velocity(
Expand Down
46 changes: 46 additions & 0 deletions tests/test_windowed_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Tests for the transparent rolling time-window cache (WindowedArray)."""

import numpy as np
import pytest

from parcels import FieldSet, ParticleSet
from parcels._core._windowed_array import WindowedArray
from parcels._datasets.structured.generated import simple_UV_dataset
from parcels.kernels import AdvectionRK2


def test_to_windowed_arrays_is_idempotent_and_forwards_max_levels():
ds = simple_UV_dataset(mesh="flat")
fs = FieldSet.from_sgrid_conventions(ds.chunk({"time": 1}), mesh="flat")

fs.to_windowed_arrays(max_levels=3)
first = fs.U.data
assert isinstance(first, WindowedArray)
assert first._max == 3

# re-wrapping returns the same object (idempotent), not a nested wrapper
fs.to_windowed_arrays(max_levels=3)
assert fs.U.data is first


@pytest.mark.parametrize("mesh", ["flat", "spherical"])
def test_dask_advection_matches_numpy(mesh):
"""An identical advection must give identical trajectories whether the field
is numpy-backed or dask-backed (windowed).
"""
ds = simple_UV_dataset(mesh=mesh)
ds["U"].data[:] = 1.0 # steady zonal flow -> in-bounds, deterministic

def run(windowed):
d = ds.chunk({"time": 1}) if chunked else ds

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
d = ds.chunk({"time": 1}) if chunked else ds
d = ds.chunk({"time": 1}) if windowed else ds

fs = FieldSet.from_sgrid_conventions(d, mesh=mesh)
if windowed:
fs.to_windowed_arrays()
pset = ParticleSet(fs, lon=np.zeros(10), lat=np.linspace(-10, 10, 10))
pset.execute(AdvectionRK2, runtime=7200, dt=np.timedelta64(15, "m"))
return np.array(pset.lon), np.array(pset.lat)

lon_np, lat_np = run(False)
lon_dk, lat_dk = run(True)
np.testing.assert_allclose(lon_dk, lon_np, atol=1e-9)
np.testing.assert_allclose(lat_dk, lat_np, atol=1e-9)
Loading