diff --git a/PySpotObserver/QUICKSTART.md b/PySpotObserver/QUICKSTART.md index 7178474..b50765b 100644 --- a/PySpotObserver/QUICKSTART.md +++ b/PySpotObserver/QUICKSTART.md @@ -8,7 +8,14 @@ pip install -r requirements.txt ``` -2. **Install package in development mode:** +2. **Optional vision pipeline support:** + ```bash + pip install -e ".[vision]" + ``` + +The dependency source of truth is `setup.py`; `requirements.txt` installs the development extra without duplicating package lists. + +3. **Minimal install without dev tools (alternative to step 1):** ```bash pip install -e . ``` @@ -215,4 +222,6 @@ logging.basicConfig( **Import errors**: Make sure all dependencies are installed: `pip install -r requirements.txt` +**Vision pipeline errors**: Install `pip install -e ".[vision]"` and set `vision_model_path` in config, pass `--vision-model-path`, or set `PYSPOTOBSERVER_VISION_MODEL` + **No images received**: Check that cameras are not already in use by another client diff --git a/PySpotObserver/README.md b/PySpotObserver/README.md index 4363785..25875db 100644 --- a/PySpotObserver/README.md +++ b/PySpotObserver/README.md @@ -11,6 +11,7 @@ A clean, Pythonic interface for streaming camera data from Boston Dynamics Spot - **YAML Configuration**: Load settings from config files or pass as parameters - **Multi-stream**: Support for multiple concurrent camera streams - **Thread-safe**: Background streaming with thread-safe image buffering +- **Optional Vision Pipeline**: ONNX Runtime inference can be enabled explicitly without making it a base dependency ## Installation @@ -27,6 +28,14 @@ pip install -e . pip install -e ".[dev]" ``` +### With optional vision pipeline support + +```bash +pip install -e ".[vision]" +``` + +`requirements.txt` delegates to `setup.py` for a development install, so dependency declarations stay in one place. + ## Quick Start ### Basic Synchronous Usage @@ -97,6 +106,7 @@ username: "" password: "" image_buffer_size: 5 image_quality_percent: 100.0 +request_timeout_seconds: 10.0 ``` ## Architecture @@ -165,7 +175,7 @@ The Python implementation differs in these ways: - **No CUDA**: Uses CPU-only NumPy arrays instead of GPU memory - **FIFO Queue**: Simpler queue.Queue instead of custom circular buffer - **Threading**: Python threading instead of C++ jthread -- **No ML Pipeline**: Focuses on camera streaming only (no inference) +- **Optional ML Pipeline**: Inference is available through an optional ONNX Runtime extra - **Simplified**: Removes Unity plugin and DLL export complexity ## Requirements @@ -175,6 +185,7 @@ The Python implementation differs in these ways: - NumPy - OpenCV (opencv-python) - PyYAML +- ONNX Runtime GPU (`onnxruntime-gpu`) only when installing the `vision` extra ## Contributing diff --git a/PySpotObserver/examples/README.md b/PySpotObserver/examples/README.md index 80663eb..f223391 100644 --- a/PySpotObserver/examples/README.md +++ b/PySpotObserver/examples/README.md @@ -10,9 +10,10 @@ Install the package and dependencies first: ```bash pip install -r requirements.txt -pip install -e . ``` +For `--vision-pipeline`, also install `pip install -e ".[vision]"` and provide a model path with `--vision-model-path`, config, or `PYSPOTOBSERVER_VISION_MODEL`. + ## Configuration The examples load `examples/config_example.yaml` by default. Fill in at least: @@ -37,6 +38,7 @@ python examples/basic_streaming.py --robot-ip 192.168.80.3 --username --p - asynchronous streaming with `--async-mode` - one or two stream configurations, mirrored across one or two robots - optional OpenCV display +- optional ONNX vision pipeline with `--vision-pipeline` - optional timing summaries with `--print-timing` Show the full CLI: diff --git a/PySpotObserver/examples/basic_streaming.py b/PySpotObserver/examples/basic_streaming.py index 359b806..f15274c 100644 --- a/PySpotObserver/examples/basic_streaming.py +++ b/PySpotObserver/examples/basic_streaming.py @@ -14,9 +14,11 @@ from contextlib import AsyncExitStack, ExitStack from dataclasses import dataclass import logging +from pathlib import Path import time from typing import Sequence + import cv2 import numpy as np @@ -123,6 +125,22 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Print per-stream timing summary at the end of the run.", ) + parser.add_argument( + "--vision-pipeline", + action="store_true", + help="Run the vision pipeline on the Spot outputs" + ) + parser.add_argument( + "--vision-model-path", + type=Path, + help="ONNX model path for --vision-pipeline. Overrides config and environment.", + ) + parser.add_argument( + "--vision-provider", + action="append", + dest="vision_providers", + help="ONNX Runtime provider to request. Repeat to set provider preference order.", + ) return parser.parse_args() @@ -259,7 +277,10 @@ def run_sync(args: argparse.Namespace, specs: list[StreamSpec]) -> int: stream = streams[spec.label] fetch_start = time.perf_counter() - rgb_images, depth_images = stream.get_current_images(timeout=args.timeout) + rgb_images, depth_images = stream.get_current_images( + timeout=args.timeout, + run_pipeline=args.vision_pipeline, + ) fetch_elapsed = time.perf_counter() - fetch_start display_elapsed = 0.0 @@ -282,9 +303,10 @@ def run_sync(args: argparse.Namespace, specs: list[StreamSpec]) -> int: return 0 -async def fetch_stream_async(stream, timeout: float) -> FetchResult: +async def fetch_stream_async(stream, timeout: float, vision_pipeline: bool) -> FetchResult: fetch_start = time.perf_counter() - rgb_images, depth_images = await stream.async_get_current_images(timeout=timeout) + rgb_images, depth_images = await stream.async_get_current_images(timeout=timeout, run_pipeline=vision_pipeline) + return FetchResult( rgb_images=rgb_images, depth_images=depth_images, @@ -310,7 +332,7 @@ async def run_async(args: argparse.Namespace, specs: list[StreamSpec]) -> int: start_time = time.perf_counter() while time.perf_counter() - start_time < args.duration: results = await asyncio.gather( - *(fetch_stream_async(streams[spec.label], args.timeout) for spec in specs) + *(fetch_stream_async(streams[spec.label], args.timeout, args.vision_pipeline) for spec in specs) ) should_quit = False diff --git a/PySpotObserver/examples/common_cli.py b/PySpotObserver/examples/common_cli.py index 802146d..509e9c5 100644 --- a/PySpotObserver/examples/common_cli.py +++ b/PySpotObserver/examples/common_cli.py @@ -109,6 +109,10 @@ def build_config_from_args(args: argparse.Namespace) -> SpotConfig: config.password = args.password if hasattr(args, "image_buffer_size") and args.image_buffer_size is not None: config.image_buffer_size = args.image_buffer_size + if hasattr(args, "vision_model_path") and args.vision_model_path is not None: + config.vision_model_path = str(args.vision_model_path) + if hasattr(args, "vision_providers") and args.vision_providers: + config.vision_providers = args.vision_providers if not config.robot_ip: raise ValueError("Robot IP must be set in --config or with --robot-ip.") diff --git a/PySpotObserver/examples/config_example.yaml b/PySpotObserver/examples/config_example.yaml index 501680d..bec26cb 100644 --- a/PySpotObserver/examples/config_example.yaml +++ b/PySpotObserver/examples/config_example.yaml @@ -12,6 +12,12 @@ image_buffer_size: 5 image_quality_percent: 100.0 request_timeout_seconds: 10.0 +# Optional vision pipeline settings +# vision_model_path: "C:/path/to/model.onnx" +# vision_providers: +# - "CUDAExecutionProvider" +# - "CPUExecutionProvider" + # Advanced settings sdk_name: "PySpotObserver" connection_retry_attempts: 3 diff --git a/PySpotObserver/pyspotobserver/__init__.py b/PySpotObserver/pyspotobserver/__init__.py index 8b3f448..1b124ae 100644 --- a/PySpotObserver/pyspotobserver/__init__.py +++ b/PySpotObserver/pyspotobserver/__init__.py @@ -6,13 +6,16 @@ """ from .config import SpotConfig, CameraType -from .connection import SpotConnection -from .camera_stream import SpotCamStream +from .connection import SpotAuthenticationError, SpotConnection, SpotConnectionError +from .camera_stream import SpotCamStream, SpotCamStreamError __version__ = "0.1.0" __all__ = [ "SpotConfig", "CameraType", "SpotConnection", + "SpotConnectionError", + "SpotAuthenticationError", "SpotCamStream", + "SpotCamStreamError", ] diff --git a/PySpotObserver/pyspotobserver/camera_stream.py b/PySpotObserver/pyspotobserver/camera_stream.py index 729b51b..bcf4531 100644 --- a/PySpotObserver/pyspotobserver/camera_stream.py +++ b/PySpotObserver/pyspotobserver/camera_stream.py @@ -8,7 +8,7 @@ import time from dataclasses import dataclass from queue import Queue, Empty -from typing import Optional, List, Tuple +from typing import Dict, Optional, List, Tuple import cv2 import numpy as np @@ -18,6 +18,7 @@ from .config import SpotConfig, CameraType from .color_correction import _ROBOT_CCMS + logger = logging.getLogger(__name__) @@ -92,6 +93,13 @@ def __init__( # Preallocated frame pool (initialized once we know image shapes) self._frame_pool: List[ImageFrame] = [] self._frame_pool_index: int = 0 + self._image_requests: List[image_pb2.ImageRequest] = [] + + # Optional per-stream vision pipeline, imported lazily. + self._vision_pipeline = None + + # Scratch buffers for color correction by image shape. + self._ccm_scratch_by_shape: Dict[Tuple[int, ...], np.ndarray] = {} # Color correction matrices (None if robot IP is not recognized) self._ccms: Optional[dict] = _ROBOT_CCMS.get(config.robot_ip) @@ -154,6 +162,7 @@ def start_streaming(self, camera_mask: int) -> None: # Parse camera mask to get ordered list of cameras self._camera_mask = camera_mask self._camera_order = self._parse_camera_mask(camera_mask) + self._image_requests = self._build_image_requests() logger.info( f"Stream '{self._stream_id}': Starting with cameras: " @@ -164,6 +173,7 @@ def start_streaming(self, camera_mask: int) -> None: self._clear_queue() self._frame_pool = [] self._frame_pool_index = 0 + self._ccm_scratch_by_shape = {} self._frame_count = 0 self._error_count = 0 @@ -217,6 +227,7 @@ def get_current_images( self, timeout: Optional[float] = None, copy: bool = False, + run_pipeline: bool = False, ) -> Tuple[List[np.ndarray], List[np.ndarray]]: """ Get the most recent frame of images. @@ -249,12 +260,16 @@ def get_current_images( frame = self._peek_latest_frame() if frame is not None: - if not copy: - return frame.rgb_images, frame.depth_images - return ( - [img.copy() for img in frame.rgb_images], - [img.copy() for img in frame.depth_images], - ) + rgb, depth = frame.rgb_images, frame.depth_images + + if copy: + rgb = [img.copy() for img in rgb] + depth = [img.copy() for img in depth] + + if run_pipeline: + return self._run_vision_pipeline(rgb, depth) + + return rgb, depth wait_timeout = poll_interval if deadline is not None: @@ -270,17 +285,22 @@ def get_current_images( except Empty: continue - if not copy: - return frame.rgb_images, frame.depth_images - return ( - [img.copy() for img in frame.rgb_images], - [img.copy() for img in frame.depth_images], - ) + rgb, depth = frame.rgb_images, frame.depth_images + + if copy: + rgb = [img.copy() for img in rgb] + depth = [img.copy() for img in depth] + + if run_pipeline: + return self._run_vision_pipeline(rgb, depth) + + return rgb, depth async def async_get_current_images( self, timeout: Optional[float] = None, copy: bool = False, + run_pipeline: bool = False, ) -> Tuple[List[np.ndarray], List[np.ndarray]]: """ Async version of get_current_images(). @@ -300,9 +320,42 @@ async def async_get_current_images( SpotCamStreamError: If not streaming or timeout occurs """ loop = asyncio.get_event_loop() - return await loop.run_in_executor( - None, self.get_current_images, timeout, copy - ) + + if not run_pipeline: + return await loop.run_in_executor( + None, self.get_current_images, timeout, copy, False + ) + + # Run full pipeline in executor + def _get_and_process(): + rgb, depth = self.get_current_images(timeout, copy, False) + return self._run_vision_pipeline(rgb, depth) + + return await loop.run_in_executor(None, _get_and_process) + + def _run_vision_pipeline( + self, + rgb_images: List[np.ndarray], + depth_images: List[np.ndarray], + ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + try: + from .vision_pipeline import VisionPipeline, VisionPipelineError + except ImportError as exc: + raise SpotCamStreamError( + "Vision pipeline support is unavailable. " + 'Install PySpotObserver with: pip install -e ".[vision]"' + ) from exc + + if self._vision_pipeline is None: + try: + self._vision_pipeline = VisionPipeline.from_config(self._config) + except VisionPipelineError as exc: + raise SpotCamStreamError(str(exc)) from exc + + try: + return self._vision_pipeline.run(rgb_images, depth_images) + except VisionPipelineError as exc: + raise SpotCamStreamError(str(exc)) from exc def get_camera_order(self) -> List[CameraType]: """ @@ -342,13 +395,10 @@ def _stream_loop(self) -> None: while not self._stop_event.is_set() and self._streaming: try: - # Build image requests for all cameras (RGB + depth) - image_requests = self._build_image_requests() - # Request images from robot start_time = time.monotonic() image_responses = self._image_client.get_image( - image_requests, + self._image_requests, timeout=self._config.request_timeout_seconds, ) request_time = time.monotonic() - start_time @@ -532,7 +582,7 @@ def _fill_frame_from_responses( ccm = None if self._ccms is not None: - ccm = self._ccms[self._camera_order[i]] + ccm = self._ccms[self._camera_order[i]] self._convert_image_response_inplace( responses[rgb_idx], is_depth=False, @@ -702,8 +752,11 @@ def _convert_image_response_inplace( ) np.multiply(img, 1.0 / 255.0, out=out_array, casting="unsafe") if ccm is not None: - img = img @ ccm.T - np.clip(img, 0.0, 1.0, out=img) + self._apply_ccm_inplace( + out_array, + ccm, + scratch=self._get_ccm_scratch(out_array.shape), + ) return elif image_proto.format == image_pb2.Image.FORMAT_RAW: @@ -743,6 +796,12 @@ def _convert_image_response_inplace( f"Output array shape mismatch: {out_array.shape} vs {img.shape}" ) np.multiply(img, 1.0 / 255.0, out=out_array, casting="unsafe") + if ccm is not None: + self._apply_ccm_inplace( + out_array, + ccm, + scratch=self._get_ccm_scratch(out_array.shape), + ) return elif pixel_format == image_pb2.Image.PIXEL_FORMAT_RGBA_U8: @@ -755,6 +814,12 @@ def _convert_image_response_inplace( f"Output array shape mismatch: {out_array.shape} vs {img.shape}" ) np.multiply(img, 1.0 / 255.0, out=out_array, casting="unsafe") + if ccm is not None: + self._apply_ccm_inplace( + out_array, + ccm, + scratch=self._get_ccm_scratch(out_array.shape), + ) return elif pixel_format == image_pb2.Image.PIXEL_FORMAT_GREYSCALE_U8: @@ -768,6 +833,12 @@ def _convert_image_response_inplace( np.multiply(img, 1.0 / 255.0, out=channel, casting="unsafe") out_array[:, :, 1] = channel out_array[:, :, 2] = channel + if ccm is not None: + self._apply_ccm_inplace( + out_array, + ccm, + scratch=self._get_ccm_scratch(out_array.shape), + ) return else: @@ -780,15 +851,30 @@ def _convert_image_response_inplace( f"Unsupported image format: {image_proto.format}" ) + def _get_ccm_scratch(self, shape: Tuple[int, ...]) -> np.ndarray: + scratch = self._ccm_scratch_by_shape.get(shape) + if scratch is None: + scratch = np.empty(shape, dtype=np.float32) + self._ccm_scratch_by_shape[shape] = scratch + return scratch + @staticmethod - def _apply_ccm_inplace(img: np.ndarray, matrix: np.ndarray) -> None: + def _apply_ccm_inplace( + img: np.ndarray, + matrix: np.ndarray, + scratch: Optional[np.ndarray] = None, + ) -> None: """ Apply a 3x3 color correction matrix to an (H, W, 3) float32 image in-place. Computes corrected = matrix @ pixel for each pixel (column-vector form), equivalent to img @ matrix.T in numpy row-vector form, then clips to [0, 1]. """ - img[:] = img @ matrix.T + if scratch is None: + img[:] = img @ matrix.T + else: + np.einsum("...c,dc->...d", img, matrix, out=scratch) + np.copyto(img, scratch) np.clip(img, 0.0, 1.0, out=img) @staticmethod diff --git a/PySpotObserver/pyspotobserver/config.py b/PySpotObserver/pyspotobserver/config.py index 7c8b22e..84142ee 100644 --- a/PySpotObserver/pyspotobserver/config.py +++ b/PySpotObserver/pyspotobserver/config.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from enum import IntFlag from pathlib import Path -from typing import Dict, Any +from typing import Any, Dict, List, Optional, Union import yaml @@ -67,9 +67,15 @@ class SpotConfig: image_quality_percent: float = 100.0 """JPEG quality for RGB images (0-100)""" - request_timeout_seconds: float = 500.0 + request_timeout_seconds: float = 10.0 """Timeout for image requests""" + vision_model_path: Optional[str] = None + """Optional ONNX model path for run_pipeline=True""" + + vision_providers: Optional[List[str]] = None + """Optional ONNX Runtime provider preference order""" + # Advanced settings sdk_name: str = "PySpotObserver" """Name to identify this SDK client""" @@ -83,8 +89,12 @@ class SpotConfig: extra_params: Dict[str, Any] = field(default_factory=dict) """Additional user-defined parameters""" + def __post_init__(self) -> None: + if self.request_timeout_seconds <= 0: + raise ValueError("request_timeout_seconds must be positive") + @classmethod - def from_yaml(cls, yaml_path: Path | str) -> "SpotConfig": + def from_yaml(cls, yaml_path: Union[Path, str]) -> "SpotConfig": """ Load configuration from a YAML file. @@ -110,7 +120,7 @@ def from_yaml(cls, yaml_path: Path | str) -> "SpotConfig": return cls(**data) - def to_yaml(self, yaml_path: Path | str) -> None: + def to_yaml(self, yaml_path: Union[Path, str]) -> None: """ Save configuration to a YAML file. @@ -128,11 +138,15 @@ def to_yaml(self, yaml_path: Path | str) -> None: 'image_buffer_size': self.image_buffer_size, 'image_quality_percent': self.image_quality_percent, 'request_timeout_seconds': self.request_timeout_seconds, + 'vision_model_path': self.vision_model_path, + 'vision_providers': self.vision_providers, 'sdk_name': self.sdk_name, 'connection_retry_attempts': self.connection_retry_attempts, 'connection_retry_delay_ms': self.connection_retry_delay_ms, } + data = {key: value for key, value in data.items() if value is not None} + if self.extra_params: data['extra_params'] = self.extra_params diff --git a/PySpotObserver/pyspotobserver/vision_pipeline.py b/PySpotObserver/pyspotobserver/vision_pipeline.py new file mode 100644 index 0000000..8bbdd26 --- /dev/null +++ b/PySpotObserver/pyspotobserver/vision_pipeline.py @@ -0,0 +1,299 @@ +""" +Optional ONNX Runtime vision pipeline support. +""" + +from __future__ import annotations + +import os +import threading +from pathlib import Path +from typing import List, Optional, Sequence, Tuple, Union + +import cv2 +import numpy as np + +from .config import SpotConfig + + +DEFAULT_MODEL_ENV_VAR = "PYSPOTOBSERVER_VISION_MODEL" +DEFAULT_PROVIDERS = ("CUDAExecutionProvider", "CPUExecutionProvider") +DEFAULT_DEPTH_SIZE = (120, 160) # (height, width) + + +class VisionPipelineError(Exception): + """Raised when the optional vision pipeline cannot run.""" + + +def _normalize_providers(providers: Optional[Union[Sequence[str], str]]) -> Tuple[str, ...]: + if providers is None: + return DEFAULT_PROVIDERS + if isinstance(providers, str): + return tuple(part.strip() for part in providers.split(",") if part.strip()) + return tuple(provider for provider in providers if provider) + + +def _dtype_for_onnx_type(type_name: str) -> np.dtype: + normalized = type_name.lower() + if "float16" in normalized: + return np.dtype(np.float16) + if "double" in normalized or "float64" in normalized: + return np.dtype(np.float64) + return np.dtype(np.float32) + + +def _depth_list_from_output(output: np.ndarray, batch_size: int) -> List[np.ndarray]: + output = np.asarray(output) + + if batch_size == 1 and output.ndim == 2: + output = output[np.newaxis, :, :] + elif output.ndim == 4 and output.shape[1] == 1: + output = output[:, 0, :, :] + elif output.ndim == 4 and output.shape[-1] == 1: + output = output[:, :, :, 0] + + if output.ndim != 3 or output.shape[0] != batch_size: + raise VisionPipelineError( + "Vision model output must be shaped as (B, H, W), " + "(B, 1, H, W), or (B, H, W, 1); " + f"got {tuple(output.shape)} for batch size {batch_size}" + ) + + return [np.asarray(output[i]) for i in range(batch_size)] + + +class VisionPipeline: + """ + Per-stream ONNX vision pipeline with reusable input buffers. + + ONNX Runtime is imported lazily so camera streaming remains available without + installing the optional vision extra. + """ + + def __init__( + self, + model_path: Union[str, os.PathLike[str]], + providers: Optional[Union[Sequence[str], str]] = None, + depth_size: Tuple[int, int] = DEFAULT_DEPTH_SIZE, + ): + self.model_path = Path(model_path).expanduser() + self.providers = _normalize_providers(providers) + self.depth_size = depth_size + self._lock = threading.Lock() + self._session = None + self._input_names: List[str] = [] + self._rgb_dtype = np.dtype(np.float32) + self._depth_dtype = np.dtype(np.float32) + self._rgb_buffer: Optional[np.ndarray] = None + self._depth_buffer: Optional[np.ndarray] = None + self._depth_resize_buffer: Optional[np.ndarray] = None + + if not self.model_path.exists(): + raise VisionPipelineError(f"Vision model not found: {self.model_path}") + + @classmethod + def from_config(cls, config: SpotConfig) -> "VisionPipeline": + extra_params = config.extra_params or {} + model_path = ( + config.vision_model_path + or extra_params.get("vision_model_path") + or os.environ.get(DEFAULT_MODEL_ENV_VAR) + ) + if not model_path: + raise VisionPipelineError( + "Vision model path is required for run_pipeline=True. " + "Set SpotConfig.vision_model_path, extra_params['vision_model_path'], " + f"or {DEFAULT_MODEL_ENV_VAR}." + ) + + providers = ( + config.vision_providers + or extra_params.get("vision_providers") + or extra_params.get("vision_provider") + ) + depth_size = extra_params.get("vision_depth_size", DEFAULT_DEPTH_SIZE) + if len(depth_size) != 2: + raise VisionPipelineError("vision_depth_size must contain height and width") + + return cls( + model_path=model_path, + providers=providers, + depth_size=(int(depth_size[0]), int(depth_size[1])), + ) + + def run( + self, + rgb_images: List[np.ndarray], + depth_images: List[np.ndarray], + ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + if not rgb_images or not depth_images: + raise VisionPipelineError("Vision pipeline requires at least one RGB/depth pair") + if len(rgb_images) != len(depth_images): + raise VisionPipelineError( + f"RGB/depth batch mismatch: {len(rgb_images)} RGB, {len(depth_images)} depth" + ) + + with self._lock: + self._init_session() + self._ensure_buffers(rgb_images, depth_images) + self._fill_buffers(rgb_images, depth_images) + + output = self._session.run( + None, + { + self._input_names[0]: self._rgb_buffer, + self._input_names[1]: self._depth_buffer, + }, + )[0] + + return rgb_images, _depth_list_from_output(output, len(rgb_images)) + + def _init_session(self) -> None: + if self._session is not None: + return + + try: + import onnxruntime as ort + except ImportError as exc: + raise VisionPipelineError( + "Vision pipeline requires ONNX Runtime. " + "Install PySpotObserver with the vision extra: " + 'pip install -e ".[vision]"' + ) from exc + + if hasattr(ort, "preload_dlls"): + ort.preload_dlls() + + available = set(ort.get_available_providers()) + selected = [provider for provider in self.providers if provider in available] + if not selected: + raise VisionPipelineError( + "None of the requested ONNX Runtime providers are available. " + f"requested={list(self.providers)}, available={sorted(available)}" + ) + + self._session = ort.InferenceSession( + str(self.model_path), + providers=selected, + ) + + inputs = self._session.get_inputs() + if len(inputs) < 2: + raise VisionPipelineError( + f"Vision model must expose at least 2 inputs; got {len(inputs)}" + ) + + self._input_names = [inputs[0].name, inputs[1].name] + self._rgb_dtype = _dtype_for_onnx_type(inputs[0].type) + self._depth_dtype = _dtype_for_onnx_type(inputs[1].type) + + def _ensure_buffers( + self, + rgb_images: Sequence[np.ndarray], + depth_images: Sequence[np.ndarray], + ) -> None: + batch_size = len(rgb_images) + h_rgb, w_rgb, channels = rgb_images[0].shape + if channels != 3: + raise VisionPipelineError(f"Expected RGB images with 3 channels, got {channels}") + + depth_h, depth_w = self.depth_size + rgb_shape = (batch_size, 3, h_rgb, w_rgb) + depth_shape = (batch_size, 1, depth_h, depth_w) + + if self._rgb_buffer is None or self._rgb_buffer.shape != rgb_shape: + self._rgb_buffer = np.empty(rgb_shape, dtype=self._rgb_dtype) + elif self._rgb_buffer.dtype != self._rgb_dtype: + self._rgb_buffer = np.empty(rgb_shape, dtype=self._rgb_dtype) + + if self._depth_buffer is None or self._depth_buffer.shape != depth_shape: + self._depth_buffer = np.empty(depth_shape, dtype=self._depth_dtype) + elif self._depth_buffer.dtype != self._depth_dtype: + self._depth_buffer = np.empty(depth_shape, dtype=self._depth_dtype) + + if self._depth_dtype == np.dtype(np.float16): + if ( + self._depth_resize_buffer is None + or self._depth_resize_buffer.shape != depth_shape + ): + self._depth_resize_buffer = np.empty(depth_shape, dtype=np.float32) + else: + self._depth_resize_buffer = None + + for index, rgb in enumerate(rgb_images): + if rgb.shape != (h_rgb, w_rgb, 3): + raise VisionPipelineError( + "All RGB images in a pipeline batch must have the same shape; " + f"image 0 is {(h_rgb, w_rgb, 3)}, image {index} is {rgb.shape}" + ) + for index, depth in enumerate(depth_images): + if depth.ndim != 2: + raise VisionPipelineError( + f"Depth image {index} must be 2D before pipeline processing; got {depth.shape}" + ) + + def _fill_buffers( + self, + rgb_images: Sequence[np.ndarray], + depth_images: Sequence[np.ndarray], + ) -> None: + depth_h, depth_w = self.depth_size + + for index, rgb in enumerate(rgb_images): + self._rgb_buffer[index, 0] = rgb[:, :, 0] + self._rgb_buffer[index, 1] = rgb[:, :, 1] + self._rgb_buffer[index, 2] = rgb[:, :, 2] + + for index, depth in enumerate(depth_images): + depth_dst = self._depth_buffer[index, 0] + resize_dst = ( + self._depth_resize_buffer[index, 0] + if self._depth_resize_buffer is not None + else depth_dst + ) + cv2.resize( + depth, + (depth_w, depth_h), + dst=resize_dst, + interpolation=cv2.INTER_NEAREST, + ) + if resize_dst is not depth_dst: + depth_dst[...] = resize_dst + + +_default_pipeline: Optional[VisionPipeline] = None +_default_pipeline_key: Optional[Tuple[str, Tuple[str, ...], Tuple[int, int]]] = None +_default_pipeline_lock = threading.Lock() + + +def run_vision_pipeline( + rgb_images: List[np.ndarray], + depth_images: List[np.ndarray], + *, + model_path: Optional[str] = None, + providers: Optional[Union[Sequence[str], str]] = None, + depth_size: Tuple[int, int] = DEFAULT_DEPTH_SIZE, +) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """ + Backwards-compatible functional entry point using a cached default pipeline. + """ + resolved_model_path = model_path or os.environ.get(DEFAULT_MODEL_ENV_VAR) + if not resolved_model_path: + raise VisionPipelineError( + "Vision model path is required. Pass model_path or set " + f"{DEFAULT_MODEL_ENV_VAR}." + ) + + provider_tuple = _normalize_providers(providers) + key = (str(Path(resolved_model_path).expanduser()), provider_tuple, depth_size) + + global _default_pipeline, _default_pipeline_key + with _default_pipeline_lock: + if _default_pipeline is None or _default_pipeline_key != key: + _default_pipeline = VisionPipeline( + model_path=resolved_model_path, + providers=provider_tuple, + depth_size=depth_size, + ) + _default_pipeline_key = key + + return _default_pipeline.run(rgb_images, depth_images) diff --git a/PySpotObserver/requirements.txt b/PySpotObserver/requirements.txt index 4cbfaba..feef701 100644 --- a/PySpotObserver/requirements.txt +++ b/PySpotObserver/requirements.txt @@ -1,12 +1,7 @@ -# Core dependencies -bosdyn-client>=5.0.0 # Boston Dynamics Spot SDK -numpy>=2.0.0,<2.4.0 # Array operations -opencv-python>=4.12.0 # Image processing and display -pyyaml>=6.0.0 # YAML configuration support +# Dependency source of truth is setup.py. +# Run this from the PySpotObserver directory for a development install. +-e .[dev] -# Optional dependencies -# For development -pytest>=7.4.0 # Testing framework -pytest-asyncio>=0.21.0 # Async test support -black>=23.0.0 # Code formatting -mypy>=1.5.0 # Type checking +# Vision pipeline support is optional because it requires ONNX Runtime. +# Install it when needed with: +# pip install -e ".[vision]" diff --git a/PySpotObserver/setup.py b/PySpotObserver/setup.py index 969b0bf..511ef1e 100644 --- a/PySpotObserver/setup.py +++ b/PySpotObserver/setup.py @@ -5,6 +5,24 @@ from setuptools import setup, find_packages from pathlib import Path +INSTALL_REQUIRES = [ + "bosdyn-client>=5.0.0", + "numpy>=2.0.0,<2.4.0", + "opencv-python>=4.12.0", + "pyyaml>=6.0.0", +] + +DEV_REQUIRES = [ + "pytest>=7.4.0", + "pytest-asyncio>=0.21.0", + "black>=23.0.0", + "mypy>=1.5.0", +] + +VISION_REQUIRES = [ + "onnxruntime-gpu>=1.25.0", +] + # Read README for long description readme_file = Path(__file__).parent / "README.md" long_description = readme_file.read_text(encoding="utf-8") if readme_file.exists() else "" @@ -32,19 +50,11 @@ "Topic :: Software Development :: Libraries :: Python Modules", ], python_requires=">=3.9", - install_requires=[ - "bosdyn-client>=5.0.0", - "numpy>=2.0.0,<2.4.0", - "opencv-python>=4.12.0", - "pyyaml>=6.0.0", - ], + install_requires=INSTALL_REQUIRES, extras_require={ - "dev": [ - "pytest>=7.4.0", - "pytest-asyncio>=0.21.0", - "black>=23.0.0", - "mypy>=1.5.0", - ], + "dev": DEV_REQUIRES, + "vision": VISION_REQUIRES, + "all": DEV_REQUIRES + VISION_REQUIRES, }, entry_points={ "console_scripts": [ diff --git a/PySpotObserver/tests/__init__.py b/PySpotObserver/tests/__init__.py new file mode 100644 index 0000000..95a7ede --- /dev/null +++ b/PySpotObserver/tests/__init__.py @@ -0,0 +1,3 @@ +""" +Tests for PySpotObserver package. +""" diff --git a/PySpotObserver/tests/test_config.py b/PySpotObserver/tests/test_config.py new file mode 100644 index 0000000..4842d5f --- /dev/null +++ b/PySpotObserver/tests/test_config.py @@ -0,0 +1,126 @@ +""" +Unit tests for configuration module. +""" + +import os +import pytest +from pathlib import Path +import tempfile + +from pyspotobserver.config import SpotConfig, CameraType + + +class TestCameraType: + """Tests for CameraType enum.""" + + def test_camera_mask_combinations(self): + """Test bitwise operations on camera types.""" + mask = CameraType.FRONTLEFT | CameraType.FRONTRIGHT + assert mask & CameraType.FRONTLEFT + assert mask & CameraType.FRONTRIGHT + assert not (mask & CameraType.BACK) + + def test_get_source_name_rgb(self): + """Test RGB camera source name generation.""" + assert CameraType.get_source_name(CameraType.FRONTLEFT) == "frontleft_fisheye_image" + assert CameraType.get_source_name(CameraType.BACK) == "back_fisheye_image" + assert CameraType.get_source_name(CameraType.HAND) == "hand_color_image" + + def test_get_source_name_depth(self): + """Test depth camera source name generation.""" + assert CameraType.get_source_name(CameraType.FRONTLEFT, depth=True) == \ + "frontleft_depth_in_visual_frame" + assert CameraType.get_source_name(CameraType.RIGHT, depth=True) == \ + "right_depth_in_visual_frame" + assert CameraType.get_source_name(CameraType.HAND, depth=True) == \ + "hand_depth_in_hand_color_frame" + + +class TestSpotConfig: + """Tests for SpotConfig dataclass.""" + + def test_default_values(self): + """Test configuration with default values.""" + config = SpotConfig(robot_ip="192.168.80.3") + assert config.robot_ip == "192.168.80.3" + assert config.username == "" + assert config.password == "" + assert config.image_buffer_size == 5 + assert config.image_quality_percent == 100.0 + assert config.request_timeout_seconds == 10.0 + assert config.vision_model_path is None + assert config.vision_providers is None + + def test_request_timeout_must_be_positive(self): + """Test that non-positive request timeouts are rejected.""" + with pytest.raises(ValueError, match="request_timeout_seconds"): + SpotConfig(robot_ip="192.168.80.3", request_timeout_seconds=0) + + def test_custom_values(self): + """Test configuration with custom values.""" + config = SpotConfig( + robot_ip="10.0.0.1", + username="admin", + password="secret", + image_buffer_size=10, + ) + assert config.robot_ip == "10.0.0.1" + assert config.username == "admin" + assert config.password == "secret" + assert config.image_buffer_size == 10 + + def test_repr_redacts_password(self): + """Test that __repr__ redacts password.""" + config = SpotConfig(robot_ip="192.168.80.3", password="secret") + repr_str = repr(config) + assert "secret" not in repr_str + assert "***" in repr_str + + def test_yaml_roundtrip(self): + """Test saving and loading from YAML.""" + config = SpotConfig( + robot_ip="192.168.80.3", + username="testuser", + password="testpass", + image_buffer_size=7, + ) + + fd, raw_path = tempfile.mkstemp(suffix=".yaml") + yaml_path = Path(raw_path) + try: + os.close(fd) + yaml_path.unlink() + + # Save to YAML + config.to_yaml(yaml_path) + assert yaml_path.exists() + + # Load from YAML + loaded_config = SpotConfig.from_yaml(yaml_path) + assert loaded_config.robot_ip == config.robot_ip + assert loaded_config.username == config.username + assert loaded_config.password == config.password + assert loaded_config.image_buffer_size == config.image_buffer_size + finally: + try: + yaml_path.unlink() + except FileNotFoundError: + pass + + def test_yaml_file_not_found(self): + """Test loading from non-existent YAML file.""" + with pytest.raises(FileNotFoundError): + SpotConfig.from_yaml("nonexistent.yaml") + + def test_extra_params(self): + """Test extra_params field.""" + config = SpotConfig( + robot_ip="192.168.80.3", + extra_params={"location": "lab", "experiment": "test1"} + ) + assert config.extra_params["location"] == "lab" + assert config.extra_params["experiment"] == "test1" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/PySpotObserver/tests/test_streaming_lifecycle.py b/PySpotObserver/tests/test_streaming_lifecycle.py new file mode 100644 index 0000000..6cd9f86 --- /dev/null +++ b/PySpotObserver/tests/test_streaming_lifecycle.py @@ -0,0 +1,197 @@ +import threading +import time + +import numpy as np +import pytest +import cv2 +from bosdyn.api import image_pb2 + +from pyspotobserver.camera_stream import ImageFrame, SpotCamStream, SpotCamStreamError +from pyspotobserver.config import CameraType, SpotConfig +from pyspotobserver.connection import SpotConnection, SpotConnectionError + + +class DummyImageClient: + def __init__(self, responses): + self._responses = list(responses) + + def get_image(self, image_requests, timeout): + response = self._responses.pop(0) + if isinstance(response, Exception): + raise response + return response + + +def make_stream(image_client=None) -> SpotCamStream: + return SpotCamStream( + image_client=image_client or DummyImageClient([]), + config=SpotConfig(robot_ip="192.168.80.3", request_timeout_seconds=0.05), + stream_id="test_stream", + ) + + +def test_get_current_images_returns_latest_frame(): + stream = make_stream() + stream._streaming = True + + older = ImageFrame( + rgb_images=[np.full((1, 1, 3), 1.0, dtype=np.float32)], + depth_images=[np.full((1, 1), 1.0, dtype=np.float32)], + camera_order=[CameraType.FRONTLEFT], + timestamp=1.0, + ) + newer = ImageFrame( + rgb_images=[np.full((1, 1, 3), 2.0, dtype=np.float32)], + depth_images=[np.full((1, 1), 2.0, dtype=np.float32)], + camera_order=[CameraType.FRONTLEFT], + timestamp=2.0, + ) + + stream._image_queue.put(older) + stream._image_queue.put(newer) + + rgb_images, depth_images = stream.get_current_images(timeout=0.01) + + assert float(rgb_images[0][0, 0, 0]) == 2.0 + assert float(depth_images[0][0, 0]) == 2.0 + assert stream._image_queue.qsize() == 2 + + +def test_get_current_images_run_pipeline_requires_model_path(): + stream = make_stream() + stream._streaming = True + frame = ImageFrame( + rgb_images=[np.full((1, 1, 3), 1.0, dtype=np.float32)], + depth_images=[np.full((1, 1), 1.0, dtype=np.float32)], + camera_order=[CameraType.FRONTLEFT], + timestamp=1.0, + ) + stream._image_queue.put(frame) + + with pytest.raises(SpotCamStreamError, match="Vision model path is required"): + stream.get_current_images(timeout=0.01, run_pipeline=True) + + +def test_get_current_images_run_pipeline_uses_lazy_pipeline(monkeypatch): + import pyspotobserver.vision_pipeline as vision_pipeline + + class FakeVisionPipeline: + @classmethod + def from_config(cls, config): + return cls() + + def run(self, rgb_images, depth_images): + return rgb_images, [depth + 10.0 for depth in depth_images] + + monkeypatch.setattr(vision_pipeline, "VisionPipeline", FakeVisionPipeline) + + stream = make_stream() + stream._streaming = True + frame = ImageFrame( + rgb_images=[np.full((1, 1, 3), 1.0, dtype=np.float32)], + depth_images=[np.full((1, 1), 2.0, dtype=np.float32)], + camera_order=[CameraType.FRONTLEFT], + timestamp=1.0, + ) + stream._image_queue.put(frame) + + rgb_images, depth_images = stream.get_current_images( + timeout=0.01, + run_pipeline=True, + ) + + assert float(rgb_images[0][0, 0, 0]) == 1.0 + assert float(depth_images[0][0, 0]) == 12.0 + + +def test_get_current_images_unblocks_when_stream_stops(): + stream = make_stream() + stream._streaming = True + + result = {} + + def reader(): + try: + stream.get_current_images(timeout=None) + except Exception as exc: # pragma: no cover - exercised by assertion + result["exc"] = exc + + thread = threading.Thread(target=reader) + thread.start() + time.sleep(0.15) + + stream._streaming = False + stream._stop_event.set() + + thread.join(timeout=1.0) + + assert not thread.is_alive() + assert isinstance(result["exc"], SpotCamStreamError) + assert "stopped" in str(result["exc"]).lower() + + +def test_stream_loop_recovers_after_initial_request_error(): + stream = make_stream(DummyImageClient([RuntimeError("boom"), ["ok"]])) + stream._streaming = True + stream._camera_order = [CameraType.FRONTLEFT] + + decoded = [ + np.zeros((1, 1, 3), dtype=np.float32), + np.zeros((1, 1), dtype=np.float32), + ] + + stream._build_image_requests = lambda: [] + stream._decode_initial_responses = lambda responses: decoded + + original_enqueue = stream._enqueue_frame + + def enqueue_and_stop(frame): + original_enqueue(frame) + stream._streaming = False + stream._stop_event.set() + + stream._enqueue_frame = enqueue_and_stop + + stream._stream_loop() + + assert stream.frame_count == 1 + assert stream.error_count == 1 + assert len(stream._frame_pool) == stream._config.image_buffer_size + + +def test_jpeg_inplace_color_correction_updates_output_array(): + stream = make_stream() + bgr = np.full((2, 2, 3), 255, dtype=np.uint8) + ok, encoded = cv2.imencode(".jpg", bgr) + assert ok + + response = image_pb2.ImageResponse() + response.shot.image.format = image_pb2.Image.FORMAT_JPEG + response.shot.image.data = encoded.tobytes() + + out = np.empty((2, 2, 3), dtype=np.float32) + zero_ccm = np.zeros((3, 3), dtype=np.float32) + + stream._convert_image_response_inplace( + response, + is_depth=False, + out_array=out, + ccm=zero_ccm, + ) + + assert np.allclose(out, 0.0) + + +def test_disconnect_raises_if_stream_does_not_stop(): + class BrokenStream: + def stop_streaming(self): + raise SpotCamStreamError("still alive") + + conn = SpotConnection(SpotConfig(robot_ip="192.168.80.3")) + conn._connected = True + conn._cam_streams = {"broken": BrokenStream()} + + with pytest.raises(SpotConnectionError, match="Failed to stop all streams cleanly"): + conn.disconnect() + + assert conn.connected is True diff --git a/install b/install index 862cd5c..440ce5a 160000 --- a/install +++ b/install @@ -1 +1 @@ -Subproject commit 862cd5cf0f2f3da8987638271fa120a7a958847b +Subproject commit 440ce5acb594f13b3ac758d14e0cd03be5341164