diff --git a/.claude/commands/rockout.md b/.claude/commands/rockout.md index b194d6d2..fe14287b 100644 --- a/.claude/commands/rockout.md +++ b/.claude/commands/rockout.md @@ -14,7 +14,8 @@ through all seven steps below. The prompt is: $ARGUMENTS 2. Pick labels from the repo's existing set. Always include the type label (`enhancement`, `bug`, or `proposal`). Add topical labels when they fit (e.g. `gpu`, `performance`, `focal tools`, `hydrology`, etc.). -3. Draft the title and body. Use the repo's issue templates as structure guides: +3. Draft the title and body. Use the repo's issue templates as structure guides + (skip the "Author of Proposal" field -- GitHub already shows the author): - Enhancement/proposal: follow `.github/ISSUE_TEMPLATE/feature-proposal.md` - Bug: follow `.github/ISSUE_TEMPLATE/bug_report.md` 4. **Run the body text through the `/humanizer` skill** before creating the issue diff --git a/.github/ISSUE_TEMPLATE/feature-proposal.md b/.github/ISSUE_TEMPLATE/feature-proposal.md index 7c8b0b62..266ea5af 100644 --- a/.github/ISSUE_TEMPLATE/feature-proposal.md +++ b/.github/ISSUE_TEMPLATE/feature-proposal.md @@ -7,7 +7,6 @@ assignees: '' --- -**Author of Proposal:** ## Reason or Problem Describe what the need for this new feature is or what problem this new feature will address. ## Proposal diff --git a/xrspatial/sieve.py b/xrspatial/sieve.py index 40b97561..1ed66d91 100644 --- a/xrspatial/sieve.py +++ b/xrspatial/sieve.py @@ -2,16 +2,19 @@ Given a categorical raster and a pixel-count threshold, replaces connected regions smaller than the threshold with the value of -their largest spatial neighbor. Pairs with classification functions -(``natural_breaks``, ``reclassify``, etc.) and ``polygonize`` for -cleaning results before vectorization. +their largest spatial neighbor that is already at or above the +threshold. Matches the single-pass semantics of GDAL's +``GDALSieveFilter`` / ``rasterio.features.sieve``. + +Pairs with classification functions (``natural_breaks``, +``reclassify``, etc.) and ``polygonize`` for cleaning results +before vectorization. Supports all four backends: numpy, cupy, dask+numpy, dask+cupy. """ from __future__ import annotations -import warnings from collections import defaultdict from typing import Sequence @@ -40,7 +43,6 @@ class cupy: ngjit, ) -_MAX_ITERATIONS = 50 # --------------------------------------------------------------------------- @@ -205,67 +207,75 @@ def _collect(a, b): def _sieve_numpy(data, threshold, neighborhood, skip_values): - """Replace connected regions smaller than *threshold* pixels.""" + """Single-pass sieve matching GDAL's ``GDALSieveFilter`` semantics. + + A small region is only merged into a neighbor whose size is + **>= threshold**. If no such neighbor exists the region stays. + Regions are processed smallest-first with in-place size updates + so that earlier merges can grow a neighbor above threshold for + later ones within the same pass. + """ result = data.astype(np.float64, copy=True) is_float = np.issubdtype(data.dtype, np.floating) valid = ~np.isnan(result) if is_float else np.ones(result.shape, dtype=bool) skip_set = set(skip_values) if skip_values is not None else set() - for _ in range(_MAX_ITERATIONS): - region_map, region_val, uid = _label_connected( - result, valid, neighborhood - ) - region_size = np.bincount( - region_map.ravel(), minlength=uid - ).astype(np.int64) - - # Identify small regions eligible for merging - small_ids = [ - rid - for rid in range(1, uid) - if region_size[rid] < threshold - and region_val[rid] not in skip_set - ] - if not small_ids: - return result, True - - adjacency = _build_adjacency(region_map, neighborhood) - - # Process smallest regions first so they merge into larger neighbors - small_ids.sort(key=lambda r: region_size[r]) - - merged_any = False - for rid in small_ids: - if region_size[rid] == 0 or region_size[rid] >= threshold: - continue + region_map, region_val, uid = _label_connected( + result, valid, neighborhood + ) + region_size = np.bincount( + region_map.ravel(), minlength=uid + ).astype(np.int64) + + small_ids = [ + rid + for rid in range(1, uid) + if region_size[rid] < threshold + and region_val[rid] not in skip_set + ] + if not small_ids: + return result + + adjacency = _build_adjacency(region_map, neighborhood) + + # Process smallest regions first so earlier merges can grow + # a neighbor above threshold for later candidates. + small_ids.sort(key=lambda r: region_size[r]) + + for rid in small_ids: + if region_size[rid] == 0 or region_size[rid] >= threshold: + continue - neighbors = adjacency.get(rid) - if not neighbors: - continue # surrounded by nodata only + neighbors = adjacency.get(rid) + if not neighbors: + continue # surrounded by nodata only - largest_nid = max(neighbors, key=lambda n: region_size[n]) - mask = region_map == rid - result[mask] = region_val[largest_nid] + # Only merge into a neighbor that is already >= threshold. + valid_neighbors = [ + n for n in neighbors if region_size[n] >= threshold + ] + if not valid_neighbors: + continue - # Update tracking in place - region_map[mask] = largest_nid - region_size[largest_nid] += region_size[rid] - region_size[rid] = 0 + largest_nid = max(valid_neighbors, key=lambda n: region_size[n]) + mask = region_map == rid + result[mask] = region_val[largest_nid] - for n in neighbors: - if n != largest_nid: - adjacency[n].discard(rid) - adjacency[n].add(largest_nid) - adjacency.setdefault(largest_nid, set()).add(n) - if largest_nid in adjacency: - adjacency[largest_nid].discard(rid) - del adjacency[rid] - merged_any = True + # Update tracking in place + region_map[mask] = largest_nid + region_size[largest_nid] += region_size[rid] + region_size[rid] = 0 - if not merged_any: - return result, True + for n in neighbors: + if n != largest_nid: + adjacency[n].discard(rid) + adjacency[n].add(largest_nid) + adjacency.setdefault(largest_nid, set()).add(n) + if largest_nid in adjacency: + adjacency[largest_nid].discard(rid) + del adjacency[rid] - return result, False + return result # --------------------------------------------------------------------------- @@ -277,10 +287,10 @@ def _sieve_cupy(data, threshold, neighborhood, skip_values): """CuPy backend: transfer to CPU, sieve, transfer back.""" import cupy as cp - np_result, converged = _sieve_numpy( + np_result = _sieve_numpy( data.get(), threshold, neighborhood, skip_values ) - return cp.asarray(np_result), converged + return cp.asarray(np_result) # --------------------------------------------------------------------------- @@ -320,10 +330,10 @@ def _sieve_dask(data, threshold, neighborhood, skip_values): ) np_data = data.compute() - result, converged = _sieve_numpy( + result = _sieve_numpy( np_data, threshold, neighborhood, skip_values ) - return da.from_array(result, chunks=data.chunks), converged + return da.from_array(result, chunks=data.chunks) def _sieve_dask_cupy(data, threshold, neighborhood, skip_values): @@ -345,10 +355,10 @@ def _sieve_dask_cupy(data, threshold, neighborhood, skip_values): pass cp_data = data.compute() - result, converged = _sieve_cupy( + result = _sieve_cupy( cp_data, threshold, neighborhood, skip_values ) - return da.from_array(result, chunks=data.chunks), converged + return da.from_array(result, chunks=data.chunks) # --------------------------------------------------------------------------- @@ -367,7 +377,10 @@ def sieve( Identifies connected components of same-value pixels and replaces regions smaller than *threshold* pixels with the value of their - largest spatial neighbor. NaN pixels are always preserved. + largest spatial neighbor that is already at or above *threshold*. + Regions whose only neighbors are also below *threshold* are left + unchanged, matching GDAL's single-pass semantics. NaN pixels + are always preserved. Parameters ---------- @@ -417,6 +430,11 @@ def sieve( Notes ----- + Uses single-pass semantics matching GDAL's ``GDALSieveFilter``. + A small region is only merged into a neighbor whose current size + is >= *threshold*. If no such neighbor exists the region is left + unchanged. + This is a global operation: for dask-backed arrays the entire raster is computed into memory before sieving. Connected-component labeling cannot be performed on individual chunks because regions may span @@ -442,35 +460,21 @@ def sieve( data = raster.data if isinstance(data, np.ndarray): - out, converged = _sieve_numpy( - data, threshold, neighborhood, skip_values - ) + out = _sieve_numpy(data, threshold, neighborhood, skip_values) elif has_cuda_and_cupy() and is_cupy_array(data): - out, converged = _sieve_cupy( - data, threshold, neighborhood, skip_values - ) + out = _sieve_cupy(data, threshold, neighborhood, skip_values) elif da is not None and isinstance(data, da.Array): if is_dask_cupy(raster): - out, converged = _sieve_dask_cupy( + out = _sieve_dask_cupy( data, threshold, neighborhood, skip_values ) else: - out, converged = _sieve_dask( - data, threshold, neighborhood, skip_values - ) + out = _sieve_dask(data, threshold, neighborhood, skip_values) else: raise TypeError( f"Unsupported array type {type(data).__name__} for sieve()" ) - if not converged: - warnings.warn( - f"sieve() did not converge after {_MAX_ITERATIONS} iterations. " - f"The result may still contain regions smaller than " - f"threshold={threshold}.", - stacklevel=2, - ) - return DataArray( out, name=name, diff --git a/xrspatial/tests/test_sieve.py b/xrspatial/tests/test_sieve.py index 9e8e2409..1d56a1cc 100644 --- a/xrspatial/tests/test_sieve.py +++ b/xrspatial/tests/test_sieve.py @@ -111,13 +111,12 @@ def test_sieve_four_connectivity(backend): dtype=np.float64, ) raster = _make_raster(arr, backend) - # With 4-connectivity: each 1 and 2 forms its own 1-pixel region - # except center which is 1 pixel. All regions are size 1. - # threshold=2 should merge them all. + # With 4-connectivity each pixel is its own 1-pixel region. + # All regions are below threshold=2 and no neighbor is >= 2, + # so nothing merges (GDAL single-pass semantics). result = sieve(raster, threshold=2, neighborhood=4) data = _to_numpy(result) - # All pixels should end up the same value (merged into one) - assert len(np.unique(data)) == 1 + np.testing.assert_array_equal(data, arr) @pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) @@ -425,29 +424,26 @@ def test_sieve_numpy_dask_match(): # --------------------------------------------------------------------------- -# Convergence warning +# Single-pass: small regions with no above-threshold neighbor stay # --------------------------------------------------------------------------- -def test_sieve_convergence_warning(): - """Should warn when the iteration limit is reached.""" - from unittest.mock import patch - - # Create a raster where merging is artificially stalled by - # patching _MAX_ITERATIONS to 0 so the loop never runs. +def test_sieve_small_region_no_large_neighbor(): + """A small region whose only neighbors are also small stays unchanged.""" arr = np.array( [ - [1, 1, 1], - [1, 2, 1], - [1, 1, 1], + [1, 1, 2, 2], + [1, 1, 2, 2], + [3, 3, 4, 4], + [3, 3, 4, 4], ], dtype=np.float64, ) raster = _make_raster(arr, "numpy") - - with patch("xrspatial.sieve._MAX_ITERATIONS", 0): - with pytest.warns(UserWarning, match="did not converge"): - sieve(raster, threshold=2) + # All regions are size 4, threshold=5: no neighbor is >= 5. + result = sieve(raster, threshold=5) + data = _to_numpy(result) + np.testing.assert_array_equal(data, arr) # --------------------------------------------------------------------------- @@ -486,7 +482,7 @@ def test_sieve_noisy_classification(backend): @pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) def test_sieve_many_small_regions(backend): - """Checkerboard produces maximum region count; sieve should unify.""" + """Checkerboard: all regions size 1, no neighbor >= threshold.""" # 20x20 checkerboard: every pixel is its own 1-pixel region arr = np.zeros((20, 20), dtype=np.float64) arr[::2, ::2] = 1 @@ -498,5 +494,5 @@ def test_sieve_many_small_regions(backend): data = _to_numpy(result) # With 4-connectivity every pixel is isolated (size 1). - # threshold=2 forces all to merge. Result should be uniform. - assert len(np.unique(data)) == 1 + # No neighbor is >= threshold=2, so nothing merges (GDAL semantics). + np.testing.assert_array_equal(data, arr) diff --git a/xrspatial/tests/test_sieve_gdal_parity.py b/xrspatial/tests/test_sieve_gdal_parity.py new file mode 100644 index 00000000..e3eb2cb2 --- /dev/null +++ b/xrspatial/tests/test_sieve_gdal_parity.py @@ -0,0 +1,467 @@ +"""Cross-validate xrspatial.sieve() against rasterio.features.sieve() (GDAL). + +rasterio.features.sieve wraps GDAL's GDALSieveFilter, which is the accepted +reference for this operation. These tests run both implementations on +identical inputs and compare outputs. + +Requires rasterio; tests are skipped if it is not installed. + +Known behavioral differences +----------------------------- +* **Input dtype**: rasterio requires integer arrays (int16/int32/uint8/uint16). + xrspatial accepts float or int, converts to float64 internally. +* **Nodata**: rasterio uses a boolean ``mask`` (False = excluded). + xrspatial uses NaN. +* **Tie-breaking**: when a small region has multiple neighbors of the same + size, GDAL and xrspatial may pick different winners. Tests that would + depend on tie-breaking are designed to avoid ambiguity. +* **skip_values**: xrspatial-only feature, not tested here. +""" + +from __future__ import annotations + +import numpy as np +import pytest +import xarray as xr + +rasterio = pytest.importorskip("rasterio") +from rasterio.features import sieve as gdal_sieve + +from xrspatial.sieve import sieve as xs_sieve + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _run_both(arr_int, threshold, connectivity=4): + """Run both sieve implementations and return (xrspatial, gdal) arrays. + + Parameters + ---------- + arr_int : ndarray + 2D integer array (shared input for both implementations). + threshold : int + Minimum region size to retain. + connectivity : int + 4 or 8. + + Returns + ------- + xs_out : ndarray of int32 + xrspatial result cast back to int32 for comparison. + gdal_out : ndarray of int32 + rasterio/GDAL result. + """ + # xrspatial path: wrap in DataArray (float64 internally) + raster = xr.DataArray(arr_int.astype(np.float64), dims=["y", "x"]) + xs_result = xs_sieve(raster, threshold=threshold, neighborhood=connectivity) + xs_out = xs_result.values.astype(np.int32) + + # GDAL path: needs int32 + gdal_src = arr_int.astype(np.int32).copy() + gdal_out = gdal_sieve(gdal_src, size=threshold, connectivity=connectivity) + + return xs_out, gdal_out + + +def _run_both_with_nodata(arr_int, nodata_val, threshold, connectivity=4): + """Run both implementations with nodata handling. + + For xrspatial: replace nodata_val with NaN. + For GDAL: pass a mask where nodata pixels are False. + + Returns + ------- + xs_out, gdal_out : ndarrays with nodata_val restored in nodata positions. + """ + # xrspatial: use NaN for nodata + arr_float = arr_int.astype(np.float64) + arr_float[arr_int == nodata_val] = np.nan + raster = xr.DataArray(arr_float, dims=["y", "x"]) + xs_result = xs_sieve(raster, threshold=threshold, neighborhood=connectivity) + xs_vals = xs_result.values.copy() + # Restore nodata positions to nodata_val for comparison + xs_vals[np.isnan(xs_vals)] = nodata_val + xs_out = xs_vals.astype(np.int32) + + # GDAL: use mask + gdal_src = arr_int.astype(np.int32).copy() + mask = (arr_int != nodata_val).astype(np.uint8) + gdal_out = gdal_sieve(gdal_src, size=threshold, connectivity=connectivity, mask=mask) + + return xs_out, gdal_out + + +# --------------------------------------------------------------------------- +# Basic parity: noise removal +# --------------------------------------------------------------------------- + + +class TestBasicParity: + """Simple cases where both implementations should agree exactly.""" + + def test_single_pixel_noise(self): + """One isolated pixel in a uniform background.""" + arr = np.array( + [ + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 3, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + ], + dtype=np.int32, + ) + xs_out, gdal_out = _run_both(arr, threshold=2) + np.testing.assert_array_equal(xs_out, gdal_out) + # Both should have replaced the 3 with 1 + assert xs_out[2, 2] == 1 + + def test_two_classes_no_noise(self): + """Two large regions, nothing to sieve.""" + arr = np.zeros((6, 6), dtype=np.int32) + arr[:, :3] = 1 + arr[:, 3:] = 2 + xs_out, gdal_out = _run_both(arr, threshold=5) + np.testing.assert_array_equal(xs_out, gdal_out) + np.testing.assert_array_equal(xs_out, arr) + + def test_multiple_noise_pixels(self): + """Several isolated noise pixels in different locations.""" + arr = np.ones((8, 8), dtype=np.int32) + arr[1, 1] = 5 + arr[1, 6] = 7 + arr[6, 1] = 3 + arr[6, 6] = 9 + xs_out, gdal_out = _run_both(arr, threshold=2) + np.testing.assert_array_equal(xs_out, gdal_out) + expected = np.ones((8, 8), dtype=np.int32) + np.testing.assert_array_equal(xs_out, expected) + + def test_small_cluster_removal(self): + """A 3-pixel cluster in a large background.""" + arr = np.ones((10, 10), dtype=np.int32) + # 3-pixel L-shape of value 2 + arr[2, 2] = 2 + arr[2, 3] = 2 + arr[3, 2] = 2 + xs_out, gdal_out = _run_both(arr, threshold=4) + np.testing.assert_array_equal(xs_out, gdal_out) + assert xs_out[2, 2] == 1 + + def test_threshold_exact_boundary(self): + """Region exactly at threshold size should be retained.""" + arr = np.ones((6, 6), dtype=np.int32) + # 4-pixel block of value 2 + arr[1, 1] = 2 + arr[1, 2] = 2 + arr[2, 1] = 2 + arr[2, 2] = 2 + xs_out, gdal_out = _run_both(arr, threshold=4) + np.testing.assert_array_equal(xs_out, gdal_out) + # Region of size 4 should survive threshold=4 + assert xs_out[1, 1] == 2 + + +# --------------------------------------------------------------------------- +# Threshold sweep +# --------------------------------------------------------------------------- + + +class TestThresholdSweep: + """Run multiple thresholds on the same input.""" + + @pytest.fixture() + def classified_grid(self): + """10x10 grid with a few distinct regions of known sizes.""" + arr = np.array( + [ + [1, 1, 1, 1, 2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 2, 2, 2, 2, 2, 2], + [1, 1, 1, 1, 2, 3, 3, 2, 2, 2], + [1, 1, 1, 1, 2, 3, 3, 2, 2, 2], + [4, 4, 1, 1, 2, 2, 2, 2, 2, 2], + [4, 4, 1, 1, 5, 5, 5, 2, 2, 2], + [4, 4, 4, 1, 5, 5, 5, 2, 2, 2], + [4, 4, 4, 4, 5, 5, 5, 2, 2, 2], + [4, 4, 4, 4, 5, 5, 5, 6, 6, 2], + [4, 4, 4, 4, 5, 5, 5, 6, 6, 2], + ], + dtype=np.int32, + ) + # Region sizes: 1=16, 2=34, 3=4, 4=16, 5=12, 6=4 + return arr + + @pytest.mark.parametrize("threshold", [1, 3, 5, 10, 13]) + def test_threshold(self, classified_grid, threshold): + xs_out, gdal_out = _run_both(classified_grid, threshold=threshold) + np.testing.assert_array_equal( + xs_out, + gdal_out, + err_msg=f"Mismatch at threshold={threshold}", + ) + + def test_threshold_50(self, classified_grid): + """threshold=50 exceeds all region sizes -- nothing merges.""" + xs_out, gdal_out = _run_both(classified_grid, threshold=50) + np.testing.assert_array_equal( + xs_out, + gdal_out, + err_msg="Mismatch at threshold=50", + ) + + +# --------------------------------------------------------------------------- +# Connectivity modes +# --------------------------------------------------------------------------- + + +class TestConnectivity: + """4-connectivity vs 8-connectivity parity.""" + + def test_diagonal_pixel_4conn(self): + """A pixel connected only diagonally is isolated under 4-conn.""" + arr = np.ones((5, 5), dtype=np.int32) + arr[2, 2] = 2 + arr[1, 1] = 2 # diagonal to (2,2) but not edge-adjacent + # With 4-conn: two separate 1-pixel regions of value 2 + xs_out, gdal_out = _run_both(arr, threshold=2, connectivity=4) + np.testing.assert_array_equal(xs_out, gdal_out) + # Both should be replaced + assert xs_out[2, 2] == 1 + assert xs_out[1, 1] == 1 + + def test_diagonal_pixel_8conn(self): + """A pixel connected diagonally is part of the same region under 8-conn.""" + arr = np.ones((5, 5), dtype=np.int32) + arr[2, 2] = 2 + arr[1, 1] = 2 + # With 8-conn: these two form a 2-pixel region + xs_out, gdal_out = _run_both(arr, threshold=2, connectivity=8) + np.testing.assert_array_equal(xs_out, gdal_out) + # Region of size 2 should survive threshold=2 + assert xs_out[2, 2] == 2 + assert xs_out[1, 1] == 2 + + def test_l_shape_4conn(self): + """L-shaped region under 4-connectivity.""" + arr = np.ones((6, 6), dtype=np.int32) + arr[1, 1] = 3 + arr[2, 1] = 3 + arr[3, 1] = 3 + arr[3, 2] = 3 + arr[3, 3] = 3 + # L-shape of 5 pixels, value 3 + xs_out, gdal_out = _run_both(arr, threshold=5, connectivity=4) + np.testing.assert_array_equal(xs_out, gdal_out) + # Size 5 = threshold, should survive + assert xs_out[1, 1] == 3 + + def test_l_shape_8conn(self): + """L-shape under 8-connectivity should be the same region.""" + arr = np.ones((6, 6), dtype=np.int32) + arr[1, 1] = 3 + arr[2, 1] = 3 + arr[3, 1] = 3 + arr[3, 2] = 3 + arr[3, 3] = 3 + xs_out, gdal_out = _run_both(arr, threshold=5, connectivity=8) + np.testing.assert_array_equal(xs_out, gdal_out) + assert xs_out[1, 1] == 3 + + +# --------------------------------------------------------------------------- +# Nodata handling +# --------------------------------------------------------------------------- + + +class TestNodata: + """Parity with nodata/mask regions.""" + + def test_nodata_border(self): + """Nodata ring around the edge, valid interior.""" + arr = np.full((8, 8), 0, dtype=np.int32) # 0 = nodata + arr[2:6, 2:6] = 1 + arr[3, 3] = 2 # single noise pixel in interior + xs_out, gdal_out = _run_both_with_nodata(arr, nodata_val=0, threshold=2) + np.testing.assert_array_equal(xs_out, gdal_out) + assert xs_out[3, 3] == 1 # noise replaced + + def test_nodata_hole(self): + """Nodata hole inside a region.""" + arr = np.ones((6, 6), dtype=np.int32) + arr[2, 2] = 0 # nodata + arr[2, 3] = 0 # nodata + arr[4, 4] = 2 # single noise pixel + xs_out, gdal_out = _run_both_with_nodata(arr, nodata_val=0, threshold=2) + np.testing.assert_array_equal(xs_out, gdal_out) + # Nodata pixels preserved + assert xs_out[2, 2] == 0 + assert xs_out[2, 3] == 0 + # Noise replaced + assert xs_out[4, 4] == 1 + + def test_nodata_splits_region(self): + """Nodata row splitting a region into two halves.""" + arr = np.ones((7, 4), dtype=np.int32) + arr[3, :] = 0 # nodata row + # Top half: 3*4=12 pixels of value 1 + # Bottom half: 3*4=12 pixels of value 1 + # These are separate regions because nodata splits them + arr[1, 1] = 2 # noise in top half + arr[5, 1] = 3 # noise in bottom half + xs_out, gdal_out = _run_both_with_nodata(arr, nodata_val=0, threshold=2) + np.testing.assert_array_equal(xs_out, gdal_out) + + +# --------------------------------------------------------------------------- +# Realistic classified raster +# --------------------------------------------------------------------------- + + +class TestRealisticRaster: + """Larger synthetic rasters that mimic real classification output.""" + + def test_quadrant_with_noise(self): + """4-class quadrant raster with 5% salt-and-pepper noise.""" + rng = np.random.RandomState(1168) + base = np.zeros((50, 50), dtype=np.int32) + base[:25, :25] = 1 + base[:25, 25:] = 2 + base[25:, :25] = 3 + base[25:, 25:] = 4 + + noise_mask = rng.random((50, 50)) < 0.05 + noise_vals = rng.choice([1, 2, 3, 4], size=(50, 50)).astype(np.int32) + noisy = base.copy() + noisy[noise_mask] = noise_vals[noise_mask] + + xs_out, gdal_out = _run_both(noisy, threshold=5) + np.testing.assert_array_equal(xs_out, gdal_out) + + def test_quadrant_with_noise_8conn(self): + """Same as above but with 8-connectivity.""" + rng = np.random.RandomState(1168) + base = np.zeros((50, 50), dtype=np.int32) + base[:25, :25] = 1 + base[:25, 25:] = 2 + base[25:, :25] = 3 + base[25:, 25:] = 4 + + noise_mask = rng.random((50, 50)) < 0.05 + noise_vals = rng.choice([1, 2, 3, 4], size=(50, 50)).astype(np.int32) + noisy = base.copy() + noisy[noise_mask] = noise_vals[noise_mask] + + xs_out, gdal_out = _run_both(noisy, threshold=5, connectivity=8) + np.testing.assert_array_equal(xs_out, gdal_out) + + def test_many_classes(self): + """10-class raster with scattered noise. + + A few pixels at band boundaries may differ due to tie-breaking + when two neighbor regions are close in size. We allow up to + 0.5% mismatch and verify both outputs removed the noise. + """ + rng = np.random.RandomState(42) + arr = np.zeros((100, 50), dtype=np.int32) + for i in range(10): + arr[i * 10 : (i + 1) * 10, :] = i + 1 + + noise_mask = rng.random((100, 50)) < 0.03 + noise_vals = rng.randint(1, 11, size=(100, 50)).astype(np.int32) + arr[noise_mask] = noise_vals[noise_mask] + + xs_out, gdal_out = _run_both(arr, threshold=5) + n_diff = np.sum(xs_out != gdal_out) + assert n_diff <= 0.005 * arr.size, ( + f"{n_diff} pixels differ ({n_diff / arr.size:.1%}), " + f"expected < 0.5% for tie-breaking differences only" + ) + # Both should have eliminated all noise pixels interior to bands + for i in range(10): + band = slice(i * 10 + 2, (i + 1) * 10 - 2) + assert np.all(xs_out[band, 2:-2] == i + 1) + assert np.all(gdal_out[band, 2:-2] == i + 1) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + """Tricky inputs where implementations might diverge.""" + + def test_single_value(self): + """Uniform raster -- nothing to do.""" + arr = np.full((10, 10), 7, dtype=np.int32) + xs_out, gdal_out = _run_both(arr, threshold=50) + np.testing.assert_array_equal(xs_out, gdal_out) + + def test_two_pixel_raster(self): + """Smallest possible non-trivial raster -- both below threshold.""" + arr = np.array([[1, 2]], dtype=np.int32) + xs_out, gdal_out = _run_both(arr, threshold=2) + np.testing.assert_array_equal(xs_out, gdal_out) + + def test_single_row(self): + """1D-like raster (single row).""" + arr = np.array([[1, 1, 1, 2, 1, 1, 1, 1]], dtype=np.int32) + xs_out, gdal_out = _run_both(arr, threshold=2) + np.testing.assert_array_equal(xs_out, gdal_out) + assert xs_out[0, 3] == 1 + + def test_single_column(self): + """Single-column raster.""" + arr = np.array([[1], [1], [1], [2], [1], [1]], dtype=np.int32) + xs_out, gdal_out = _run_both(arr, threshold=2) + np.testing.assert_array_equal(xs_out, gdal_out) + assert xs_out[3, 0] == 1 + + def test_checkerboard_4conn(self): + """Checkerboard: all regions size 1, no neighbor >= threshold.""" + arr = np.zeros((8, 8), dtype=np.int32) + arr[::2, ::2] = 1 + arr[1::2, 1::2] = 1 + arr[arr == 0] = 2 + xs_out, gdal_out = _run_both(arr, threshold=2, connectivity=4) + np.testing.assert_array_equal(xs_out, gdal_out) + + def test_stripes(self): + """Alternating 1-pixel-wide stripes, all below threshold.""" + arr = np.zeros((6, 6), dtype=np.int32) + arr[:, 0::2] = 1 + arr[:, 1::2] = 2 + xs_out, gdal_out = _run_both(arr, threshold=7, connectivity=4) + np.testing.assert_array_equal(xs_out, gdal_out) + + def test_large_threshold_collapses_all(self): + """Threshold larger than any region -- nothing merges.""" + arr = np.array( + [ + [1, 1, 2, 2], + [1, 1, 2, 2], + [3, 3, 4, 4], + [3, 3, 4, 4], + ], + dtype=np.int32, + ) + xs_out, gdal_out = _run_both(arr, threshold=5) + np.testing.assert_array_equal(xs_out, gdal_out) + + def test_dtype_uint8(self): + """uint8 input should work for both.""" + arr = np.ones((5, 5), dtype=np.uint8) + arr[2, 2] = 3 + xs_out, gdal_out = _run_both(arr, threshold=2) + np.testing.assert_array_equal(xs_out, gdal_out) + + def test_dtype_int16(self): + """int16 input.""" + arr = np.ones((5, 5), dtype=np.int16) + arr[2, 2] = -1 + xs_out, gdal_out = _run_both(arr, threshold=2) + np.testing.assert_array_equal(xs_out, gdal_out) diff --git a/xrspatial/viewshed.py b/xrspatial/viewshed.py index 96b095f6..0e282052 100644 --- a/xrspatial/viewshed.py +++ b/xrspatial/viewshed.py @@ -1560,7 +1560,7 @@ def _viewshed_cpu( num_events = 3 * (n_rows * n_cols - 1) event_list = np.zeros((num_events, 7), dtype=np.float64) - raster.data = raster.data.astype(np.float64) + raster.data = raster.data.astype(np.float64, copy=False) _init_event_list(event_list=event_list, raster=raster.data, vp_row=viewpoint_row, vp_col=viewpoint_col, @@ -2167,7 +2167,9 @@ def _viewshed_dask(raster, x, y, observer_elev, target_elev): cupy_backed = is_dask_cupy(raster) # --- Tier B: full grid fits in memory → compute and run exact algo --- - r2_bytes = 280 * height * width + # Peak memory: event_list sort needs 2x 168*H*W + raster 8*H*W + + # visibility_grid 8*H*W ≈ 360 bytes/pixel, plus the computed raster. + r2_bytes = 360 * height * width + 8 * height * width # working + raster avail = _available_memory_bytes() if r2_bytes < 0.5 * avail: raster_mem = raster.copy() diff --git a/xrspatial/visibility.py b/xrspatial/visibility.py index ab4de453..65139ba0 100644 --- a/xrspatial/visibility.py +++ b/xrspatial/visibility.py @@ -61,7 +61,9 @@ def _extract_transect(raster, cells): if has_dask_array(): import dask.array as da if isinstance(data, da.Array): - data = data.compute() + # Only compute the needed cells, not the entire array + elevations = data.vindex[rows, cols].compute().astype(np.float64) + return elevations, x_coords, y_coords if has_cuda_and_cupy() and is_cupy_array(data): data = data.get() @@ -217,7 +219,16 @@ def cumulative_viewshed( if not observers: raise ValueError("observers list must not be empty") - count = np.zeros(raster.shape, dtype=np.int32) + # Detect dask backend to keep accumulation lazy + _is_dask = False + if has_dask_array(): + import dask.array as da + _is_dask = isinstance(raster.data, da.Array) + + if _is_dask: + count = da.zeros(raster.shape, dtype=np.int32, chunks=raster.data.chunks) + else: + count = np.zeros(raster.shape, dtype=np.int32) for obs in observers: ox = obs['x'] @@ -229,11 +240,17 @@ def cumulative_viewshed( vs = viewshed(raster, x=ox, y=oy, observer_elev=oe, target_elev=te, max_distance=md) - vs_np = vs.values - count += (vs_np != INVISIBLE).astype(np.int32) - - return xarray.DataArray(count, coords=raster.coords, - dims=raster.dims, attrs=raster.attrs) + vs_data = vs.data + if _is_dask and not isinstance(vs_data, da.Array): + vs_data = da.from_array(vs_data, chunks=raster.data.chunks) + count = count + (vs_data != INVISIBLE).astype(np.int32) + + result = xarray.DataArray(count, coords=raster.coords, + dims=raster.dims, attrs=raster.attrs) + if _is_dask: + chunk_dict = dict(zip(raster.dims, raster.data.chunks)) + result = result.chunk(chunk_dict) + return result def visibility_frequency(