diff --git a/.gitignore b/.gitignore index 9702592..b4c1227 100644 --- a/.gitignore +++ b/.gitignore @@ -21,8 +21,7 @@ test_manifest.yml /htmlcov /test-reports/ /tests/cli/test.cfg -/tests/cli/*.yaml -.eggs +/tests/cli/*.yaml.eggs .coverage *.xml simdb-coverage-report diff --git a/pyproject.toml b/pyproject.toml index 2cc00bd..d3090db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "numpy>=1.14", "pydantic>=2.10.6", "python-dateutil>=2.6", + "plotext>=5.3.2", "pyyaml>=3.13", "requests>=2.27.0", "semantic-version>=2.8", diff --git a/src/simdb/cli/commands/simulation.py b/src/simdb/cli/commands/simulation.py index 7f442ce..d28955a 100644 --- a/src/simdb/cli/commands/simulation.py +++ b/src/simdb/cli/commands/simulation.py @@ -16,7 +16,12 @@ from simdb.validation import ValidationError, Validator from . import check_meta_args, pass_config -from .utils import print_simulations +from .utils import ( + is_numeric_1d, + print_quantity, + print_simulations, + show_quantity_textual_plot, +) from .validators import validate_non_negative @@ -353,6 +358,69 @@ def simulation_query( ) +@simulation.command("data", cls=n_required_args_adaptor(2)) +@pass_config +@click.argument("remote", required=False) +@click.argument("sim_id") +@click.argument("ids_path") +@click.option("--username", help="Username used to authenticate with the remote.") +@click.option("--password", help="Password used to authenticate with the remote.") +def simulation_data( + config: Config, + remote: Optional[str], + sim_id: str, + ids_path: str, + username: Optional[str], + password: Optional[str], +): + """Fetch IDS field data for simulation SIM_ID (UUID or alias) from REMOTE. + + \b + IDS_PATH format: + ids_name[:]/path/to/field + + \b + Examples: + simdb sim data iter 4dd781b... profiles_1d[0]/grid/rho_tor_norm + simdb sim data 4dd781b... equilibrium:0/time_slice[0]/profiles_1d/psi + """ + api = RemoteAPI(remote, username, password, config) + + try: + result = api.get_simulation_data(sim_id, ids_path) + except Exception as err: + raise click.ClickException(str(err)) from err + + click.echo(f"simulation : {result['simulation']}") + click.echo(f"path : {result['path']} (occurrence {result['occurrence']})") + + coordinates = result.get("coordinates") or [] + plot_coordinate = next( + ( + coord + for coord in coordinates + if isinstance(coord.get("data"), list) + and isinstance(result["field"].get("data"), list) + and len(coord["data"]) == len(result["field"]["data"]) + ), + None, + ) + field_is_1d = is_numeric_1d(result["field"].get("data")) + if field_is_1d: + show_quantity_textual_plot( + result["field"], label="field", x_quantity=plot_coordinate + ) + else: + print_quantity(result["field"], label="field") + + if config.verbose and coordinates: + for coord in coordinates: + if field_is_1d and is_numeric_1d(coord.get("data")): + continue + if isinstance(coord.get("data"), list): + print_quantity(coord, label=f"coord {coord['name']}", show_stats=False) + + @simulation.command("validate", cls=n_required_args_adaptor(1)) @pass_config @click.argument("remote", required=False) diff --git a/src/simdb/cli/commands/utils.py b/src/simdb/cli/commands/utils.py index ab2c919..abd352f 100644 --- a/src/simdb/cli/commands/utils.py +++ b/src/simdb/cli/commands/utils.py @@ -1,7 +1,12 @@ +import importlib from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, TypeVar import click +from rich.console import Console, Group +from rich.panel import Panel +from rich.table import Table +from rich.text import Text if TYPE_CHECKING: # Only importing these for type checking and documentation generation in order to @@ -10,6 +15,231 @@ else: Config = TypeVar("Config") +_RICH_CONSOLE = Console() + + +def _get_shape(data: Any) -> Tuple[int, ...]: + """Recursively compute shape of a nested list""" + if not isinstance(data, list): + return () + if not data: + return (0,) + return (len(data), *_get_shape(data[0])) + + +def _fmt_val(v: Any) -> str: + if isinstance(v, float): + return f"{v:.6g}" + return str(v) + + +def _fmt_row(row: list) -> str: + """Format a 1-D list with numpy-style head/tail truncation.""" + if len(row) <= 8: + return " ".join(_fmt_val(v) for v in row) + head = " ".join(_fmt_val(v) for v in row[:3]) + tail = " ".join(_fmt_val(v) for v in row[-3:]) + return f"{head} ... {tail}" + + +def _is_numeric(v: Any) -> bool: + return isinstance(v, (int, float)) and not isinstance(v, bool) + + +def is_numeric_1d(data: Any) -> bool: + return isinstance(data, list) and bool(data) and all(_is_numeric(v) for v in data) + + +def _quantity_axis_label(q: dict, fallback: str = "") -> str: + name = q.get("name") or fallback + units = q.get("units") or "-" + label = str(name).rsplit("/", 1)[-1] or str(name) + return f"{label} [{units}]" + + +def _build_array_body(data: list, shape: Tuple[int, ...]) -> str: + """Build string for 1-D or 2-D arrays.""" + if len(shape) == 1: + return f"[{_fmt_row(data)}]" + + if len(shape) == 2: + rows = data if len(data) <= 8 else [*data[:3], ..., *data[-3:]] + lines = [] + for row in rows: + if row is ...: + lines.append(" ...") + else: + lines.append(f" [{_fmt_row(row)}]") + inner = "\n".join(lines) + return f"[\n{inner}\n]" + + return f"<{len(shape)}-D array, shape {shape}>" + + +def _iter_numeric(data: Any) -> Iterable[float]: + """Yield all numeric leaf values from a nested list, skipping None.""" + if isinstance(data, list): + for item in data: + yield from _iter_numeric(item) + elif isinstance(data, (int, float)) and data is not None: + yield float(data) + + +def _compute_stats(data: Any) -> Optional[Dict[str, float]]: + """Return basic statistics for numeric data, or None if not applicable.""" + values = list(_iter_numeric(data)) + if len(values) < 2: + return None + n = len(values) + vmin = min(values) + vmax = max(values) + mean = sum(values) / n + std = (sum((x - mean) ** 2 for x in values) / n) ** 0.5 + sorted_v = sorted(values) + mid = n // 2 + median = sorted_v[mid] if n % 2 else (sorted_v[mid - 1] + sorted_v[mid]) / 2 + return { + "n": n, + "min": vmin, + "max": vmax, + "mean": mean, + "std": std, + "median": median, + } + + +def _stats_table(stats: Dict[str, float]) -> Table: + table = Table(show_header=True, header_style="bold", box=None, padding=(0, 2)) + for key in ("n", "min", "max", "mean", "std", "median"): + table.add_column(key, justify="right") + table.add_row( + str(int(stats["n"])), + _fmt_val(stats["min"]), + _fmt_val(stats["max"]), + _fmt_val(stats["mean"]), + _fmt_val(stats["std"]), + _fmt_val(stats["median"]), + ) + return table + + +def _plot_stats_table(stats: Dict[str, float], shape: Tuple[int, ...]) -> Table: + table = Table(show_header=True, header_style="bold", box=None, padding=(0, 2)) + for key in ("n", "min", "max", "mean", "std", "median"): + table.add_column(key, justify="right") + table.add_row( + str(int(stats["n"])), + _fmt_val(stats["min"]), + _fmt_val(stats["max"]), + _fmt_val(stats["mean"]), + _fmt_val(stats["std"]), + _fmt_val(stats["median"]), + ) + return table + + +def _plot_panel( + *, + plot: Text, + title: str, + units: str, + stats: Optional[Dict[str, float]], + shape: Tuple[int, ...], +) -> None: + content = plot + if stats: + content = Group(plot, _plot_stats_table(stats, shape)) + + _RICH_CONSOLE.print( + Panel( + content, + title=f"[bold]{title}[/bold] [dim]\\[{units}][/dim]", + subtitle=f"shape {shape}", + ) + ) + + +def show_quantity_textual_plot( + q: dict, + label: str = "", + x_quantity: Optional[dict] = None, +) -> None: + """Print line plot for a 1-D numeric QuantityData dict.""" + name = q["name"] + units = q["units"] or "-" + data = q["data"] + if not is_numeric_1d(data): + print_quantity(q, label=label) + return + + try: + plotext = importlib.import_module("plotext") + except ImportError: + print_quantity(q, label=label) + return + + y_values = [float(value) for value in data] + shape = _get_shape(data) + x_values = None + xlabel = "index [-]" + if ( + x_quantity + and is_numeric_1d(x_quantity.get("data")) + and len(x_quantity["data"]) == len(y_values) + ): + x_values = [float(value) for value in x_quantity["data"]] + xlabel = _quantity_axis_label(x_quantity, fallback="x") + + title = label or name + if x_values is None: + x_values = [float(index) for index in range(len(y_values))] + + console_width = _RICH_CONSOLE.size.width + plot_width = max(48, min(70, console_width - 12)) + + plotext.clear_figure() + plotext.theme("clear") + plotext.plotsize(plot_width, 18) + plotext.xlabel(xlabel) + plotext.ylabel(_quantity_axis_label(q, fallback=label or "field")) + plotext.plot(x_values, y_values, marker="braille", color="cyan") + plot = Text.from_ansi(plotext.build()) + stats = _compute_stats(y_values) + _plot_panel( + plot=plot, + title=title, + units=units, + stats=stats, + shape=shape, + ) + + +def print_quantity(q: dict, label: str = "", show_stats: bool = True) -> None: + """Print a QuantityData dict with array display and stats.""" + name = q["name"] + units = q["units"] or "-" + data = q["data"] + title = f"[bold]{label or name}[/bold] [dim]\\[{units}][/dim]" + + if not isinstance(data, list): + _RICH_CONSOLE.print(Panel(f"{_fmt_val(data)}", title=title, subtitle="scalar")) + return + + shape = _get_shape(data) + stats = _compute_stats(data) + array_body = _build_array_body(data, shape) + subtitle = f"shape ({shape[0]},)" if len(shape) == 1 else f"shape {shape}" + if show_stats and stats: + _RICH_CONSOLE.print( + Panel( + Group(array_body, _stats_table(stats)), + title=title, + subtitle=subtitle, + ) + ) + else: + _RICH_CONSOLE.print(Panel(array_body, title=title, subtitle=subtitle)) + def _flatten_dict(values: Dict) -> List[Tuple[str, str]]: items = [] diff --git a/src/simdb/cli/remote_api.py b/src/simdb/cli/remote_api.py index a81d082..28520b2 100644 --- a/src/simdb/cli/remote_api.py +++ b/src/simdb/cli/remote_api.py @@ -663,6 +663,11 @@ def delete_metadata(self, sim_id: str, key: str) -> List[str]: res = self.delete("simulation/metadata/" + sim_id, {"key": key}) return [data["value"] for data in res.json()] + @try_request + def get_simulation_data(self, sim_id: str, path: str) -> Dict[str, Any]: + res = self.get(f"simulation/{sim_id}/data", params={"path": path}) + return res.json() + @try_request def get_directory(self) -> str: res = self.get("staging_dir") diff --git a/src/simdb/database/models/simulation.py b/src/simdb/database/models/simulation.py index 201e9bf..6b8d825 100644 --- a/src/simdb/database/models/simulation.py +++ b/src/simdb/database/models/simulation.py @@ -361,14 +361,19 @@ def from_data(cls, data: Dict[str, Union[str, Dict, List]]) -> "Simulation": outputs = checked_get(data, "outputs", list) simulation.outputs = [File.from_data(el) for el in outputs] if "metadata" in data: - metadata = checked_get(data, "metadata", list) - meta_dict = {} - for el in metadata: - if not isinstance(el, dict): - raise Exception("corrupted metadata element - expected dictionary") - if "element" in el and "value" in el: - meta_dict[el["element"]] = el["value"] - simulation._set_metadata_dict(meta_dict) + metadata = data.get("metadata") + if isinstance(metadata, list): + meta_dict = {} + for el in metadata: + if not isinstance(el, dict): + raise Exception( + "corrupted metadata element - expected dictionary" + ) + if "element" in el and "value" in el: + meta_dict[el["element"]] = el["value"] + simulation._set_metadata_dict(meta_dict) + elif isinstance(metadata, dict): + simulation._set_metadata_dict(metadata) return simulation @classmethod diff --git a/src/simdb/json.py b/src/simdb/json.py index 89301c8..a50c7bd 100644 --- a/src/simdb/json.py +++ b/src/simdb/json.py @@ -22,7 +22,10 @@ def _custom_hook(obj: Dict[str, str]) -> Any: return uuid.UUID(obj["hex"]) elif obj["_type"] == "numpy.ndarray": np_bytes = base64.decodebytes(obj["bytes"].encode()) - return np.frombuffer(np_bytes, dtype=obj["dtype"]) + arr = np.frombuffer(np_bytes, dtype=obj["dtype"]) + if "shape" in obj: + arr = arr.reshape(obj["shape"]) + return arr else: obj_type = obj["_type"] raise ValueError(f"Unknown type to deserialise {obj_type}.") @@ -49,4 +52,16 @@ def default(self, o: Any) -> Any: return {"_type": "uuid.UUID", "hex": o.hex} elif isinstance(o, enum.Enum): return o.value + elif isinstance(o, np.ndarray): + encoded_bytes = base64.b64encode(o.data).decode() + return { + "_type": "numpy.ndarray", + "dtype": o.dtype.name, + "shape": o.shape, + "bytes": encoded_bytes, + } + elif isinstance(o, np.integer): + return int(o) + elif isinstance(o, np.floating): + return float(o) return super().default(o) diff --git a/src/simdb/remote/apis/v1_2/__init__.py b/src/simdb/remote/apis/v1_2/__init__.py index 920f303..6b8bc01 100644 --- a/src/simdb/remote/apis/v1_2/__init__.py +++ b/src/simdb/remote/apis/v1_2/__init__.py @@ -10,6 +10,7 @@ from simdb.remote.core.typing import current_app from simdb.remote.models import StagingDirectoryResponse +from .simulation_data import api as data_ns from .simulations import api as sim_ns api = Api( @@ -31,7 +32,7 @@ ) api.add_namespace(sim_ns) -namespaces = [metadata_ns, watcher_ns, file_ns, sim_ns] +namespaces = [metadata_ns, watcher_ns, file_ns, sim_ns, data_ns] @api.route("/staging_dir", defaults={"sim_hex": None}) diff --git a/src/simdb/remote/apis/v1_2/simulation_data.py b/src/simdb/remote/apis/v1_2/simulation_data.py new file mode 100644 index 0000000..ff1c8b2 --- /dev/null +++ b/src/simdb/remote/apis/v1_2/simulation_data.py @@ -0,0 +1,189 @@ +"""IMAS simulation data endpoint: /data. + +TODO: Temporary solution to retrieve data (for IBEX backend) +""" + +from typing import Annotated, Any, NamedTuple + +import numpy as np +from flask_restx import Namespace, Resource +from imas.ids_defs import EMPTY_FLOAT +from imas.ids_primitive import IDSPrimitive + +from simdb.cli.manifest import DataObject +from simdb.database import DatabaseError +from simdb.imas.utils import ( + ImasError, + open_imas, +) +from simdb.remote.core.auth import User, requires_auth +from simdb.remote.core.pydantic_utils import ( + Query, + ResponseException, + ServerException, + pydantic_validate, +) +from simdb.remote.core.typing import current_app +from simdb.remote.models import ImasDataQueryParams, ImasDataResponse, QuantityData +from simdb.uri import URI + +api = Namespace("data", path="/") + + +# Helpers + + +def _to_python(value: Any) -> Any: + """Convert a value returned by IDSPrimitive.value to a JSON-serialisable + Python object.""" + if isinstance(value, np.ndarray): + flat = value.tolist() + + def _clean(v): + if isinstance(v, float) and ( + v != v or v == float("inf") or v == float("-inf") or v == EMPTY_FLOAT + ): + return None + if isinstance(v, list): + return [_clean(x) for x in v] + return v + + return _clean(flat) + return value + + +def _parse_ids_path(path: str) -> tuple: + """Parse ``ids_name[:occurrence][/ids_path]`` into a 3-tuple""" + head, _, ids_path = path.partition("/") + if ":" in head: + ids_name, occ_str = head.split(":", 1) + try: + occurrence = int(occ_str) + except ValueError as exc: + raise ValueError( + f"Invalid occurrence in path '{path}': '{occ_str}'" + ) from exc + else: + ids_name, occurrence = head, 0 + return ids_name, occurrence, ids_path + + +def _get_coordinates(node: IDSPrimitive, ids_name: str) -> list: + """Return a :class:`QuantityData` for each coordinate dimension of *node*.""" + coords = [] + for i in range(node.metadata.ndim): + coord = node.coordinates[i] + if isinstance(coord, IDSPrimitive): + data = ( + _to_python(coord.value) + if coord.has_value + else list(range(node.shape[i])) + ) + coords.append( + QuantityData( + name=f"{ids_name}/{coord._path}", + units=coord.metadata.units or "", + data=data, + ) + ) + else: + # Index-based coordinate: coord is already a numpy arange + coords.append( + QuantityData( + name=f"dim_{i + 1}", + units="", + data=coord.tolist(), + ) + ) + return coords + + +def _get_ids_node(entry, ids_name: str, occurrence: int, ids_path: str) -> IDSPrimitive: + """Return the :class:`IDSPrimitive` leaf node at *ids_path* inside *ids_name*.""" + ids_obj = entry.get( + ids_name, + occurrence, + lazy=True, + autoconvert=False, + ignore_unknown_dd_version=True, + ) + node = ids_obj[ids_path] if ids_path else ids_obj + if not isinstance(node, IDSPrimitive): + raise ValueError( + f"path does not point to a scalar/array leaf " + f"(reached {type(node).__name__}); add more path segments" + ) + if not node.has_value: + raise ValueError("field is not populated (no data written)") + return node + + +class _SimulationImasFile(NamedTuple): + simulation: Any + imas_file: Any + + +def _get_simulation_and_imas_file(sim_id: str) -> _SimulationImasFile: + try: + simulation = current_app.db.get_simulation(sim_id) + except DatabaseError as exc: + raise ResponseException(str(exc), 404) from exc + + imas_outputs = [f for f in simulation.outputs if f.type == DataObject.Type.IMAS] + if not imas_outputs: + raise ResponseException(f"Simulation {sim_id} has no IMAS output files", 404) + + return _SimulationImasFile(simulation, imas_outputs[0]) + + +# Endpoints + + +@api.route("/simulation//data") +class SimulationImasData(Resource): + @requires_auth() + @pydantic_validate(api) + def get( + self, + sim_id: str, + user: User, + params: Annotated[ImasDataQueryParams, Query()], + ) -> ImasDataResponse: + """Return the value at a given IDS path for a simulation's IMAS output.""" + result = _get_simulation_and_imas_file(sim_id) + + try: + ids_name, occurrence, ids_path = _parse_ids_path(params.path) + except ValueError as exc: + raise ResponseException(str(exc)) from exc + + try: + imas_uri = URI(str(result.imas_file.uri)) + if imas_uri.authority.host and "cache_mode" not in imas_uri.query: + imas_uri.query.set("cache_mode", "none") + entry = open_imas(imas_uri) + with entry: + node = _get_ids_node(entry, ids_name, occurrence, ids_path) + coordinates = _get_coordinates(node, ids_name) + field = QuantityData( + name=f"{ids_name}/{node._path}", + units=node.metadata.units or "", + data=_to_python(node.value), + ) + except (ValueError, AttributeError, IndexError, KeyError) as exc: + raise ResponseException(f"Invalid IDS path '{params.path}': {exc}") from exc + except ImasError as exc: + raise ServerException(f"Failed to open IMAS data: {exc}") from exc + except Exception as exc: + msg = str(exc) + if "is empty" in msg or "not found" in msg.lower(): + raise ResponseException(msg, 404) from exc + raise ServerException(msg) from exc + + return ImasDataResponse( + simulation=str(result.simulation.uuid), + path=params.path, + occurrence=occurrence, + field=field, + coordinates=coordinates, + ) diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py index 0a8218c..60e063b 100644 --- a/src/simdb/remote/models.py +++ b/src/simdb/remote/models.py @@ -26,7 +26,9 @@ BeforeValidator, ConfigDict, Field, + InstanceOf, PlainSerializer, + field_validator, model_validator, ) from pydantic import ( @@ -108,21 +110,6 @@ class RangeValue(BaseModel): max: float -MetadataValue = Union[ - CustomUUID, - str, - int, - float, - bool, - list, - RangeValue, - dict[str, Any], - None, -] -"""Supported types for simulation metadata values. Numpy arrays and regular arrays -containing numeric data are automatically converted to RangeValue.""" - - class StatusPatchData(BaseModel): """Post data for updating simulation status.""" @@ -181,6 +168,53 @@ def __getitem__(self, item) -> FileData: return self.root[item] +def _deserialize_numpy(v: Any) -> Any: + if isinstance(v, np.ndarray): + return v + if isinstance(v, dict) and v.get("_type") == "numpy.ndarray": + np_bytes = base64.b64decode(v["bytes"].encode()) + arr = np.frombuffer(np_bytes, dtype=v["dtype"]) + if "shape" in v: + arr = arr.reshape(v["shape"]) + return arr + raise ValueError(f"Cannot deserialize {v} to np.ndarray") + + +def _serialize_numpy(o: np.ndarray) -> dict: + """Serialize numpy arrays to dict format for the web dashboard.""" + encoded_bytes = base64.b64encode(o.data).decode() + return { + "_type": "numpy.ndarray", + "dtype": o.dtype.name, + "shape": o.shape, + "bytes": encoded_bytes, + } + + +NumpyArray = Annotated[ + InstanceOf[np.ndarray], + BeforeValidator(_deserialize_numpy), + PlainSerializer(_serialize_numpy, return_type=dict), +] + + +MetadataValue = Union[ + CustomUUID, + str, + int, + float, + bool, + RangeValue, + list, + dict, + NumpyArray, + None, +] +"""Supported types for simulation metadata values. RangeValue, numpy arrays and +scalars are automatically converted to their plain Python equivalents before +validation.""" + + class MetadataData(BaseModel): """Key-value pair for simulation metadata.""" @@ -352,7 +386,7 @@ class SimulationListItem(BaseModel): """Alias of the simulation.""" datetime: str """Creation timestamp.""" - metadata: Optional[MetadataDataList] = None + metadata: MetadataDataList = MetadataDataList() """Simulation metadata.""" @@ -572,6 +606,48 @@ class StagingDirectoryResponse(BaseModel): """Path to the staging dir.""" +class ImasDataQueryParams(BaseModel): + """Query parameters for the IMAS field-data endpoint.""" + + path: str + """Full IDS path including IDS name and optional occurrence.""" + + @field_validator("path", mode="before") + @classmethod + def _strip_path(cls, v: Any) -> str: + v = str(v).strip() + if not v: + raise ValueError("must not be empty") + return v + + +class QuantityData(BaseModel): + """A named, unit-bearing data quantity (field value or coordinate).""" + + name: str + """IDS path of this quantity relative to the IDS root""" + units: str + """Physical units of the quantity""" + data: Any + """Data value: a Python scalar for 0-D quantities, or a nested list for + arrays. """ + + +class ImasDataResponse(BaseModel): + """Response from the IMAS field-data endpoint.""" + + simulation: str + """UUID of the simulation.""" + path: str + """Requested IDS path.""" + occurrence: int + """IDS occurrence index.""" + field: QuantityData + """The requested quantity""" + coordinates: List[QuantityData] + """Coordinates for each dimension of *field*, in dimension order.""" + + class ErrorResponse(BaseModel): """Response model for server errors.""" diff --git a/src/simdb/validation/validator.py b/src/simdb/validation/validator.py index 2daa21c..99f7edb 100644 --- a/src/simdb/validation/validator.py +++ b/src/simdb/validation/validator.py @@ -28,6 +28,43 @@ class CustomValidator(ValidatorBase): types_mapping = cast(Any, cerberus.Validator).types_mapping.copy() types_mapping["numpy"] = cerberus.TypeDefinition("numpy", (np.ndarray,), ()) + @staticmethod + def _range_value(value) -> Optional[np.ndarray]: + if isinstance(value, dict) and {"min", "max"} <= value.keys(): + return np.array([value["min"], value["max"]]) + if hasattr(value, "min") and hasattr(value, "max"): + return np.array([value.min, value.max]) + return None + + @staticmethod + def _value_preview(value, max_length: int = 200) -> str: + preview = repr(value) + if len(preview) > max_length: + preview = f"{preview[:max_length]}..." + return f"{type(value).__name__} {preview}" + + def _numeric_array(self, field, value) -> Optional[np.ndarray]: + if not isinstance(value, np.ndarray): + self._error( + field, + f"Value is not a numpy array: {self._value_preview(value)}", + ) + return None + try: + value = value.astype(float, copy=False) + except (TypeError, ValueError): + self._error( + field, + f"Values in numpy array must be numeric: {self._value_preview(value)}", + ) + return None + + value = value[~np.isnan(value)] + if value.size == 0: + self._error(field, "Values in numpy array are NaN or empty") + return None + return value + def _validate_exists(self, check_exists, field, value): """The rule's arguments are validated against this schema: {'type': ['string'], @@ -40,11 +77,9 @@ def _validate_min_value(self, min_value, field, value): {'type': 'float'} """ - if not isinstance(value, np.ndarray): - value = value[~np.isnan(value)] - if value.size == 0: - self._error(field, "Values in numpy array are NaN or empty") - self._error(field, "Value is not a numpy array") + value = self._numeric_array(field, value) + if value is None: + return if min_value is not None and value.min() < min_value: self._error(field, f"Minimum {value.min()} less than {min_value}") @@ -53,11 +88,9 @@ def _validate_max_value(self, max_value, field, value): {'type': 'float'} """ - if not isinstance(value, np.ndarray): - value = value[~np.isnan(value)] - if value.size == 0: - self._error(field, "Values in numpy array are NaN or empty") - self._error(field, "Value is not a numpy array") + value = self._numeric_array(field, value) + if value is None: + return if max_value is not None and value.max() > max_value: self._error(field, f"Maximum {value.max()} greater than {max_value}") @@ -65,9 +98,9 @@ def _compare(self, comparison, field, value, comparator: str, message: str): if comparison is None: return if isinstance(value, np.ndarray): - value = value[~np.isnan(value)] - if value.size == 0: - self._error(field, "Values in numpy array are NaN or empty") + value = self._numeric_array(field, value) + if value is None: + return if not getattr(value, comparator)(comparison).all(): self._error(field, f"Values are not {message} {comparison}") elif isinstance(value, float): @@ -112,6 +145,9 @@ def _normalize_coerce_float(cls, value): def _normalize_coerce_numpy(cls, value): if isinstance(value, np.ndarray): return value + range_value = cls._range_value(value) + if range_value is not None: + return range_value elif isinstance(value, str): return np.fromstring(value[1:-1], sep=" ") else: diff --git a/tests/cli/test_cli_simulation_command.py b/tests/cli/test_cli_simulation_command.py index 0120fc0..5f61c60 100644 --- a/tests/cli/test_cli_simulation_command.py +++ b/tests/cli/test_cli_simulation_command.py @@ -85,3 +85,55 @@ def test_simulation_validate_command(remote_api, get_local_db): runner = CliRunner() result = runner.invoke(cli, [f"--config-file={config_file}", "simulation"]) assert result.exception is None + + +@mock.patch("simdb.cli.commands.simulation.show_quantity_textual_plot") +@mock.patch("simdb.cli.commands.simulation.RemoteAPI") +def test_simulation_data_command(mock_remote_api_cls, mock_textual_plot): + """``simdb simulation data`` prints field info.""" + mock_api = mock_remote_api_cls.return_value + mock_api.get_simulation_data.return_value = { + "simulation": "a304a6955b3f11f1809bd4f5ef75ec04", + "path": "core_profiles/profiles_1d[0]/electrons/temperature", + "occurrence": 0, + "field": { + "name": "core_profiles/profiles_1d[0]/electrons/temperature", + "units": "eV", + "data": [1000.0, 1200.0, 900.0], + }, + "coordinates": [ + { + "name": "core_profiles/profiles_1d[0]/grid/rho_tor_norm", + "units": "", + "data": [0.0, 0.5, 1.0], + } + ], + } + + config_file = config_test_file() + runner = CliRunner() + result = runner.invoke( + cli, + [ + f"--config-file={config_file}", + "simulation", + "data", + "test_sim", + "core_profiles/profiles_1d[0]/electrons/temperature", + ], + ) + + assert result.exception is None, result.output + mock_api.get_simulation_data.assert_called_once_with( + "test_sim", "core_profiles/profiles_1d[0]/electrons/temperature" + ) + result_data = mock_api.get_simulation_data.return_value + mock_textual_plot.assert_called_once_with( + result_data["field"], + label="field", + x_quantity=result_data["coordinates"][0], + ) + assert "simulation : a304a6955b3f11f1809bd4f5ef75ec04" in result.output + assert "shape (3,)" not in result.output + assert "1000" not in result.output + assert "1200" not in result.output diff --git a/tests/remote/api/test_metadata.py b/tests/remote/api/test_metadata.py index dbd99ce..f3e5e32 100644 --- a/tests/remote/api/test_metadata.py +++ b/tests/remote/api/test_metadata.py @@ -51,6 +51,30 @@ def test_get_metadata_values(client): assert "machine-a" in rv.json or "machine-b" in rv.json +def test_get_metadata_list_value(client): + """Test that float lists are auto-converted to Range (new behavior).""" + list_data = [1.0, 2.5, 3.7] + simulation_data_1 = generate_simulation_data(metadata={"ip": list_data}) + rv_post_1 = post_simulation(client, simulation_data_1) + assert rv_post_1.status_code == 200 + + rv = client.get("/v1.2/metadata", headers=HEADERS) + assert rv.status_code == 200 + mkeys = MetadataKeyInfoList.model_validate_json(rv.data) + mkey = next((k for k in mkeys.root if k.name == "ip"), None) + assert mkey is not None, "ip key not found in metadata keys" + assert mkey.type == "Range" + + rv = client.get("/v1.2/metadata/ip", headers=HEADERS) + assert rv.status_code == 200 + mdata = MetadataValueList.model_validate_json(rv.data) + assert len(mdata.root) == 1 + a = mdata.root[0] + assert isinstance(a, RangeValue) + assert a.min == 1.0 + assert a.max == 3.7 + + def test_get_metadata_range_value(client): """Test metadata Range storage""" # Create a simulation with a range metadata value