diff --git a/pyproject.toml b/pyproject.toml index 9b7b2bd..60632a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "point-collocation" -version = "0.8.0" +version = "0.9.0" description = "Point-based lat/lon/time matchups against cloud-hosted NetCDF/Zarr granules" readme = "README.md" license = { file = "LICENSE" } diff --git a/src/point_collocation/core/engine.py b/src/point_collocation/core/engine.py index 944a9d4..2ce161c 100644 --- a/src/point_collocation/core/engine.py +++ b/src/point_collocation/core/engine.py @@ -7,7 +7,8 @@ * Open each granule individually (never ``open_mfdataset``) to minimise cloud I/O and avoid memory leaks. * Extract the requested variables at each point's location/time using - nearest-neighbor selection (gridded) or kdtree (non-gridded, e.g. swath). + vectorised nearest-neighbour selection (gridded, ``axis`` method) or + k-d tree index (non-gridded/swath data). * Collect results into a ``pandas.DataFrame`` with one row per (point, granule) pair. @@ -59,7 +60,7 @@ # Re-export geolocation pairs for callers that import them from this module. from point_collocation.core._open_method import _GEOLOC_PAIRS # noqa: F401 -_VALID_SPATIAL_METHODS = {"nearest", "xoak-kdtree", "kdtree", "auto", "xoak-haversine"} +_VALID_SPATIAL_METHODS = {"axis", "xoak-kdtree", "kdtree", "auto", "xoak-haversine"} # Time dimension names used as a fallback when cf_xarray is not installed or # when the dataset lacks CF-convention axis/units attributes. Tried in order. @@ -136,17 +137,22 @@ def matchup( * ``"auto"`` *(default)* — automatically selects the best method based on the dimensionality of the geolocation coordinates: - - **1-D coordinates** (regular/gridded data): uses ``"nearest"`` - (``ds.sel(..., method="nearest")``). If ``"nearest"`` fails for - any reason, falls back to ``"kdtree"`` automatically. - - **2-D coordinates** (irregular/swath data): uses ``"kdtree"``. + - **1-D coordinates** (regular/gridded data, both lat and lon + are 1-D): uses ``"axis"`` (vectorised + ``ds.sel(..., method="nearest")`` over all points at once). + If ``"axis"`` fails for any reason, falls back to ``"kdtree"`` + automatically. + - **2-D coordinates** (irregular/swath data, or either coordinate + is 2-D): uses ``"kdtree"``. ``xoak-kdtree`` and ``xoak-haversine`` are never selected automatically; set them explicitly if needed. - * ``"nearest"`` — ``ds.sel(..., method="nearest")`` directly. - Requires 1-D coordinate arrays; raises :exc:`ValueError` with a - suggestion to use ``"auto"`` or ``"kdtree"`` for 2-D coordinates. + * ``"axis"`` — Vectorised ``ds.sel(..., method="nearest")`` for all + points in a single call. Requires 1-D (regular-grid) coordinate + arrays for both latitude and longitude; raises :exc:`ValueError` + with a suggestion to use ``"auto"`` or ``"kdtree"`` for 2-D + coordinates. * ``"kdtree"`` — xarray's built-in :class:`xarray.indexes.NDPointIndex` with the default ``ScipyKDTreeAdapter``. Works with both 1-D and 2-D coordinate @@ -475,15 +481,16 @@ def _check_spatial_compat( ) -> None: """Raise if lat/lon dimensionality is incompatible with *spatial_method*. - Only validates for ``spatial_method="nearest"``, which requires 1-D - coordinate arrays. ``spatial_method="xoak-kdtree"``, + Only validates for ``spatial_method="axis"``, which requires 1-D + coordinate arrays for both latitude and longitude. + ``spatial_method="xoak-kdtree"``, ``spatial_method="xoak-haversine"``, ``spatial_method="kdtree"``, and ``spatial_method="auto"`` work with both 1-D and 2-D arrays and are not validated here. Uses only metadata (``dims``) — does **not** load array data. """ - if spatial_method != "nearest": + if spatial_method != "axis": return lon_var = ds.coords[lon_name] if lon_name in ds.coords else ds[lon_name] @@ -494,7 +501,7 @@ def _check_spatial_compat( if lon_ndim != 1 or lat_ndim != 1: raise ValueError( - f"spatial_method='nearest' requires 1-D geolocation arrays, but found " + f"spatial_method='axis' requires 1-D geolocation arrays, but found " f"{lon_name!r} with dims={tuple(lon_var.dims)} and " f"{lat_name!r} with dims={tuple(lat_var.dims)}. " "Use spatial_method='auto' or spatial_method='kdtree' for 2-D " @@ -709,12 +716,12 @@ def _execute_plan( # Track whether we have already validated spatial compat on the first granule. spatial_checked = False - # For "auto" spatial_method, the effective method ("nearest" or "kdtree") + # For "auto" spatial_method, the effective method ("axis" or "kdtree") # is determined on the first opened granule based on lat/lon dimensionality. # For explicit methods this always equals spatial_method. effective_spatial: str = spatial_method - # When auto resolves to "nearest" on 1-D data, allow one fallback to - # "kdtree" per granule if nearest extraction fails. + # When auto resolves to "axis" on 1-D data, allow one fallback to + # "kdtree" per granule if axis extraction fails. auto_1d_fallback: bool = (spatial_method == "auto") # For "auto" open_method, probe only the first granule to determine whether @@ -800,17 +807,24 @@ def _execute_plan( if lat_name in ds.coords else ds[lat_name] ) - if lat_var_check.ndim == 1: - effective_spatial = "nearest" + lon_var_check = ( + ds.coords[lon_name] + if lon_name in ds.coords + else ds[lon_name] + ) + if lat_var_check.ndim == 1 and lon_var_check.ndim == 1: + effective_spatial = "axis" # auto_1d_fallback already True; keep it so - # that a nearest failure falls back to kdtree. + # that an axis failure falls back to kdtree. else: effective_spatial = "kdtree" auto_1d_fallback = False if not silent: + lat_ndim = lat_var_check.ndim + lon_ndim = lon_var_check.ndim print( f"spatial_method='auto': using '{effective_spatial}' " - f"(lat/lon dims: {lat_var_check.ndim}-D)" + f"(lat/lon dims: {lat_ndim}-D/{lon_ndim}-D)" ) else: effective_spatial = spatial_method @@ -835,7 +849,7 @@ def _execute_plan( # to the spatial extent of the query points before building # the k-d tree. A global granule with only a few scattered # points would otherwise cause the index to cover the entire global - # grid, which is very slow. Skip this step for "nearest" (1-D) + # grid, which is very slow. Skip this step for "axis" (1-D) # since it does not build an index. if effective_spatial in ("xoak-kdtree", "xoak-haversine", "kdtree"): lat_var = ds.coords[lat_name] if lat_name in ds.coords else ds[lat_name] @@ -880,8 +894,8 @@ def _execute_plan( output_rows.extend(rows_for_granule) batch_rows.extend(rows_for_granule) elif auto_1d_fallback: - # auto resolved to "nearest" on 1-D coords. Try - # nearest for each point; if it fails, fall back to + # auto resolved to "axis" on 1-D coords. Try + # axis batch extraction; if it fails, fall back to # ndpoint for the whole granule (and all future ones). def _make_row(pt_idx: object) -> dict: r = plan.points.loc[pt_idx].to_dict() @@ -893,10 +907,9 @@ def _make_row(pt_idx: object) -> dict: rows_for_granule = [_make_row(idx) for idx in pt_indices] try: - for row in rows_for_granule: - _extract_nearest(ds, row, variables, lon_name, lat_name, time_dim, additional_axes=resolved_add_axes) - except Exception as _nearest_exc: - # nearest failed; rebuild clean rows and retry with kdtree. + _extract_axis_batch(ds, rows_for_granule, variables, lon_name, lat_name, time_dim, additional_axes=resolved_add_axes) + except Exception as _axis_exc: + # axis failed; rebuild clean rows and retry with kdtree. rows_for_granule = [_make_row(idx) for idx in pt_indices] # Apply slicing for kdtree on 1-D coords. pt_lats = [float(plan.points.loc[idx]["lat"]) for idx in pt_indices] @@ -911,24 +924,27 @@ def _make_row(pt_idx: object) -> dict: auto_1d_fallback = False except Exception as nd_exc: raise ValueError( - "spatial_method='auto' tried both 'nearest' and 'kdtree' " + "spatial_method='auto' tried both 'axis' and 'kdtree' " "for a granule with 1-D lat/lon coordinates, but both " "failed. Check that the dataset has valid geolocation " - f"coordinates. 'nearest' error: {_nearest_exc!r}; " + f"coordinates. 'axis' error: {_axis_exc!r}; " f"'kdtree' error: {nd_exc!r}" ) from nd_exc output_rows.extend(rows_for_granule) batch_rows.extend(rows_for_granule) else: + # explicit spatial_method="axis": batch all points in one call. + rows_for_granule = [] for pt_idx in pt_indices: row = plan.points.loc[pt_idx].to_dict() if not has_user_pc_id: row["pc_id"] = pt_idx row["granule_id"] = gm.granule_id row["granule_time"] = granule_time - _extract_nearest(ds, row, variables, lon_name, lat_name, time_dim, additional_axes=resolved_add_axes) - output_rows.append(row) - batch_rows.append(row) + rows_for_granule.append(row) + _extract_axis_batch(ds, rows_for_granule, variables, lon_name, lat_name, time_dim, additional_axes=resolved_add_axes) + output_rows.extend(rows_for_granule) + batch_rows.extend(rows_for_granule) batch_matched_points += len(pt_indices) @@ -1261,6 +1277,153 @@ def _drop_nan_geoloc( return stacked.isel({"__pc__": valid}) +def _extract_axis_batch( + ds: xr.Dataset, + rows: list[dict], + variables: list[str], + lon_name: str, + lat_name: str, + time_dim: str | None = None, + *, + additional_axes: "dict[str, dict[str, str]] | None" = None, +) -> None: + """Vectorized extraction using ``ds.sel(..., method='nearest')`` for 1-D lat/lon. + + Batches all points into a single ``.sel()`` call using xarray's + vectorised indexing API. This requires 1-D (regular-grid) coordinate + arrays for both latitude and longitude. + + Modifies each dict in *rows* in-place, including ``granule_lat`` and + ``granule_lon`` for the matched grid location. ``granule_time`` is set by + the caller from granule metadata before this function is called. + + Parameters + ---------- + time_dim: + Name of the time dimension in *ds*, as detected by + :func:`_find_time_dim`. When not ``None``, each variable is + squeezed or nearest-selected along this dimension after spatial + selection so that the result is always free of the time axis. + additional_axes: + Resolved additional 1D matching axes from + :func:`_resolve_additional_axes_for_ds`. Each entry maps an axis + name to ``{"points_col": ..., "source_coord": ...}``. When provided, + a vectorised nearest-neighbour selection is also performed along each + axis. + """ + if additional_axes is None: + additional_axes = {} + if not rows: + return + + lats = [row["lat"] for row in rows] + lons = [row["lon"] for row in rows] + + # Use a unique dimension name for the points indexer to avoid conflicts + # with any existing coordinate in the dataset. + pts_dim = "pc_points_idx" + lat_pts = xr.DataArray(lats, dims=[pts_dim]) + lon_pts = xr.DataArray(lons, dims=[pts_dim]) + + # Extract the actual matched coordinates (nearest-neighbour grid positions). + try: + matched_lats = ds.coords[lat_name].sel({lat_name: lat_pts}, method="nearest").values + matched_lons = ds.coords[lon_name].sel({lon_name: lon_pts}, method="nearest").values + for i, row in enumerate(rows): + row["granule_lat"] = float(matched_lats[i]) + row["granule_lon"] = float(matched_lons[i]) + except Exception: + for row in rows: + row["granule_lat"] = float("nan") + row["granule_lon"] = float("nan") + + # Build the base vectorised selection dict. + base_sel: dict = {lat_name: lat_pts, lon_name: lon_pts} + + # Add additional axes as vectorised indexers when all row values are present. + for axis_name, info in additional_axes.items(): + pts_col = info["points_col"] + src_coord = info["source_coord"] + ax_vals = [row.get(pts_col) for row in rows] + if all(v is not None for v in ax_vals): + base_sel[src_coord] = xr.DataArray(ax_vals, dims=[pts_dim]) + + # Pre-compute a vectorised time indexer (one timestamp per point). + # When the variable being extracted has a time dimension, including it in + # the .sel() call means that .load() only reads one time step per point + # instead of all time steps — a significant reduction in I/O for datasets + # like MERRA-2 3-hourly files where each granule contains several time + # steps (e.g. shape (time=8, lev=72, lat=361, lon=576)). + _time_pts_da: "xr.DataArray | None" = None + if time_dim is not None and time_dim not in base_sel: + _raw_times = [row.get("time") for row in rows] + if all(t is not None for t in _raw_times): + try: + _time_pts_da = xr.DataArray( + [pd.Timestamp(t) for t in _raw_times], + dims=[pts_dim], + ) + except Exception: + pass # fall back to per-point _select_time after .load() + + for var in variables: + try: + # Build a per-variable selection dict. Add time as a vectorised + # indexer when (a) a valid time indexer was pre-computed above, + # and (b) this variable actually has the time dimension. This + # avoids loading all time steps during .load() when only one is + # needed per point. + _var_sel = dict(base_sel) + _sel_has_time = time_dim in _var_sel + if ( + _time_pts_da is not None + and time_dim is not None + and not _sel_has_time + and time_dim in ds[var].dims + ): + _var_sel[time_dim] = _time_pts_da + _sel_has_time = True + selected = ds[var].sel(_var_sel, method="nearest") + # Load the selected data into memory in a single operation before + # iterating over individual points. Without this, each + # ``float(point_data)`` or ``to_series()`` call inside the loop + # would trigger a full recomputation of the dask graph for the + # entire ``selected`` array (all N points at once), resulting in + # N × (cost of loading all N points) = O(N²) I/O. Calling + # ``.load()`` here materialises the dask graph exactly once so + # that the per-point loop operates on in-memory NumPy data. + selected = selected.load() + # selected has shape (n_points, ...) with pts_dim as leading dim + # (time dimension already reduced when _sel_has_time is True). + for i, row in enumerate(rows): + point_data = selected.isel({pts_dim: i}) + if time_dim is not None and not _sel_has_time: + point_data = _select_time(point_data, time_dim, row.get("time")) + if point_data.ndim == 0: + row[var] = float(point_data) + elif point_data.ndim == 1: + # Single leftover dim (e.g. wavelength): expand into + # coord-keyed columns (Rrs_346, Rrs_348, …). + # The bare placeholder column is dropped later in + # _execute_plan when the expanded columns are detected. + row[var] = float("nan") # placeholder dropped later + for coord_val, val in point_data.to_series().items(): + row[f"{var}_{int(coord_val)}"] = float(val) + else: + raise ValueError( + f"Variable {var!r} still has {point_data.ndim} unmatched " + f"dimensions {list(point_data.dims)!r} after spatial and " + "additional-axis selection. " + "Add the appropriate axes to coord_spec to match them, or " + "request a variable with fewer dimensions." + ) + except ValueError: + raise + except Exception: + for row in rows: + row[var] = float("nan") + + def _extract_nearest( ds: xr.Dataset, row: dict, diff --git a/tests/test_plan.py b/tests/test_plan.py index 6c606bd..5952960 100644 --- a/tests/test_plan.py +++ b/tests/test_plan.py @@ -2092,7 +2092,7 @@ def test_matchup_batch_size_controls_print_frequency( open_dataset_kwargs={"engine": "netcdf4"}, silent=False, batch_size=1, - spatial_method="nearest", + spatial_method="axis", ) captured = capsys.readouterr() lines = [ln for ln in captured.out.splitlines() if ln.strip() and "granules" in ln] @@ -2640,7 +2640,7 @@ def test_default_batch_size_processes_all_in_one_batch( # (all 3 granules processed in a single batch). # (A "Points columns" header line is also printed; filter it out.) pc.matchup(p, open_method="dataset", open_dataset_kwargs={"engine": "netcdf4"}, - silent=False, spatial_method="nearest") + silent=False, spatial_method="axis") captured = capsys.readouterr() lines = [ln for ln in captured.out.strip().splitlines() if ln.strip() and "granules" in ln] assert len(lines) == 1, ( @@ -3044,7 +3044,7 @@ def test_granule_range_progress_shows_absolute_numbers( silent=False, batch_size=1, granule_range=(2, 3), - spatial_method="nearest", + spatial_method="axis", ) captured = capsys.readouterr() lines = [ln for ln in captured.out.splitlines() if ln.strip() and "granules" in ln] @@ -5007,14 +5007,14 @@ def test_cf_partial_detection_raises(self) -> None: class TestSpatialCompatCheck: """Tests for _check_spatial_compat().""" - def test_nearest_1d_ok(self) -> None: + def test_axis_1d_ok(self) -> None: from point_collocation.core.engine import _check_spatial_compat ds = xr.Dataset(coords={"lon": [0.0], "lat": [0.0]}) # Should not raise - _check_spatial_compat(ds, "lon", "lat", "nearest") + _check_spatial_compat(ds, "lon", "lat", "axis") - def test_nearest_2d_raises(self) -> None: + def test_axis_2d_raises(self) -> None: from point_collocation.core.engine import _check_spatial_compat ds = xr.Dataset( @@ -5023,10 +5023,10 @@ def test_nearest_2d_raises(self) -> None: "lat": (["nrows", "ncols"], [[0.0]]), } ) - with pytest.raises(ValueError, match="spatial_method='nearest'"): - _check_spatial_compat(ds, "lon", "lat", "nearest") + with pytest.raises(ValueError, match="spatial_method='axis'"): + _check_spatial_compat(ds, "lon", "lat", "axis") - def test_nearest_2d_error_mentions_auto(self) -> None: + def test_axis_2d_error_mentions_auto(self) -> None: from point_collocation.core.engine import _check_spatial_compat ds = xr.Dataset( @@ -5036,7 +5036,7 @@ def test_nearest_2d_error_mentions_auto(self) -> None: } ) with pytest.raises(ValueError, match="auto"): - _check_spatial_compat(ds, "lon", "lat", "nearest") + _check_spatial_compat(ds, "lon", "lat", "axis") def test_auto_any_dims_ok(self) -> None: from point_collocation.core.engine import _check_spatial_compat @@ -5202,10 +5202,10 @@ def test_swath_matchup_returns_nearest_value( assert len(result) == 1 assert not math.isnan(result.loc[0, "sst"]) - def test_nearest_with_2d_data_raises( + def test_axis_with_2d_data_raises( self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch ) -> None: - """spatial_method='nearest' with 2-D lat/lon raises a clear ValueError.""" + """spatial_method='axis' with 2-D lat/lon raises a clear ValueError.""" nc_path = str(tmp_path / "swath.nc") _make_l2_swath_dataset(nrows=4, ncols=5).to_netcdf(nc_path, engine="netcdf4") @@ -5233,11 +5233,11 @@ def test_nearest_with_2d_data_raises( time_buffer=pd.Timedelta(0), ) - with pytest.raises(ValueError, match="spatial_method='nearest'"): + with pytest.raises(ValueError, match="spatial_method='axis'"): pc.matchup( p, open_method="dataset", - spatial_method="nearest", + spatial_method="axis", open_dataset_kwargs={"engine": "netcdf4"}, ) @@ -6401,8 +6401,7 @@ def _make_granule_meta(self) -> "GranuleMeta": def test_auto_is_default( self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch ) -> None: - """Calling matchup() without spatial_method uses 'auto' (1-D coords → nearest).""" - pytest.importorskip("scipy") + """Calling matchup() without spatial_method uses 'auto' (1-D coords → axis).""" nc_path = str(tmp_path / "grid.nc") _make_l3_dataset([-90.0, 0.0, 90.0], [-180.0, 0.0, 180.0]).to_netcdf(nc_path, engine="netcdf4") @@ -6427,10 +6426,10 @@ def test_auto_is_default( assert "sst" in result.columns assert not math.isnan(result.loc[0, "sst"]) - def test_auto_1d_routes_to_nearest( + def test_auto_1d_routes_to_axis( self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch ) -> None: - """auto with 1-D coords routes to 'nearest' (no scipy/xoak required).""" + """auto with 1-D coords routes to 'axis' (no scipy/xoak required).""" nc_path = str(tmp_path / "grid.nc") _make_l3_dataset([-90.0, 0.0, 90.0], [-180.0, 0.0, 180.0], seed=5).to_netcdf( nc_path, engine="netcdf4" @@ -6450,7 +6449,7 @@ def test_auto_1d_routes_to_nearest( source_kwargs={"short_name": "TEST"}, time_buffer=pd.Timedelta(0), ) - # auto + 1D coords → should produce the same result as explicit nearest + # auto + 1D coords → should produce the same result as explicit axis result_auto = pc.matchup( p, open_method="dataset", variables=["sst"], spatial_method="auto", open_dataset_kwargs={"engine": "netcdf4"}, @@ -6464,11 +6463,11 @@ def test_auto_1d_routes_to_nearest( source_kwargs={"short_name": "TEST"}, time_buffer=pd.Timedelta(0), ) - result_nearest = pc.matchup( + result_axis = pc.matchup( p2, open_method="dataset", variables=["sst"], - spatial_method="nearest", open_dataset_kwargs={"engine": "netcdf4"}, + spatial_method="axis", open_dataset_kwargs={"engine": "netcdf4"}, ) - assert result_auto.loc[0, "sst"] == pytest.approx(result_nearest.loc[0, "sst"]) + assert result_auto.loc[0, "sst"] == pytest.approx(result_axis.loc[0, "sst"]) def test_auto_2d_routes_to_kdtree( self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch @@ -6555,7 +6554,7 @@ def test_auto_prints_resolved_method( ) -> None: """auto prints a one-line message showing the resolved spatial method and dims.""" pytest.importorskip("scipy") - # Test 1-D path (nearest) + # Test 1-D path (axis) nc_path_1d = str(tmp_path / "grid.nc") _make_l3_dataset([-90.0, 0.0, 90.0], [-180.0, 0.0, 180.0]).to_netcdf( nc_path_1d, engine="netcdf4" @@ -6579,7 +6578,7 @@ def test_auto_prints_resolved_method( open_dataset_kwargs={"engine": "netcdf4"}) captured = capsys.readouterr() assert "spatial_method='auto'" in captured.out - assert "'nearest'" in captured.out + assert "'axis'" in captured.out assert "1-D" in captured.out # Test 2-D path (kdtree) @@ -6630,7 +6629,7 @@ def test_explicit_method_does_not_print_auto_message( source_kwargs={"short_name": "TEST"}, time_buffer=pd.Timedelta(0), ) - pc.matchup(p, open_method="dataset", spatial_method="nearest", silent=False, + pc.matchup(p, open_method="dataset", spatial_method="axis", silent=False, open_dataset_kwargs={"engine": "netcdf4"}) captured = capsys.readouterr() assert "spatial_method='auto'" not in captured.out @@ -6651,10 +6650,10 @@ def test_auto_invalid_string_raises(self) -> None: with pytest.raises(ValueError, match="spatial_method"): pc.matchup(p, spatial_method="bogus") - def test_explicit_nearest_with_2d_raises_useful_message( + def test_explicit_axis_with_2d_raises_useful_message( self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch ) -> None: - """Explicit nearest with 2-D coords raises ValueError mentioning 'auto'/'kdtree'.""" + """Explicit axis with 2-D coords raises ValueError mentioning 'auto'/'kdtree'.""" nc_path = str(tmp_path / "swath.nc") _make_l2_swath_dataset(nrows=4, ncols=5).to_netcdf(nc_path, engine="netcdf4") mock_ea = MagicMock() @@ -6675,7 +6674,7 @@ def test_explicit_nearest_with_2d_raises_useful_message( ) with pytest.raises(ValueError, match="auto"): pc.matchup( - p, open_method="dataset", spatial_method="nearest", + p, open_method="dataset", spatial_method="axis", open_dataset_kwargs={"engine": "netcdf4"}, ) @@ -7674,7 +7673,7 @@ def test_matchup_accepts_coord_spec_param( p, open_method="dataset", open_dataset_kwargs={"engine": "netcdf4"}, - spatial_method="nearest", + spatial_method="axis", coord_spec={"time": {"source": "auto", "points": "auto"}}, ) assert "sst" in result.columns @@ -7693,7 +7692,7 @@ def test_matchup_with_depth_additional_axis( p, open_method="dataset", open_dataset_kwargs={"engine": "netcdf4"}, - spatial_method="nearest", + spatial_method="axis", coord_spec=coord_spec, ) assert "temp" in result.columns @@ -7712,7 +7711,7 @@ def test_matchup_without_depth_expands_columns( p, open_method="dataset", open_dataset_kwargs={"engine": "netcdf4"}, - spatial_method="nearest", + spatial_method="axis", ) # Should have expanded columns like temp_0, temp_10, temp_50, temp_100 expanded_cols = [c for c in result.columns if c.startswith("temp_")] @@ -7751,7 +7750,7 @@ def test_matchup_prints_coord_spec_summary_when_not_silent( p, open_method="dataset", open_dataset_kwargs={"engine": "netcdf4"}, - spatial_method="nearest", + spatial_method="axis", silent=False, ) captured = capsys.readouterr() @@ -7878,7 +7877,7 @@ def test_plan_with_latitude_longitude_normalised_for_matching( p, open_method="dataset", open_dataset_kwargs={"engine": "netcdf4"}, - spatial_method="nearest", + spatial_method="axis", ) assert "sst" in result.columns @@ -8057,11 +8056,11 @@ def test_coord_spec_none_uses_defaults( mock_ea.open.return_value = [nc_path] r1 = pc.matchup(p, open_method="dataset", open_dataset_kwargs={"engine": "netcdf4"}, - spatial_method="nearest", coord_spec=None) + spatial_method="axis", coord_spec=None) mock_ea.open.return_value = [nc_path] r2 = pc.matchup(p, open_method="dataset", open_dataset_kwargs={"engine": "netcdf4"}, - spatial_method="nearest") + spatial_method="axis") assert list(r1.columns) == list(r2.columns) assert len(r1) == len(r2) @@ -8307,3 +8306,490 @@ def test_consistent_sources_in_both_no_conflict( assert isinstance(ds, xr.Dataset) assert "grid_lat" in ds.coords assert "grid_lon" in ds.coords + + +# --------------------------------------------------------------------------- +# Tests for _extract_axis_batch (spatial_method="axis") +# --------------------------------------------------------------------------- + +class TestExtractAxisBatch: + """Unit tests for _extract_axis_batch() and spatial_method='axis'.""" + + def _make_granule_meta(self) -> "GranuleMeta": + return GranuleMeta( + granule_id="https://example.com/test.nc", + begin=pd.Timestamp("2023-06-01T00:00:00Z"), + end=pd.Timestamp("2023-06-01T23:59:59Z"), + bbox=(-180.0, -90.0, 180.0, 90.0), + result_index=0, + ) + + def _make_plan( + self, + nc_path: str, + monkeypatch: pytest.MonkeyPatch, + pts: pd.DataFrame, + ) -> "Plan": + mock_ea = MagicMock() + mock_ea.open.return_value = [nc_path] + monkeypatch.setitem(__import__("sys").modules, "earthaccess", mock_ea) + gm = self._make_granule_meta() + return Plan( + points=pts, + results=[object()], + granules=[gm], + point_granule_map={i: [0] for i in pts.index}, + variables=["sst"], + source_kwargs={"short_name": "TEST"}, + time_buffer=pd.Timedelta(0), + ) + + # ------------------------------------------------------------------ + # Direct unit tests for _extract_axis_batch + # ------------------------------------------------------------------ + + def test_basic_vectorized_extraction(self) -> None: + """_extract_axis_batch returns correct values for multiple points.""" + from point_collocation.core.engine import _extract_axis_batch + + lats = [-90.0, 0.0, 90.0] + lons = [-180.0, 0.0, 180.0] + sst_data = np.arange(9, dtype=np.float32).reshape(3, 3) + ds = xr.Dataset( + {"sst": (["lat", "lon"], sst_data)}, + coords={"lat": lats, "lon": lons}, + ) + rows = [ + {"lat": -90.0, "lon": -180.0}, + {"lat": 0.0, "lon": 0.0}, + {"lat": 90.0, "lon": 180.0}, + ] + _extract_axis_batch(ds, rows, ["sst"], "lon", "lat") + # Each point should get the exact grid value + assert rows[0]["sst"] == pytest.approx(float(sst_data[0, 0])) + assert rows[1]["sst"] == pytest.approx(float(sst_data[1, 1])) + assert rows[2]["sst"] == pytest.approx(float(sst_data[2, 2])) + + def test_granule_lat_lon_set(self) -> None: + """_extract_axis_batch sets granule_lat and granule_lon.""" + from point_collocation.core.engine import _extract_axis_batch + + ds = xr.Dataset( + {"sst": (["lat", "lon"], [[1.0, 2.0], [3.0, 4.0]])}, + coords={"lat": [0.0, 10.0], "lon": [0.0, 10.0]}, + ) + rows = [{"lat": 1.0, "lon": 1.0}] # nearest is (0, 0) + _extract_axis_batch(ds, rows, ["sst"], "lon", "lat") + assert rows[0]["granule_lat"] == pytest.approx(0.0) + assert rows[0]["granule_lon"] == pytest.approx(0.0) + + def test_empty_rows_is_noop(self) -> None: + """_extract_axis_batch with empty rows does not raise.""" + from point_collocation.core.engine import _extract_axis_batch + + ds = xr.Dataset( + {"sst": (["lat", "lon"], [[1.0]])}, + coords={"lat": [0.0], "lon": [0.0]}, + ) + _extract_axis_batch(ds, [], ["sst"], "lon", "lat") # no error + + def test_single_point(self) -> None: + """_extract_axis_batch works for a single point.""" + from point_collocation.core.engine import _extract_axis_batch + + ds = xr.Dataset( + {"sst": (["lat", "lon"], [[5.0, 6.0]])}, + coords={"lat": [0.0], "lon": [-10.0, 10.0]}, + ) + rows = [{"lat": 0.1, "lon": -9.9}] + _extract_axis_batch(ds, rows, ["sst"], "lon", "lat") + assert rows[0]["sst"] == pytest.approx(5.0) + + def test_many_points(self) -> None: + """_extract_axis_batch handles many points in one call.""" + from point_collocation.core.engine import _extract_axis_batch + + n = 50 + lats = np.linspace(-90, 90, 10) + lons = np.linspace(-180, 180, 10) + sst_data = np.random.default_rng(42).uniform(20, 30, (10, 10)).astype(np.float32) + ds = xr.Dataset( + {"sst": (["lat", "lon"], sst_data)}, + coords={"lat": lats, "lon": lons}, + ) + # Use random query points + rng = np.random.default_rng(99) + rows = [ + {"lat": float(rng.uniform(-90, 90)), "lon": float(rng.uniform(-180, 180))} + for _ in range(n) + ] + _extract_axis_batch(ds, rows, ["sst"], "lon", "lat") + # All rows must have sst set (not NaN) + for row in rows: + assert not math.isnan(row["sst"]) + + def test_without_time_dimension(self) -> None: + """_extract_axis_batch handles dataset without time dimension.""" + from point_collocation.core.engine import _extract_axis_batch + + ds = xr.Dataset( + {"sst": (["lat", "lon"], [[10.0, 11.0], [12.0, 13.0]])}, + coords={"lat": [0.0, 1.0], "lon": [0.0, 1.0]}, + ) + rows = [{"lat": 0.0, "lon": 0.0, "time": pd.Timestamp("2023-06-01")}] + _extract_axis_batch(ds, rows, ["sst"], "lon", "lat", time_dim=None) + assert rows[0]["sst"] == pytest.approx(10.0) + + def test_with_singleton_time_dimension(self) -> None: + """_extract_axis_batch squeezes singleton time dimension.""" + from point_collocation.core.engine import _extract_axis_batch + + ds = xr.Dataset( + {"sst": (["time", "lat", "lon"], [[[10.0, 11.0], [12.0, 13.0]]])}, + coords={ + "time": pd.to_datetime(["2023-06-01"]), + "lat": [0.0, 1.0], + "lon": [0.0, 1.0], + }, + ) + rows = [{"lat": 1.0, "lon": 1.0, "time": pd.Timestamp("2023-06-01")}] + _extract_axis_batch(ds, rows, ["sst"], "lon", "lat", time_dim="time") + assert rows[0]["sst"] == pytest.approx(13.0) + + def test_with_multiple_time_steps_selects_nearest(self) -> None: + """_extract_axis_batch selects nearest time step per point.""" + from point_collocation.core.engine import _extract_axis_batch + + times = pd.to_datetime(["2023-06-01", "2023-06-02", "2023-06-03"]) + sst_data = np.array([[[10.0]], [[20.0]], [[30.0]]]) # shape (3, 1, 1) + ds = xr.Dataset( + {"sst": (["time", "lat", "lon"], sst_data)}, + coords={"time": times, "lat": [0.0], "lon": [0.0]}, + ) + rows = [ + {"lat": 0.0, "lon": 0.0, "time": pd.Timestamp("2023-06-02")}, + {"lat": 0.0, "lon": 0.0, "time": pd.Timestamp("2023-06-03")}, + ] + _extract_axis_batch(ds, rows, ["sst"], "lon", "lat", time_dim="time") + assert rows[0]["sst"] == pytest.approx(20.0) + assert rows[1]["sst"] == pytest.approx(30.0) + + def test_time_dimension_excluded_from_loaded_array(self) -> None: + """Vectorized time selection: loaded array has no time dim per point. + + With multiple time steps in the dataset, the time dimension must be + reduced *before* .load() (by including it in the .sel() call as a + vectorised indexer). This test checks that each point gets the + correct per-point time step *and* that the result contains values + from distinct time steps (not e.g. all from index 0), which would + only be possible if time was selected vectorially per point. + """ + from point_collocation.core.engine import _extract_axis_batch + + times = pd.to_datetime( + ["2023-06-01T00:00", "2023-06-01T03:00", "2023-06-01T06:00", + "2023-06-01T09:00", "2023-06-01T12:00", "2023-06-01T15:00", + "2023-06-01T18:00", "2023-06-01T21:00"] + ) + # 8 time steps, 3 lat, 3 lon — each time step has a distinct constant value + sst_vals = np.arange(1.0, 9.0).reshape(8, 1, 1) * np.ones((8, 3, 3)) + ds = xr.Dataset( + {"sst": (["time", "lat", "lon"], sst_vals.astype(np.float32))}, + coords={"time": times, "lat": [-1.0, 0.0, 1.0], "lon": [-1.0, 0.0, 1.0]}, + ) + # Each row requests a different time step. + rows = [ + {"lat": 0.0, "lon": 0.0, "time": times[2]}, # -> value 3.0 + {"lat": 0.0, "lon": 0.0, "time": times[5]}, # -> value 6.0 + {"lat": 0.0, "lon": 0.0, "time": times[7]}, # -> value 8.0 + ] + _extract_axis_batch(ds, rows, ["sst"], "lon", "lat", time_dim="time") + assert rows[0]["sst"] == pytest.approx(3.0) + assert rows[1]["sst"] == pytest.approx(6.0) + assert rows[2]["sst"] == pytest.approx(8.0) + + def test_time_dimension_with_extra_axis(self) -> None: + """Vectorized time + extra-axis selection both work together.""" + from point_collocation.core.engine import _extract_axis_batch + + times = pd.to_datetime(["2023-06-01", "2023-06-02"]) + levels = [100.0, 500.0, 1000.0] + # shape (time=2, lev=3, lat=2, lon=2) + # value = time_idx * 100 + lev_idx + data = np.array( + [[[[0, 0], [0, 0]], [[1, 1], [1, 1]], [[2, 2], [2, 2]]], + [[[100, 100], [100, 100]], [[101, 101], [101, 101]], [[102, 102], [102, 102]]]], + dtype=np.float32, + ) + ds = xr.Dataset( + {"omega": (["time", "lev", "lat", "lon"], data)}, + coords={ + "time": times, + "lev": levels, + "lat": [0.0, 1.0], + "lon": [0.0, 1.0], + }, + ) + rows = [ + {"lat": 0.0, "lon": 0.0, "time": times[0], "lev_val": 500.0}, # time=0, lev=1 -> 1 + {"lat": 0.0, "lon": 0.0, "time": times[1], "lev_val": 1000.0}, # time=1, lev=2 -> 102 + ] + additional_axes = {"lev": {"points_col": "lev_val", "source_coord": "lev"}} + _extract_axis_batch( + ds, rows, ["omega"], "lon", "lat", + time_dim="time", additional_axes=additional_axes, + ) + assert rows[0]["omega"] == pytest.approx(1.0) + assert rows[1]["omega"] == pytest.approx(102.0) + + def test_with_extra_axis(self) -> None: + """_extract_axis_batch handles additional axes (e.g. depth).""" + from point_collocation.core.engine import _extract_axis_batch + + depths = [0.0, 10.0, 20.0] + sst_data = np.array( + [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]] + ) # shape (lat=2, lon=2, depth=3) + ds = xr.Dataset( + {"temp": (["lat", "lon", "depth"], sst_data)}, + coords={"lat": [0.0, 1.0], "lon": [0.0, 1.0], "depth": depths}, + ) + rows = [ + {"lat": 0.0, "lon": 0.0, "depth": 0.0}, + {"lat": 0.0, "lon": 0.0, "depth": 10.0}, + ] + additional_axes = {"depth": {"points_col": "depth", "source_coord": "depth"}} + _extract_axis_batch(ds, rows, ["temp"], "lon", "lat", additional_axes=additional_axes) + assert rows[0]["temp"] == pytest.approx(1.0) + assert rows[1]["temp"] == pytest.approx(2.0) + + def test_expands_wavelength_dimension(self) -> None: + """_extract_axis_batch expands a leftover wavelength dimension into columns.""" + from point_collocation.core.engine import _extract_axis_batch + + wavelengths = [412, 443, 490] + # Rrs has dims (lat, lon, wavelength) + rrs_data = np.zeros((2, 2, 3), dtype=np.float32) + rrs_data[0, 0, :] = [0.01, 0.02, 0.03] + rrs_data[0, 1, :] = [0.04, 0.05, 0.06] + rrs_data[1, 0, :] = [0.07, 0.08, 0.09] + rrs_data[1, 1, :] = [0.10, 0.11, 0.12] + ds = xr.Dataset( + {"Rrs": (["lat", "lon", "wavelength"], rrs_data)}, + coords={"lat": [0.0, 1.0], "lon": [0.0, 10.0], "wavelength": wavelengths}, + ) + rows = [{"lat": 0.0, "lon": 0.0}] + _extract_axis_batch(ds, rows, ["Rrs"], "lon", "lat") + # Should expand into Rrs_412, Rrs_443, Rrs_490 + assert "Rrs_412" in rows[0] + assert "Rrs_443" in rows[0] + assert "Rrs_490" in rows[0] + assert rows[0]["Rrs_412"] == pytest.approx(0.01) + assert rows[0]["Rrs_443"] == pytest.approx(0.02) + assert rows[0]["Rrs_490"] == pytest.approx(0.03) + + # ------------------------------------------------------------------ + # Integration tests via pc.matchup() + # ------------------------------------------------------------------ + + def test_axis_method_1d_grid( + self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """spatial_method='axis' works correctly on a 1-D regular grid.""" + nc_path = str(tmp_path / "grid.nc") + lats = [-90.0, -45.0, 0.0, 45.0, 90.0] + lons = [-180.0, -90.0, 0.0, 90.0, 180.0] + ds = _make_l3_dataset(lats, lons, seed=7) + ds.to_netcdf(nc_path, engine="netcdf4") + + pts = pd.DataFrame( + { + "lat": [0.0, -45.0, 90.0], + "lon": [0.0, -90.0, 180.0], + "time": pd.to_datetime(["2023-06-01"] * 3), + } + ) + p = self._make_plan(nc_path, monkeypatch, pts) + result = pc.matchup( + p, + open_method="dataset", + variables=["sst"], + spatial_method="axis", + open_dataset_kwargs={"engine": "netcdf4"}, + ) + assert "sst" in result.columns + assert len(result) == 3 + assert not result["sst"].isna().any() + # Verify matched lat/lon are grid coords + assert result["granule_lat"].iloc[0] == pytest.approx(0.0) + assert result["granule_lon"].iloc[0] == pytest.approx(0.0) + + def test_axis_with_time_dimension( + self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """spatial_method='axis' correctly handles datasets with a time dimension.""" + nc_path = str(tmp_path / "timegrid.nc") + lats = [-45.0, 0.0, 45.0] + lons = [-90.0, 0.0, 90.0] + times = ["2023-06-01", "2023-06-02", "2023-06-03"] + ds = _make_l3_time_dataset(lats, lons, times, seed=11) + ds.to_netcdf(nc_path, engine="netcdf4") + + mock_ea = MagicMock() + mock_ea.open.return_value = [nc_path] + monkeypatch.setitem(__import__("sys").modules, "earthaccess", mock_ea) + + pts = pd.DataFrame( + { + "lat": [0.0, 0.0], + "lon": [0.0, 0.0], + "time": pd.to_datetime(["2023-06-01", "2023-06-03"]), + } + ) + gm = self._make_granule_meta() + p = Plan( + points=pts, + results=[object()], + granules=[gm], + point_granule_map={0: [0], 1: [0]}, + variables=["sst"], + source_kwargs={"short_name": "TEST"}, + time_buffer=pd.Timedelta(0), + ) + result = pc.matchup( + p, + open_method="dataset", + variables=["sst"], + spatial_method="axis", + open_dataset_kwargs={"engine": "netcdf4"}, + ) + assert "sst" in result.columns + assert len(result) == 2 + assert not result["sst"].isna().any() + + def test_axis_without_time_dimension( + self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """spatial_method='axis' handles dataset without any time dimension.""" + nc_path = str(tmp_path / "notimegrid.nc") + ds = _make_l3_dataset([-45.0, 0.0, 45.0], [-90.0, 0.0, 90.0], seed=3) + ds.to_netcdf(nc_path, engine="netcdf4") + + pts = pd.DataFrame( + { + "lat": [0.0, -45.0], + "lon": [0.0, -90.0], + "time": pd.to_datetime(["2023-06-01", "2023-06-01"]), + } + ) + p = self._make_plan(nc_path, monkeypatch, pts) + result = pc.matchup( + p, + open_method="dataset", + variables=["sst"], + spatial_method="axis", + open_dataset_kwargs={"engine": "netcdf4"}, + ) + assert "sst" in result.columns + assert len(result) == 2 + assert not result["sst"].isna().any() + + def test_axis_matches_auto_for_1d_coords( + self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """spatial_method='axis' and 'auto' return the same values for 1-D coords.""" + nc_path = str(tmp_path / "grid.nc") + ds = _make_l3_dataset([-90.0, 0.0, 90.0], [-180.0, 0.0, 180.0], seed=55) + ds.to_netcdf(nc_path, engine="netcdf4") + + pts = pd.DataFrame( + { + "lat": [0.0, -90.0], + "lon": [0.0, 180.0], + "time": pd.to_datetime(["2023-06-01", "2023-06-01"]), + } + ) + + def _run(method: str) -> "pd.DataFrame": + mock_ea = MagicMock() + mock_ea.open.return_value = [nc_path] + monkeypatch.setitem(__import__("sys").modules, "earthaccess", mock_ea) + gm = self._make_granule_meta() + p = Plan( + points=pts.copy(), + results=[object()], + granules=[gm], + point_granule_map={0: [0], 1: [0]}, + variables=["sst"], + source_kwargs={"short_name": "TEST"}, + time_buffer=pd.Timedelta(0), + ) + return pc.matchup( + p, open_method="dataset", variables=["sst"], + spatial_method=method, open_dataset_kwargs={"engine": "netcdf4"}, + ) + + r_axis = _run("axis") + r_auto = _run("auto") + for i in range(len(pts)): + assert r_axis.loc[i, "sst"] == pytest.approx(r_auto.loc[i, "sst"]) + + def test_auto_uses_axis_for_1d_coords( + self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture + ) -> None: + """spatial_method='auto' with 1-D coords resolves to 'axis' (no scipy needed).""" + nc_path = str(tmp_path / "grid.nc") + _make_l3_dataset([-90.0, 0.0, 90.0], [-180.0, 0.0, 180.0]).to_netcdf(nc_path, engine="netcdf4") + mock_ea = MagicMock() + mock_ea.open.return_value = [nc_path] + monkeypatch.setitem(__import__("sys").modules, "earthaccess", mock_ea) + + pts = pd.DataFrame( + {"lat": [0.0], "lon": [0.0], "time": pd.to_datetime(["2023-06-01T12:00:00"])} + ) + p = Plan( + points=pts, + results=[object()], + granules=[self._make_granule_meta()], + point_granule_map={0: [0]}, + source_kwargs={"short_name": "TEST"}, + time_buffer=pd.Timedelta(0), + ) + pc.matchup(p, open_method="dataset", spatial_method="auto", silent=False, + variables=["sst"], open_dataset_kwargs={"engine": "netcdf4"}) + captured = capsys.readouterr() + assert "'axis'" in captured.out + + def test_auto_2d_coords_uses_kdtree( + self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """spatial_method='auto' with 2-D coords resolves to 'kdtree'.""" + pytest.importorskip("scipy") + nc_path = str(tmp_path / "swath.nc") + ds_swath = _make_l2_swath_dataset(nrows=4, ncols=5, seed=42) + ds_swath.to_netcdf(nc_path, engine="netcdf4") + mock_ea = MagicMock() + mock_ea.open.return_value = [nc_path] + monkeypatch.setitem(__import__("sys").modules, "earthaccess", mock_ea) + + lat_val = float(ds_swath["lat"].values[1, 2]) + lon_val = float(ds_swath["lon"].values[1, 2]) + pts = pd.DataFrame( + {"lat": [lat_val], "lon": [lon_val], "time": pd.to_datetime(["2023-06-01T12:00:00"])} + ) + gm = self._make_granule_meta() + p = Plan( + points=pts, + results=[object()], + granules=[gm], + point_granule_map={0: [0]}, + variables=["sst"], + source_kwargs={"short_name": "TEST"}, + time_buffer=pd.Timedelta(0), + ) + result = pc.matchup( + p, open_method="dataset", spatial_method="auto", variables=["sst"], + open_dataset_kwargs={"engine": "netcdf4"}, + ) + assert "sst" in result.columns + assert not math.isnan(result.loc[0, "sst"])