Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "point-collocation"
version = "0.8.0"
version = "0.9.0"
description = "Point-based lat/lon/time matchups against cloud-hosted NetCDF/Zarr granules"
readme = "README.md"
license = { file = "LICENSE" }
Expand Down
227 changes: 195 additions & 32 deletions src/point_collocation/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
* Open each granule individually (never ``open_mfdataset``) to minimise
cloud I/O and avoid memory leaks.
* Extract the requested variables at each point's location/time using
nearest-neighbor selection (gridded) or kdtree (non-gridded, e.g. swath).
vectorised nearest-neighbour selection (gridded, ``axis`` method) or
k-d tree index (non-gridded/swath data).
* Collect results into a ``pandas.DataFrame`` with one row per
(point, granule) pair.

Expand Down Expand Up @@ -59,7 +60,7 @@
# Re-export geolocation pairs for callers that import them from this module.
from point_collocation.core._open_method import _GEOLOC_PAIRS # noqa: F401

_VALID_SPATIAL_METHODS = {"nearest", "xoak-kdtree", "kdtree", "auto", "xoak-haversine"}
_VALID_SPATIAL_METHODS = {"axis", "xoak-kdtree", "kdtree", "auto", "xoak-haversine"}

# Time dimension names used as a fallback when cf_xarray is not installed or
# when the dataset lacks CF-convention axis/units attributes. Tried in order.
Expand Down Expand Up @@ -136,17 +137,22 @@ def matchup(
* ``"auto"`` *(default)* — automatically selects the best method
based on the dimensionality of the geolocation coordinates:

- **1-D coordinates** (regular/gridded data): uses ``"nearest"``
(``ds.sel(..., method="nearest")``). If ``"nearest"`` fails for
any reason, falls back to ``"kdtree"`` automatically.
- **2-D coordinates** (irregular/swath data): uses ``"kdtree"``.
- **1-D coordinates** (regular/gridded data, both lat and lon
are 1-D): uses ``"axis"`` (vectorised
``ds.sel(..., method="nearest")`` over all points at once).
If ``"axis"`` fails for any reason, falls back to ``"kdtree"``
automatically.
- **2-D coordinates** (irregular/swath data, or either coordinate
is 2-D): uses ``"kdtree"``.

``xoak-kdtree`` and ``xoak-haversine`` are never selected
automatically; set them explicitly if needed.

* ``"nearest"`` — ``ds.sel(..., method="nearest")`` directly.
Requires 1-D coordinate arrays; raises :exc:`ValueError` with a
suggestion to use ``"auto"`` or ``"kdtree"`` for 2-D coordinates.
* ``"axis"`` — Vectorised ``ds.sel(..., method="nearest")`` for all
points in a single call. Requires 1-D (regular-grid) coordinate
arrays for both latitude and longitude; raises :exc:`ValueError`
with a suggestion to use ``"auto"`` or ``"kdtree"`` for 2-D
coordinates.
* ``"kdtree"`` — xarray's built-in
:class:`xarray.indexes.NDPointIndex` with the default
``ScipyKDTreeAdapter``. Works with both 1-D and 2-D coordinate
Expand Down Expand Up @@ -475,15 +481,16 @@ def _check_spatial_compat(
) -> None:
"""Raise if lat/lon dimensionality is incompatible with *spatial_method*.

Only validates for ``spatial_method="nearest"``, which requires 1-D
coordinate arrays. ``spatial_method="xoak-kdtree"``,
Only validates for ``spatial_method="axis"``, which requires 1-D
coordinate arrays for both latitude and longitude.
``spatial_method="xoak-kdtree"``,
``spatial_method="xoak-haversine"``, ``spatial_method="kdtree"``, and
``spatial_method="auto"`` work with both 1-D and 2-D arrays and are not
validated here.

Uses only metadata (``dims``) — does **not** load array data.
"""
if spatial_method != "nearest":
if spatial_method != "axis":
return

lon_var = ds.coords[lon_name] if lon_name in ds.coords else ds[lon_name]
Expand All @@ -494,7 +501,7 @@ def _check_spatial_compat(

if lon_ndim != 1 or lat_ndim != 1:
raise ValueError(
f"spatial_method='nearest' requires 1-D geolocation arrays, but found "
f"spatial_method='axis' requires 1-D geolocation arrays, but found "
f"{lon_name!r} with dims={tuple(lon_var.dims)} and "
f"{lat_name!r} with dims={tuple(lat_var.dims)}. "
"Use spatial_method='auto' or spatial_method='kdtree' for 2-D "
Expand Down Expand Up @@ -709,12 +716,12 @@ def _execute_plan(
# Track whether we have already validated spatial compat on the first granule.
spatial_checked = False

# For "auto" spatial_method, the effective method ("nearest" or "kdtree")
# For "auto" spatial_method, the effective method ("axis" or "kdtree")
# is determined on the first opened granule based on lat/lon dimensionality.
# For explicit methods this always equals spatial_method.
effective_spatial: str = spatial_method
# When auto resolves to "nearest" on 1-D data, allow one fallback to
# "kdtree" per granule if nearest extraction fails.
# When auto resolves to "axis" on 1-D data, allow one fallback to
# "kdtree" per granule if axis extraction fails.
auto_1d_fallback: bool = (spatial_method == "auto")

# For "auto" open_method, probe only the first granule to determine whether
Expand Down Expand Up @@ -800,17 +807,24 @@ def _execute_plan(
if lat_name in ds.coords
else ds[lat_name]
)
if lat_var_check.ndim == 1:
effective_spatial = "nearest"
lon_var_check = (
ds.coords[lon_name]
if lon_name in ds.coords
else ds[lon_name]
)
if lat_var_check.ndim == 1 and lon_var_check.ndim == 1:
effective_spatial = "axis"
# auto_1d_fallback already True; keep it so
# that a nearest failure falls back to kdtree.
# that an axis failure falls back to kdtree.
else:
effective_spatial = "kdtree"
auto_1d_fallback = False
if not silent:
lat_ndim = lat_var_check.ndim
lon_ndim = lon_var_check.ndim
print(
f"spatial_method='auto': using '{effective_spatial}' "
f"(lat/lon dims: {lat_var_check.ndim}-D)"
f"(lat/lon dims: {lat_ndim}-D/{lon_ndim}-D)"
)
else:
effective_spatial = spatial_method
Expand All @@ -835,7 +849,7 @@ def _execute_plan(
# to the spatial extent of the query points before building
# the k-d tree. A global granule with only a few scattered
# points would otherwise cause the index to cover the entire global
# grid, which is very slow. Skip this step for "nearest" (1-D)
# grid, which is very slow. Skip this step for "axis" (1-D)
# since it does not build an index.
if effective_spatial in ("xoak-kdtree", "xoak-haversine", "kdtree"):
lat_var = ds.coords[lat_name] if lat_name in ds.coords else ds[lat_name]
Expand Down Expand Up @@ -880,8 +894,8 @@ def _execute_plan(
output_rows.extend(rows_for_granule)
batch_rows.extend(rows_for_granule)
elif auto_1d_fallback:
# auto resolved to "nearest" on 1-D coords. Try
# nearest for each point; if it fails, fall back to
# auto resolved to "axis" on 1-D coords. Try
# axis batch extraction; if it fails, fall back to
# ndpoint for the whole granule (and all future ones).
def _make_row(pt_idx: object) -> dict:
r = plan.points.loc[pt_idx].to_dict()
Expand All @@ -893,10 +907,9 @@ def _make_row(pt_idx: object) -> dict:

rows_for_granule = [_make_row(idx) for idx in pt_indices]
try:
for row in rows_for_granule:
_extract_nearest(ds, row, variables, lon_name, lat_name, time_dim, additional_axes=resolved_add_axes)
except Exception as _nearest_exc:
# nearest failed; rebuild clean rows and retry with kdtree.
_extract_axis_batch(ds, rows_for_granule, variables, lon_name, lat_name, time_dim, additional_axes=resolved_add_axes)
except Exception as _axis_exc:
# axis failed; rebuild clean rows and retry with kdtree.
rows_for_granule = [_make_row(idx) for idx in pt_indices]
# Apply slicing for kdtree on 1-D coords.
pt_lats = [float(plan.points.loc[idx]["lat"]) for idx in pt_indices]
Expand All @@ -911,24 +924,27 @@ def _make_row(pt_idx: object) -> dict:
auto_1d_fallback = False
except Exception as nd_exc:
raise ValueError(
"spatial_method='auto' tried both 'nearest' and 'kdtree' "
"spatial_method='auto' tried both 'axis' and 'kdtree' "
"for a granule with 1-D lat/lon coordinates, but both "
"failed. Check that the dataset has valid geolocation "
f"coordinates. 'nearest' error: {_nearest_exc!r}; "
f"coordinates. 'axis' error: {_axis_exc!r}; "
f"'kdtree' error: {nd_exc!r}"
) from nd_exc
output_rows.extend(rows_for_granule)
batch_rows.extend(rows_for_granule)
else:
# explicit spatial_method="axis": batch all points in one call.
rows_for_granule = []
for pt_idx in pt_indices:
row = plan.points.loc[pt_idx].to_dict()
if not has_user_pc_id:
row["pc_id"] = pt_idx
row["granule_id"] = gm.granule_id
row["granule_time"] = granule_time
_extract_nearest(ds, row, variables, lon_name, lat_name, time_dim, additional_axes=resolved_add_axes)
output_rows.append(row)
batch_rows.append(row)
rows_for_granule.append(row)
_extract_axis_batch(ds, rows_for_granule, variables, lon_name, lat_name, time_dim, additional_axes=resolved_add_axes)
output_rows.extend(rows_for_granule)
batch_rows.extend(rows_for_granule)

batch_matched_points += len(pt_indices)

Expand Down Expand Up @@ -1261,6 +1277,153 @@ def _drop_nan_geoloc(
return stacked.isel({"__pc__": valid})


def _extract_axis_batch(
ds: xr.Dataset,
rows: list[dict],
variables: list[str],
lon_name: str,
lat_name: str,
time_dim: str | None = None,
*,
additional_axes: "dict[str, dict[str, str]] | None" = None,
) -> None:
"""Vectorized extraction using ``ds.sel(..., method='nearest')`` for 1-D lat/lon.

Batches all points into a single ``.sel()`` call using xarray's
vectorised indexing API. This requires 1-D (regular-grid) coordinate
arrays for both latitude and longitude.

Modifies each dict in *rows* in-place, including ``granule_lat`` and
``granule_lon`` for the matched grid location. ``granule_time`` is set by
the caller from granule metadata before this function is called.

Parameters
----------
time_dim:
Name of the time dimension in *ds*, as detected by
:func:`_find_time_dim`. When not ``None``, each variable is
squeezed or nearest-selected along this dimension after spatial
selection so that the result is always free of the time axis.
additional_axes:
Resolved additional 1D matching axes from
:func:`_resolve_additional_axes_for_ds`. Each entry maps an axis
name to ``{"points_col": ..., "source_coord": ...}``. When provided,
a vectorised nearest-neighbour selection is also performed along each
axis.
"""
if additional_axes is None:
additional_axes = {}
if not rows:
return

lats = [row["lat"] for row in rows]
lons = [row["lon"] for row in rows]

# Use a unique dimension name for the points indexer to avoid conflicts
# with any existing coordinate in the dataset.
pts_dim = "pc_points_idx"
lat_pts = xr.DataArray(lats, dims=[pts_dim])
lon_pts = xr.DataArray(lons, dims=[pts_dim])

# Extract the actual matched coordinates (nearest-neighbour grid positions).
try:
matched_lats = ds.coords[lat_name].sel({lat_name: lat_pts}, method="nearest").values
matched_lons = ds.coords[lon_name].sel({lon_name: lon_pts}, method="nearest").values
for i, row in enumerate(rows):
row["granule_lat"] = float(matched_lats[i])
row["granule_lon"] = float(matched_lons[i])
except Exception:
for row in rows:
row["granule_lat"] = float("nan")
row["granule_lon"] = float("nan")

# Build the base vectorised selection dict.
base_sel: dict = {lat_name: lat_pts, lon_name: lon_pts}

# Add additional axes as vectorised indexers when all row values are present.
for axis_name, info in additional_axes.items():
pts_col = info["points_col"]
src_coord = info["source_coord"]
ax_vals = [row.get(pts_col) for row in rows]
if all(v is not None for v in ax_vals):
base_sel[src_coord] = xr.DataArray(ax_vals, dims=[pts_dim])

# Pre-compute a vectorised time indexer (one timestamp per point).
# When the variable being extracted has a time dimension, including it in
# the .sel() call means that .load() only reads one time step per point
# instead of all time steps — a significant reduction in I/O for datasets
# like MERRA-2 3-hourly files where each granule contains several time
# steps (e.g. shape (time=8, lev=72, lat=361, lon=576)).
_time_pts_da: "xr.DataArray | None" = None
if time_dim is not None and time_dim not in base_sel:
_raw_times = [row.get("time") for row in rows]
if all(t is not None for t in _raw_times):
try:
_time_pts_da = xr.DataArray(
[pd.Timestamp(t) for t in _raw_times],
dims=[pts_dim],
)
except Exception:
pass # fall back to per-point _select_time after .load()

for var in variables:
try:
# Build a per-variable selection dict. Add time as a vectorised
# indexer when (a) a valid time indexer was pre-computed above,
# and (b) this variable actually has the time dimension. This
# avoids loading all time steps during .load() when only one is
# needed per point.
_var_sel = dict(base_sel)
_sel_has_time = time_dim in _var_sel
if (
_time_pts_da is not None
and time_dim is not None
and not _sel_has_time
and time_dim in ds[var].dims
):
_var_sel[time_dim] = _time_pts_da
_sel_has_time = True
selected = ds[var].sel(_var_sel, method="nearest")
# Load the selected data into memory in a single operation before
# iterating over individual points. Without this, each
# ``float(point_data)`` or ``to_series()`` call inside the loop
# would trigger a full recomputation of the dask graph for the
# entire ``selected`` array (all N points at once), resulting in
# N × (cost of loading all N points) = O(N²) I/O. Calling
# ``.load()`` here materialises the dask graph exactly once so
# that the per-point loop operates on in-memory NumPy data.
selected = selected.load()
# selected has shape (n_points, ...) with pts_dim as leading dim
# (time dimension already reduced when _sel_has_time is True).
for i, row in enumerate(rows):
point_data = selected.isel({pts_dim: i})
if time_dim is not None and not _sel_has_time:
point_data = _select_time(point_data, time_dim, row.get("time"))
if point_data.ndim == 0:
row[var] = float(point_data)
elif point_data.ndim == 1:
# Single leftover dim (e.g. wavelength): expand into
# coord-keyed columns (Rrs_346, Rrs_348, …).
# The bare placeholder column is dropped later in
# _execute_plan when the expanded columns are detected.
row[var] = float("nan") # placeholder dropped later
for coord_val, val in point_data.to_series().items():
row[f"{var}_{int(coord_val)}"] = float(val)
else:
raise ValueError(
f"Variable {var!r} still has {point_data.ndim} unmatched "
f"dimensions {list(point_data.dims)!r} after spatial and "
"additional-axis selection. "
"Add the appropriate axes to coord_spec to match them, or "
"request a variable with fewer dimensions."
)
except ValueError:
raise
except Exception:
for row in rows:
row[var] = float("nan")


def _extract_nearest(
ds: xr.Dataset,
row: dict,
Expand Down
Loading