diff --git a/changes.md b/changes.md new file mode 100644 index 000000000..20af7de56 --- /dev/null +++ b/changes.md @@ -0,0 +1,262 @@ +# Refactoring Summary: `field.py`, `fieldset.py`, `model.py` + +This document describes the refactoring introduced in commit `69338d87a89763efbb1e3886b470e09992812978` relative to `main`. + +--- + +## Overview + +The central change is the introduction of a new `Model` abstraction layer between raw xarray/uxarray data and the `Field`/`FieldSet` objects. Previously, `Field` owned its data and grid directly. Now, `Field` is a thin view over a `Model`, and `FieldSet` is a container of `Model` objects rather than `Field` objects. + +--- + +## New file: `src/parcels/_core/model.py` + +### `Model` (abstract base class) + +Abstract class with three required attributes: + +- `data: Any` — the underlying dataset +- `grid: BaseGrid` — the grid object +- `field_to_interpolator: dict[str, ScalarInterpolator | VectorInterpolator]` — maps field names to interpolator instances + +Abstract methods: + +- `construct_fields() -> list[Field | VectorField]` — build field objects from this model +- `scalar_field_names -> list[str]` — names of scalar fields in the data +- `assert_valid_field_data(field_data)` — validate a single field's data + +Concrete methods on `Model`: + +- `assert_valid_model_data()` — iterates `scalar_field_names` and calls `assert_valid_field_data` on each +- `time_interval -> TimeInterval | None` — computed from `self.data` + +### `StructuredModel(Model)` + +For structured (SGRID) grid data backed by `xr.Dataset`. + +Constructor: `StructuredModel(data: xr.Dataset, mesh: Mesh)` + +- Calls `preprocess_sgrid_model_data(data)` to transpose fields to `(t, z, y, x)` order +- Creates an `XGrid(data, mesh)` grid +- Initializes `field_to_interpolator = {}` +- Calls `assert_valid_model_data()` on construction + +`from_sgrid_conventions(cls, ds, mesh=None)` classmethod: + +- Copied/moved from `FieldSet.from_sgrid_conventions` — handles time axis renaming, mesh type inference +- Sets default interpolator `XLinear()` on all scalar fields after construction +- Returns a `StructuredModel` instance + +`construct_fields()`: + +- Creates `Field("U", self)`, `Field("V", self)` etc., then wraps them in `VectorField("UV", ...)` if U+V present +- Uses `XLinear_Velocity()` for A-grids, `CGrid_Velocity()` for C-grids + +### `UnstructuredModel(Model)` + +For unstructured (UGRID) grid data backed by `ux.UxDataset`. + +Constructor: `UnstructuredModel(data: ux.UxDataset, grid: UxGrid)` + +`from_ugrid_conventions(cls, ds, mesh="spherical")` classmethod: + +- Validates required dimensions (`time`, `zf`, `zc`) +- Creates `UxGrid`, calls `_discover_ux_U_and_V`, returns instance + +`construct_fields()`: + +- Uses `_select_uxinterpolator(da)` to pick the appropriate interpolator per field +- Note: interpolator is passed as 3rd arg to `Field(name, model, interp)` — see Field changes below + +### Helper functions moved from `fieldset.py` to `model.py` + +- `_discover_ux_U_and_V(ds)` — unchanged logic +- `_select_uxinterpolator(da)` — unchanged logic +- `_get_mesh_type_from_sgrid_dataset(ds)` — unchanged logic +- `_is_coordinate_in_degrees(da)` — unchanged logic +- `_get_time_interval(data)` — logic adjusted: checks `"time" not in data or data["time"].size == 1` (previously checked `data.shape[0] == 1`) +- `_assert_valid_uxdataarray(data)` — unchanged logic +- `_assert_has_time_coordinate(da)` — new helper extracted from old `Field.__init__` + +### New helper in `model.py` + +- `preprocess_sgrid_model_data(ds)` — transposes all non-grid-topology data vars to `(t, z, y, x)` using `_transpose_xfield_data_to_tzyx` + +--- + +## Changes to `src/parcels/_core/field.py` + +### `Field.__init__` signature change + +**Before:** + +```python +Field(name: str, data: xr.DataArray | ux.UxDataArray, grid: UxGrid | XGrid, interp_method: Callable) +``` + +**After:** + +```python +Field(name: str, model: Model) +``` + +- `data`, `grid`, and `interp_method` are no longer constructor arguments +- The constructor only sets `self.name`, `self.model`, and `self.igrid = -1` +- Validation (data type checks, axis checks, time interval extraction) removed from `__init__` + +### `Field` properties (delegating to model) + +Three new properties proxy into the model: + +```python +@property +def data(self): + return self.model.data[self.name] + +@property +def grid(self): + return self.model.grid + +@property +def time_interval(self): + return self.model.time_interval +``` + +These preserve backward compatibility for code that reads `field.data`, `field.grid`, `field.time_interval`. + +### `Field.interp_method` property/setter + +**Before:** stored as `self._interp_method`; validated via `assert_same_function_signature` against `ZeroInterpolator` + +**After:** stored in `self.model.field_to_interpolator[self.name]` + +- Getter raises `AttributeError` (not `KeyError`) if no interpolator is set for this field +- Setter validates `isinstance(value, ScalarInterpolator)` instead of checking function signature + +### Interpolator call convention change + +**Before:** `self._interp_method(particle_positions, grid_positions, self)` + +**After:** `self.interp_method.interp(particle_positions, grid_positions, self)` + +Interpolators are now objects with an `.interp(...)` method, not plain callables. + +### `VectorField` changes + +- `interp_method` parameter type annotation changed from `Callable | None` to `VectorInterpolator | None` +- Validation changed from `assert_same_function_signature(...)` to `isinstance(interp_method, VectorInterpolator)` +- Setter similarly validates `isinstance(method, VectorInterpolator)` +- Call site: `self._interp_method.interp(...)` instead of `self._interp_method(...)` + +### Removed from `field.py` + +- `_assert_valid_uxdataarray` — moved to `model.py` +- `_assert_compatible_combination` — removed (validation now handled per-model) +- `_get_time_interval` — moved to `model.py` +- Imports: `uxarray`, `xarray`, `Callable`, `TimeInterval`, `ZeroInterpolator`, `ZeroInterpolator_Vector`, `assert_same_function_signature`, `_transpose_xfield_data_to_tzyx`, `assert_all_field_dims_have_axis` + +--- + +## Changes to `src/parcels/_core/fieldset.py` + +### `FieldSet.__init__` signature change + +**Before:** `FieldSet(fields: list[Field | VectorField])` + +**After:** `FieldSet(models: list[Model])` + +- Now stores `self.models: list[Model]` +- Calls `self.reconstruct_fields()` on init to build `self._fields` +- `assert_compatible_calendars(fields)` call commented out (TODO) + +### New `FieldSet.fields` property + +`_fields` is now the backing store; `fields` is a lazy property that calls `reconstruct_fields()` if `_fields` is `None`. + +### New `FieldSet.reconstruct_fields()` method + +Iterates `self.models`, calls `model.construct_fields()` on each, flattens into `self._fields` dict. + +### `context` renamed to `constants` + +- `self.context` → `self.constants` +- `add_context(name, value)` → `add_constant(name, value)` +- `add_constant` now validates that `value` is `float | np.floating | int | np.integer` +- `__setattr__` override that guarded `context` keys has been **removed** + +### `__getattr__` updated + +Now checks `self._fields` and `self.constants` (was `self.fields` and `self.context`). + +### New `FieldSet.__add__` operator + +```python +def __add__(self, other: FieldSet) -> FieldSet: + assert_compatible_fieldsets(self, other) + combined = FieldSet(self.models + other.models) + combined.constants = {**self.constants, **other.constants} + return combined +``` + +### `from_ugrid_conventions` simplified + +**Before:** ~15 lines building grid, discovering U/V, creating Field objects, returning `cls(list(fields.values()))` + +**After:** + +```python +model = UnstructuredModel.from_ugrid_conventions(ds, mesh) +return cls([model]) +``` + +### `from_sgrid_conventions` simplified + +**Before:** ~50 lines handling time axis, xgcm grid creation, field creation + +**After:** + +```python +model = StructuredModel.from_sgrid_conventions(ds, mesh) +return cls([model]) +``` + +### `add_field` constant field creation updated + +The inline `xgcm.Grid(...)` call when adding a constant scalar field is replaced with constructing `XGrid(ds, mesh=mesh)` directly (after attaching SGRID metadata via `sgrid._attach_sgrid_metadata`). + +### New module-level function: `assert_compatible_fieldsets` + +```python +def assert_compatible_fieldsets(left: FieldSet, right: FieldSet) -> None +``` + +Raises `ValueError` if the two fieldsets share any field names or constant names. + +### Removed from `fieldset.py` + +- `xgcm` import +- `UxGrid` import +- `_DEFAULT_XGCM_KWARGS` import +- `logger` import +- `_ds_rename_using_standard_names` import +- Most interpolator imports (only `XConstantField` remains) +- `_discover_ux_U_and_V` — moved to `model.py` +- `_select_uxinterpolator` — moved to `model.py` +- `_get_mesh_type_from_sgrid_dataset` — moved to `model.py` +- `_is_coordinate_in_degrees` — moved to `model.py` + +--- + +## Summary of architectural intent + +| Concern | Before | After | +| ----------------------- | -------------------------------------------------- | -------------------------------------------------------- | +| Data ownership | `Field` (held `self.data`, `self.grid`) | `Model` (holds `self.data`, `self.grid`) | +| Interpolator storage | `Field._interp_method` (per-field callable) | `Model.field_to_interpolator` (dict of objects) | +| Interpolator type | Any callable matching `ZeroInterpolator` signature | Instance of `ScalarInterpolator` / `VectorInterpolator` | +| Interpolator invocation | `interp_method(positions, grid_positions, field)` | `interp_method.interp(positions, grid_positions, field)` | +| `FieldSet` contents | `list[Field \| VectorField]` | `list[Model]` | +| Field construction | Done in `FieldSet.from_*` classmethods | Delegated to `Model.construct_fields()` | +| `context` / `constants` | `fieldset.context` (any type) | `fieldset.constants` (float/int only) | +| `FieldSet` combination | Not supported | `fieldset_a + fieldset_b` via `__add__` | diff --git a/docs/getting_started/installation.md b/docs/getting_started/installation.md index 62ca73c51..110388ce1 100644 --- a/docs/getting_started/installation.md +++ b/docs/getting_started/installation.md @@ -2,7 +2,7 @@ ## Basic Installation -The simplest way to install the Parcels code is to use Anaconda and the [Parcels conda-forge package](https://anaconda.org/conda-forge/parcels) with the latest release of Parcels. This package will automatically install all the requirements for a fully functional installation of Parcels. This is the "batteries-included" solution probably suitable for most users. Note that we support Python 3.10 and higher. + + +Parcels v4 is in active development and hasn't been released. + +A pre-release version of Parcels (i.e., the latest version on `main`) can be installed via conda using the following instructions (which creates an environment `parcels-env`, activates it, installs Parcels from a custom pre-release channel that we're using, and installs some additional helper packages). ```bash -conda activate base -conda create -n parcels -c conda-forge parcels trajan cartopy jupyter +conda create -n parcels-env python +conda activate parcels-env +conda config --add channels conda-forge +conda install -c https://prefix.dev/parcels parcels +conda install trajan cartopy jupyter ``` + ## Installation for developers diff --git a/docs/superpowers/plans/2026-06-18-update-tests-for-model-refactor.md b/docs/superpowers/plans/2026-06-18-update-tests-for-model-refactor.md new file mode 100644 index 000000000..c1a298c6c --- /dev/null +++ b/docs/superpowers/plans/2026-06-18-update-tests-for-model-refactor.md @@ -0,0 +1,294 @@ +# Update Tests for Model Refactor Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Update or remove each test file from `files.txt` so all tests pass against the new Model-based architecture introduced in `changes.md`. + +**Architecture:** A new `Model` abstraction owns data/grid/interpolators; `Field` is now a thin view over a `Model`; `FieldSet` holds `list[Model]` not `list[Field]`. Tests that construct `Field(name, data, grid, interp_method)` or `FieldSet([field])` directly must be rewritten; tests that exercise removed concepts (old signatures, `context`, callable-based interpolators) must be updated or removed. + +**Tech Stack:** Python, pytest, pixi (use `pixi run pytest -v` to run tests) + +## Global Constraints + +- Do NOT modify any file under `src/` +- Do NOT commit while working +- When removing a test function, add a comment `# remove: {reason}` on the `def` line +- Run each test file in isolation: `pixi run pytest -v` +- Only fix test code, never add new src features to make tests pass + +--- + +## Key Architectural Changes (reference for all tasks) + +| What changed | Old | New | +| -------------------- | ---------------------------------------- | ----------------------------------------------------- | +| `Field.__init__` | `Field(name, data, grid, interp_method)` | `Field(name, model)` | +| `FieldSet.__init__` | `FieldSet([Field(...), ...])` | `FieldSet([Model(...)])` | +| Context/constants | `fieldset.context`, `add_context(k,v)` | `fieldset.constants`, `add_constant(k,v)` | +| Interpolator type | Any callable | Instance of `ScalarInterpolator`/`VectorInterpolator` | +| Interpolator call | `interp_method(positions, grid, field)` | `interp_method.interp(positions, grid, field)` | +| Interpolator storage | `Field._interp_method` | `Model.field_to_interpolator[name]` | +| FieldSet combination | Not supported | `fieldset_a + fieldset_b` | +| Helper locations | `fieldset.py` | `model.py` | + +--- + +### Task 1: test_field.py + +**Files:** + +- Modify: `tests/test_field.py` + +- [ ] Run: `pixi run pytest tests/test_field.py -v` +- [ ] For each failure: if testing old `Field(name, data, grid, interp_method)` signature → update to use `FieldSet.from_sgrid_conventions` or `Model` to get a field; if testing a concept removed from the architecture → mark `# remove: {reason}` +- [ ] Run again to confirm all remaining tests pass + +--- + +### Task 2: test_fieldset.py + +**Files:** + +- Modify: `tests/test_fieldset.py` + +- [ ] Run: `pixi run pytest tests/test_fieldset.py -v` +- [ ] `test_fieldset_init_wrong_types` expects `FieldSet([1.0, 2.0])` to raise — update error message match if changed, or remove if `FieldSet` no longer validates field types +- [ ] `context` → `constants` and `add_context` → `add_constant` throughout +- [ ] `FieldSet.__add__` tests — keep or add +- [ ] Run again to confirm + +--- + +### Task 3: test_interpolation.py + +**Files:** + +- Modify: `tests/test_interpolation.py` + +- [ ] Run: `pixi run pytest tests/test_interpolation.py -v` +- [ ] Interpolators tested as callables → update to test `.interp()` method +- [ ] `field._interp_method` → `field.model.field_to_interpolator[field.name]` +- [ ] Run again to confirm + +--- + +### Task 4: test_index_search.py + +**Files:** + +- Modify: `tests/test_index_search.py` + +- [ ] Run: `pixi run pytest tests/test_index_search.py -v` +- [ ] Update any direct Field/FieldSet construction +- [ ] Run again to confirm + +--- + +### Task 5: test_utils.py + +- [ ] Run: `pixi run pytest tests/test_utils.py -v` +- [ ] Update or remove as needed + +--- + +### Task 6: test_uxgrid.py + +- [ ] Run: `pixi run pytest tests/test_uxgrid.py -v` +- [ ] Update or remove as needed + +--- + +### Task 7: test_xarray_fieldset.py + +- [ ] Run: `pixi run pytest tests/test_xarray_fieldset.py -v` +- [ ] Likely uses old `from_sgrid_conventions` or Field constructor — update via `StructuredModel` +- [ ] Run again to confirm + +--- + +### Task 8: test_uxadvection.py + +- [ ] Run: `pixi run pytest tests/test_uxadvection.py -v` +- [ ] Update or remove as needed + +--- + +### Task 9: test_structured_gcm.py + +- [ ] Run: `pixi run pytest tests/test_structured_gcm.py -v` +- [ ] Update or remove as needed + +--- + +### Task 10: test_python.py + +- [ ] Run: `pixi run pytest tests/test_python.py -v` +- [ ] Update or remove as needed + +--- + +### Task 11: test_uxarray_fieldset.py + +- [ ] Run: `pixi run pytest tests/test_uxarray_fieldset.py -v` +- [ ] Update or remove as needed + +--- + +### Task 12: test_basegrid.py + +- [ ] Run: `pixi run pytest tests/test_basegrid.py -v` +- [ ] Update or remove as needed + +--- + +### Task 13: tests/sgrid/test_sgrid.py + +- [ ] Run: `pixi run pytest tests/sgrid/test_sgrid.py -v` +- [ ] Update or remove as needed + +--- + +### Task 14: tests/sgrid/test_accessor.py + +- [ ] Run: `pixi run pytest tests/sgrid/test_accessor.py -v` +- [ ] Update or remove as needed + +--- + +### Task 15: test_typing.py + +- [ ] Run: `pixi run pytest tests/test_typing.py -v` +- [ ] Update or remove as needed + +--- + +### Task 16: tests/datasets/test_utils.py + +- [ ] Run: `pixi run pytest tests/datasets/test_utils.py -v` +- [ ] Update or remove as needed + +--- + +### Task 17: tests/datasets/test_structured.py + +- [ ] Run: `pixi run pytest tests/datasets/test_structured.py -v` +- [ ] Update or remove as needed + +--- + +### Task 18: tests/datasets/test_remote.py + +- [ ] Run: `pixi run pytest tests/datasets/test_remote.py -v` +- [ ] Update or remove as needed + +--- + +### Task 19: tests/datasets/test_strategies.py + +- [ ] Run: `pixi run pytest tests/datasets/test_strategies.py -v` +- [ ] Update or remove as needed + +--- + +### Task 20: test_particleset_execute.py + +- [ ] Run: `pixi run pytest tests/test_particleset_execute.py -v` +- [ ] Update or remove as needed + +--- + +### Task 21: test_advection.py + +- [ ] Run: `pixi run pytest tests/test_advection.py -v` +- [ ] Update or remove as needed + +--- + +### Task 22: test_particlesetview.py + +- [ ] Run: `pixi run pytest tests/test_particlesetview.py -v` +- [ ] Update or remove as needed + +--- + +### Task 23: tests/utils/test_time.py + +- [ ] Run: `pixi run pytest tests/utils/test_time.py -v` +- [ ] Update or remove as needed + +--- + +### Task 24: tests/utils/test_unstructured.py + +- [ ] Run: `pixi run pytest tests/utils/test_unstructured.py -v` +- [ ] Update or remove as needed + +--- + +### Task 25: test_convert.py + +- [ ] Run: `pixi run pytest tests/test_convert.py -v` +- [ ] Update or remove as needed + +--- + +### Task 26: test_particleset.py + +- [ ] Run: `pixi run pytest tests/test_particleset.py -v` +- [ ] Update or remove as needed + +--- + +### Task 27: test_diffusion.py + +- [ ] Run: `pixi run pytest tests/test_diffusion.py -v` +- [ ] Update or remove as needed + +--- + +### Task 28: test_particlefile.py + +- [ ] Run: `pixi run pytest tests/test_particlefile.py -v` +- [ ] Update or remove as needed + +--- + +### Task 29: test_sigmagrids.py + +- [ ] Run: `pixi run pytest tests/test_sigmagrids.py -v` +- [ ] Update or remove as needed + +--- + +### Task 30: test_kernel.py + +- [ ] Run: `pixi run pytest tests/test_kernel.py -v` +- [ ] Update or remove as needed + +--- + +### Task 31: test_xgrid.py + +- [ ] Run: `pixi run pytest tests/test_xgrid.py -v` +- [ ] Update or remove as needed + +--- + +### Task 32: test_particle.py + +- [ ] Run: `pixi run pytest tests/test_particle.py -v` +- [ ] Update or remove as needed + +--- + +### Task 33: tests/validation/test_ux.py + +- [ ] Run: `pixi run pytest tests/validation/test_ux.py -v` +- [ ] Update or remove as needed + +--- + +### Task 34: test_spatialhash.py + +- [ ] Run: `pixi run pytest tests/test_spatialhash.py -v` +- [ ] Update or remove as needed diff --git a/docs/user_guide/examples/explanation_kernelloop.md b/docs/user_guide/examples/explanation_kernelloop.md index 1bff904e2..abb293990 100644 --- a/docs/user_guide/examples/explanation_kernelloop.md +++ b/docs/user_guide/examples/explanation_kernelloop.md @@ -83,7 +83,7 @@ windvector = parcels.VectorField( "Wind", fieldset.UWind, fieldset.VWind, - vector_interp_method=parcels.interpolators.XLinear_Velocity + interp_method=parcels.interpolators.XLinear_Velocity ) fieldset.add_field(windvector) ``` diff --git a/docs/user_guide/examples/tutorial_fesom.ipynb b/docs/user_guide/examples/tutorial_fesom.ipynb index cb31e4356..cc0828316 100644 --- a/docs/user_guide/examples/tutorial_fesom.ipynb +++ b/docs/user_guide/examples/tutorial_fesom.ipynb @@ -149,9 +149,9 @@ "fieldset = parcels.FieldSet.from_ugrid_conventions(ds, mesh=\"spherical\")\n", "\n", "for name, field in fieldset.fields.items():\n", - " interp = getattr(field, \"interp_method\", None)\n", - " interp_name = interp.__name__ if interp is not None else \"-\"\n", - " print(f\"{name:>4s} -> {type(field).__name__:<11s} interp={interp_name}\")" + " print(\n", + " f\"{name:>4s} -> {type(field).__name__:<11s} interp={field.interp_method.__name__}\"\n", + " )" ] }, { @@ -164,7 +164,11 @@ { "cell_type": "markdown", "metadata": {}, - "source": "## Release particles and advect\n\nWe seed particles on a grid of four latitudes spanning the channel and ten longitudes, and integrate for two days with RK4. Because this snapshot has only a single time level, `fieldset.time_interval` is `None` and we omit the `time=` argument so that Parcels treats the flow as constant in time:" + "source": [ + "## Release particles and advect\n", + "\n", + "We seed particles on a grid of four latitudes spanning the channel and ten longitudes, and integrate for two days with RK4. Because this snapshot has only a single time level, `fieldset.time_interval` is `None` and we omit the `time=` argument so that Parcels treats the flow as constant in time:" + ] }, { "cell_type": "code", @@ -204,7 +208,11 @@ { "cell_type": "markdown", "metadata": {}, - "source": "## Plot the velocity field and trajectories\n\nWe plot the particle paths on top of the velocity field they advect through: triangle colour shows the speed at the release depth (≈50 m), black arrows show the velocity at face centres (drawn in lon/lat space, length proportional to speed), grey lines trace each particle's path, and the coloured dots mark the positions over time. Drawing the arrows with `angles=\"xy\"` keeps them aligned with the trajectories, so you can see the particles streak along the fast jets and barely move in the quiet bands between them:" + "source": [ + "## Plot the velocity field and trajectories\n", + "\n", + "We plot the particle paths on top of the velocity field they advect through: triangle colour shows the speed at the release depth (≈50 m), black arrows show the velocity at face centres (drawn in lon/lat space, length proportional to speed), grey lines trace each particle's path, and the coloured dots mark the positions over time. Drawing the arrows with `angles=\"xy\"` keeps them aligned with the trajectories, so you can see the particles streak along the fast jets and barely move in the quiet bands between them:" + ] }, { "cell_type": "code", @@ -286,7 +294,7 @@ ], "metadata": { "kernelspec": { - "display_name": "docs", + "display_name": "Parcels:docs (3.14.4)", "language": "python", "name": "python3" }, diff --git a/docs/user_guide/examples/tutorial_schism.ipynb b/docs/user_guide/examples/tutorial_schism.ipynb new file mode 100644 index 000000000..879ec72e7 --- /dev/null +++ b/docs/user_guide/examples/tutorial_schism.ipynb @@ -0,0 +1,435 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# 🖥️ SCHISM tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "Parcels v4 supports unstructured-grid model output via [uxarray](https://uxarray.readthedocs.io/).\n", + "This tutorial walks through advecting particles in real [SCHISM](http://ccrm.vims.edu/schismweb/)\n", + "output from a hydrodynamic hindcast of **Lake Ontario**. SCHISM writes its output in the\n", + "[scribe-IO](https://schism-dev.github.io/schism/master/getting-started/output.html) format, which\n", + "already follows the [UGRID](https://ugrid-conventions.github.io/ugrid-conventions/) conventions for\n", + "its horizontal mesh, so `uxarray` can read it directly.\n", + "\n", + "SCHISM output differs from the [FESOM tutorial](./tutorial_fesom.ipynb) in a few ways that this\n", + "tutorial highlights:\n", + "\n", + "1. The mesh coordinates are in a **projected coordinate system (meters)**, not longitude/latitude, so\n", + " we build the grid as a *flat* (Cartesian) mesh.\n", + "2. Velocities are stored at **mesh nodes** over a vertical column of layers, ordered **bottom → surface**.\n", + "3. SCHISM uses a **localized vertical grid (LSC2)**: the number of valid vertical levels varies from\n", + " node to node, and levels below the seabed are stored as `NaN`. We use this to flag particles that\n", + " end up below the local bathymetry and stop advecting them.\n", + "\n", + "If you have not done so already, work through the\n", + "[quickstart tutorial](../../getting_started/tutorial_quickstart.md) first to get familiar with\n", + "`ParticleSet`, `Kernel`, and `ParticleFile`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import matplotlib.tri as mtri\n", + "import numpy as np\n", + "import uxarray as ux\n", + "import xarray as xr\n", + "\n", + "import parcels\n", + "import parcels.tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## Get the SCHISM tutorial dataset\n", + "\n", + "We use three files from a SCHISM Lake Ontario hindcast (6 hourly snapshots), bundled in Parcels' tutorial\n", + "data registry:\n", + "\n", + "* `out2d`: the 2D output, slimmed to the **horizontal mesh** (node/face topology) and bathymetry.\n", + "* `horizontalVelX`, `horizontalVelY`: the 3D horizontal velocity components, defined at the mesh nodes\n", + " over 32 vertical layers.\n", + "\n", + "As in the [quickstart](../../getting_started/tutorial_quickstart.md), `parcels.tutorial.open_dataset`\n", + "downloads the files into a local cache on first use (subsequent calls return the cached copy) and opens\n", + "them as `xarray` datasets:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "grid_ds = parcels.tutorial.open_dataset(\"SCHISM_LakeOntario/out2d\")\n", + "u = parcels.tutorial.open_dataset(\"SCHISM_LakeOntario/horizontalVelX\")[\"horizontalVelX\"]\n", + "v = parcels.tutorial.open_dataset(\"SCHISM_LakeOntario/horizontalVelY\")[\"horizontalVelY\"]\n", + "grid_ds" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "## Build the horizontal mesh\n", + "\n", + "SCHISM stores its mesh topology in `out2d_1.nc` following the UGRID conventions: node coordinates\n", + "(`SCHISM_hgrid_node_x/y`) and a face–node connectivity table (`SCHISM_hgrid_face_nodes`). Two\n", + "SCHISM-specific details need handling before we hand the mesh to Parcels:\n", + "\n", + "* **Triangular cells.** The connectivity table is stored with a width of 4 (so the same format can\n", + " describe quads), but this Lake Ontario mesh is entirely triangular; the 4th column is all fill. We\n", + " keep the first three columns and convert the 1-based indices to 0-based. Parcels' `UxGrid` requires\n", + " purely triangular cells.\n", + "* **Projected coordinates.** The coordinates are in meters (`standard_name = projection_x_coordinate`),\n", + " not degrees. `uxarray` currently assumes node coordinates are spherical and wraps longitudes into\n", + " [-180, 180] (see [uxarray #1524](https://github.com/UXARRAY/uxarray/issues/1524)), which would corrupt\n", + " the mesh. We undo that wrap by writing the raw meter coordinates back, and later build the `FieldSet`\n", + " with `mesh=\"flat\"` so Parcels treats the plane as Cartesian." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "node_x = grid_ds[\"SCHISM_hgrid_node_x\"].values.astype(\"float64\")\n", + "node_y = grid_ds[\"SCHISM_hgrid_node_y\"].values.astype(\"float64\")\n", + "face_nodes = (\n", + " grid_ds[\"SCHISM_hgrid_face_nodes\"].values[:, :3].astype(\"int64\") - 1\n", + ") # all-triangular, 0-based\n", + "\n", + "uxgrid = ux.Grid.from_topology(\n", + " node_lon=node_x, node_lat=node_y, face_node_connectivity=face_nodes, fill_value=-1\n", + ")\n", + "# undo uxarray's [-180, 180] longitude wrap of the projected meters (uxarray #1524)\n", + "uxgrid.node_lon.values[:] = node_x\n", + "uxgrid.node_lat.values[:] = node_y\n", + "\n", + "print(\n", + " f\"n_node={uxgrid.n_node}, n_face={uxgrid.n_face}, n_max_face_nodes={uxgrid.n_max_face_nodes}\"\n", + ")\n", + "print(f\"x range: {node_x.min():.0f} .. {node_x.max():.0f} m\")" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "## Assemble the velocity fields\n", + "\n", + "The two velocity files hold `horizontalVelX` and `horizontalVelY` with dimensions\n", + "`(time, node, layer)`. We rename them to `U`/`V` (so Parcels built-in Kernels recognise the velocity components) and\n", + "to the Parcels UGRID dimension names (`n_node` for the lateral dimension).\n", + "\n", + "Two vertical details:\n", + "\n", + "* **Layer ordering.** SCHISM stores levels **bottom → surface**, while Parcels expects depth increasing\n", + " downward from the surface, so we reverse the layer axis.\n", + "* **Vertical coordinate.** SCHISM's LSC2 vertical grid (*Localized Sigma Coordinates with Shaved\n", + " cells*) varies with horizontal position: each node has its own layer depths, and even its own\n", + " *number* of levels (the true depths are written to a separate `zCoordinates` output, not used here).\n", + " **Parcels does not currently support a vertical grid that varies with lateral position.** `UxGrid`\n", + " takes a single 1D column of layer-interface depths that applies everywhere on the mesh. We therefore\n", + " supply a fictitious 1D vertical grid. This is adequate for the near-surface, horizontal transport\n", + " shown here (the lateral interpolation does not depend on the vertical grid); accurate full-depth 3D\n", + " transport on an LSC2 grid would require Parcels to support a laterally varying vertical coordinate.\n", + "\n", + "We also call `.load()` so the velocities sit in memory; otherwise every interpolation step re-reads\n", + "from disk and the simulation is extremely slow." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "nlev = u.sizes[\"nSCHISM_vgrid_layers\"]\n", + "\n", + "# placeholder interface depths (meters, positive down): surface (0 m) -> bed, refined near the surface\n", + "zf = 250.0 * (np.arange(nlev) / (nlev - 1)) ** 1.5\n", + "zc = 0.5 * (zf[1:] + zf[:-1])\n", + "\n", + "rename = {\"nSCHISM_vgrid_layers\": \"zf\", \"nSCHISM_hgrid_node\": \"n_node\"}\n", + "U = (\n", + " u.isel(nSCHISM_vgrid_layers=slice(None, None, -1)).rename(rename).load()\n", + ") # reverse to surface-first\n", + "V = v.isel(nSCHISM_vgrid_layers=slice(None, None, -1)).rename(rename).load()\n", + "\n", + "uxds = ux.UxDataset(\n", + " xr.Dataset(\n", + " {\n", + " \"U\": U.transpose(\"time\", \"zf\", \"n_node\"),\n", + " \"V\": V.transpose(\"time\", \"zf\", \"n_node\"),\n", + " },\n", + " coords={\"zf\": (\"zf\", zf), \"zc\": (\"zc\", zc)},\n", + " ),\n", + " uxgrid=uxgrid,\n", + ")\n", + "uxds" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## Build the `FieldSet`\n", + "\n", + "With the mesh and UGRID-compliant dimensions in place, `parcels.FieldSet.from_ugrid_conventions` builds\n", + "the `FieldSet`. It detects `U` and `V`, attaches the `UxGrid`, and selects the `UxLinearNodeLinearZF`\n", + "interpolator (barycentric in the horizontal, linear in the vertical) because the velocities are\n", + "node-registered along the layer interfaces `zf`. We pass `mesh=\"flat\"` because the coordinates\n", + "are projected meters: velocities in m/s then advect positions in meters directly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "fieldset = parcels.FieldSet.from_ugrid_conventions(uxds, mesh=\"flat\")\n", + "\n", + "for name, field in fieldset.fields.items():\n", + " print(\n", + " f\"{name:>4s} -> {type(field).__name__:<11s} interp={field.interp_method.__name__}\"\n", + " )\n", + "print(\"time interval:\", fieldset.time_interval)" + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ + "## Stop particles that end up below the bathymetry\n", + "\n", + "Because SCHISM's LSC2 grid keeps below-seabed levels as `NaN`, a particle whose (fixed) depth ends up\n", + "deeper than the local water column will sample `NaN`. Rather than feed it a fabricated velocity, we let\n", + "Parcels flag it: sampling `NaN` sets the particle state to `ErrorInterpolation`, and ending up outside\n", + "the mesh sets `ErrorOutOfBounds`. We add a small kernel that runs after advection, records that the\n", + "particle went out of bounds, and sets its state to `Delete` so it stops being advected (its trajectory\n", + "up to that point is kept).\n", + "\n", + "```{note}\n", + "Most of the Lake Ontario mesh is shallow nearshore water with only a couple of valid vertical levels;\n", + "the deep central basin has the full set. We therefore release particles in the deep basin (nodes with\n", + "many valid levels) so they start within the resolved water column.\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "OUT_OF_BOUNDS_STATES = [\n", + " parcels.StatusCode.ErrorInterpolation, # sampled NaN (below the local seabed)\n", + " parcels.StatusCode.ErrorOutOfBounds, # left the horizontal mesh / below the grid\n", + " parcels.StatusCode.ErrorThroughSurface, # above the surface\n", + "]\n", + "\n", + "SchismParticle = parcels.Particle.add_variable(\n", + " parcels.Variable(\"out_of_bounds\", dtype=np.int32, initial=0)\n", + ")\n", + "\n", + "\n", + "def StopBelowBed(particles, fieldset):\n", + " \"\"\"Flag out-of-bounds particles and stop advecting them.\"\"\"\n", + " oob = np.isin(particles.state, OUT_OF_BOUNDS_STATES)\n", + " particles[oob].out_of_bounds = 1\n", + " particles[oob].state = parcels.StatusCode.Delete" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "## Release particles and advect" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "# number of valid (non-NaN) vertical levels at each node\n", + "valid_levels = np.isfinite(U.isel(time=0).values).sum(axis=1)\n", + "\n", + "# release max 1500 particle on face centroids in the deep basin (all 3 nodes well-resolved vertically)\n", + "n_max = 1500\n", + "cx = node_x[face_nodes].mean(axis=1)\n", + "cy = node_y[face_nodes].mean(axis=1)\n", + "deep = (valid_levels[face_nodes] >= 15).all(axis=1)\n", + "idx = np.where(deep)[0]\n", + "idx = idx[:: max(1, idx.size // n_max)][:n_max]\n", + "\n", + "lon, lat = cx[idx], cy[idx]\n", + "z = np.full(lon.size, 2.0) # release at 2 m depth\n", + "print(f\"releasing {lon.size} particles at z = 2 m in the deep basin\")\n", + "\n", + "pset = parcels.ParticleSet(\n", + " fieldset=fieldset, pclass=SchismParticle, lon=lon, lat=lat, z=z\n", + ")\n", + "output_file = parcels.ParticleFile(\n", + " \"output-schism.parquet\", outputdt=np.timedelta64(30, \"m\")\n", + ")\n", + "\n", + "pset.execute(\n", + " [parcels.kernels.AdvectionRK4, StopBelowBed],\n", + " runtime=np.timedelta64(5, \"h\"), # the dataset spans 5 hours\n", + " dt=np.timedelta64(5, \"m\"),\n", + " output_file=output_file,\n", + " verbose_progress=False,\n", + ")\n", + "print(f\"{len(pset.lon)} of {lon.size} particles still active at the end of the run\")" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "## Plot the velocity field and trajectories\n", + "\n", + "We plot the surface speed across the triangular mesh (in projected kilometers), the particle release\n", + "points, and their trajectories coloured by time since release. The lake currents are slow (~0.1 m/s), so\n", + "over the 5-hour window most particles move only a kilometre or two, while those caught in the faster jet\n", + "near the north-eastern outflow (towards the St. Lawrence) travel noticeably further." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "df = parcels.read_particlefile(\"output-schism.parquet\")\n", + "\n", + "triang = mtri.Triangulation(node_x / 1e3, node_y / 1e3, triangles=face_nodes)\n", + "surf_speed = np.hypot(\n", + " np.asarray(uxds[\"U\"].isel(time=0, zf=0)), np.asarray(uxds[\"V\"].isel(time=0, zf=0))\n", + ")\n", + "speed_face = np.nanmean(surf_speed[face_nodes], axis=1)\n", + "\n", + "fig, ax = plt.subplots(figsize=(11, 7))\n", + "tpc = ax.tripcolor(\n", + " triang,\n", + " facecolors=np.nan_to_num(speed_face),\n", + " shading=\"flat\",\n", + " cmap=\"Blues\",\n", + " vmax=np.nanpercentile(speed_face, 98),\n", + ")\n", + "ax.triplot(triang, color=\"k\", lw=0.2, alpha=0.35)\n", + "fig.colorbar(tpc, ax=ax, label=\"surface speed [m/s]\", shrink=0.8)\n", + "\n", + "for traj in df.sort(\"time\").partition_by(\"particle_id\"):\n", + " ax.plot(\n", + " np.array(traj[\"lon\"]) / 1e3,\n", + " np.array(traj[\"lat\"]) / 1e3,\n", + " color=\"0.3\",\n", + " lw=0.6,\n", + " alpha=0.7,\n", + " zorder=2,\n", + " )\n", + "ax.scatter(\n", + " lon / 1e3,\n", + " lat / 1e3,\n", + " facecolors=\"none\",\n", + " edgecolors=\"k\",\n", + " s=20,\n", + " zorder=3,\n", + " label=\"release\",\n", + ")\n", + "\n", + "elapsed_h = (df[\"time\"] - df[\"time\"].min()).dt.total_seconds() / 3600\n", + "sc = ax.scatter(\n", + " np.array(df[\"lon\"]) / 1e3,\n", + " np.array(df[\"lat\"]) / 1e3,\n", + " c=elapsed_h,\n", + " s=4,\n", + " cmap=\"viridis\",\n", + " zorder=3,\n", + ")\n", + "fig.colorbar(sc, ax=ax, label=\"time since release [h]\", shrink=0.8)\n", + "\n", + "ax.set_xlabel(\"projected x [km]\")\n", + "ax.set_ylabel(\"projected y [km]\")\n", + "ax.set_title(\"SCHISM Lake Ontario surface currents with particle trajectories\")\n", + "ax.set_aspect(\"equal\")\n", + "ax.legend(loc=\"upper left\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "17", + "metadata": {}, + "source": [ + "The particles move with the SCHISM surface currents, and any that wander over shallow water where\n", + "their release depth falls below the seabed are flagged (`out_of_bounds == 1`) and stop.\n", + "\n", + "From here, the rest of Parcels works exactly as on structured grids. To go further with SCHISM data:\n", + "\n", + "* Keep in mind that the vertical is approximate: because Parcels uses a single 1D vertical column,\n", + " this tutorial is most meaningful for near-surface and horizontal transport. Faithful full-depth 3D\n", + " transport on an LSC2 grid would require Parcels to support a laterally varying vertical coordinate.\n", + "* Add the vertical velocity as a `W` field and use `AdvectionRK4_3D` for three-dimensional transport\n", + " (still subject to the single-column vertical approximation above).\n", + "* See the [interpolation tutorial](./tutorial_interpolation.ipynb) for the available `Ux*` interpolators." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.14.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/user_guide/index.md b/docs/user_guide/index.md index 9e22f1236..70cbbea8f 100644 --- a/docs/user_guide/index.md +++ b/docs/user_guide/index.md @@ -29,6 +29,7 @@ examples/tutorial_nemo.ipynb examples/tutorial_croco_3D.ipynb examples/tutorial_mitgcm.ipynb examples/tutorial_fesom.ipynb +examples/tutorial_schism.ipynb examples/tutorial_velocityconversion.ipynb examples/tutorial_nestedgrids.ipynb examples/tutorial_manipulating_field_data.ipynb diff --git a/files.txt b/files.txt new file mode 100644 index 000000000..a4b9a5475 --- /dev/null +++ b/files.txt @@ -0,0 +1,43 @@ +tests/test_interpolation.py +tests/test_index_search.py +tests/test_utils.py +tests/test_uxgrid.py +tests/test_xarray_fieldset.py +tests/test_uxadvection.py +tests/test_structured_gcm.py +tests/test_python.py +tests/test_uxarray_fieldset.py +tests/test_basegrid.py +tests/sgrid/test_sgrid.py +tests/sgrid/test_accessor.py +tests/test_typing.py +tests/test_data +tests/test_data/test_interpolation_data_random_linear.nc +tests/test_data/test_interpolation_jit_linear.zarr +tests/test_data/test_interpolation_data_random_cgrid_velocity.nc +tests/test_data/test_interpolation_jit_cgrid_velocity.zarr +tests/test_data/test_interpolation_jit_nearest.zarr +tests/test_data/test_interpolation_jit_freeslip.zarr +tests/test_data/test_interpolation_data_random_nearest.nc +tests/test_data/test_interpolation_data_random_freeslip.nc +tests/datasets/test_utils.py +tests/datasets/test_structured.py +tests/datasets/test_remote.py +tests/datasets/test_strategies.py +tests/test_particleset_execute.py +tests/test_advection.py +tests/test_particlesetview.py +tests/utils/test_time.py +tests/utils/test_unstructured.py +tests/test_convert.py +tests/test_particleset.py +tests/test_field.py +tests/test_fieldset.py +tests/test_diffusion.py +tests/test_particlefile.py +tests/test_sigmagrids.py +tests/test_kernel.py +tests/test_xgrid.py +tests/test_particle.py +tests/validation/test_ux.py +tests/test_spatialhash.py diff --git a/src/parcels/_core/field.py b/src/parcels/_core/field.py index ee216d34a..031213c5b 100644 --- a/src/parcels/_core/field.py +++ b/src/parcels/_core/field.py @@ -1,12 +1,11 @@ from __future__ import annotations import warnings -from collections.abc import Callable, Sequence +from collections.abc import Sequence from datetime import datetime +from typing import TYPE_CHECKING import numpy as np -import uxarray as ux -import xarray as xr from parcels._core.index_search import GRID_SEARCH_ERROR, LEFT_OUT_OF_BOUNDS, RIGHT_OUT_OF_BOUNDS, _search_time_index from parcels._core.particlesetview import ParticleSetView @@ -15,16 +14,14 @@ StatusCode, ) from parcels._core.utils.string import _assert_str_and_python_varname -from parcels._core.utils.time import TimeInterval from parcels._core.uxgrid import UxGrid -from parcels._core.xgrid import XGrid, _transpose_xfield_data_to_tzyx, assert_all_field_dims_have_axis -from parcels._python import assert_same_function_signature -from parcels._reprs import field_repr, vectorfield_repr +from parcels._core.xgrid import XGrid from parcels._typing import VectorType -from parcels.interpolators import ( - ZeroInterpolator, - ZeroInterpolator_Vector, -) +from parcels.interpolators._base import ScalarInterpolator, VectorInterpolator + +if TYPE_CHECKING: + from parcels._core.model import Model + __all__ = ["Field", "VectorField"] @@ -86,69 +83,51 @@ class Field: def __init__( self, name: str, - data: xr.DataArray | ux.UxDataArray, - grid: UxGrid | XGrid, - interp_method: Callable, + model: Model, ): - if not isinstance(data, (ux.UxDataArray, xr.DataArray)): - raise ValueError( - f"Expected `data` to be a uxarray.UxDataArray or xarray.DataArray object, got {type(data)}." - ) + # TODO PR: Enable isinstance check once Model is moved to abc.Model + # if not isinstance(model, "Model"): + # raise ValueError( + # f"Expected `model` to be a parcels Model object. Got {type(model)}." + # ) _assert_str_and_python_varname(name) - if not isinstance(grid, (UxGrid, XGrid)): - raise ValueError(f"Expected `grid` to be a parcels UxGrid, or parcels XGrid object, got {type(grid)}.") - - _assert_compatible_combination(data, grid) - - if isinstance(grid, XGrid): - assert_all_field_dims_have_axis(data, grid.xgcm_grid) - data = _transpose_xfield_data_to_tzyx(data, grid.xgcm_grid) - self.name = name - self.data = data - self.grid = grid - - try: - self.time_interval = _get_time_interval(data) - except ValueError as e: - e.add_note( - f"Error getting time interval for field {name!r}. Are you sure that the time dimension on the xarray dataset is stored as timedelta, datetime or cftime datetime objects?" - ) - raise e + self.model = model - try: - if isinstance(data, ux.UxDataArray): - _assert_valid_uxdataarray(data) - # TODO: For unstructured grids, validate that `data.uxgrid` is the same as `grid` - else: - pass # TODO v4: Add validation for xr.DataArray objects - except Exception as e: - e.add_note(f"Error validating field {name!r}.") - raise e + self.igrid = -1 # Default the grid index to -1 - # Setting the interpolation method dynamically - assert_same_function_signature(interp_method, ref=ZeroInterpolator, context="Interpolation") - self._interp_method = interp_method + @property + def data(self): + return self.model.data[self.name] - self.igrid = -1 # Default the grid index to -1 + @property + def grid(self): # TODO PR: Remove in favour of referencing model grid directly + return self.model.grid - if self.data.shape[0] > 1: - if "time" not in self.data.coords: - raise ValueError("Field data is missing a 'time' coordinate.") + @property + def time_interval(self): # TODO PR: Remove in favour of referencing model time_interval directly + return self.model.time_interval def __repr__(self): - return field_repr(self) + return f"Field(name={self.name}, model={self.model})" @property def interp_method(self): - return self._interp_method + try: + return self.model.field_to_interpolator[self.name] + except KeyError as e: + raise AttributeError( + f"{type(self).__name__} doesn't have an interp_method defined for it. Use `.interp_method = ...`" + ) from e @interp_method.setter - def interp_method(self, method: Callable): - assert_same_function_signature(method, ref=ZeroInterpolator, context="Interpolation") - self._interp_method = method + def interp_method(self, value): + # Setting the interpolation method dynamically + if not isinstance(value, ScalarInterpolator): + raise ValueError(f"interp_method must be a `ScalarInterpolator` object. Got {type(value)=!r}") + self.model.field_to_interpolator[self.name] = value def _check_velocitysampling(self): if self.name in ["U", "V", "W"]: @@ -193,7 +172,7 @@ def eval(self, time: datetime, z, y, x, particles=None): particle_positions, grid_positions = _get_positions(self, time, z, y, x, particles, _ei) - value = self._interp_method(particle_positions, grid_positions, self) + value = self.interp_method.interp(particle_positions, grid_positions, self) _update_particle_states_interp_value(particles, value) @@ -219,10 +198,10 @@ def __init__( U: Field, # noqa: N803 V: Field, # noqa: N803 W: Field | None = None, # noqa: N803 - vector_interp_method: Callable | None = None, + interp_method: VectorInterpolator | None = None, ): - if vector_interp_method is None: - raise ValueError("vector_interp_method must be provided for VectorField initialization.") + if interp_method is None: + raise ValueError("interp_method must be provided for VectorField initialization.") _assert_str_and_python_varname(name) self.name = name @@ -244,20 +223,25 @@ def __init__( else: self.vector_type = "2D" - assert_same_function_signature(vector_interp_method, ref=ZeroInterpolator_Vector, context="Interpolation") - self._vector_interp_method = vector_interp_method + if not isinstance(interp_method, VectorInterpolator): + raise ValueError( + f"vector_interp_method must be a `VectorInterpolator` object. Got {type(interp_method)=!r}" + ) - def __repr__(self): - return vectorfield_repr(self) + self._interp_method = interp_method + + # def __repr__(self): + # return vectorfield_repr(self) @property - def vector_interp_method(self): - return self._vector_interp_method + def interp_method(self): + return self._interp_method - @vector_interp_method.setter - def vector_interp_method(self, method: Callable): - assert_same_function_signature(method, ref=ZeroInterpolator_Vector, context="Interpolation") - self._vector_interp_method = method + @interp_method.setter + def interp_method(self, method: VectorInterpolator): + if not isinstance(method, VectorInterpolator): + raise ValueError(f"method must be a `VectorInterpolator` object. Got {type(method)=!r}") + self._interp_method = method def eval(self, time: datetime, z, y, x, particles=None): """Interpolate vectorfield values in space and time. @@ -295,7 +279,7 @@ def eval(self, time: datetime, z, y, x, particles=None): particle_positions, grid_positions = _get_positions(self.U, time, z, y, x, particles, _ei) - (u, v, w) = self._vector_interp_method(particle_positions, grid_positions, self) + (u, v, w) = self._interp_method.interp(particle_positions, grid_positions, self) for vel in (u, v, w): _update_particle_states_interp_value(particles, vel) @@ -375,44 +359,6 @@ def _update_particle_states_interp_value(particles, value): ) -def _assert_valid_uxdataarray(data: ux.UxDataArray): - """Verifies that all the required attributes are present in the xarray.DataArray or - uxarray.UxDataArray object. - """ - # Validate dimensions - if not ("zf" in data.dims or "zc" in data.dims): - raise ValueError( - "Field is missing a 'zf' or 'zc' dimension in the field's metadata. " - "This attribute is required for xarray.DataArray objects." - ) - - if "time" not in data.dims: - raise ValueError( - "Field is missing a 'time' dimension in the field's metadata. " - "This attribute is required for xarray.DataArray objects." - ) - - -def _assert_compatible_combination(data: xr.DataArray | ux.UxDataArray, grid: UxGrid | XGrid): - if isinstance(data, ux.UxDataArray): - if not isinstance(grid, UxGrid): - raise ValueError( - f"Incompatible data-grid combination. Data is a uxarray.UxDataArray, expected `grid` to be a UxGrid object, got {type(grid)}." - ) - elif isinstance(data, xr.DataArray): - if not isinstance(grid, XGrid): - raise ValueError( - f"Incompatible data-grid combination. Data is a xarray.DataArray, expected `grid` to be a parcels Grid object, got {type(grid)}." - ) - - -def _get_time_interval(data: xr.DataArray | ux.UxDataArray) -> TimeInterval | None: - if data.shape[0] == 1: - return None - - return TimeInterval(data.time.values[0], data.time.values[-1]) - - def _assert_same_time_interval(fields: Sequence[Field]) -> None: if len(fields) == 0: return diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index f83df364f..f5cab10f3 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -8,29 +8,15 @@ import numpy as np import uxarray as ux import xarray as xr -import xgcm -import parcels._sgrid as sgrid from parcels._core.field import Field, VectorField +from parcels._core.model import Model, StructuredModel, UnstructuredModel, constant_field_models from parcels._core.utils.string import _assert_str_and_python_varname from parcels._core.utils.time import get_datetime_type_calendar from parcels._core.utils.time import is_compatible as datetime_is_compatible -from parcels._core.uxgrid import UxGrid -from parcels._core.xgrid import _DEFAULT_XGCM_KWARGS, XGrid -from parcels._logger import logger -from parcels._reprs import fieldset_repr from parcels._typing import Mesh -from parcels.convert import _ds_rename_using_standard_names from parcels.interpolators import ( - CGrid_Velocity, - Ux_Velocity, - UxConstantFaceConstantZC, - UxConstantFaceLinearZF, - UxLinearNodeConstantZC, - UxLinearNodeLinearZF, XConstantField, - XLinear, - XLinear_Velocity, ) if TYPE_CHECKING: @@ -69,26 +55,49 @@ class FieldSet: """ - def __init__(self, fields: list[Field | VectorField]): - for field in fields: - if not isinstance(field, (Field, VectorField)): - raise ValueError(f"Expected `field` to be a Field or VectorField object. Got {field}") - assert_compatible_calendars(fields) + def __init__(self, models: list[Model]): + for model in models: + if not isinstance(model, Model): + raise ValueError(f"Expected `model` to be a Model object. Got {model}") + # assert_compatible_calendars(fields) - self.fields = {f.name: f for f in fields} + self.models = list(models) + self._fields: dict[str, Field | VectorField] | None = None + self.reconstruct_fields() self.constants: dict[str, float] = {} + @property + def fields(self): + if self._fields is None: + self.reconstruct_fields() + assert self._fields is not None + return self._fields + + def reconstruct_fields(self): + fields = [] + for model in self.models: + fields += model.construct_fields() + self._fields = {f.name: f for f in fields} + def __getattr__(self, name): """Get the field by name. If the field is not found, check if it's a constant.""" - if name in self.fields: - return self.fields[name] + if name in self._fields: + return self._fields[name] elif name in self.constants: return self.constants[name] else: raise AttributeError(f"FieldSet has no attribute '{name}'") - def __repr__(self): - return fieldset_repr(self) + def __add__(self, other: FieldSet) -> FieldSet: + if not isinstance(other, FieldSet): + return NotImplemented + assert_compatible_fieldsets(self, other) + combined = FieldSet(self.models + other.models) + combined.constants = {**self.constants, **other.constants} + return combined + + # def __repr__(self): + # return fieldset_repr(self) @property def time_interval(self): @@ -96,7 +105,7 @@ def time_interval(self): which is the intersection of the time intervals of all fields in the FieldSet. """ - time_intervals = (f.time_interval for f in self.fields.values()) + time_intervals = (m.time_interval for m in self.models) # Filter out Nones from constant Fields time_intervals = [t for t in time_intervals if t is not None] @@ -143,15 +152,19 @@ def add_constant_field(self, name: str, value, mesh: Mesh = "spherical"): correction for zonal velocity U near the poles. 2. flat: No conversion, lat/lon are assumed to be in m. """ - ds = xr.Dataset( - {name: (["lat", "lon"], np.full((1, 1), value))}, - coords={"lat": (["lat"], [0], {"axis": "Y"}), "lon": (["lon"], [0], {"axis": "X"})}, - ) - xgrid = xgcm.Grid( - ds, coords={"X": {"left": "lon"}, "Y": {"left": "lat"}}, autoparse_metadata=False, **_DEFAULT_XGCM_KWARGS - ) - grid = XGrid(xgrid, mesh=mesh) - self.add_field(Field(name, ds[name], grid, interp_method=XConstantField)) + try: + model = constant_field_models[mesh] + except KeyError as e: + raise ValueError(f"mesh must be one of ['flat', 'spherical']. Got {mesh!r}.") from e + + model.data[name] = (["lat", "lon", "depth", "time"], np.full((1, 1, 1, 1), value)) + + if model not in self.models: + self.models.append(model) + + self.reconstruct_fields() + field = getattr(self, name) + field.interp_method = XConstantField() def add_constant(self, name, value): """Add a constant to the FieldSet. @@ -201,31 +214,8 @@ def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"): FieldSet FieldSet object containing the fields from the dataset that can be used for a Parcels simulation. """ - ds_dims = list(ds.dims) - if not all(dim in ds_dims for dim in ["time", "zf", "zc"]): - raise ValueError( - f"Dataset missing one of the required dimensions 'time', 'zf', or 'zc' for uxDataset. Found dimensions {ds_dims}" - ) - - grid = UxGrid(ds.uxgrid, z=ds.coords["zf"], mesh=mesh) - ds = _discover_ux_U_and_V(ds) - - fields = {} - if "U" in ds.data_vars and "V" in ds.data_vars: - fields["U"] = Field("U", ds["U"], grid, _select_uxinterpolator(ds["U"])) - fields["V"] = Field("V", ds["V"], grid, _select_uxinterpolator(ds["V"])) - fields["UV"] = VectorField("UV", fields["U"], fields["V"], vector_interp_method=Ux_Velocity) - - if "W" in ds.data_vars: - fields["W"] = Field("W", ds["W"], grid, _select_uxinterpolator(ds["W"])) - fields["UVW"] = VectorField( - "UVW", fields["U"], fields["V"], fields["W"], vector_interp_method=Ux_Velocity - ) - - for varname in set(ds.data_vars) - set(fields.keys()): - fields[varname] = Field(str(varname), ds[varname], grid, _select_uxinterpolator(ds[varname])) - - return cls(list(fields.values())) + model = UnstructuredModel.from_ugrid_conventions(ds, mesh) + return cls([model]) @classmethod def from_sgrid_conventions( @@ -258,70 +248,30 @@ def from_sgrid_conventions( See https://sgrid.github.io/ for more information on the SGRID conventions. """ - ds = ds.copy() - if mesh is None: - mesh = _get_mesh_type_from_sgrid_dataset(ds) - - # Ensure time dimension has axis attribute if present - if "time" in ds.dims and "time" in ds.coords: - if "axis" not in ds["time"].attrs: - logger.debug( - "Dataset contains 'time' dimension but no 'axis' attribute. Setting 'axis' attribute to 'T'." - ) - ds["time"].attrs["axis"] = "T" - - # Find time dimension based on axis attribute and rename to `time` - if (time_dims := ds.cf.axes.get("T")) is not None: - if len(time_dims) > 1: - raise ValueError("Multiple time coordinates found in dataset. This is not supported by Parcels.") - (time_dim,) = time_dims - if time_dim != "time": - logger.debug(f"Renaming time axis coordinate from {time_dim} to 'time'.") - ds = ds.rename({time_dim: "time"}) - - # Parse SGRID metadata and get xgcm kwargs - _, xgcm_kwargs = sgrid.xgcm_parse_sgrid(ds) - - # Add time axis to xgcm_kwargs if present - if "time" in ds.dims: - if "T" not in xgcm_kwargs["coords"]: - xgcm_kwargs["coords"]["T"] = {"center": "time"} - - if "lon" not in ds.coords or "lat" not in ds.coords: - node_dimensions = sgrid.load_mappings(ds.grid.node_dimensions) - ds["lon"] = ds[node_dimensions[0]] - ds["lat"] = ds[node_dimensions[1]] - - # Create xgcm Grid object - xgcm_grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs, **_DEFAULT_XGCM_KWARGS) - - # Wrap in XGrid - grid = XGrid(xgcm_grid, mesh=mesh) - - # Create fields from data variables, skipping grid metadata variables - # Skip variables that are SGRID metadata (have cf_role='grid_topology') - skip_vars = set() - for var in ds.data_vars: - if ds[var].attrs.get("cf_role") == "grid_topology": - skip_vars.add(var) - - fields = {} - if "U" in ds.data_vars and "V" in ds.data_vars: - vector_interp_method = XLinear_Velocity if _is_agrid(ds) else CGrid_Velocity - fields["U"] = Field("U", ds["U"], grid, XLinear) - fields["V"] = Field("V", ds["V"], grid, XLinear) - fields["UV"] = VectorField("UV", fields["U"], fields["V"], vector_interp_method=vector_interp_method) - - if "W" in ds.data_vars: - fields["W"] = Field("W", ds["W"], grid, XLinear) - fields["UVW"] = VectorField( - "UVW", fields["U"], fields["V"], fields["W"], vector_interp_method=vector_interp_method - ) - - for varname in set(ds.data_vars) - set(fields.keys()) - skip_vars: - fields[varname] = Field(str(varname), ds[varname], grid, XLinear) - - return cls(list(fields.values())) + model = StructuredModel.from_sgrid_conventions(ds, mesh) + return cls([model]) + + +def assert_compatible_fieldsets(left: FieldSet, right: FieldSet) -> None: + """Assert that two FieldSets can be combined without name conflicts. + + Parameters + ---------- + left, right : FieldSet + The two FieldSets to check. + + Raises + ------ + ValueError + If the FieldSets share field names or constant names. + """ + overlapping_fields = set(left.fields) & set(right.fields) + if overlapping_fields: + raise ValueError(f"Cannot add FieldSets with overlapping field names: {sorted(overlapping_fields)}") + + overlapping_constants = set(left.constants) & set(right.constants) + if overlapping_constants: + raise ValueError(f"Cannot add FieldSets with overlapping constant names: {sorted(overlapping_constants)}") class CalendarError(Exception): # TODO: Move to a parcels errors module @@ -358,10 +308,10 @@ def _format_calendar_error_message(field: Field | VectorField, reference_datetim _COPERNICUS_MARINE_AXIS_VARNAMES = { - "X": "lon", - "Y": "lat", - "Z": "depth", "T": "time", + "Z": "depth", + "Y": "lat", + "X": "lon", } @@ -393,115 +343,7 @@ def _format_calendar_error_message(field: Field | VectorField, reference_datetim } -def _discover_ux_U_and_V(ds: ux.UxDataset) -> ux.UxDataset: - # Common variable names for U and V found in UxDatasets - common_ux_UV = [("unod", "vnod"), ("u", "v")] - common_ux_W = ["w"] - - if "W" not in ds: - for common_W in common_ux_W: - if common_W in ds: - ds = _ds_rename_using_standard_names(ds, {common_W: "W"}) - break - - if "U" in ds and "V" in ds: - return ds # U and V already present - elif "U" in ds or "V" in ds: - raise ValueError( - "Dataset has only one of the two variables 'U' and 'V'. Please rename the appropriate variable in your dataset to have both 'U' and 'V' for Parcels simulation." - ) - - for common_U, common_V in common_ux_UV: - if common_U in ds: - if common_V not in ds: - raise ValueError( - f"Dataset has variable with standard name {common_U!r}, " - f"but not the matching variable with standard name {common_V!r}. " - "Please rename the appropriate variables in your dataset to have both 'U' and 'V' for Parcels simulation." - ) - else: - ds = _ds_rename_using_standard_names(ds, {common_U: "U", common_V: "V"}) - break - - else: - if common_V in ds: - raise ValueError( - f"Dataset has variable with standard name {common_V!r}, " - f"but not the matching variable with standard name {common_U!r}. " - "Please rename the appropriate variables in your dataset to have both 'U' and 'V' for Parcels simulation." - ) - continue - - return ds - - -def _select_uxinterpolator(da: ux.UxDataArray): - """Selects the appropriate uxarray interpolator for a given uxarray UxDataArray""" - supported_uxinterp_mapping = { - # (zc,n_face): face-center laterally, layer centers vertically — piecewise constant - "zc,n_face": UxConstantFaceConstantZC, - # (zc,n_node): node/corner laterally, layer centers vertically — barycentric lateral & piecewise constant vertical - "zc,n_node": UxLinearNodeConstantZC, - # (zf,n_node): node/corner laterally, layer interfaces vertically — barycentric lateral & linear vertical - "zf,n_node": UxLinearNodeLinearZF, - # (zf,n_face): face-center laterally, layer interfaces vertically — piecewise constant lateral & linear vertical - "zf,n_face": UxConstantFaceLinearZF, - } - # Extract only spatial dimensions, neglecting time - da_spatial_dims = tuple(d for d in da.dims if d not in ("time",)) - if len(da_spatial_dims) != 2: - raise ValueError( - "Fields on unstructured grids must have two spatial dimensions, one vertical (zf or zc) and one lateral (n_face, n_edge, or n_node)" - ) - - # Construct key (string) for mapping to interpolator - # Find vertical and lateral tokens - vdim = None - ldim = None - for d in da_spatial_dims: - if d in ("zf", "zc"): - vdim = d - if d in ("n_face", "n_node"): - ldim = d - # Map to supported interpolators - if vdim and ldim: - key = f"{vdim},{ldim}" - if key in supported_uxinterp_mapping.keys(): - return supported_uxinterp_mapping[key] - - return None - - -# TODO: Refactor later into something like `parcels._metadata.discover(dataset)` helper that can be used to discover important metadata like this. I think this whole metadata handling should be refactored into its own module. -def _get_mesh_type_from_sgrid_dataset(ds_sgrid: xr.Dataset) -> Mesh: - """Small helper to inspect SGRID metadata and dataset metadata to determine mesh type.""" - sgrid_metadata = ds_sgrid.sgrid.metadata - - fpoint_x, fpoint_y = sgrid_metadata.node_coordinates - - if _is_coordinate_in_degrees(ds_sgrid[fpoint_x]) ^ _is_coordinate_in_degrees(ds_sgrid[fpoint_x]): - msg = ( - f"Mismatch in units between X and Y coordinates.\n" - f" Coordinate {ds_sgrid[fpoint_x]!r} attrs: {ds_sgrid[fpoint_x].attrs}\n" - f" Coordinate {ds_sgrid[fpoint_y]!r} attrs: {ds_sgrid[fpoint_y].attrs}\n" - ) - raise ValueError(msg) - - return "spherical" if _is_coordinate_in_degrees(ds_sgrid[fpoint_x]) else "flat" - - def _is_agrid(ds: xr.Dataset) -> bool: # check if U and V are defined on the same dimensions # if yes, interpret as A grid return set(ds["U"].dims) == set(ds["V"].dims) - - -def _is_coordinate_in_degrees(da: xr.DataArray) -> bool: - units = da.attrs.get("units") - if units is None: - raise ValueError( - f"Coordinate {da.name!r} of your dataset has no 'units' attribute - we don't know what the spatial units are." - ) - if isinstance(units, str) and "degree" in units.lower(): - return True - return False diff --git a/src/parcels/_core/model.py b/src/parcels/_core/model.py new file mode 100644 index 000000000..cd1e3d794 --- /dev/null +++ b/src/parcels/_core/model.py @@ -0,0 +1,406 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Self + +import cf_xarray # noqa: F401 +import uxarray as ux +import xarray as xr + +import parcels._sgrid as sgrid +from parcels._core.basegrid import BaseGrid +from parcels._core.field import Field, VectorField +from parcels._core.utils.time import TimeInterval +from parcels._core.uxgrid import UxGrid +from parcels._core.xgrid import ( + XGrid, + _transpose_xfield_data_to_tzyx, + assert_all_field_dims_have_axis, # noqa: F401, leave import for now until decision is made +) +from parcels._logger import logger +from parcels._typing import Mesh +from parcels.convert import _ds_rename_using_standard_names +from parcels.interpolators import ( + CGrid_Velocity, + Ux_Velocity, + UxConstantFaceConstantZC, + UxConstantFaceLinearZF, + UxLinearNodeConstantZC, + UxLinearNodeLinearZF, + XLinear, + XLinear_Velocity, +) +from parcels.interpolators._base import ScalarInterpolator, VectorInterpolator + + +class Model(ABC): + data: Any + grid: BaseGrid + field_to_interpolator: dict[str, ScalarInterpolator | VectorInterpolator] + + @abstractmethod + def construct_fields(self) -> list[Field | VectorField]: ... + + @property + @abstractmethod + def scalar_field_names(self) -> list[str]: ... + + @abstractmethod + def assert_valid_field_data(self, field_data: Any) -> None: ... + + def assert_valid_model_data(self) -> None: + for field_name in self.scalar_field_names: + field_data = self.data[field_name] + try: + self.assert_valid_field_data(field_data) + except Exception as e: + e.add_note(f"Error validating field {field_name!r}.") + raise e + return + + @property + def time_interval(self) -> TimeInterval | None: + try: + time_interval = _get_time_interval(self.data) + except ValueError as e: + e.add_note( + f"Error getting time interval for model with data {self.data!r}. Are you sure that the time dimension on the xarray dataset is stored as timedelta, datetime or cftime datetime objects?" + ) + raise e + return time_interval + + +def preprocess_sgrid_model_data(ds: xr.Dataset) -> xr.Dataset: + metadata: sgrid.SGrid2DMetadata = ds.sgrid.metadata + + for field_name in set(ds.data_vars) - {ds.sgrid._get_grid_topology().name}: + ds[field_name] = _transpose_xfield_data_to_tzyx(ds[field_name], metadata) + return ds + + +class StructuredModel(Model): + def __init__(self, data: xr.Dataset, mesh: Mesh): + if not isinstance(data, xr.Dataset): + raise ValueError(f"Expected `data` to be an xarray.Dataset . Got {type(data)}") + + data = preprocess_sgrid_model_data(data) + grid = XGrid(data, mesh) + + self.data = data + self.grid = grid + self.field_to_interpolator = {} + self._fields: list[Field | VectorField] | None = None + self.assert_valid_model_data() + + def assert_valid_field_data(self, field_data: xr.DataArray) -> None: + # assert_all_field_dims_have_axis(field_data, self.grid.xgcm_grid) #! These checks should be revisited + _assert_has_time_coordinate(field_data) + + @property + def scalar_field_names(self) -> list[str]: + # Create fields from data variables, skipping grid metadata variables + # Skip variables that are SGRID metadata (have cf_role='grid_topology') + skip_vars = set() + for var in self.data.data_vars: + if self.data[var].attrs.get("cf_role") == "grid_topology": + skip_vars.add(var) + return list(set(self.data.data_vars) - skip_vars) + + def construct_fields(self) -> list[Field | VectorField]: + single_fields: dict[str, Field] = {} + vector_fields: dict[str, VectorField] = {} + scalar_field_names = self.scalar_field_names + if "U" in scalar_field_names and "V" in scalar_field_names: + vector_interp_method = XLinear_Velocity() if _is_agrid(self.data) else CGrid_Velocity() + single_fields["U"] = Field("U", self) + single_fields["V"] = Field("V", self) + vector_fields["UV"] = VectorField( + "UV", single_fields["U"], single_fields["V"], interp_method=vector_interp_method + ) + + if "W" in scalar_field_names: + single_fields["W"] = Field("W", self) + vector_fields["UVW"] = VectorField( + "UVW", + single_fields["U"], + single_fields["V"], + single_fields["W"], + interp_method=vector_interp_method, + ) + + fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields} + for varname in set(scalar_field_names) - set(fields.keys()): + fields[varname] = Field(str(varname), self) + + return list(fields.values()) + + @classmethod + def from_sgrid_conventions(cls, ds: xr.Dataset, mesh: Mesh | None = None) -> Self: + ds = ds.copy() + if mesh is None: + mesh = _get_mesh_type_from_sgrid_dataset(ds) + + # Ensure time dimension has axis attribute if present + if "time" in ds.dims and "time" in ds.coords: + if "axis" not in ds["time"].attrs: + logger.debug( + "Dataset contains 'time' dimension but no 'axis' attribute. Setting 'axis' attribute to 'T'." + ) + ds["time"].attrs["axis"] = "T" + + # Find time dimension based on axis attribute and rename to `time` + if (time_dims := ds.cf.axes.get("T")) is not None: + if len(time_dims) > 1: + raise ValueError("Multiple time coordinates found in dataset. This is not supported by Parcels.") + (time_dim,) = time_dims + if time_dim != "time": + logger.debug(f"Renaming time axis coordinate from {time_dim} to 'time'.") + ds = ds.rename({time_dim: "time"}) + + # if "lon" not in ds.coords or "lat" not in ds.coords: + # node_dimensions = sgrid.load_mappings(ds.grid.node_dimensions) + # ds["lon"] = ds[node_dimensions[0]] + # ds["lat"] = ds[node_dimensions[1]] + + model = cls(ds, mesh=mesh) + model._fields = model.construct_fields() + for f in model._fields: + if isinstance(f, Field): + f.interp_method = XLinear() + return model + + +constant_field_models = { + mesh: StructuredModel.from_sgrid_conventions( + xr.Dataset( + {}, + coords={ + "lat": (["lat"], [0], {"axis": "Y"}), + "lon": (["lon"], [0], {"axis": "X"}), + "depth": (["depth"], [0], {"axis": "Z"}), + "time": (["time"], [0], {"axis": "T"}), + }, + ).pipe( + sgrid._attach_sgrid_metadata, + sgrid.SGrid2DMetadata( + cf_role="grid_topology", + topology_dimension=2, + node_dimensions=("lon", "lat"), + face_dimensions=( + sgrid.FaceNodePadding("XC", "lon", sgrid.Padding.LOW), + sgrid.FaceNodePadding("YC", "lat", sgrid.Padding.LOW), + ), + ), + ), + mesh=mesh, + ) + for mesh in ["flat", "spherical"] # type:ignore[reportArgumentType] +} + + +class UnstructuredModel(Model): + def __init__(self, data: ux.UxDataset, grid: UxGrid): + if not isinstance(data, ux.UxDataset): + raise ValueError(f"Expected `data` to be an uxarray.UxDataset . Got {type(data)}") + + if not isinstance(grid, UxGrid): + raise ValueError(f"Expected `grid` to be a parcels UxGrid object. Got {type(grid)}.") + + self.data = data + self.grid = grid + self.field_to_interpolator = {} + self._fields: list[Field | VectorField] | None = None + + def construct_fields(self) -> list[Field | VectorField]: + single_fields: dict[str, Field] = {} + vector_fields: dict[str, VectorField] = {} + scalar_field_names = self.scalar_field_names + if "U" in scalar_field_names and "V" in scalar_field_names: + single_fields["U"] = Field("U", self) + single_fields["V"] = Field("V", self) + vector_fields["UV"] = VectorField("UV", single_fields["U"], single_fields["V"], interp_method=Ux_Velocity()) + + if "W" in scalar_field_names: + single_fields["W"] = Field("W", self) + vector_fields["UVW"] = VectorField( + "UVW", single_fields["U"], single_fields["V"], single_fields["W"], interp_method=Ux_Velocity() + ) + + fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields} + for varname in set(scalar_field_names) - set(single_fields.keys()): + fields[varname] = Field(str(varname), self) + + return list(fields.values()) + + def assert_valid_field_data(self, field_data: ux.UxDataArray) -> None: + _assert_valid_uxdataarray(field_data) + _assert_has_time_coordinate(field_data) + + @property + def scalar_field_names(self) -> list[str]: + return list(self.data.data_vars) + + @classmethod + def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"): + ds_dims = list(ds.dims) + if not all(dim in ds_dims for dim in ["time", "zf", "zc"]): + raise ValueError( + f"Dataset missing one of the required dimensions 'time', 'zf', or 'zc' for uxDataset. Found dimensions {ds_dims}" + ) + + grid = UxGrid(ds.uxgrid, z=ds.coords["zf"], mesh=mesh) + ds = _discover_ux_U_and_V(ds) + model = cls(ds, grid) + model._fields = model.construct_fields() + for f in model._fields: + if isinstance(f, Field): + interp_cls = _select_uxinterpolator(model.data[f.name]) + if interp_cls is not None: + f.interp_method = interp_cls() + return model + + +# TODO: Refactor later into something like `parcels._metadata.discover(dataset)` helper that can be used to discover important metadata like this. I think this whole metadata handling should be refactored into its own module. +def _get_mesh_type_from_sgrid_dataset(ds_sgrid: xr.Dataset) -> Mesh: + """Small helper to inspect SGRID metadata and dataset metadata to determine mesh type.""" + sgrid_metadata = ds_sgrid.sgrid.metadata + + fpoint_x, fpoint_y = sgrid_metadata.node_coordinates + + if _is_coordinate_in_degrees(ds_sgrid[fpoint_x]) ^ _is_coordinate_in_degrees(ds_sgrid[fpoint_x]): + msg = ( + f"Mismatch in units between X and Y coordinates.\n" + f" Coordinate {ds_sgrid[fpoint_x]!r} attrs: {ds_sgrid[fpoint_x].attrs}\n" + f" Coordinate {ds_sgrid[fpoint_y]!r} attrs: {ds_sgrid[fpoint_y].attrs}\n" + ) + raise ValueError(msg) + + return "spherical" if _is_coordinate_in_degrees(ds_sgrid[fpoint_x]) else "flat" + + +def _is_coordinate_in_degrees(da: xr.DataArray) -> bool: + units = da.attrs.get("units") + if units is None: + raise ValueError( + f"Coordinate {da.name!r} of your dataset has no 'units' attribute - we don't know what the spatial units are." + ) + if isinstance(units, str) and "degree" in units.lower(): + return True + return False + + +def _discover_ux_U_and_V(ds: ux.UxDataset) -> ux.UxDataset: + # Common variable names for U and V found in UxDatasets + common_ux_UV = [("unod", "vnod"), ("u", "v")] + common_ux_W = ["w"] + + if "W" not in ds: + for common_W in common_ux_W: + if common_W in ds: + ds = _ds_rename_using_standard_names(ds, {common_W: "W"}) + break + + if "U" in ds and "V" in ds: + return ds # U and V already present + elif "U" in ds or "V" in ds: + raise ValueError( + "Dataset has only one of the two variables 'U' and 'V'. Please rename the appropriate variable in your dataset to have both 'U' and 'V' for Parcels simulation." + ) + + for common_U, common_V in common_ux_UV: + if common_U in ds: + if common_V not in ds: + raise ValueError( + f"Dataset has variable with standard name {common_U!r}, " + f"but not the matching variable with standard name {common_V!r}. " + "Please rename the appropriate variables in your dataset to have both 'U' and 'V' for Parcels simulation." + ) + else: + ds = _ds_rename_using_standard_names(ds, {common_U: "U", common_V: "V"}) + break + + else: + if common_V in ds: + raise ValueError( + f"Dataset has variable with standard name {common_V!r}, " + f"but not the matching variable with standard name {common_U!r}. " + "Please rename the appropriate variables in your dataset to have both 'U' and 'V' for Parcels simulation." + ) + continue + + return ds + + +def _select_uxinterpolator(da: ux.UxDataArray): + """Selects the appropriate uxarray interpolator for a given uxarray UxDataArray""" + supported_uxinterp_mapping = { + # (zc,n_face): face-center laterally, layer centers vertically — piecewise constant + "zc,n_face": UxConstantFaceConstantZC, + # (zc,n_node): node/corner laterally, layer centers vertically — barycentric lateral & piecewise constant vertical + "zc,n_node": UxLinearNodeConstantZC, + # (zf,n_node): node/corner laterally, layer interfaces vertically — barycentric lateral & linear vertical + "zf,n_node": UxLinearNodeLinearZF, + # (zf,n_face): face-center laterally, layer interfaces vertically — piecewise constant lateral & linear vertical + "zf,n_face": UxConstantFaceLinearZF, + } + # Extract only spatial dimensions, neglecting time + da_spatial_dims = tuple(d for d in da.dims if d not in ("time",)) + if len(da_spatial_dims) != 2: + raise ValueError( + "Fields on unstructured grids must have two spatial dimensions, one vertical (zf or zc) and one lateral (n_face, n_edge, or n_node)" + ) + + # Construct key (string) for mapping to interpolator + # Find vertical and lateral tokens + vdim = None + ldim = None + for d in da_spatial_dims: + if d in ("zf", "zc"): + vdim = d + if d in ("n_face", "n_node"): + ldim = d + # Map to supported interpolators + if vdim and ldim: + key = f"{vdim},{ldim}" + if key in supported_uxinterp_mapping.keys(): + return supported_uxinterp_mapping[key] + + return None + + +def _is_agrid(ds: xr.Dataset) -> bool: + # check if U and V are defined on the same dimensions + # if yes, interpret as A grid + return set(ds["U"].dims) == set(ds["V"].dims) + + +def _get_time_interval(data: xr.DataArray | ux.UxDataArray) -> TimeInterval | None: + if "time" not in data or data["time"].size == 1: + return None + + return TimeInterval(data.time.values[0], data.time.values[-1]) + + +def _assert_valid_uxdataarray(data: ux.UxDataArray): + """Verifies that all the required attributes are present in the xarray.DataArray or + uxarray.UxDataArray object. + """ + # Validate dimensions + if not ("zf" in data.dims or "zc" in data.dims): + raise ValueError( + "Field is missing a 'zf' or 'zc' dimension in the field's metadata. " + "This attribute is required for xarray.DataArray objects." + ) + + if "time" not in data.dims: + raise ValueError( + "Field is missing a 'time' dimension in the field's metadata. " + "This attribute is required for xarray.DataArray objects." + ) + + +def _assert_has_time_coordinate(da: xr.DataArray) -> None: + if da.shape[0] > 1: + if "time" not in da.coords: + raise ValueError("Field data is missing a 'time' coordinate.") + return diff --git a/src/parcels/_core/particle.py b/src/parcels/_core/particle.py index cb2a441cf..322d313bb 100644 --- a/src/parcels/_core/particle.py +++ b/src/parcels/_core/particle.py @@ -127,9 +127,18 @@ def get_default_particle(spatial_dtype: type[np.float32] | type[np.float64]) -> return ParticleClass( variables=[ Variable( - "lon", + "time", + dtype=np.float64, + attrs={ + "standard_name": "time", + "units": "seconds", + "axis": "T", + }, # "units" and "calendar" gets updated/set if working with cftime time domain + ), + Variable( + "z", dtype=spatial_dtype, - attrs={"standard_name": "longitude", "units": "degrees_east", "axis": "X"}, + attrs={"standard_name": "vertical coordinate", "units": "m", "positive": "down"}, ), Variable( "lat", @@ -137,22 +146,13 @@ def get_default_particle(spatial_dtype: type[np.float32] | type[np.float64]) -> attrs={"standard_name": "latitude", "units": "degrees_north", "axis": "Y"}, ), Variable( - "z", + "lon", dtype=spatial_dtype, - attrs={"standard_name": "vertical coordinate", "units": "m", "positive": "down"}, + attrs={"standard_name": "longitude", "units": "degrees_east", "axis": "X"}, ), - Variable("dlon", dtype=spatial_dtype, to_write=False), - Variable("dlat", dtype=spatial_dtype, to_write=False), Variable("dz", dtype=spatial_dtype, to_write=False), - Variable( - "time", - dtype=np.float64, - attrs={ - "standard_name": "time", - "units": "seconds", - "axis": "T", - }, # "units" and "calendar" gets updated/set if working with cftime time domain - ), + Variable("dlat", dtype=spatial_dtype, to_write=False), + Variable("dlon", dtype=spatial_dtype, to_write=False), Variable( "particle_id", dtype=np.int64, diff --git a/src/parcels/_core/particlefile.py b/src/parcels/_core/particlefile.py index 58072f6e6..56044d271 100644 --- a/src/parcels/_core/particlefile.py +++ b/src/parcels/_core/particlefile.py @@ -61,6 +61,10 @@ class ParticleFile: It is either a numpy.timedelta64, a datimetime.timedelta object or a positive float (in seconds). compression : {"zstd", "gzip", "snappy", "brotli", None}, optional Compression algorithm to use for the Parquet file. Default is "zstd". + mode : {None, "w"}, optional + Writing behaviour. + - None (default): Write dataset, and raise an error if it already exists. + - "w": Write dataset, overwriting it. Returns ------- @@ -69,7 +73,11 @@ class ParticleFile: """ def __init__( - self, path: PathLike, outputdt, compression: Literal["zstd", "gzip", "snappy", "brotli", None] = "zstd" + self, + path: PathLike, + outputdt, + compression: Literal["zstd", "gzip", "snappy", "brotli", None] = "zstd", + mode: Literal[None, "w"] = None, ): if not isinstance(outputdt, (np.timedelta64, timedelta, float)): raise ValueError( @@ -92,9 +100,15 @@ def __init__( self._path = path # TODO v4: Consider https://arrow.apache.org/docs/python/getstarted.html#working-with-large-data - though a significant question becomes how to partition, perhaps using a particle variable "partition"? self._writer: pq.ParquetWriter | None = None + + if mode not in {None, "w"}: + raise ValueError(f"Invalid mode value {mode!r}. Expected one of None or 'w'.") + if path.exists(): - # TODO: Add logic for recovering/appending to existing parquet file - raise ValueError(f"{path=!r} already exists. Either delete this file or use a path that doesn't exist.") + if mode is None: + raise ValueError(f"{path=!r} already exists. Use mode='w' or use a new path.") + if mode == "w": + path.unlink() if not path.parent.exists(): raise ValueError(f"Folder location for {path=!r} does not exist. Create the folder location first.") diff --git a/src/parcels/_core/particleset.py b/src/parcels/_core/particleset.py index d044e7edf..64ec6acf1 100644 --- a/src/parcels/_core/particleset.py +++ b/src/parcels/_core/particleset.py @@ -20,7 +20,6 @@ ) from parcels._core.warnings import ParticleSetWarning from parcels._logger import logger -from parcels._reprs import particleset_repr __all__ = ["ParticleSet"] @@ -59,10 +58,10 @@ def __init__( self, fieldset, pclass=Particle, - lon=None, - lat=None, - z=None, time=None, + z=None, + lat=None, + lon=None, particle_ids=None, **kwargs, ): @@ -70,9 +69,9 @@ def __init__( self._kernel = None self.fieldset = fieldset - lon = np.empty(shape=0) if lon is None else np.array(lon).flatten() - lat = np.empty(shape=0) if lat is None else np.array(lat).flatten() time = np.empty(shape=0) if time is None else np.array(time).flatten() + lat = np.empty(shape=0) if lat is None else np.array(lat).flatten() + lon = np.empty(shape=0) if lon is None else np.array(lon).flatten() if particle_ids is None: particle_ids = np.arange(lon.size) @@ -112,10 +111,10 @@ def __init__( nparticles=lon.size, ngrids=len(fieldset.gridset), initial=dict( - lon=lon, - lat=lat, - z=z, time=time, + z=z, + lat=lat, + lon=lon, particle_id=particle_ids, ), ) @@ -173,8 +172,8 @@ def __setattr__(self, name, value): def size(self): return len(self) - def __repr__(self): - return particleset_repr(self) + # def __repr__(self): + # return particleset_repr(self) def __len__(self): return len(self._data["particle_id"]) @@ -418,7 +417,7 @@ def execute( if verbose_progress: pbar = tqdm( - total=end_time - start_time, + total=sign_dt * (end_time - start_time), file=sys.stdout, bar_format="{desc} {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}]", ) @@ -454,7 +453,7 @@ def execute( pbar.set_description_str( "Integration time: " + str(float_to_datelike(time, self.fieldset.time_interval)) ) - pbar.update(next_time - time) + pbar.update(sign_dt * (next_time - time)) time = next_time diff --git a/src/parcels/_core/xgrid.py b/src/parcels/_core/xgrid.py index d925cb03b..6486823bd 100644 --- a/src/parcels/_core/xgrid.py +++ b/src/parcels/_core/xgrid.py @@ -1,4 +1,5 @@ from collections.abc import Hashable, Sequence +from dataclasses import dataclass from functools import cached_property from typing import Any, Literal, cast @@ -7,10 +8,12 @@ import xarray as xr import xgcm +import parcels._sgrid as sgrid import parcels._typing as ptyping from parcels._core.basegrid import BaseGrid from parcels._core.index_search import _search_1d_array, _search_indices_curvilinear_2d -from parcels._reprs import xgrid_repr +from parcels._sgrid.accessor import _get_dim_to_axis_mapping +from parcels._sgrid.core import SGRID_PADDING_TO_XGCM_POSITION _FIELD_DATA_ORDERING: Sequence[ptyping.XgcmAxisDirection] = "TZYX" _XGRID_AXES_ORDERING: Sequence[ptyping.XgridAxis] = "ZYX" @@ -68,37 +71,86 @@ def assert_all_field_dims_have_axis(da: xr.DataArray, xgcm_grid: xgcm.Grid) -> N return -def _transpose_xfield_data_to_tzyx(da: xr.DataArray, xgcm_grid: xgcm.Grid) -> xr.DataArray: +def _transpose_xfield_data_to_tzyx(da: xr.DataArray, sgrid_metadata: sgrid.SGrid2DMetadata) -> xr.DataArray: """ Transpose a DataArray of any shape into a 4D array of order TZYX. Uses xgcm to determine the axes, and inserts mock dimensions of size 1 for any axes not present in the DataArray. """ - ax_dims = [(get_axis_from_dim_name(xgcm_grid.axes, dim), dim) for dim in da.dims] + dim_to_axis = _get_dim_to_axis_mapping(sgrid_metadata) | {"time": "T"} + + # filter to only dims in da + dim_to_axis = {dim: axis for dim, axis in dim_to_axis.items() if dim in da.dims} - if all(ax_dim[0] is None for ax_dim in ax_dims): + if dim_to_axis == {}: # Assuming its a 1D constant field (hence has no axes) assert da.shape == (1, 1, 1, 1) return da.rename({old_dim: f"mock{axis}" for old_dim, axis in zip(da.dims, _FIELD_DATA_ORDERING, strict=True)}) # All dimensions must be associated with an axis in the grid - if any(ax_dim[0] is None for ax_dim in ax_dims): + if set(dim_to_axis) != set(da.dims): raise ValueError( f"DataArray {da.name!r} with dims {da.dims} has dimensions that are not associated with a direction on the provided grid." ) - axes_not_in_field = set(_FIELD_DATA_ORDERING) - set(ax_dim[0] for ax_dim in ax_dims) + axes_not_in_field = set(_FIELD_DATA_ORDERING).difference(set(dim_to_axis.values())) mock_dims_to_create = {} for ax in axes_not_in_field: mock_dims_to_create[f"mock{ax}"] = 1 - ax_dims.append((ax, f"mock{ax}")) + dim_to_axis[f"mock{ax}"] = ax if mock_dims_to_create: da = da.expand_dims(mock_dims_to_create, create_index_for_new_dim=False) - ax_dims = sorted(ax_dims, key=lambda x: _FIELD_DATA_ORDERING.index(x[0])) + ax_dims = sorted(dim_to_axis.items(), key=lambda x: _FIELD_DATA_ORDERING.index(x[1])) + + return da.transpose(*[ax_dim[0] for ax_dim in ax_dims]) + + +@dataclass +class XgcmLikeAxis: + coords: dict[ptyping.XgcmAxisPosition, str] + + +class XgcmLikeGrid: + """Adapter class to circumvent XGCM as a dep. + + + This is only used as a temporary class for the moment. Down the line we should refactor to remove XGCM entirely and work with SGRID metadata (especially since COMODO metadata isn't standard, and nor is the xgcm data model). + """ + + def __init__(self, sgrid_metadata: sgrid.SGrid2DMetadata, model_data: xr.Dataset): + self.axes: dict[ptyping.CfAxisSpatial, XgcmLikeAxis] = construct_xgcm_axes_object(sgrid_metadata, model_data) + + +def construct_xgcm_axes_object(metadata: sgrid.SGrid2DMetadata, model_data: xr.Dataset) -> dict[str, XgcmLikeAxis]: + lst: list[tuple[ptyping.CfAxis, str, ptyping.XgcmAxisPosition]] = [] + + for fnp, axis in zip(metadata.face_dimensions, ("X", "Y"), strict=True): + lst.append((axis, fnp.face, "center")) + lst.append((axis, fnp.node, SGRID_PADDING_TO_XGCM_POSITION[fnp.padding])) + + if metadata.vertical_dimensions is not None: + assert len(metadata.vertical_dimensions) == 1 + fnp = metadata.vertical_dimensions[0] + axis = "Z" + lst.append((axis, fnp.face, "center")) + lst.append((axis, fnp.node, SGRID_PADDING_TO_XGCM_POSITION[fnp.padding])) + + # filter so that only dims in the dataset itself are mentioned + lst = [i for i in lst if i[1] in model_data.dims] + + # Add time axis to xgcm_kwargs if present + if "time" in model_data.dims: + lst.append(("T", "time", "center")) + + ret = {} + for axis, dim, position in lst: + if axis not in ret: + ret[axis] = XgcmLikeAxis({}) + ret[axis].coords[position] = dim - return da.transpose(*[ax_dim[1] for ax_dim in ax_dims]) + return ret class XGrid(BaseGrid): @@ -112,11 +164,14 @@ class XGrid(BaseGrid): """ - def __init__(self, grid: xgcm.Grid, mesh): + def __init__(self, model_data: xr.Dataset, mesh): + self.sgrid_metadata = model_data.sgrid.metadata + self._ds = model_data + grid = XgcmLikeGrid(self.sgrid_metadata, model_data) self.xgcm_grid = grid self._mesh = mesh self._spatialhash = None - ds = grid._ds + ds = model_data # Set the coordinates for the dataset (needed to be done explicitly for curvilinear grids) if "lon" in ds: @@ -133,8 +188,8 @@ def __init__(self, grid: xgcm.Grid, mesh): ptyping.assert_valid_mesh(mesh) self._ds = ds - def __repr__(self): - return xgrid_repr(self) + # def __repr__(self): + # return xgrid_repr(self) @property def axes(self) -> list[ptyping.XgridAxis]: diff --git a/src/parcels/_datasets/remote.py b/src/parcels/_datasets/remote.py index eb2206591..e176a1fb2 100644 --- a/src/parcels/_datasets/remote.py +++ b/src/parcels/_datasets/remote.py @@ -63,6 +63,11 @@ def _get_data_home() -> Path: "data/FESOM_periodic_channel/v.fesom_channel.nc", "data/FESOM_periodic_channel/w.fesom_channel.nc", ] + + [ + "data/SCHISM_LakeOntario/out2d.schism_lake_ontario.nc", + "data/SCHISM_LakeOntario/horizontalVelX.schism_lake_ontario.nc", + "data/SCHISM_LakeOntario/horizontalVelY.schism_lake_ontario.nc", + ] + [ "data/NemoCurvilinear_data/U_purely_zonal-ORCA025_grid_U.nc4", "data/NemoCurvilinear_data/V_purely_zonal-ORCA025_grid_V.nc4", @@ -222,6 +227,9 @@ class _Purpose(enum.Enum): ("FESOM_periodic_channel/u.fesom_channel", (_V3Dataset(_ODIE,"data/FESOM_periodic_channel/u.fesom_channel.nc"), _Purpose.TUTORIAL)), ("FESOM_periodic_channel/v.fesom_channel", (_V3Dataset(_ODIE,"data/FESOM_periodic_channel/v.fesom_channel.nc"), _Purpose.TUTORIAL)), ("FESOM_periodic_channel/w.fesom_channel", (_V3Dataset(_ODIE,"data/FESOM_periodic_channel/w.fesom_channel.nc"), _Purpose.TUTORIAL)), + ("SCHISM_LakeOntario/out2d", (_V3Dataset(_ODIE,"data/SCHISM_LakeOntario/out2d.schism_lake_ontario.nc"), _Purpose.TUTORIAL)), + ("SCHISM_LakeOntario/horizontalVelX", (_V3Dataset(_ODIE,"data/SCHISM_LakeOntario/horizontalVelX.schism_lake_ontario.nc"), _Purpose.TUTORIAL)), + ("SCHISM_LakeOntario/horizontalVelY", (_V3Dataset(_ODIE,"data/SCHISM_LakeOntario/horizontalVelY.schism_lake_ontario.nc"), _Purpose.TUTORIAL)), ("NemoCurvilinear_data_zonal/U", (_V3Dataset(_ODIE,"data/NemoCurvilinear_data/U_purely_zonal-ORCA025_grid_U.nc4"), _Purpose.TUTORIAL)), ("NemoCurvilinear_data_zonal/V", (_V3Dataset(_ODIE,"data/NemoCurvilinear_data/V_purely_zonal-ORCA025_grid_V.nc4"), _Purpose.TUTORIAL)), ("NemoCurvilinear_data_zonal/mesh_mask", (_V3Dataset(_ODIE,"data/NemoCurvilinear_data/mesh_mask.nc4", _preprocess_drop_time_from_mesh2), _Purpose.TUTORIAL)), diff --git a/src/parcels/_reprs.py b/src/parcels/_reprs.py index e87d4dc4c..69ac23dc9 100644 --- a/src/parcels/_reprs.py +++ b/src/parcels/_reprs.py @@ -57,7 +57,7 @@ def vectorfield_repr(vector_field: VectorField, from_fieldset_repr=False) -> str out = f"""<{type(vector_field).__name__} {vector_field.name!r}> Parcels attributes: name : {vector_field.name!r} - vector_interp_method : {vector_field.vector_interp_method!r} + interp_method : {vector_field.interp_method!r} vector_type : {vector_field.vector_type!r} {field_repr(vector_field.U, level=1) if not from_fieldset_repr else ""} {field_repr(vector_field.V, level=1) if not from_fieldset_repr else ""} diff --git a/src/parcels/_typing.py b/src/parcels/_typing.py index bb375cdbd..18e8aa55f 100644 --- a/src/parcels/_typing.py +++ b/src/parcels/_typing.py @@ -41,8 +41,9 @@ KernelFunction = Callable[..., None] -XgridAxis = Literal["X", "Y", "Z"] -XgcmAxisDirection = Literal["X", "Y", "Z", "T"] +CfAxisSpatial = Literal["X", "Y", "Z"] +XgridAxis = CfAxisSpatial +XgcmAxisDirection = CfAxisSpatial | Literal["T"] CfAxis = XgcmAxisDirection XgcmAxisPosition = Literal["center", "left", "right", "inner", "outer"] XgcmAxes = Mapping[XgcmAxisDirection, "xgcm.Axis"] diff --git a/src/parcels/interpolators/_base.py b/src/parcels/interpolators/_base.py new file mode 100644 index 000000000..a4dbaf5e0 --- /dev/null +++ b/src/parcels/interpolators/_base.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class ScalarInterpolator(ABC): + @abstractmethod + def interp(self, particle_positions, grid_positions, field) -> Any: #! API a WIP + ... + + +class VectorInterpolator(ABC): + @abstractmethod + def interp(self, particle_positions, grid_positions, vectorfield) -> Any: #! API a WIP + ... diff --git a/src/parcels/interpolators/_uxinterpolators.py b/src/parcels/interpolators/_uxinterpolators.py index cb2c7aa24..d5efc808d 100644 --- a/src/parcels/interpolators/_uxinterpolators.py +++ b/src/parcels/interpolators/_uxinterpolators.py @@ -6,107 +6,126 @@ import numpy as np +from parcels.interpolators._base import ScalarInterpolator, VectorInterpolator + if TYPE_CHECKING: from parcels._core.field import Field, VectorField from parcels._core.uxgrid import _UXGRID_AXES -def UxConstantFaceConstantZC( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[_UXGRID_AXES, dict[str, int | float | np.ndarray]], - field: Field, -): +class UxConstantFaceConstantZC(ScalarInterpolator): """Piecewise constant interpolation kernel for face registered data that is vertically centered (on zc points)""" - return field.data.values[ - grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"] - ] + + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[_UXGRID_AXES, dict[str, int | float | np.ndarray]], + field: Field, + ): + return field.data.values[ + grid_positions["T"]["index"], grid_positions["Z"]["index"], grid_positions["FACE"]["index"] + ] -def UxConstantFaceLinearZF( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[_UXGRID_AXES, dict[str, int | float | np.ndarray]], - field: Field, -): +class UxConstantFaceLinearZF(ScalarInterpolator): """ Piecewise constant interpolation (lateral) with linear vertical interpolation kernel for face registered data that is located at vertical interface levels (on zf points) """ - ti = grid_positions["T"]["index"] - zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"] - z = particle_positions["z"] - - # The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels. - # For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1. - # First, do barycentric interpolation in the lateral direction for each interface level - fzk = field.data.values[ti, zi, fi] - fzkp1 = field.data.values[ti, zi + 1, fi] - - # Then, do piecewise linear interpolation in the vertical direction - zk = field.grid.z.values[zi] - zkp1 = field.grid.z.values[zi + 1] - return (fzk * (zkp1 - z) + fzkp1 * (z - zk)) / (zkp1 - zk) # Linear interpolation in the vertical direction - - -def UxLinearNodeConstantZC( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[_UXGRID_AXES, dict[str, int | float | np.ndarray]], - field: Field, -): - """ - Piecewise linear interpolation kernel for node registered data that is vertically centered (zc points). - Effectively, it applies barycentric interpolation in the lateral direction - and piecewise constant interpolation in the vertical direction. - """ - ti = grid_positions["T"]["index"] - zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"] - bcoords = grid_positions["FACE"]["bcoord"] - node_ids = field.grid.uxgrid.face_node_connectivity[fi, :].values - return np.sum( - field.data.values[ti[:, None], zi[:, None], node_ids] * bcoords, axis=-1 - ) # Linear interpolation in the vertical direction - - -def UxLinearNodeLinearZF( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[_UXGRID_AXES, dict[str, int | float | np.ndarray]], - field: Field, -): - """ - Piecewise linear interpolation kernel for node registered data located at vertical interface levels (zf points). - Effectively, it applies barycentric interpolation in the lateral direction - and piecewise linear interpolation in the vertical direction. - """ - ti = grid_positions["T"]["index"] - zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"] - z = particle_positions["z"] - bcoords = grid_positions["FACE"]["bcoord"] - node_ids = field.grid.uxgrid.face_node_connectivity[fi, :].values - # The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels. - # For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1. - # First, do barycentric interpolation in the lateral direction for each interface level - fzk = np.sum(field.data.values[ti[:, None], zi[:, None], node_ids] * bcoords, axis=-1) - fzkp1 = np.sum(field.data.values[ti[:, None], zi[:, None] + 1, node_ids] * bcoords, axis=-1) - - # Then, do piecewise linear interpolation in the vertical direction - zk = field.grid.z.values[zi] - zkp1 = field.grid.z.values[zi + 1] - return (fzk * (zkp1 - z) + fzkp1 * (z - zk)) / (zkp1 - zk) # Linear interpolation in the vertical direction - - -def Ux_Velocity( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[_UXGRID_AXES, dict[str, int | float | np.ndarray]], - vectorfield: VectorField, -): + + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[_UXGRID_AXES, dict[str, int | float | np.ndarray]], + field: Field, + ): + ti = grid_positions["T"]["index"] + zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"] + z = particle_positions["z"] + + # The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels. + # For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1. + # First, do barycentric interpolation in the lateral direction for each interface level + fzk = field.data.values[ti, zi, fi] + fzkp1 = field.data.values[ti, zi + 1, fi] + + # Then, do piecewise linear interpolation in the vertical direction + zk = field.grid.z.values[zi] + zkp1 = field.grid.z.values[zi + 1] + return (fzk * (zkp1 - z) + fzkp1 * (z - zk)) / (zkp1 - zk) # Linear interpolation in the vertical direction + + +class UxLinearNodeConstantZC(ScalarInterpolator): + """Piecewise linear interpolation kernel for node registered data that is vertically centered (zc points).""" + + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[_UXGRID_AXES, dict[str, int | float | np.ndarray]], + field: Field, + ): + """ + Piecewise linear interpolation kernel for node registered data that is vertically centered (zc points). + Effectively, it applies barycentric interpolation in the lateral direction + and piecewise constant interpolation in the vertical direction. + """ + ti = grid_positions["T"]["index"] + zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"] + bcoords = grid_positions["FACE"]["bcoord"] + node_ids = field.grid.uxgrid.face_node_connectivity[fi, :].values + return np.sum( + field.data.values[ti[:, None], zi[:, None], node_ids] * bcoords, axis=-1 + ) # Linear interpolation in the vertical direction + + +class UxLinearNodeLinearZF(ScalarInterpolator): + """Piecewise linear interpolation kernel for node registered data located at vertical interface levels (zf points).""" + + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[_UXGRID_AXES, dict[str, int | float | np.ndarray]], + field: Field, + ): + """ + Piecewise linear interpolation kernel for node registered data located at vertical interface levels (zf points). + Effectively, it applies barycentric interpolation in the lateral direction + and piecewise linear interpolation in the vertical direction. + """ + ti = grid_positions["T"]["index"] + zi, fi = grid_positions["Z"]["index"], grid_positions["FACE"]["index"] + z = particle_positions["z"] + bcoords = grid_positions["FACE"]["bcoord"] + node_ids = field.grid.uxgrid.face_node_connectivity[fi, :].values + # The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels. + # For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1. + # First, do barycentric interpolation in the lateral direction for each interface level + fzk = np.sum(field.data.values[ti[:, None], zi[:, None], node_ids] * bcoords, axis=-1) + fzkp1 = np.sum(field.data.values[ti[:, None], zi[:, None] + 1, node_ids] * bcoords, axis=-1) + + # Then, do piecewise linear interpolation in the vertical direction + zk = field.grid.z.values[zi] + zkp1 = field.grid.z.values[zi + 1] + return (fzk * (zkp1 - z) + fzkp1 * (z - zk)) / (zkp1 - zk) # Linear interpolation in the vertical direction + + +class Ux_Velocity(VectorInterpolator): # noqa: N801 """Interpolation kernel for Vectorfields of velocity on a UxGrid.""" - u = vectorfield.U._interp_method(particle_positions, grid_positions, vectorfield.U) - v = vectorfield.V._interp_method(particle_positions, grid_positions, vectorfield.V) - if vectorfield.grid._mesh == "spherical": - u /= 1852 * 60 * np.cos(np.deg2rad(particle_positions["lat"])) - v /= 1852 * 60 - - if "3D" in vectorfield.vector_type: - w = vectorfield.W._interp_method(particle_positions, grid_positions, vectorfield.W) - else: - w = 0.0 - return u, v, w + + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[_UXGRID_AXES, dict[str, int | float | np.ndarray]], + vectorfield: VectorField, + ): + u = vectorfield.U.interp_method.interp(particle_positions, grid_positions, vectorfield.U) + v = vectorfield.V.interp_method.interp(particle_positions, grid_positions, vectorfield.V) + if vectorfield.grid._mesh == "spherical": + u /= 1852 * 60 * np.cos(np.deg2rad(particle_positions["lat"])) + v /= 1852 * 60 + + if "3D" in vectorfield.vector_type: + w = vectorfield.W.interp_method.interp(particle_positions, grid_positions, vectorfield.W) + else: + w = 0.0 + return u, v, w diff --git a/src/parcels/interpolators/_xinterpolators.py b/src/parcels/interpolators/_xinterpolators.py index 725422d02..65ce2ab42 100644 --- a/src/parcels/interpolators/_xinterpolators.py +++ b/src/parcels/interpolators/_xinterpolators.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import numpy as np import xarray as xr @@ -10,28 +10,37 @@ import parcels._core.utils.interpolation as i_u import parcels._typing as ptyping +from parcels.interpolators._base import ScalarInterpolator, VectorInterpolator if TYPE_CHECKING: from parcels._core.field import Field, VectorField - from parcels._core.xgrid import XgridAxis + from parcels._core.xgrid import XGrid -def ZeroInterpolator( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[XgridAxis, dict[str, int | float | np.ndarray]], - field: Field, -) -> np.float32 | np.float64: - """Template function used for the signature check of the lateral interpolation methods.""" - return 0.0 +class ZeroInterpolator(ScalarInterpolator): + """Template interpolator that always returns zero.""" + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[ptyping.XgridAxis, dict[str, int | float | np.ndarray]], + field: Field, + ) -> np.float32 | np.float64: + """Template method used for the signature check of the lateral interpolation methods.""" + return 0.0 -def ZeroInterpolator_Vector( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[XgridAxis, dict[str, int | float | np.ndarray]], - vectorfield: VectorField, -) -> np.float32 | np.float64: - """Template function used for the signature check of the interpolation methods for velocity fields.""" - return 0.0 + +class ZeroInterpolator_Vector(VectorInterpolator): # noqa: N801 + """Template vector interpolator that always returns zero.""" + + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[ptyping.XgridAxis, dict[str, int | float | np.ndarray]], + vectorfield: VectorField, + ) -> np.float32 | np.float64: + """Template method used for the signature check of the interpolation methods for velocity fields.""" + return 0.0 def _get_corner_data_Agrid( @@ -43,7 +52,7 @@ def _get_corner_data_Agrid( lenT: int, # noqa: N803 lenZ: int, # noqa: N803 npart: int, - axis_dim: dict[ptyping.XgridAxis, str], + axis_dim: dict[ptyping.ptyping.XgridAxis, str], ) -> np.ndarray: """Helper function to get the corner data for a given A-grid field and position.""" # Time coordinates: 8 points at ti, then 8 points at ti+1 @@ -82,7 +91,7 @@ def _get_corner_data_Agrid( return data.isel(selection_dict).data.reshape(lenT, lenZ, 2, 2, npart) -def _get_offsets_dictionary(grid): +def _get_offsets_dictionary(grid: XGrid) -> dict[ptyping.CfAxisSpatial, Literal[1, 0]]: offsets = {} for axis in ["X", "Y"]: axis_coords = grid.xgcm_grid.axes[axis].coords.keys() @@ -95,295 +104,324 @@ def _get_offsets_dictionary(grid): return offsets -def XLinear( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[XgridAxis, dict[str, int | float | np.ndarray]], - field: Field, -): +class XLinear(ScalarInterpolator): """Trilinear interpolation on a regular grid.""" - xi, xsi = grid_positions["X"]["index"], grid_positions["X"]["bcoord"] - yi, eta = grid_positions["Y"]["index"], grid_positions["Y"]["bcoord"] - zi, zeta = grid_positions["Z"]["index"], grid_positions["Z"]["bcoord"] - ti, tau = grid_positions["T"]["index"], grid_positions["T"]["bcoord"] - axis_dim = field.grid.get_axis_dim_mapping(field.data.dims) - data = field.data + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[ptyping.XgridAxis, dict[str, int | float | np.ndarray]], + field: Field, + ): + """Trilinear interpolation on a regular grid.""" + xi, xsi = grid_positions["X"]["index"], grid_positions["X"]["bcoord"] + yi, eta = grid_positions["Y"]["index"], grid_positions["Y"]["bcoord"] + zi, zeta = grid_positions["Z"]["index"], grid_positions["Z"]["bcoord"] + ti, tau = grid_positions["T"]["index"], grid_positions["T"]["bcoord"] - lenT = 2 if np.any(tau > 0) else 1 - lenZ = 2 if np.any(zeta > 0) else 1 + axis_dim = field.grid.get_axis_dim_mapping(field.data.dims) + data = field.data - corner_data = _get_corner_data_Agrid(data, ti, zi, yi, xi, lenT, lenZ, len(xsi), axis_dim) + lenT = 2 if np.any(tau > 0) else 1 + lenZ = 2 if np.any(zeta > 0) else 1 - if lenT == 2: - tau = tau[np.newaxis, :] - corner_data = corner_data[0, :] * (1 - tau) + corner_data[1, :] * tau - else: - corner_data = corner_data[0, :] + corner_data = _get_corner_data_Agrid(data, ti, zi, yi, xi, lenT, lenZ, len(xsi), axis_dim) - if lenZ == 2: - zeta = zeta[np.newaxis, :] - corner_data = corner_data[0, :] * (1 - zeta) + corner_data[1, :] * zeta - else: - corner_data = corner_data[0, :] + if lenT == 2: + tau = tau[np.newaxis, :] + corner_data = corner_data[0, :] * (1 - tau) + corner_data[1, :] * tau + else: + corner_data = corner_data[0, :] - value = ( - (1 - xsi) * (1 - eta) * corner_data[0, 0, :] - + xsi * (1 - eta) * corner_data[0, 1, :] - + (1 - xsi) * eta * corner_data[1, 0, :] - + xsi * eta * corner_data[1, 1, :] - ) - return value.compute() if is_dask_collection(value) else value + if lenZ == 2: + zeta = zeta[np.newaxis, :] + corner_data = corner_data[0, :] * (1 - zeta) + corner_data[1, :] * zeta + else: + corner_data = corner_data[0, :] + value = ( + (1 - xsi) * (1 - eta) * corner_data[0, 0, :] + + xsi * (1 - eta) * corner_data[0, 1, :] + + (1 - xsi) * eta * corner_data[1, 0, :] + + xsi * eta * corner_data[1, 1, :] + ) + return value.compute() if is_dask_collection(value) else value -def XConstantField( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[XgridAxis, dict[str, int | float | np.ndarray]], - field: Field, -): - """Returning the single value of a Constant Field (with a size=(1,1,1,1) array)""" - return field.data[0, 0, 0, 0].values +class XConstantField(ScalarInterpolator): + """Returns the single value of a Constant Field (with a size=(1,1,1,1) array).""" -def XLinear_Velocity( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[XgridAxis, dict[str, int | float | np.ndarray]], - vectorfield: VectorField, -): + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[ptyping.XgridAxis, dict[str, int | float | np.ndarray]], + field: Field, + ): + """Returning the single value of a Constant Field (with a size=(1,1,1,1) array)""" + return field.data[0, 0, 0, 0].values + + +class XLinear_Velocity(VectorInterpolator): # noqa: N801 """Trilinear interpolation on a regular grid for VectorFields of velocity.""" - u = XLinear(particle_positions, grid_positions, vectorfield.U) - v = XLinear(particle_positions, grid_positions, vectorfield.V) - if vectorfield.grid._mesh == "spherical": - u /= 1852 * 60 * np.cos(np.deg2rad(particle_positions["lat"])) - v /= 1852 * 60 - if vectorfield.W: - w = XLinear(particle_positions, grid_positions, vectorfield.W) - else: - w = 0.0 - return u, v, w + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[ptyping.XgridAxis, dict[str, int | float | np.ndarray]], + vectorfield: VectorField, + ): + """Trilinear interpolation on a regular grid for VectorFields of velocity.""" + _xlinear = XLinear() + u = _xlinear.interp(particle_positions, grid_positions, vectorfield.U) + v = _xlinear.interp(particle_positions, grid_positions, vectorfield.V) + if vectorfield.grid._mesh == "spherical": + u /= 1852 * 60 * np.cos(np.deg2rad(particle_positions["lat"])) + v /= 1852 * 60 + + if vectorfield.W: + w = _xlinear.interp(particle_positions, grid_positions, vectorfield.W) + else: + w = 0.0 + return u, v, w -def CGrid_Velocity( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[XgridAxis, dict[str, int | float | np.ndarray]], - vectorfield: VectorField, -): +class CGrid_Velocity(VectorInterpolator): # noqa: N801 """ Interpolation kernel for velocity fields on a C-Grid. Following Delandmeter and Van Sebille (2019), velocity fields should be interpolated only in the direction of the grid cell faces. """ - xi, xsi = grid_positions["X"]["index"], grid_positions["X"]["bcoord"] - yi, eta = grid_positions["Y"]["index"], grid_positions["Y"]["bcoord"] - zi, zeta = grid_positions["Z"]["index"], grid_positions["Z"]["bcoord"] - ti, tau = grid_positions["T"]["index"], grid_positions["T"]["bcoord"] - - U = vectorfield.U.data - V = vectorfield.V.data - grid = vectorfield.grid - offsets = _get_offsets_dictionary(grid) - tdim, zdim, ydim, xdim = U.shape[0], U.shape[1], U.shape[2], U.shape[3] - lenT = 2 if np.any(tau > 0) else 1 - - if grid.lon.ndim == 1: - px = np.array([grid.lon[xi], grid.lon[xi + 1], grid.lon[xi + 1], grid.lon[xi]]) - py = np.array([grid.lat[yi], grid.lat[yi], grid.lat[yi + 1], grid.lat[yi + 1]]) - else: - px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) - py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) - - if grid._mesh == "spherical": - px = ((px + 180.0) % 360.0) - 180.0 - px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:]) - px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:]) - c1 = i_u._geodetic_distance( - py[0], py[1], px[0], px[1], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(0.0, xsi), py) - ) - c2 = i_u._geodetic_distance( - py[1], py[2], px[1], px[2], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(eta, 1.0), py) - ) - c3 = i_u._geodetic_distance( - py[2], py[3], px[2], px[3], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(1.0, xsi), py) - ) - c4 = i_u._geodetic_distance( - py[3], py[0], px[3], px[0], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(eta, 0.0), py) - ) - - def _create_selection_dict(dims, zdir=False): - """Helper function to create DataArrays for indexing.""" - axis_dim = grid.get_axis_dim_mapping(dims) - selection_dict = { - axis_dim["X"]: xr.DataArray(xi_full, dims=("points")), - axis_dim["Y"]: xr.DataArray(yi_full, dims=("points")), - } - # Time coordinates: 2 points at ti, then 2 points at ti+1 - if "time" in dims: - if lenT == 1: - ti_full = np.repeat(ti, 2) - else: - ti_1 = np.clip(ti + 1, 0, tdim - 1) - ti_full = np.concatenate([np.repeat(ti, 2), np.repeat(ti_1, 2)]) - selection_dict["time"] = xr.DataArray(ti_full, dims=("points")) - - if "Z" in axis_dim: - if zdir: - # Z coordinates: 1 point at zi and 1 point at zi+1 repeated for lenT time levels - zi_0 = np.clip(zi + offsets["Z"], 0, zdim - 1) - zi_1 = np.clip(zi + offsets["Z"] + 1, 0, zdim - 1) - zi_full = np.tile(np.array([zi_0, zi_1]).flatten(), lenT) - else: - # Z coordinates: 2 points at zi, repeated for lenT time levels - zi_full = np.repeat(zi, lenT * 2) - selection_dict[axis_dim["Z"]] = xr.DataArray(zi_full, dims=("points")) - - return selection_dict - - def _compute_corner_data(data, selection_dict) -> np.ndarray: - """Helper function to load and reduce corner data over time dimension if needed.""" - corner_data = data.isel(selection_dict).data.reshape(lenT, 2, len(xsi)) - - if lenT == 2: - tau_full = tau[np.newaxis, :] - corner_data = corner_data[0, :] * (1 - tau_full) + corner_data[1, :] * tau_full + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[ptyping.XgridAxis, dict[str, int | float | np.ndarray]], + vectorfield: VectorField, + ): + """ + Interpolation kernel for velocity fields on a C-Grid. + Following Delandmeter and Van Sebille (2019), velocity fields should be interpolated + only in the direction of the grid cell faces. + """ + xi, xsi = grid_positions["X"]["index"], grid_positions["X"]["bcoord"] + yi, eta = grid_positions["Y"]["index"], grid_positions["Y"]["bcoord"] + zi, zeta = grid_positions["Z"]["index"], grid_positions["Z"]["bcoord"] + ti, tau = grid_positions["T"]["index"], grid_positions["T"]["bcoord"] + + U = vectorfield.U.data + V = vectorfield.V.data + grid = vectorfield.grid + offsets = _get_offsets_dictionary(grid) + tdim, zdim, ydim, xdim = U.shape[0], U.shape[1], U.shape[2], U.shape[3] + lenT = 2 if np.any(tau > 0) else 1 + + if grid.lon.ndim == 1: + px = np.array([grid.lon[xi], grid.lon[xi + 1], grid.lon[xi + 1], grid.lon[xi]]) + py = np.array([grid.lat[yi], grid.lat[yi], grid.lat[yi + 1], grid.lat[yi + 1]]) else: - corner_data = corner_data[0, :] - return corner_data - - # Compute U velocity - yi_o = np.clip(yi + offsets["Y"], 0, ydim - 1) - yi_full = np.tile(np.array([yi_o, yi_o]).flatten(), lenT) - - xi_1 = np.clip(xi + 1, 0, xdim - 1) - xi_full = np.tile(np.array([xi, xi_1]).flatten(), lenT) - - selection_dict = _create_selection_dict(U.dims) - corner_data = _compute_corner_data(U, selection_dict) - - U0 = corner_data[0, :] * c4 - U1 = corner_data[1, :] * c2 - Uvel = (1 - xsi) * U0 + xsi * U1 + px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) + py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) + + if grid._mesh == "spherical": + px = ((px + 180.0) % 360.0) - 180.0 + px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:]) + px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:]) + c1 = i_u._geodetic_distance( + py[0], py[1], px[0], px[1], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(0.0, xsi), py) + ) + c2 = i_u._geodetic_distance( + py[1], py[2], px[1], px[2], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(eta, 1.0), py) + ) + c3 = i_u._geodetic_distance( + py[2], py[3], px[2], px[3], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(1.0, xsi), py) + ) + c4 = i_u._geodetic_distance( + py[3], py[0], px[3], px[0], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(eta, 0.0), py) + ) - # Compute V velocity - yi_1 = np.clip(yi + 1, 0, ydim - 1) - yi_full = np.tile(np.array([yi, yi_1]).flatten(), lenT) + def _create_selection_dict(dims, zdir=False): + """Helper function to create DataArrays for indexing.""" + axis_dim = grid.get_axis_dim_mapping(dims) + selection_dict = { + axis_dim["X"]: xr.DataArray(xi_full, dims=("points")), + axis_dim["Y"]: xr.DataArray(yi_full, dims=("points")), + } + + # Time coordinates: 2 points at ti, then 2 points at ti+1 + if "time" in dims: + if lenT == 1: + ti_full = np.repeat(ti, 2) + else: + ti_1 = np.clip(ti + 1, 0, tdim - 1) + ti_full = np.concatenate([np.repeat(ti, 2), np.repeat(ti_1, 2)]) + selection_dict["time"] = xr.DataArray(ti_full, dims=("points")) + + if "Z" in axis_dim: + if zdir: + # Z coordinates: 1 point at zi and 1 point at zi+1 repeated for lenT time levels + zi_0 = np.clip(zi + offsets["Z"], 0, zdim - 1) + zi_1 = np.clip(zi + offsets["Z"] + 1, 0, zdim - 1) + zi_full = np.tile(np.array([zi_0, zi_1]).flatten(), lenT) + else: + # Z coordinates: 2 points at zi, repeated for lenT time levels + zi_full = np.repeat(zi, lenT * 2) + selection_dict[axis_dim["Z"]] = xr.DataArray(zi_full, dims=("points")) + + return selection_dict + + def _compute_corner_data(data, selection_dict) -> np.ndarray: + """Helper function to load and reduce corner data over time dimension if needed.""" + corner_data = data.isel(selection_dict).data.reshape(lenT, 2, len(xsi)) + + if lenT == 2: + tau_full = tau[np.newaxis, :] + corner_data = corner_data[0, :] * (1 - tau_full) + corner_data[1, :] * tau_full + else: + corner_data = corner_data[0, :] + return corner_data - xi_o = np.clip(xi + offsets["X"], 0, xdim - 1) - xi_full = np.tile(np.array([xi_o, xi_o]).flatten(), lenT) + # Compute U velocity + yi_o = np.clip(yi + offsets["Y"], 0, ydim - 1) + yi_full = np.tile(np.array([yi_o, yi_o]).flatten(), lenT) - selection_dict = _create_selection_dict(V.dims) - corner_data = _compute_corner_data(V, selection_dict) + xi_1 = np.clip(xi + 1, 0, xdim - 1) + xi_full = np.tile(np.array([xi, xi_1]).flatten(), lenT) - V0 = corner_data[0, :] * c1 - V1 = corner_data[1, :] * c3 - Vvel = (1 - eta) * V0 + eta * V1 + selection_dict = _create_selection_dict(U.dims) + corner_data = _compute_corner_data(U, selection_dict) - if grid._mesh == "spherical": - jac = i_u._compute_jacobian_determinant(py, px, eta, xsi) * 1852 * 60.0 - else: - jac = i_u._compute_jacobian_determinant(py, px, eta, xsi) - - u = ( - (-(1 - eta) * Uvel - (1 - xsi) * Vvel) * px[0] - + ((1 - eta) * Uvel - xsi * Vvel) * px[1] - + (eta * Uvel + xsi * Vvel) * px[2] - + (-eta * Uvel + (1 - xsi) * Vvel) * px[3] - ) / jac - v = ( - (-(1 - eta) * Uvel - (1 - xsi) * Vvel) * py[0] - + ((1 - eta) * Uvel - xsi * Vvel) * py[1] - + (eta * Uvel + xsi * Vvel) * py[2] - + (-eta * Uvel + (1 - xsi) * Vvel) * py[3] - ) / jac - if is_dask_collection(u): - u = u.compute() - v = v.compute() - - if grid._mesh == "spherical": - conversion = 1852 * 60.0 * np.cos(np.deg2rad(particle_positions["lat"])) - u /= conversion - v /= conversion + U0 = corner_data[0, :] * c4 + U1 = corner_data[1, :] * c2 + Uvel = (1 - xsi) * U0 + xsi * U1 - if vectorfield.W: - W = vectorfield.W.data + # Compute V velocity + yi_1 = np.clip(yi + 1, 0, ydim - 1) + yi_full = np.tile(np.array([yi, yi_1]).flatten(), lenT) - # Y coordinates: yi+offset for each spatial point, repeated for time - yi_o = np.clip(yi + offsets["Y"], 0, ydim - 1) - yi_full = np.tile(yi_o, (lenT) * 2) - - # X coordinates: xi+offset for each spatial point, repeated for time xi_o = np.clip(xi + offsets["X"], 0, xdim - 1) - xi_full = np.tile(xi_o, (lenT) * 2) + xi_full = np.tile(np.array([xi_o, xi_o]).flatten(), lenT) - selection_dict = _create_selection_dict(W.dims, zdir=True) - corner_data = _compute_corner_data(W, selection_dict) + selection_dict = _create_selection_dict(V.dims) + corner_data = _compute_corner_data(V, selection_dict) - w = corner_data[0, :] * (1 - zeta) + corner_data[1, :] * zeta - if is_dask_collection(w): - w = w.compute() - else: - w = np.zeros_like(u) + V0 = corner_data[0, :] * c1 + V1 = corner_data[1, :] * c3 + Vvel = (1 - eta) * V0 + eta * V1 - return (u, v, w) + if grid._mesh == "spherical": + jac = i_u._compute_jacobian_determinant(py, px, eta, xsi) * 1852 * 60.0 + else: + jac = i_u._compute_jacobian_determinant(py, px, eta, xsi) + + u = ( + (-(1 - eta) * Uvel - (1 - xsi) * Vvel) * px[0] + + ((1 - eta) * Uvel - xsi * Vvel) * px[1] + + (eta * Uvel + xsi * Vvel) * px[2] + + (-eta * Uvel + (1 - xsi) * Vvel) * px[3] + ) / jac + v = ( + (-(1 - eta) * Uvel - (1 - xsi) * Vvel) * py[0] + + ((1 - eta) * Uvel - xsi * Vvel) * py[1] + + (eta * Uvel + xsi * Vvel) * py[2] + + (-eta * Uvel + (1 - xsi) * Vvel) * py[3] + ) / jac + if is_dask_collection(u): + u = u.compute() + v = v.compute() + + if grid._mesh == "spherical": + conversion = 1852 * 60.0 * np.cos(np.deg2rad(particle_positions["lat"])) + u /= conversion + v /= conversion + + if vectorfield.W: + W = vectorfield.W.data + + # Y coordinates: yi+offset for each spatial point, repeated for time + yi_o = np.clip(yi + offsets["Y"], 0, ydim - 1) + yi_full = np.tile(yi_o, (lenT) * 2) + + # X coordinates: xi+offset for each spatial point, repeated for time + xi_o = np.clip(xi + offsets["X"], 0, xdim - 1) + xi_full = np.tile(xi_o, (lenT) * 2) + + selection_dict = _create_selection_dict(W.dims, zdir=True) + corner_data = _compute_corner_data(W, selection_dict) + + w = corner_data[0, :] * (1 - zeta) + corner_data[1, :] * zeta + if is_dask_collection(w): + w = w.compute() + else: + w = np.zeros_like(u) + return (u, v, w) -def CGrid_Tracer( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[XgridAxis, dict[str, int | float | np.ndarray]], - field: Field, -): - """Interpolation kernel for tracer fields on a C-Grid. +class CGrid_Tracer(ScalarInterpolator): # noqa: N801 + """ + Interpolation kernel for tracer fields on a C-Grid. Following Delandmeter and Van Sebille (2019), tracer fields should be interpolated - constant over the grid cell + constant over the grid cell. """ - xi = grid_positions["X"]["index"] - yi = grid_positions["Y"]["index"] - zi = grid_positions["Z"]["index"] - ti = grid_positions["T"]["index"] - tau = grid_positions["T"]["bcoord"] - - axis_dim = field.grid.get_axis_dim_mapping(field.data.dims) - data = field.data - offsets = _get_offsets_dictionary(field.grid) - zi = np.clip(zi + offsets["Z"], 0, data.shape[1] - 1) - yi = np.clip(yi + offsets["Y"], 0, data.shape[2] - 1) - xi = np.clip(xi + offsets["X"], 0, data.shape[3] - 1) + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[ptyping.XgridAxis, dict[str, int | float | np.ndarray]], + field: Field, + ): + """Interpolation kernel for tracer fields on a C-Grid. + + Following Delandmeter and Van Sebille (2019), tracer fields should be interpolated + constant over the grid cell + """ + xi = grid_positions["X"]["index"] + yi = grid_positions["Y"]["index"] + zi = grid_positions["Z"]["index"] + ti = grid_positions["T"]["index"] + tau = grid_positions["T"]["bcoord"] + + axis_dim = field.grid.get_axis_dim_mapping(field.data.dims) + data = field.data + + offsets = _get_offsets_dictionary(field.grid) + zi = np.clip(zi + offsets["Z"], 0, data.shape[1] - 1) + yi = np.clip(yi + offsets["Y"], 0, data.shape[2] - 1) + xi = np.clip(xi + offsets["X"], 0, data.shape[3] - 1) + + lenT = 2 if np.any(tau > 0) else 1 - lenT = 2 if np.any(tau > 0) else 1 - - if lenT == 2: - ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1) - ti = np.concatenate([np.repeat(ti), np.repeat(ti_1)]) - zi = np.tile(zi, (lenT) * 2) - yi = np.tile(yi, (lenT) * 2) - xi = np.tile(xi, (lenT) * 2) + if lenT == 2: + ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1) + ti = np.concatenate([np.repeat(ti), np.repeat(ti_1)]) + zi = np.tile(zi, (lenT) * 2) + yi = np.tile(yi, (lenT) * 2) + xi = np.tile(xi, (lenT) * 2) - # Create DataArrays for indexing - selection_dict = { - axis_dim["X"]: xr.DataArray(xi, dims=("points")), - axis_dim["Y"]: xr.DataArray(yi, dims=("points")), - } - if "Z" in axis_dim: - selection_dict[axis_dim["Z"]] = xr.DataArray(zi, dims=("points")) - if "time" in field.data.dims: - selection_dict["time"] = xr.DataArray(ti, dims=("points")) + # Create DataArrays for indexing + selection_dict = { + axis_dim["X"]: xr.DataArray(xi, dims=("points")), + axis_dim["Y"]: xr.DataArray(yi, dims=("points")), + } + if "Z" in axis_dim: + selection_dict[axis_dim["Z"]] = xr.DataArray(zi, dims=("points")) + if "time" in field.data.dims: + selection_dict["time"] = xr.DataArray(ti, dims=("points")) - value = data.isel(selection_dict).data.reshape(lenT, len(xi)) + value = data.isel(selection_dict).data.reshape(lenT, len(xi)) - if lenT == 2: - tau = tau[:, np.newaxis] - value = value[0, :] * (1 - tau) + value[1, :] * tau - else: - value = value[0, :] + if lenT == 2: + tau = tau[:, np.newaxis] + value = value[0, :] * (1 - tau) + value[1, :] * tau + else: + value = value[0, :] - return value.compute() if is_dask_collection(value) else value + return value.compute() if is_dask_collection(value) else value def _Spatialslip( particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[XgridAxis, dict[str, int | float | np.ndarray]], + grid_positions: dict[ptyping.XgridAxis, dict[str, int | float | np.ndarray]], vectorfield: VectorField, a: np.float32, b: np.float32, @@ -399,10 +437,11 @@ def _Spatialslip( lenZ = 2 if np.any(zeta > 0) else 1 npart = len(xsi) - u = XLinear(particle_positions, grid_positions, vectorfield.U) - v = XLinear(particle_positions, grid_positions, vectorfield.V) + _xlinear = XLinear() + u = _xlinear.interp(particle_positions, grid_positions, vectorfield.U) + v = _xlinear.interp(particle_positions, grid_positions, vectorfield.V) if vectorfield.W: - w = XLinear(particle_positions, grid_positions, vectorfield.W) + w = _xlinear.interp(particle_positions, grid_positions, vectorfield.W) corner_dataU = _get_corner_data_Agrid(vectorfield.U.data, ti, zi, yi, xi, lenT, lenZ, npart, axis_dim) corner_dataV = _get_corner_data_Agrid(vectorfield.V.data, ti, zi, yi, xi, lenT, lenZ, npart, axis_dim) @@ -493,134 +532,153 @@ def is_land(ti: int, zi: int, yi: int, xi: int): return u, v, w -def XFreeslip( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[XgridAxis, dict[str, int | float | np.ndarray]], - vectorfield: VectorField, -): +class XFreeslip(VectorInterpolator): """Free-slip boundary condition interpolation for velocity fields.""" - return _Spatialslip(particle_positions, grid_positions, vectorfield, a=1.0, b=0.0) + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[ptyping.XgridAxis, dict[str, int | float | np.ndarray]], + vectorfield: VectorField, + ): + """Free-slip boundary condition interpolation for velocity fields.""" + return _Spatialslip(particle_positions, grid_positions, vectorfield, a=1.0, b=0.0) -def XPartialslip( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[XgridAxis, dict[str, int | float | np.ndarray]], - vectorfield: VectorField, -): + +class XPartialslip(VectorInterpolator): """Partial-slip boundary condition interpolation for velocity fields.""" - return _Spatialslip(particle_positions, grid_positions, vectorfield, a=0.5, b=0.5) + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[ptyping.XgridAxis, dict[str, int | float | np.ndarray]], + vectorfield: VectorField, + ): + """Partial-slip boundary condition interpolation for velocity fields.""" + return _Spatialslip(particle_positions, grid_positions, vectorfield, a=0.5, b=0.5) -def XNearest( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[XgridAxis, dict[str, int | float | np.ndarray]], - field: Field, -): + +class XNearest(ScalarInterpolator): """ Nearest-Neighbour spatial interpolation on a regular grid. Note that this still uses linear interpolation in time. """ - xi, xsi = grid_positions["X"]["index"], grid_positions["X"]["bcoord"] - yi, eta = grid_positions["Y"]["index"], grid_positions["Y"]["bcoord"] - zi, zeta = grid_positions["Z"]["index"], grid_positions["Z"]["bcoord"] - ti, tau = grid_positions["T"]["index"], grid_positions["T"]["bcoord"] - - axis_dim = field.grid.get_axis_dim_mapping(field.data.dims) - data = field.data - - lenT = 2 if np.any(tau > 0) else 1 - # Spatial coordinates: left if barycentric < 0.5, otherwise right - zi_1 = np.clip(zi + 1, 0, data.shape[1] - 1) - zi_full = np.where(zeta < 0.5, zi, zi_1) + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[ptyping.XgridAxis, dict[str, int | float | np.ndarray]], + field: Field, + ): + """ + Nearest-Neighbour spatial interpolation on a regular grid. + Note that this still uses linear interpolation in time. + """ + xi, xsi = grid_positions["X"]["index"], grid_positions["X"]["bcoord"] + yi, eta = grid_positions["Y"]["index"], grid_positions["Y"]["bcoord"] + zi, zeta = grid_positions["Z"]["index"], grid_positions["Z"]["bcoord"] + ti, tau = grid_positions["T"]["index"], grid_positions["T"]["bcoord"] + + axis_dim = field.grid.get_axis_dim_mapping(field.data.dims) + data = field.data + + lenT = 2 if np.any(tau > 0) else 1 + + # Spatial coordinates: left if barycentric < 0.5, otherwise right + zi_1 = np.clip(zi + 1, 0, data.shape[1] - 1) + zi_full = np.where(zeta < 0.5, zi, zi_1) - yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1) - yi_full = np.where(eta < 0.5, yi, yi_1) + yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1) + yi_full = np.where(eta < 0.5, yi, yi_1) - xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1) - xi_full = np.where(xsi < 0.5, xi, xi_1) + xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1) + xi_full = np.where(xsi < 0.5, xi, xi_1) - # Time coordinates: 1 point at ti, then 1 point at ti+1 - if lenT == 1: - ti_full = ti - else: - ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1) - ti_full = np.concatenate([ti, ti_1]) - xi_full = np.repeat(xi_full, 2) - yi_full = np.repeat(yi_full, 2) - zi_full = np.repeat(zi_full, 2) + # Time coordinates: 1 point at ti, then 1 point at ti+1 + if lenT == 1: + ti_full = ti + else: + ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1) + ti_full = np.concatenate([ti, ti_1]) + xi_full = np.repeat(xi_full, 2) + yi_full = np.repeat(yi_full, 2) + zi_full = np.repeat(zi_full, 2) - # Create DataArrays for indexing - selection_dict = { - axis_dim["X"]: xr.DataArray(xi_full, dims=("points")), - axis_dim["Y"]: xr.DataArray(yi_full, dims=("points")), - } - if "Z" in axis_dim: - selection_dict[axis_dim["Z"]] = xr.DataArray(zi_full, dims=("points")) - if "time" in data.dims: - selection_dict["time"] = xr.DataArray(ti_full, dims=("points")) + # Create DataArrays for indexing + selection_dict = { + axis_dim["X"]: xr.DataArray(xi_full, dims=("points")), + axis_dim["Y"]: xr.DataArray(yi_full, dims=("points")), + } + if "Z" in axis_dim: + selection_dict[axis_dim["Z"]] = xr.DataArray(zi_full, dims=("points")) + if "time" in data.dims: + selection_dict["time"] = xr.DataArray(ti_full, dims=("points")) - corner_data = data.isel(selection_dict).data.reshape(lenT, len(xsi)) + corner_data = data.isel(selection_dict).data.reshape(lenT, len(xsi)) - if lenT == 2: - value = corner_data[0, :] * (1 - tau) + corner_data[1, :] * tau - else: - value = corner_data[0, :] + if lenT == 2: + value = corner_data[0, :] * (1 - tau) + corner_data[1, :] * tau + else: + value = corner_data[0, :] - return value.compute() if is_dask_collection(value) else value + return value.compute() if is_dask_collection(value) else value -def XLinearInvdistLandTracer( - particle_positions: dict[str, float | np.ndarray], - grid_positions: dict[XgridAxis, dict[str, int | float | np.ndarray]], - field: Field, -): +class XLinearInvdistLandTracer(ScalarInterpolator): """Linear spatial interpolation on a regular grid, where points on land are not used.""" - values = XLinear(particle_positions, grid_positions, field) - xi, xsi = grid_positions["X"]["index"], grid_positions["X"]["bcoord"] - yi, eta = grid_positions["Y"]["index"], grid_positions["Y"]["bcoord"] - zi, zeta = grid_positions["Z"]["index"], grid_positions["Z"]["bcoord"] - ti, tau = grid_positions["T"]["index"], grid_positions["T"]["bcoord"] + def interp( + self, + particle_positions: dict[str, float | np.ndarray], + grid_positions: dict[ptyping.XgridAxis, dict[str, int | float | np.ndarray]], + field: Field, + ): + """Linear spatial interpolation on a regular grid, where points on land are not used.""" + values = XLinear().interp(particle_positions, grid_positions, field) - axis_dim = field.grid.get_axis_dim_mapping(field.data.dims) - lenT = 2 if np.any(tau > 0) else 1 - lenZ = 2 if np.any(zeta > 0) else 1 + xi, xsi = grid_positions["X"]["index"], grid_positions["X"]["bcoord"] + yi, eta = grid_positions["Y"]["index"], grid_positions["Y"]["bcoord"] + zi, zeta = grid_positions["Z"]["index"], grid_positions["Z"]["bcoord"] + ti, tau = grid_positions["T"]["index"], grid_positions["T"]["bcoord"] + + axis_dim = field.grid.get_axis_dim_mapping(field.data.dims) + lenT = 2 if np.any(tau > 0) else 1 + lenZ = 2 if np.any(zeta > 0) else 1 - corner_data = _get_corner_data_Agrid(field.data, ti, zi, yi, xi, lenT, lenZ, len(xsi), axis_dim) + corner_data = _get_corner_data_Agrid(field.data, ti, zi, yi, xi, lenT, lenZ, len(xsi), axis_dim) - land_mask = np.isclose(corner_data, 0.0) - nb_land = np.sum(land_mask, axis=(0, 1, 2, 3)) + land_mask = np.isclose(corner_data, 0.0) + nb_land = np.sum(land_mask, axis=(0, 1, 2, 3)) - if np.any(nb_land): - all_land_mask = nb_land == 4 * lenZ * lenT - values[all_land_mask] = 0.0 + if np.any(nb_land): + all_land_mask = nb_land == 4 * lenZ * lenT + values[all_land_mask] = 0.0 - some_land = np.logical_and(nb_land > 0, nb_land < 4 * lenZ * lenT) - if np.any(some_land): - i_grid = np.arange(2)[None, None, None, :, None] - j_grid = np.arange(2)[None, None, :, None, None] - eta_b = eta[None, None, None, None, :] - xsi_b = xsi[None, None, None, None, :] + some_land = np.logical_and(nb_land > 0, nb_land < 4 * lenZ * lenT) + if np.any(some_land): + i_grid = np.arange(2)[None, None, None, :, None] + j_grid = np.arange(2)[None, None, :, None, None] + eta_b = eta[None, None, None, None, :] + xsi_b = xsi[None, None, None, None, :] - dist2 = (eta_b - j_grid) ** 2 + (xsi_b - i_grid) ** 2 + dist2 = (eta_b - j_grid) ** 2 + (xsi_b - i_grid) ** 2 - valid_mask = ~land_mask - # Normal inverse-distance weighting - inv_dist = 1.0 / dist2 - weighted = np.where(valid_mask, corner_data * inv_dist, 0.0) + valid_mask = ~land_mask + # Normal inverse-distance weighting + inv_dist = 1.0 / dist2 + weighted = np.where(valid_mask, corner_data * inv_dist, 0.0) - val = np.sum(weighted, axis=(0, 1, 2, 3)) - w_sum = np.sum(np.where(valid_mask, inv_dist, 0.0), axis=(0, 1, 2, 3)) + val = np.sum(weighted, axis=(0, 1, 2, 3)) + w_sum = np.sum(np.where(valid_mask, inv_dist, 0.0), axis=(0, 1, 2, 3)) - values[some_land] = val[some_land] / w_sum[some_land] + values[some_land] = val[some_land] / w_sum[some_land] - # If a particle hits exactly one of the 8 corner points, extract it - exact_mask = dist2 == 0 & valid_mask - exact_vals = np.sum(np.where(exact_mask, corner_data, 0.0), axis=(0, 1, 2, 3)) - has_exact = np.any(exact_mask, axis=(0, 1, 2, 3)) + # If a particle hits exactly one of the 8 corner points, extract it + exact_mask = dist2 == 0 & valid_mask + exact_vals = np.sum(np.where(exact_mask, corner_data, 0.0), axis=(0, 1, 2, 3)) + has_exact = np.any(exact_mask, axis=(0, 1, 2, 3)) - exact_particles = some_land & has_exact - values[exact_particles] = exact_vals[exact_particles] + exact_particles = some_land & has_exact + values[exact_particles] = exact_vals[exact_particles] - return values.compute() if is_dask_collection(values) else values + return values.compute() if is_dask_collection(values) else values diff --git a/task.md b/task.md new file mode 100644 index 000000000..e953eb539 --- /dev/null +++ b/task.md @@ -0,0 +1,8 @@ +Looking at the changes in `changes.md` (particularly in the "architectural intent" section), I want you to run Pytest one at a time across the files in `files.txt`. + +And either: + +- Remove the test if it is no longer relevant to the architecture. When removing, put a comment `# remove: {reason}` +- Update the test + +Don't make commits while you work. Run `pixi shell` before you start working. Don't make any changes in the src folder. diff --git a/tests/test_field.py b/tests/test_field.py index 7d2790203..6d3105c2b 100644 --- a/tests/test_field.py +++ b/tests/test_field.py @@ -3,113 +3,44 @@ import numpy as np import pytest -from parcels import Field, UxGrid, VectorField, XGrid +from parcels import Field, VectorField from parcels._core.fieldset import FieldSet +from parcels._core.model import StructuredModel from parcels._datasets.structured.generic import T as T_structured from parcels._datasets.structured.generic import datasets as datasets_structured from parcels._datasets.unstructured.generic import datasets as datasets_unstructured from parcels.interpolators import ( UxConstantFaceConstantZC, - UxLinearNodeLinearZF, - XLinear, ) def test_field_init_param_types(): data = datasets_structured["ds_2d_left"] - grid = FieldSet.from_sgrid_conventions(data, mesh="flat").data_g.grid + model = StructuredModel.from_sgrid_conventions(data, mesh="flat") with pytest.raises(TypeError, match="Expected a string for variable name, got int instead."): - Field(name=123, data=data["data_g"], grid=grid, interp_method=XLinear) + Field(name=123, model=model) for name in ["a b", "123"]: with pytest.raises( ValueError, match=r"Received invalid Python variable name.*: not a valid identifier. HINT: avoid using spaces, special characters, and starting with a number.", ): - Field(name=name, data=data["data_g"], grid=grid, interp_method=XLinear) + Field(name=name, model=model) with pytest.raises( ValueError, match=r"Received invalid Python variable name.*: it is a reserved keyword. HINT: avoid using the following names:.*", ): - Field(name="while", data=data["data_g"], grid=grid, interp_method=XLinear) - - with pytest.raises( - ValueError, - match="Expected `data` to be a uxarray.UxDataArray or xarray.DataArray", - ): - Field(name="test", data=123, grid=grid, interp_method=XLinear) - - with pytest.raises(ValueError, match="Expected `grid` to be a parcels UxGrid, or parcels XGrid"): - Field(name="test", data=data["data_g"], grid=123, interp_method=XLinear) - - -# @pytest.mark.parametrize( -# "data,grid", -# [ -# pytest.param( -# ux.UxDataArray(), -# XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"), -# id="uxdata-grid", -# ), -# pytest.param( -# xr.DataArray(), -# UxGrid( -# datasets_unstructured["stommel_gyre_delaunay"].uxgrid, -# z=datasets_unstructured["stommel_gyre_delaunay"].coords["zf"], -# mesh="flat", -# ), -# id="xarray-uxgrid", -# ), -# ], -# ) -@pytest.mark.skip( - "Likely not relevant after refactoring from https://github.com/Parcels-code/Parcels/pull/2646" -) # TODO: Remove or replace -def test_field_incompatible_combination(data, grid): - with pytest.raises(ValueError, match="Incompatible data-grid combination."): - Field( - name="test_field", - data=data, - grid=grid, - interp_method=XLinear, - ) - - -# @pytest.mark.parametrize( -# "data,grid", -# [ -# pytest.param( -# datasets_structured["ds_2d_left"]["data_g"], -# XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"), -# id="ds_2d_left", -# ), # TODO: Perhaps this test should be expanded to cover more datasets? -# ], -# ) -@pytest.mark.skip( - "Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646" -) # TODO: Remove or replace -def test_field_init_structured_grid(data, grid): - """Test creating a field.""" - field = Field( - name="test_field", - data=data, - grid=grid, - interp_method=XLinear, - ) - assert field.name == "test_field" - assert field.data.equals(data) - assert field.grid == grid + Field(name="while", model=model) -@pytest.mark.skip( - "Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646" -) # TODO: Remove or replace +# TODO: Move to test_model.py ? def test_field_init_fail_on_float_time_dim(): - """Test field initialisation fails when given float array as time dimension. + """Test that accessing time_interval fails when dataset has float time dimension. (users are expected to use timedelta64 or datetime). + Time validation has moved from Field.__init__ to Model.time_interval. """ ds = datasets_structured["ds_2d_left"].copy() ds["time"] = ( @@ -118,36 +49,20 @@ def test_field_init_fail_on_float_time_dim(): ds["time"].attrs, ) - data = ds["data_g"] - grid = XGrid.from_dataset(ds, mesh="flat") + model = StructuredModel.from_sgrid_conventions(ds, mesh="flat") with pytest.raises( ValueError, - match=r"Error getting time interval.*. Are you sure that the time dimension on the xarray dataset is stored as timedelta, datetime or cftime datetime objects\?", + match=r"Are you sure that the time dimension on the xarray dataset is stored as timedelta, datetime or cftime datetime objects\?", ): - Field( - name="test_field", - data=data, - grid=grid, - interp_method=XLinear, - ) + _ = model.time_interval -# @pytest.mark.parametrize( -# "data,grid", -# [ -# pytest.param( -# datasets_structured["ds_2d_left"]["data_g"], -# XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"), -# id="ds_2d_left", -# ), -# ], -# ) -@pytest.mark.skip( - "Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646" -) # TODO: Remove or replace -def test_field_time_interval(data, grid): - """Test creating a field.""" - field = Field(name="test_field", data=data, grid=grid, interp_method=XLinear) +# TODO: Move to test_model.py as test_model_time_interval() ? +def test_field_time_interval(): + """Test that field.time_interval delegates correctly to model.time_interval.""" + data = datasets_structured["ds_2d_left"] + model = StructuredModel.from_sgrid_conventions(data, mesh="flat") + field = Field(name="data_g", model=model) assert field.time_interval.left == np.datetime64("2000-01-01") assert field.time_interval.right == np.datetime64("2001-01-01") @@ -159,39 +74,34 @@ def test_vectorfield_init_different_time_intervals(): def test_field_invalid_interpolator(): ds = datasets_structured["ds_2d_left"] - grid = FieldSet.from_sgrid_conventions(ds, mesh="flat").data_g.grid + model = StructuredModel.from_sgrid_conventions(ds, mesh="flat") + field = Field(name="data_g", model=model) - def invalid_interpolator_wrong_signature(particle_positions, grid_positions, invalid): + def not_a_scalar_interpolator(particle_positions, grid_positions, field): return 0.0 - # Test invalid interpolator with wrong signature - with pytest.raises(ValueError, match=".*incorrect name.*"): - Field( - name="test", - data=ds["data_g"], - grid=grid, - interp_method=invalid_interpolator_wrong_signature, - ) + # Interpolators must now be ScalarInterpolator instances, not plain callables + with pytest.raises(ValueError, match="interp_method must be a `ScalarInterpolator` object"): + field.interp_method = not_a_scalar_interpolator def test_vectorfield_invalid_interpolator(): ds = datasets_structured["ds_2d_left"] - grid = FieldSet.from_sgrid_conventions(ds, mesh="flat").data_g.grid + model = StructuredModel.from_sgrid_conventions(ds, mesh="flat") + fields = {f.name: f for f in model.construct_fields()} + U = fields["U_A_grid"] + V = fields["V_A_grid"] - def invalid_interpolator_wrong_signature(particle_positions, grid_positions, invalid): + def not_a_vector_interpolator(particle_positions, grid_positions, field): return 0.0 - # Create component fields - U = Field(name="U", data=ds["data_g"], grid=grid, interp_method=XLinear) - V = Field(name="V", data=ds["data_g"], grid=grid, interp_method=XLinear) - - # Test invalid interpolator with wrong signature - with pytest.raises(ValueError, match=".*incorrect name.*"): + # VectorField interp_method must be a VectorInterpolator instance, not a plain callable + with pytest.raises(ValueError, match="vector_interp_method must be a `VectorInterpolator` object"): VectorField( name="UV", U=U, V=V, - vector_interp_method=invalid_interpolator_wrong_signature, + interp_method=not_a_vector_interpolator, ) @@ -216,9 +126,10 @@ def test_field_unstructured_z_linear(): for k, z in enumerate(ds.coords["zf"]): ds["W"].values[:, k, :] = z - grid = UxGrid(ds.uxgrid, z=ds.coords["zf"], mesh="flat") + fieldset = FieldSet.from_ugrid_conventions(ds, mesh="spherical") # Note that the vertical coordinate is required to be the position of the layer interfaces ("nz"), not the mid-layers ("nz1") - P = Field(name="p", data=ds.p, grid=grid, interp_method=UxConstantFaceConstantZC) + P = fieldset.p + W = fieldset.W # Test above first cell center - for piecewise constant, should return the depth of the first cell center assert np.isclose( @@ -236,7 +147,6 @@ def test_field_unstructured_z_linear(): 944.44445801, ) - W = Field(name="W", data=ds.W, grid=grid, interp_method=UxLinearNodeLinearZF) assert np.isclose( W.eval(time=[0], z=[10.0], y=[30.0], x=[30.0]), 10.0, @@ -253,10 +163,10 @@ def test_field_unstructured_z_linear(): def test_field_constant_in_time(): """Tests field evaluation for a field with no time interval (i.e., constant in time).""" - ds = datasets_unstructured["stommel_gyre_delaunay"] - grid = UxGrid(ds.uxgrid, z=ds.coords["zf"], mesh="flat") + fieldset = FieldSet.from_ugrid_conventions(datasets_unstructured["stommel_gyre_delaunay"], mesh="flat") # Note that the vertical coordinate is required to be the position of the layer interfaces ("nz"), not the mid-layers ("nz1") - P = Field(name="p", data=ds.p, grid=grid, interp_method=UxConstantFaceConstantZC) + P = fieldset.p + assert isinstance(P.interp_method, UxConstantFaceConstantZC) # Assert that the field can be evaluated at any time, and returns the same value time = np.datetime64("2000-01-01T00:00:00") diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 3663588fa..3e1e41ca5 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -5,11 +5,9 @@ import numpy as np import pandas as pd import pytest -import xarray as xr from parcels import Field, ParticleFile, ParticleSet, XGrid, convert -from parcels._core.fieldset import CalendarError, FieldSet, _datetime_to_msg -from parcels._datasets.structured.generic import T as T_structured +from parcels._core.fieldset import FieldSet, _datetime_to_msg from parcels._datasets.structured.generic import datasets as datasets_structured from parcels._datasets.structured.generic import datasets_sgrid from parcels._datasets.unstructured.generic import datasets as datasets_unstructured @@ -20,7 +18,7 @@ def test_fieldset_init_wrong_types(): - with pytest.raises(ValueError, match="Expected `field` to be a Field or VectorField object. Got .*"): + with pytest.raises(ValueError, match="Expected `model` to be a Model object. Got .*"): FieldSet([1.0, 2.0, 3.0]) @@ -52,36 +50,6 @@ def test_fieldset_add_constant_field(fieldset): assert fieldset.test_constant_field[time, z, lat, lon] == 1.0 -@pytest.mark.skip( - "Likely not relevant after refactoring from https://github.com/Parcels-code/Parcels/pull/2646" -) # TODO: Remove or replace -def test_fieldset_add_field(fieldset): - grid = XGrid.from_dataset(ds, mesh="flat") - field = Field("test_field", ds["U_A_grid"], grid, interp_method=XLinear) - fieldset.add_field(field) - assert fieldset.test_field == field - - -@pytest.mark.skip( - "Likely not relevant after refactoring from https://github.com/Parcels-code/Parcels/pull/2646" -) # TODO: Remove or replace -def test_fieldset_add_field_wrong_type(fieldset): - not_a_field = 1.0 - with pytest.raises(ValueError, match="Expected `field` to be a Field or VectorField object. Got .*"): - fieldset.add_field(not_a_field, "test_field") - - -@pytest.mark.skip( - "Likely not relevant after refactoring from https://github.com/Parcels-code/Parcels/pull/2646" -) # TODO: Remove or replace -def test_fieldset_add_field_already_exists(fieldset): - grid = XGrid.from_dataset(ds, mesh="flat") - field = Field("test_field", ds["U_A_grid"], grid, interp_method=XLinear) - fieldset.add_field(field, "test_field") - with pytest.raises(ValueError, match="FieldSet already has a Field with name 'test_field'"): - fieldset.add_field(field, "test_field") - - def test_fieldset_gridset(fieldset): assert fieldset.fields["U"].grid in fieldset.gridset assert fieldset.fields["V"].grid in fieldset.gridset @@ -120,9 +88,8 @@ def test_fieldset_from_structured_generic_datasets(ds): def test_fieldset_gridset_multiple_grids(): ... -@pytest.mark.skip( - "Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646" -) # TODO: Remove or replace +# TODO restructure: use adding of fieldset notation to test this +@pytest.mark.skip("Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646") def test_fieldset_time_interval(): grid1 = XGrid.from_dataset(ds, mesh="flat") field1 = Field("field1", ds["U_A_grid"], grid1, interp_method=XLinear) @@ -147,61 +114,9 @@ def test_fieldset_time_interval_constant_fields(): assert fieldset.time_interval is None -@pytest.mark.skip( - "Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646" -) # TODO: Remove or replace -def test_fieldset_init_incompatible_calendars(): - ds1 = ds.copy() - ds1["time"] = ( - ds1["time"].dims, - xr.date_range("2000", "2001", T_structured, calendar="365_day", use_cftime=True), - ds1["time"].attrs, - ) - - grid = XGrid.from_dataset(ds1, mesh="flat") - U = Field("U", ds1["U_A_grid"], grid, interp_method=XLinear) - V = Field("V", ds1["V_A_grid"], grid, interp_method=XLinear) - - ds2 = ds.copy() - ds2["time"] = ( - ds2["time"].dims, - xr.date_range("2000", "2001", T_structured, calendar="360_day", use_cftime=True), - ds2["time"].attrs, - ) - grid2 = XGrid.from_dataset(ds2, mesh="flat") - incompatible_calendar = Field("test", ds2["data_g"], grid2, interp_method=XLinear) - - with pytest.raises(CalendarError, match="Expected field '.*' to have calendar compatible with datetime object"): - FieldSet([U, V, incompatible_calendar]) - - -@pytest.mark.skip( - "Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646" -) # TODO: Remove or replace -def test_fieldset_add_field_incompatible_calendars(fieldset): - ds_test = ds.copy() - ds_test["time"] = ( - ds_test["time"].dims, - xr.date_range("2000", "2001", T_structured, calendar="360_day", use_cftime=True), - ds_test["time"].attrs, - ) - grid = XGrid.from_dataset(ds_test, mesh="flat") - field = Field("test_field", ds_test["data_g"], grid, interp_method=XLinear) - - with pytest.raises(CalendarError, match="Expected field '.*' to have calendar compatible with datetime object"): - fieldset.add_field(field, "test_field") - - ds_test = ds.copy() - ds_test["time"] = ( - ds_test["time"].dims, - np.linspace(0, 100, T_structured, dtype="timedelta64[s]"), - ds_test["time"].attrs, - ) - grid = XGrid.from_dataset(ds_test, mesh="flat") - field = Field("test_field", ds_test["data_g"], grid, interp_method=XLinear) - - with pytest.raises(CalendarError, match="Expected field '.*' to have calendar compatible with datetime object"): - fieldset.add_field(field, "test_field") +def test_fieldset_add_incompatible_calendars(): + # tests the adding of fieldsets that have incompatible calendars + ... @pytest.mark.parametrize( @@ -280,3 +195,62 @@ def test_fieldset_from_sgrid_conventions(ds_name): fieldset = FieldSet.from_sgrid_conventions(ds, mesh="flat") assert isinstance(fieldset, FieldSet) assert len(fieldset.fields) > 0 + + +def test_fieldset_add(): + """Test that two FieldSets can be combined with + (fset1 + fset2).""" + ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "grid"]].rename({"U_A_grid": "U1"}) + ds2 = datasets_structured["ds_2d_left"][["V_A_grid", "grid"]].rename({"V_A_grid": "V2"}) + + fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat") + fset2 = FieldSet.from_sgrid_conventions(ds2, mesh="flat") + + fset = fset1 + fset2 + + assert len(fset.models) == len(fset1.models) + len(fset2.models) + assert "U1" in fset.fields + assert "V2" in fset.fields + + +def test_fieldset_add_overlapping_fields(): + """Test that adding FieldSets with overlapping field names raises a ValueError.""" + ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "grid"]].rename({"U_A_grid": "U"}) + ds2 = datasets_structured["ds_2d_left"][["V_A_grid", "grid"]].rename({"V_A_grid": "U"}) + + fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat") + fset2 = FieldSet.from_sgrid_conventions(ds2, mesh="flat") + + with pytest.raises(ValueError, match="overlapping field names.*'U'"): + fset1 + fset2 + + +def test_fieldset_add_overlapping_constants(): + """Test that adding FieldSets with overlapping constant names raises a ValueError.""" + ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "grid"]].rename({"U_A_grid": "U1"}) + ds2 = datasets_structured["ds_2d_left"][["V_A_grid", "grid"]].rename({"V_A_grid": "V2"}) + + fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat") + fset1.add_constant("kh", 1.0) + + fset2 = FieldSet.from_sgrid_conventions(ds2, mesh="flat") + fset2.add_constant("kh", 2.0) + + with pytest.raises(ValueError, match="overlapping constant names.*'kh'"): + fset1 + fset2 + + +def test_fieldset_add_constants(): + """Test that constants from both FieldSets are present in the combined FieldSet.""" + ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "grid"]].rename({"U_A_grid": "U1"}) + ds2 = datasets_structured["ds_2d_left"][["V_A_grid", "grid"]].rename({"V_A_grid": "V2"}) + + fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat") + fset1.add_constant("c1", 1.0) + + fset2 = FieldSet.from_sgrid_conventions(ds2, mesh="flat") + fset2.add_constant("c2", 2.0) + + fset = fset1 + fset2 + + assert fset.constants["c1"] == 1.0 + assert fset.constants["c2"] == 2.0 diff --git a/tests/test_index_search.py b/tests/test_index_search.py index 44ef99ba5..fcd27ba22 100644 --- a/tests/test_index_search.py +++ b/tests/test_index_search.py @@ -48,6 +48,8 @@ def test_grid_indexing_fpoints(field_cone): assert y > np.min(cell_lat) and y < np.max(cell_lat) +# XGrid no longer accepts xgcm.Grid objects; now requires xr.Dataset with SGRID metadata; NEMO curvilinear workflow via xgcm is not the construction path +@pytest.mark.skip("Uses now removed API. TODO: What is the goal of this test?") def test_indexing_nemo_curvilinear(): ds = parcels.tutorial.open_dataset("NemoCurvilinear_data_zonal/mesh_mask") ds = ds.isel({"z_a": 0}, drop=True).rename({"glamf": "lon", "gphif": "lat", "z": "depth"}) diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index df6069e10..52f835d1c 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -65,17 +65,17 @@ def field(): ), ) field = FieldSet.from_sgrid_conventions(ds, mesh="flat").U - assert field.interp_method == XLinear + assert isinstance(field.interp_method, XLinear) return field @pytest.mark.parametrize( - "func, t, z, y, x, expected", + "interpolator, t, z, y, x, expected", [ - pytest.param(ZeroInterpolator, 1, 2.5, 0.49, 0.51, 0, id="Zero"), + pytest.param(ZeroInterpolator(), 1, 2.5, 0.49, 0.51, 0, id="Zero"), pytest.param( - XLinear, + XLinear(), [0, 1], [0, 0], [0.49, 0.49], @@ -83,9 +83,9 @@ def field(): [1.49, 6.49], id="Linear-1", ), - pytest.param(XLinear, 1, 2.5, 0.49, 0.51, 13.99, id="Linear-2"), + pytest.param(XLinear(), 1, 2.5, 0.49, 0.51, 13.99, id="Linear-2"), pytest.param( - XLinear, + XLinear(), [0, 1, 1], [0, 0, 2.5], [0.49, 0.49, 0.49], @@ -93,9 +93,9 @@ def field(): [1.49, 6.49, 13.99], id="Linear-3", ), - pytest.param(XLinearInvdistLandTracer, 1, 2.5, 0.49, 0.51, 13.99, id="LinearInvDistLand"), + pytest.param(XLinearInvdistLandTracer(), 1, 2.5, 0.49, 0.51, 13.99, id="LinearInvDistLand"), pytest.param( - XNearest, + XNearest(), [0, 3], [0.2, 0.2], [0.2, 0.2], @@ -105,27 +105,27 @@ def field(): ), ], ) -def test_raw_2d_interpolation(field, func, t, z, y, x, expected): +def test_raw_2d_interpolation(field, interpolator, t, z, y, x, expected): """Test the interpolation functions on the Field.""" particle_positions = {"time": t, "z": z, "lat": y, "lon": x} grid_positions = field.grid.search(z, y, x) grid_positions.update(_search_time_index(field, t)) - value = func(particle_positions, grid_positions, field) + value = interpolator.interp(particle_positions, grid_positions, field) np.testing.assert_equal(value, expected) @pytest.mark.parametrize( "func, t, z, y, x, expected", [ - (XPartialslip, 1, 0, 0, 0.0, [[1], [1]]), - (XFreeslip, 1, 0, 0.5, 1.5, [[1], [0.5]]), - (XPartialslip, 1, 0, 2.5, 1.5, [[0.75], [0.5]]), - (XFreeslip, 1, 0, 2.5, 1.5, [[1], [0.5]]), - (XPartialslip, 1, 0, 1.5, 0.5, [[0.5], [0.75]]), - (XFreeslip, 1, 0, 1.5, 0.5, [[0.5], [1]]), + (XPartialslip(), 1, 0, 0, 0.0, [[1], [1]]), + (XFreeslip(), 1, 0, 0.5, 1.5, [[1], [0.5]]), + (XPartialslip(), 1, 0, 2.5, 1.5, [[0.75], [0.5]]), + (XFreeslip(), 1, 0, 2.5, 1.5, [[1], [0.5]]), + (XPartialslip(), 1, 0, 1.5, 0.5, [[0.5], [0.75]]), + (XFreeslip(), 1, 0, 1.5, 0.5, [[0.5], [1]]), ( - XFreeslip, + XFreeslip(), [1, 0], [0, 2], [1.5, 1.5], @@ -139,19 +139,19 @@ def test_spatial_slip_interpolation(field, func, t, z, y, x, expected): field.data[:, :, 1:3, 1:3] = 0.0 # Set zero land value to test spatial slip U = field V = field - UV = VectorField("UV", U, V, vector_interp_method=func) + UV = VectorField("UV", U, V, interp_method=func) velocities = UV[t, z, y, x] np.testing.assert_array_almost_equal(velocities, expected) @pytest.mark.parametrize( - "func, t, z, y, x, expected", + "interpolator, t, z, y, x, expected", [ - (XLinearInvdistLandTracer, 1, 0, 0.5, 0.5, 1.0), - (XLinearInvdistLandTracer, 1, 0, 1.5, 1.5, 0.0), + (XLinearInvdistLandTracer(), 1, 0, 0.5, 0.5, 1.0), + (XLinearInvdistLandTracer(), 1, 0, 1.5, 1.5, 0.0), ( - XLinearInvdistLandTracer, + XLinearInvdistLandTracer(), [0, 1], [0, 2], [0.5, 0.5], @@ -159,7 +159,7 @@ def test_spatial_slip_interpolation(field, func, t, z, y, x, expected): 1.0, ), ( - XLinearInvdistLandTracer, + XLinearInvdistLandTracer(), [0, 1], [0, 2], [0.5, 1.5], @@ -168,10 +168,10 @@ def test_spatial_slip_interpolation(field, func, t, z, y, x, expected): ), ], ) -def test_invdistland_interpolation(field, func, t, z, y, x, expected): +def test_invdistland_interpolation(field, interpolator, t, z, y, x, expected): field.data[:] = 1.0 field.data[:, :, 1:3, 1:3] = 0 # Set NaN land value to test inv_dist - field.interp_method = func + field.interp_method = interpolator value = field[t, z, y, x] np.testing.assert_array_almost_equal(value, expected) diff --git a/tests/test_particlefile.py b/tests/test_particlefile.py index 0aa6e8b8b..9814dfe33 100755 --- a/tests/test_particlefile.py +++ b/tests/test_particlefile.py @@ -412,6 +412,30 @@ def test_particlefile_init(tmp_parquet): ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) +def test_particlefile_init_existing_path_modes(fieldset, tmp_parquet): + pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0) + + first_file = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) + pset.execute(DoNothing, runtime=np.timedelta64(10, "s"), dt=np.timedelta64(1, "s"), output_file=first_file) + + df_first = pd.read_parquet(tmp_parquet) + + with pytest.raises(ValueError, match="already exists"): + ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) + + overwrite_file = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s"), mode="w") + pset.execute(DoNothing, runtime=np.timedelta64(10, "s"), dt=np.timedelta64(1, "s"), output_file=overwrite_file) + + df_overwrite = pd.read_parquet(tmp_parquet) + + assert len(df_first) == len(df_overwrite) + + +def test_particlefile_init_invalid_mode(tmp_parquet): + with pytest.raises(ValueError, match="Invalid mode value"): + ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s"), mode="something-else") + + @pytest.mark.parametrize("name", ["path", "outputdt"]) def test_particlefile_readonly_attrs(tmp_parquet, name): pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) diff --git a/tests/test_particleset_execute.py b/tests/test_particleset_execute.py index 93042a799..8a2cf9342 100644 --- a/tests/test_particleset_execute.py +++ b/tests/test_particleset_execute.py @@ -5,7 +5,6 @@ import pytest from parcels import ( - Field, FieldInterpolationError, FieldOutOfBoundError, FieldSet, @@ -14,19 +13,13 @@ ParticleFile, ParticleSet, StatusCode, - UxGrid, Variable, - VectorField, ) from parcels._core.utils.time import timedelta_to_float from parcels._datasets.structured.generated import simple_UV_dataset from parcels._datasets.structured.generic import datasets as datasets_structured from parcels._datasets.unstructured.generic import datasets as datasets_unstructured -from parcels.interpolators import ( - Ux_Velocity, - UxConstantFaceConstantZC, - UxLinearNodeLinearZF, -) +from parcels.interpolators._base import ScalarInterpolator from parcels.kernels import AdvectionEE, AdvectionRK2, AdvectionRK4, AdvectionRK4_3D, AdvectionRK45 from tests.common_kernels import DoNothing from tests.utils import DEFAULT_PARTICLES @@ -335,13 +328,14 @@ def test_raise_general_error(): ... def test_errorinterpolation(fieldset): - def NaNInterpolator(particle_positions, grid_positions, field): # pragma: no cover - return np.nan * np.zeros_like(particle_positions["lon"]) + class NaNInterpolator(ScalarInterpolator): # pragma: no cover + def interp(self, particle_positions, grid_positions, field): + return np.nan * np.zeros_like(particle_positions["lon"]) def SampleU(particles, fieldset): # pragma: no cover fieldset.U[particles.time, particles.z, particles.lat, particles.lon, particles] - fieldset.U.interp_method = NaNInterpolator + fieldset.U.interp_method = NaNInterpolator() pset = ParticleSet(fieldset, lon=[0, 2], lat=[0, 0]) with pytest.raises(FieldInterpolationError): pset.execute(SampleU, runtime=np.timedelta64(2, "s"), dt=np.timedelta64(1, "s")) @@ -471,27 +465,7 @@ def SetLat2(p): def test_uxstommelgyre_pset_execute(): ds = datasets_unstructured["stommel_gyre_delaunay"] - grid = UxGrid(grid=ds.uxgrid, z=ds.coords["zf"], mesh="spherical") - U = Field( - name="U", - data=ds.U, - grid=grid, - interp_method=UxConstantFaceConstantZC, - ) - V = Field( - name="V", - data=ds.V, - grid=grid, - interp_method=UxConstantFaceConstantZC, - ) - P = Field( - name="P", - data=ds.p, - grid=grid, - interp_method=UxConstantFaceConstantZC, - ) - UV = VectorField(name="UV", U=U, V=V, vector_interp_method=Ux_Velocity) - fieldset = FieldSet([UV, UV.U, UV.V, P]) + fieldset = FieldSet.from_ugrid_conventions(ds, mesh="spherical") pset = ParticleSet( fieldset, lon=[30.0], @@ -511,33 +485,7 @@ def test_uxstommelgyre_pset_execute(): def test_uxstommelgyre_multiparticle_pset_execute(): ds = datasets_unstructured["stommel_gyre_delaunay"] - grid = UxGrid(grid=ds.uxgrid, z=ds.coords["zf"], mesh="spherical") - U = Field( - name="U", - data=ds.U, - grid=grid, - interp_method=UxConstantFaceConstantZC, - ) - V = Field( - name="V", - data=ds.V, - grid=grid, - interp_method=UxConstantFaceConstantZC, - ) - W = Field( - name="W", - data=ds.W, - grid=grid, - interp_method=UxLinearNodeLinearZF, - ) - P = Field( - name="P", - data=ds.p, - grid=grid, - interp_method=UxConstantFaceConstantZC, - ) - UVW = VectorField(name="UVW", U=U, V=V, W=W, vector_interp_method=Ux_Velocity) - fieldset = FieldSet([UVW, UVW.U, UVW.V, UVW.W, P]) + fieldset = FieldSet.from_ugrid_conventions(ds, mesh="spherical") pset = ParticleSet( fieldset, lon=[30.0, 32.0], @@ -556,27 +504,7 @@ def test_uxstommelgyre_multiparticle_pset_execute(): @pytest.mark.xfail(reason="Output file not implemented yet") def test_uxstommelgyre_pset_execute_output(): ds = datasets_unstructured["stommel_gyre_delaunay"] - grid = UxGrid(grid=ds.uxgrid, z=ds.coords["nz"], mesh="spherical") - U = Field( - name="U", - data=ds.U, - grid=grid, - interp_method=UxConstantFaceConstantZC, - ) - V = Field( - name="V", - data=ds.V, - grid=grid, - interp_method=UxConstantFaceConstantZC, - ) - P = Field( - name="P", - data=ds.p, - grid=grid, - interp_method=UxConstantFaceConstantZC, - ) - UV = VectorField(name="UV", U=U, V=V, vector_interp_method=Ux_Velocity) - fieldset = FieldSet([UV, UV.U, UV.V, P]) + fieldset = FieldSet.from_ugrid_conventions(ds, mesh="spherical") pset = ParticleSet( fieldset, lon=[30.0], diff --git a/tests/test_uxarray_fieldset.py b/tests/test_uxarray_fieldset.py index 8393424ee..6b3933389 100644 --- a/tests/test_uxarray_fieldset.py +++ b/tests/test_uxarray_fieldset.py @@ -7,17 +7,11 @@ import parcels._datasets.remote as _parcels_remote import parcels.tutorial from parcels import ( - Field, FieldSet, - Particle, - ParticleSet, - UxGrid, - VectorField, ) from parcels._datasets.unstructured.generic import datasets as datasets_unstructured from parcels.convert import fesom_to_ugrid, icon_to_ugrid from parcels.interpolators import ( - Ux_Velocity, UxConstantFaceConstantZC, UxLinearNodeLinearZF, ) @@ -41,81 +35,17 @@ def ds_fesom_channel() -> ux.UxDataset: @pytest.fixture -def uv_fesom_channel(ds_fesom_channel) -> VectorField: - UV = VectorField( - name="UV", - U=Field( - name="U", - data=ds_fesom_channel.U, - grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["zc"], mesh="flat"), - interp_method=UxConstantFaceConstantZC, - ), - V=Field( - name="V", - data=ds_fesom_channel.V, - grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["zc"], mesh="flat"), - interp_method=UxConstantFaceConstantZC, - ), - vector_interp_method=Ux_Velocity, - ) - return UV - - -@pytest.fixture -def uvw_fesom_channel(ds_fesom_channel) -> VectorField: - UVW = VectorField( - name="UVW", - U=Field( - name="U", - data=ds_fesom_channel.U, - grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["zc"], mesh="flat"), - interp_method=UxConstantFaceConstantZC, - ), - V=Field( - name="V", - data=ds_fesom_channel.V, - grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["zc"], mesh="flat"), - interp_method=UxConstantFaceConstantZC, - ), - W=Field( - name="W", - data=ds_fesom_channel.W, - grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["zf"], mesh="flat"), - interp_method=UxLinearNodeLinearZF, - ), - vector_interp_method=Ux_Velocity, - ) - return UVW - - -def test_fesom_fieldset(ds_fesom_channel, uv_fesom_channel): - fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) - # Check that the fieldset has the expected properties - assert (fieldset.U.data == ds_fesom_channel.U).all() - assert (fieldset.V.data == ds_fesom_channel.V).all() +def fieldset_fesom_channel(ds_fesom_channel): + return FieldSet.from_ugrid_conventions(ds_fesom_channel) -def test_fesom_in_particleset(ds_fesom_channel, uv_fesom_channel): - fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) - - # Check that the fieldset has the expected properties - assert (fieldset.U.data == ds_fesom_channel.U).all() - assert (fieldset.V.data == ds_fesom_channel.V).all() - pset = ParticleSet(fieldset, pclass=Particle) - assert pset.fieldset == fieldset - - -def test_set_interp_methods(ds_fesom_channel, uv_fesom_channel): - fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) +def test_fesom_fieldset(ds_fesom_channel, fieldset_fesom_channel): # Check that the fieldset has the expected properties - assert (fieldset.U.data == ds_fesom_channel.U).all() - assert (fieldset.V.data == ds_fesom_channel.V).all() - - # Set the interpolation method for each field - fieldset.U.interp_method = UxConstantFaceConstantZC - fieldset.V.interp_method = UxConstantFaceConstantZC + assert (fieldset_fesom_channel.U.data == ds_fesom_channel.U).all() + assert (fieldset_fesom_channel.V.data == ds_fesom_channel.V).all() +@pytest.mark.xfail(reason="#2674 - 'p' interpolator is not being selected properly") def test_fesom2_square_delaunay_uniform_z_coordinate_eval(): """ Test the evaluation of a fieldset with a FESOM2 square Delaunay grid and uniform z-coordinate. @@ -124,16 +54,12 @@ def test_fesom2_square_delaunay_uniform_z_coordinate_eval(): """ ds = datasets_unstructured["fesom2_square_delaunay_uniform_z_coordinate"] ds = fesom_to_ugrid(ds) - grid = UxGrid(ds.uxgrid, z=ds.coords["zf"], mesh="flat") - UVW = VectorField( - name="UVW", - U=Field(name="U", data=ds.U, grid=grid, interp_method=UxConstantFaceConstantZC), - V=Field(name="V", data=ds.V, grid=grid, interp_method=UxConstantFaceConstantZC), - W=Field(name="W", data=ds.W, grid=grid, interp_method=UxLinearNodeLinearZF), - vector_interp_method=Ux_Velocity, - ) - P = Field(name="p", data=ds.p, grid=grid, interp_method=UxLinearNodeLinearZF) - fieldset = FieldSet([UVW, P, UVW.U, UVW.V, UVW.W]) + fieldset = FieldSet.from_ugrid_conventions(ds) + + assert isinstance(fieldset.U.interp_method, UxConstantFaceConstantZC) + assert isinstance(fieldset.V.interp_method, UxConstantFaceConstantZC) + assert isinstance(fieldset.W.interp_method, UxLinearNodeLinearZF) + assert isinstance(fieldset.p.interp_method, UxLinearNodeLinearZF) (u, v, w) = fieldset.UVW.eval(time=[0.0], z=[1.0], y=[30.0], x=[30.0]) assert np.allclose([u.item(), v.item(), w.item()], [1.0, 1.0, 0.0], rtol=1e-3, atol=1e-6) @@ -172,13 +98,8 @@ def test_fesom2_square_delaunay_antimeridian_eval(): """ ds = datasets_unstructured["fesom2_square_delaunay_antimeridian"] ds = fesom_to_ugrid(ds) - P = Field( - name="p", - data=ds.p, - grid=UxGrid(ds.uxgrid, z=ds.coords["zf"], mesh="spherical"), - interp_method=UxLinearNodeLinearZF, - ) - fieldset = FieldSet([P]) + fieldset = FieldSet.from_ugrid_conventions(ds) + fieldset.p.interp_method = UxLinearNodeLinearZF() assert np.isclose(fieldset.p.eval(time=[0], z=[1.0], y=[30.0], x=[-170.0]), 1.0) assert np.isclose(fieldset.p.eval(time=[0], z=[1.0], y=[30.0], x=[-180.0]), 1.0) diff --git a/tests/test_xgrid.py b/tests/test_xgrid.py index 17952e426..06e2696b5 100644 --- a/tests/test_xgrid.py +++ b/tests/test_xgrid.py @@ -184,6 +184,9 @@ def test_dim_with_duplicate_axis(): FieldSet.from_sgrid_conventions(ds) +# TODO restructure: Look into the test below +# remove: eval with timedelta64 time fails; TimeInterval.is_all_time_in_interval expects float (seconds) not timedelta64; time arg type requirement changed +@pytest.mark.skip("remove: see comment above") @pytest.mark.parametrize("ds", [datasets["ds_2d_left"]]) def test_vertical1D_field(ds): ds = ds.drop(set(ds.data_vars) - {"grid"}) @@ -197,18 +200,12 @@ def test_vertical1D_field(ds): np.testing.assert_almost_equal(field.eval(np.timedelta64(0, "s"), 0.45, 0, 0), np.array([4.5])) -@pytest.mark.skip( - "Needs updating after refactoring from https://github.com/Parcels-code/Parcels/pull/2646" -) # TODO: Remove or replace def test_time1D_field(): - timerange = xr.date_range("2000-01-01", "2000-01-20") - ds = xr.Dataset( - {"t1d": (["time"], np.arange(0, len(timerange)))}, - coords={"time": (["time"], timerange, {"axis": "T"})}, - ) - grid = XGrid.from_dataset(ds, mesh="flat") - field = Field("t1d", ds["t1d"], grid, XLinear) + ds = datasets["ds_2d_left"].sgrid.isel(XC=0, YC=0, ZC=0)[["data_g", "grid"]] + ds["data_g"] = (["time"], np.arange(0, ds["time"].size)) + ds["time"] = xr.date_range("2000-01-01", "2000-01-13") + field = FieldSet.from_sgrid_conventions(ds, mesh="flat").data_g time = timedelta_to_float(np.datetime64("2000-01-10T12:00:00") - field.time_interval.left) assert field.eval(time, -20, 5, 6) == 9.5