Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 53 additions & 18 deletions xbout/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,24 +542,58 @@ def collect(
ds : numpy.ndarray

"""
from os.path import join
from pathlib import Path as _Path

datapath = join(path, prefix + "*.nc")
datapath_glob = str(_Path(path) / (prefix + "*.nc"))

ds, _ = _auto_open_mfboutdataset(
datapath, keep_xboundaries=xguards, keep_yboundaries=yguards, info=info
)
# Fast path: use lazy loader which only opens one file for metadata.
# Falls back to open_mfdataset if the directory cannot be detected or
# the variable is not supported by the lazy loader.
try:
path_obj = _Path(path)
if path_obj.is_dir():
ds = lazyload.lazy_open_boutdataset(
path,
keep_xboundaries=xguards,
keep_yboundaries=yguards,
info=info,
prefix=prefix,
)
else:
raise ValueError("path is not a directory")

if varname not in ds:
raise KeyError(
"No variable, {} was found in {}.".format(varname, datapath_glob)
)

da = ds[varname]
dims = list(da.dims)

if varname not in ds:
raise KeyError("No variable, {} was found in {}.".format(varname, datapath))
except Exception:
# Fall back to the slow multi-file open
ds, _ = _auto_open_mfboutdataset(
datapath_glob,
keep_xboundaries=xguards,
keep_yboundaries=yguards,
info=info,
)

dims = list(ds.dims)
inds = [tind, xind, yind, zind]
if varname not in ds:
raise KeyError(
"No variable, {} was found in {}.".format(varname, datapath_glob)
)

da = ds[varname]
dims = list(ds.dims)

inds = {"t": tind, "x": xind, "y": yind, "z": zind}

selection = {}

# Convert indexing values to an isel suitable format
for dim, ind in zip(dims, inds):
for dim in dims:
ind = inds.get(dim)
if isinstance(ind, int):
indexer = [ind]
elif isinstance(ind, list):
Expand All @@ -570,25 +604,26 @@ def collect(
else:
indexer = None

if indexer:
if indexer is not None:
selection[dim] = indexer

try:
version = ds["BOUT_VERSION"]
except KeyError:
# If BOUT Version is not saved in the dataset
version = ds.attrs.get("metadata", {}).get("BOUT_VERSION", 0)
if version == 0 and "BOUT_VERSION" in ds:
version = float(ds["BOUT_VERSION"].values)
except Exception:
version = 0

# Subtraction of z-dimensional data occurs in boutdata.collect
# if BOUT++ version is old - same feature added here
if (version < 3.5) and ("z" in dims):
zsize = int(ds["nz"]) - 1
ds = ds.isel(z=slice(zsize))
zsize = int(ds.attrs.get("metadata", {}).get("nz", da.sizes["z"]))
da = da.isel(z=slice(zsize))

if selection:
ds = ds.isel(selection)
da = da.isel(selection)

result = ds[varname].values
result = da.values

# Close netCDF files to ensure they are not locked if collect is called again
ds.close()
Expand Down
Loading