Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions xrspatial/viewshed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,7 +1560,7 @@ def _viewshed_cpu(
num_events = 3 * (n_rows * n_cols - 1)
event_list = np.zeros((num_events, 7), dtype=np.float64)

raster.data = raster.data.astype(np.float64)
raster.data = raster.data.astype(np.float64, copy=False)

_init_event_list(event_list=event_list, raster=raster.data,
vp_row=viewpoint_row, vp_col=viewpoint_col,
Expand Down Expand Up @@ -2167,7 +2167,9 @@ def _viewshed_dask(raster, x, y, observer_elev, target_elev):
cupy_backed = is_dask_cupy(raster)

# --- Tier B: full grid fits in memory → compute and run exact algo ---
r2_bytes = 280 * height * width
# Peak memory: event_list sort needs 2x 168*H*W + raster 8*H*W +
# visibility_grid 8*H*W ≈ 360 bytes/pixel, plus the computed raster.
r2_bytes = 360 * height * width + 8 * height * width # working + raster
avail = _available_memory_bytes()
if r2_bytes < 0.5 * avail:
raster_mem = raster.copy()
Expand Down
31 changes: 24 additions & 7 deletions xrspatial/visibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _extract_transect(raster, cells):
if has_dask_array():
import dask.array as da
if isinstance(data, da.Array):
data = data.compute()
# Only compute the needed cells, not the entire array
elevations = data.vindex[rows, cols].compute().astype(np.float64)
return elevations, x_coords, y_coords
if has_cuda_and_cupy() and is_cupy_array(data):
data = data.get()

Expand Down Expand Up @@ -217,7 +219,16 @@ def cumulative_viewshed(
if not observers:
raise ValueError("observers list must not be empty")

count = np.zeros(raster.shape, dtype=np.int32)
# Detect dask backend to keep accumulation lazy
_is_dask = False
if has_dask_array():
import dask.array as da
_is_dask = isinstance(raster.data, da.Array)

if _is_dask:
count = da.zeros(raster.shape, dtype=np.int32, chunks=raster.data.chunks)
else:
count = np.zeros(raster.shape, dtype=np.int32)

for obs in observers:
ox = obs['x']
Expand All @@ -229,11 +240,17 @@ def cumulative_viewshed(
vs = viewshed(raster, x=ox, y=oy, observer_elev=oe,
target_elev=te, max_distance=md)

vs_np = vs.values
count += (vs_np != INVISIBLE).astype(np.int32)

return xarray.DataArray(count, coords=raster.coords,
dims=raster.dims, attrs=raster.attrs)
vs_data = vs.data
if _is_dask and not isinstance(vs_data, da.Array):
vs_data = da.from_array(vs_data, chunks=raster.data.chunks)
count = count + (vs_data != INVISIBLE).astype(np.int32)

result = xarray.DataArray(count, coords=raster.coords,
dims=raster.dims, attrs=raster.attrs)
if _is_dask:
chunk_dict = dict(zip(raster.dims, raster.data.chunks))
result = result.chunk(chunk_dict)
return result


def visibility_frequency(
Expand Down
Loading