From 711f5c6107362df1f9a713fae34a459492aec264 Mon Sep 17 00:00:00 2001 From: Aanya Agrawal <64680673+screechingviolet@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:10:27 -0400 Subject: [PATCH 1/3] added vision pipeline functionality to async streaming --- PySpotObserver/basic_streaming_copy.py | 412 ++++++++++++++++++ PySpotObserver/examples/basic_streaming.py | 13 +- .../pyspotobserver/camera_stream.py | 51 ++- .../pyspotobserver/vision_pipeline.py | 122 ++++++ PySpotObserver/requirements.txt | 1 + 5 files changed, 581 insertions(+), 18 deletions(-) create mode 100644 PySpotObserver/basic_streaming_copy.py create mode 100644 PySpotObserver/pyspotobserver/vision_pipeline.py diff --git a/PySpotObserver/basic_streaming_copy.py b/PySpotObserver/basic_streaming_copy.py new file mode 100644 index 0000000..5e99389 --- /dev/null +++ b/PySpotObserver/basic_streaming_copy.py @@ -0,0 +1,412 @@ +""" +Unified streaming example for Spot camera feeds. + +This example combines the old basic, async, and multi-stream demos into one CLI. +Use `--async-mode` to switch to async retrieval. Add `--secondary-cameras` for a +second stream, and optionally `--secondary-robot-ip` to run that stream on a +second robot using the same credentials. +""" + +from __future__ import annotations + +import argparse +import asyncio +from contextlib import AsyncExitStack, ExitStack +from dataclasses import dataclass +import logging +import time +from typing import Sequence + + +import cv2 +import numpy as np + +from pyspotobserver import CameraType, SpotConfig, SpotConnection +from examples.common_cli import ( + add_common_connection_arguments, + build_camera_mask, + build_config_from_args, + parse_camera_list, +) + + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +@dataclass +class TimingStats: + frames: int = 0 + fetch_seconds: float = 0.0 + display_seconds: float = 0.0 + loop_seconds: float = 0.0 + + def add(self, *, fetch_seconds: float, display_seconds: float, loop_seconds: float) -> None: + self.frames += 1 + self.fetch_seconds += fetch_seconds + self.display_seconds += display_seconds + self.loop_seconds += loop_seconds + + +@dataclass +class FetchResult: + rgb_images: Sequence[np.ndarray] + depth_images: Sequence[np.ndarray] + fetch_seconds: float + + +@dataclass +class StreamSpec: + label: str + stream_label: str + stream_id: str + cameras: list[CameraType] + robot_label: str + + @property + def display_label(self) -> str: + return f"{self.stream_label} [{self.robot_label}]" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + add_common_connection_arguments(parser) + parser.add_argument( + "--cameras", + default="frontleft,frontright", + help="Comma-separated cameras for the primary stream.", + ) + parser.add_argument( + "--secondary-cameras", + help="Optional comma-separated cameras for a second stream configuration on each connected robot.", + ) + parser.add_argument( + "--secondary-robot-ip", + help="Optional second robot IP. When provided, the configured stream set is started on this robot too.", + ) + parser.add_argument( + "--duration", + type=float, + default=30.0, + help="Maximum streaming duration in seconds.", + ) + parser.add_argument( + "--timeout", + type=float, + default=2.0, + help="Per-frame retrieval timeout in seconds.", + ) + parser.add_argument( + "--stream-id", + default="primary_stream", + help="Stream identifier for the primary stream.", + ) + parser.add_argument( + "--secondary-stream-id", + default="secondary_stream", + help="Stream identifier for the optional second stream.", + ) + parser.add_argument( + "--async-mode", + action="store_true", + help="Use async connection management and async frame retrieval.", + ) + parser.add_argument( + "--no-display", + action="store_true", + help="Disable OpenCV windows and only log frame metadata.", + ) + parser.add_argument( + "--print-timing", + action="store_true", + help="Print per-stream timing summary at the end of the run.", + ) + parser.add_argument( + "--vision-pipeline", + action="store_true", + help="Run the vision pipeline on the Spot outputs" + ) + return parser.parse_args() + + +def build_stream_specs(args: argparse.Namespace) -> list[StreamSpec]: + primary_cameras = parse_camera_list(args.cameras) + stream_templates = [ + ("primary", args.stream_id, primary_cameras), + ] + if args.secondary_cameras: + stream_templates.append( + ("secondary", args.secondary_stream_id, parse_camera_list(args.secondary_cameras)) + ) + + robot_labels = ["primary"] + if args.secondary_robot_ip: + robot_labels.append("secondary") + + specs = [] + for robot_label in robot_labels: + for stream_label, stream_id, cameras in stream_templates: + specs.append( + StreamSpec( + label=f"{robot_label}:{stream_label}", + stream_label=stream_label, + stream_id=stream_id, + cameras=cameras.copy(), + robot_label=robot_label, + ) + ) + return specs + + +def display_images(window_prefix: str, stream, rgb_images: Sequence[np.ndarray], depth_images: Sequence[np.ndarray]) -> bool: + for i, (rgb, depth) in enumerate(zip(rgb_images, depth_images)): + camera_name = stream.get_camera_order()[i].name + rgb_display = (rgb * 255).astype(np.uint8) + rgb_display = cv2.cvtColor(rgb_display, cv2.COLOR_RGB2BGR) + + depth_normalized = cv2.normalize(depth, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) + depth_colored = cv2.applyColorMap(depth_normalized[0], cv2.COLORMAP_JET) + cv2.imshow(f"{window_prefix} - {camera_name} - RGB", rgb_display) + cv2.imshow(f"{window_prefix} - {camera_name} - Depth", depth_colored) + + return bool(cv2.waitKey(1) & 0xFF == ord("q")) + + +def print_timing_summary( + timing_by_label: dict[str, TimingStats], + specs: Sequence[StreamSpec], + connections: dict[str, SpotConnection], + elapsed_seconds: float, +) -> None: + print("Timing summary:") + print(f"Elapsed wall time: {elapsed_seconds:.3f} s") + timing_by_robot: dict[str, TimingStats] = {} + for spec in specs: + stats = timing_by_label[spec.label] + avg_fetch_ms = (stats.fetch_seconds / stats.frames) * 1000.0 if stats.frames else 0.0 + avg_display_ms = (stats.display_seconds / stats.frames) * 1000.0 if stats.frames else 0.0 + avg_loop_ms = (stats.loop_seconds / stats.frames) * 1000.0 if stats.frames else 0.0 + fps = stats.frames / elapsed_seconds if elapsed_seconds > 0 else 0.0 + print( + f"{spec.display_label} ({connections[spec.robot_label].config.robot_ip}): " + f"frames={stats.frames}, avg_fetch={avg_fetch_ms:.3f} ms, " + f"avg_display={avg_display_ms:.3f} ms, avg_loop={avg_loop_ms:.3f} ms, fps={fps:.2f}" + ) + robot_stats = timing_by_robot.setdefault(spec.robot_label, TimingStats()) + robot_stats.frames += stats.frames + robot_stats.fetch_seconds += stats.fetch_seconds + robot_stats.display_seconds += stats.display_seconds + robot_stats.loop_seconds += stats.loop_seconds + + if len(timing_by_robot) > 1: + print("Per-robot aggregate throughput:") + for robot_label, stats in timing_by_robot.items(): + avg_fetch_ms = (stats.fetch_seconds / stats.frames) * 1000.0 if stats.frames else 0.0 + avg_display_ms = (stats.display_seconds / stats.frames) * 1000.0 if stats.frames else 0.0 + avg_loop_ms = (stats.loop_seconds / stats.frames) * 1000.0 if stats.frames else 0.0 + fps = stats.frames / elapsed_seconds if elapsed_seconds > 0 else 0.0 + print( + f"{robot_label} robot ({connections[robot_label].config.robot_ip}): " + f"frames={stats.frames}, avg_fetch={avg_fetch_ms:.3f} ms, " + f"avg_display={avg_display_ms:.3f} ms, avg_loop={avg_loop_ms:.3f} ms, fps={fps:.2f}" + ) + + +def build_connection_configs(args: argparse.Namespace) -> dict[str, SpotConfig]: + primary_config = build_config_from_args(args) + configs: dict[str, SpotConfig] = {"primary": primary_config} + if args.secondary_robot_ip: + secondary_config = type(primary_config)(**vars(primary_config)) + secondary_config.robot_ip = args.secondary_robot_ip + configs["secondary"] = secondary_config + return configs + + +def start_streams(connections: dict[str, SpotConnection], specs: list[StreamSpec]) -> dict[str, object]: + streams = {} + for spec in specs: + conn = connections[spec.robot_label] + stream = conn.create_cam_stream(stream_id=spec.stream_id) + stream.start_streaming(build_camera_mask(spec.cameras)) + streams[spec.label] = stream + logger.info( + "%s stream on %s robot (%s) cameras: %s", + spec.display_label, + spec.robot_label, + conn.config.robot_ip, + stream.get_camera_order(), + ) + return streams + + +def run_sync(args: argparse.Namespace, specs: list[StreamSpec]) -> int: + connection_configs = build_connection_configs(args) + timing_by_label = {spec.label: TimingStats() for spec in specs} + overall_start = time.perf_counter() + + with ExitStack() as stack: + connections = { + label: stack.enter_context(SpotConnection(config)) + for label, config in connection_configs.items() + } + for label, conn in connections.items(): + logger.info("Connected to %s robot: %s", label, conn) + + streams = start_streams(connections, specs) + + try: + start_time = time.perf_counter() + should_quit = False + while time.perf_counter() - start_time < args.duration and not should_quit: + for spec in specs: + loop_start = time.perf_counter() + stream = streams[spec.label] + + fetch_start = time.perf_counter() + rgb_images, depth_images = stream.get_current_images(timeout=args.timeout) + fetch_elapsed = time.perf_counter() - fetch_start + + display_elapsed = 0.0 + if not args.no_display: + display_start = time.perf_counter() + should_quit = display_images(spec.display_label, stream, rgb_images, depth_images) + display_elapsed = time.perf_counter() - display_start + + timing_by_label[spec.label].add( + fetch_seconds=fetch_elapsed, + display_seconds=display_elapsed, + loop_seconds=fetch_elapsed + display_elapsed, + ) + + if should_quit: + logger.info("User requested quit") + break + finally: + finalize_sync(connections, streams, specs, timing_by_label, args.print_timing, overall_start) + return 0 + + +async def fetch_stream_async(stream, timeout: float, vision_pipeline: bool) -> FetchResult: + fetch_start = time.perf_counter() + rgb_images, depth_images = await stream.async_get_current_images(timeout=timeout, run_pipeline=vision_pipeline) + + return FetchResult( + rgb_images=rgb_images, + depth_images=depth_images, + fetch_seconds=time.perf_counter() - fetch_start, + ) + + +async def run_async(args: argparse.Namespace, specs: list[StreamSpec]) -> int: + connection_configs = build_connection_configs(args) + timing_by_label = {spec.label: TimingStats() for spec in specs} + overall_start = time.perf_counter() + + async with AsyncExitStack() as stack: + connections = {} + for label, config in connection_configs.items(): + connections[label] = await stack.enter_async_context(SpotConnection(config)) + for label, conn in connections.items(): + logger.info("Connected to %s robot: %s", label, conn) + + streams = start_streams(connections, specs) + + try: + start_time = time.perf_counter() + while time.perf_counter() - start_time < args.duration: + results = await asyncio.gather( + *(fetch_stream_async(streams[spec.label], args.timeout, args.vision_pipeline) for spec in specs) + ) + + should_quit = False + for spec, result in zip(specs, results): + display_elapsed = 0.0 + if not args.no_display: + display_start = time.perf_counter() + should_quit = display_images( + spec.display_label, + streams[spec.label], + result.rgb_images, + result.depth_images, + ) or should_quit + display_elapsed = time.perf_counter() - display_start + + timing_by_label[spec.label].add( + fetch_seconds=result.fetch_seconds, + display_seconds=display_elapsed, + loop_seconds=result.fetch_seconds + display_elapsed, + ) + + if should_quit: + logger.info("User requested quit") + break + + await asyncio.sleep(0.01) + finally: + finalize_async(connections, streams, specs, timing_by_label, args.print_timing, overall_start) + return 0 + + +def finalize_sync( + connections, + streams, + specs: Sequence[StreamSpec], + timing_by_label: dict[str, TimingStats], + print_timing: bool, + overall_start: float, +) -> int: + for label, stream in streams.items(): + stream.stop_streaming() + logger.info( + "%s stream stats: frames=%s, errors=%s", + label, + stream.frame_count, + stream.error_count, + ) + cv2.destroyAllWindows() + for label, conn in connections.items(): + logger.info("%s robot active streams: %s", label, conn.list_streams()) + logger.info("Disconnected from robot(s)") + if print_timing: + print_timing_summary(timing_by_label, specs, connections, time.perf_counter() - overall_start) + return 0 + + +def finalize_async( + connections, + streams, + specs: Sequence[StreamSpec], + timing_by_label: dict[str, TimingStats], + print_timing: bool, + overall_start: float, +) -> int: + for label, stream in streams.items(): + stream.stop_streaming() + logger.info( + "%s stream stats: frames=%s, errors=%s", + label, + stream.frame_count, + stream.error_count, + ) + cv2.destroyAllWindows() + for label, conn in connections.items(): + logger.info("%s robot active streams: %s", label, conn.list_streams()) + logger.info("Disconnected from robot(s)") + if print_timing: + print_timing_summary(timing_by_label, specs, connections, time.perf_counter() - overall_start) + return 0 + + +def main() -> int: + args = parse_args() + specs = build_stream_specs(args) + if args.async_mode: + return asyncio.run(run_async(args, specs)) + return run_sync(args, specs) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/PySpotObserver/examples/basic_streaming.py b/PySpotObserver/examples/basic_streaming.py index 359b806..dbed5c9 100644 --- a/PySpotObserver/examples/basic_streaming.py +++ b/PySpotObserver/examples/basic_streaming.py @@ -17,6 +17,7 @@ import time from typing import Sequence + import cv2 import numpy as np @@ -123,6 +124,11 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Print per-stream timing summary at the end of the run.", ) + parser.add_argument( + "--vision-pipeline", + action="store_true", + help="Run the vision pipeline on the Spot outputs" + ) return parser.parse_args() @@ -282,9 +288,10 @@ def run_sync(args: argparse.Namespace, specs: list[StreamSpec]) -> int: return 0 -async def fetch_stream_async(stream, timeout: float) -> FetchResult: +async def fetch_stream_async(stream, timeout: float, vision_pipeline: bool) -> FetchResult: fetch_start = time.perf_counter() - rgb_images, depth_images = await stream.async_get_current_images(timeout=timeout) + rgb_images, depth_images = await stream.async_get_current_images(timeout=timeout, run_pipeline=vision_pipeline) + return FetchResult( rgb_images=rgb_images, depth_images=depth_images, @@ -310,7 +317,7 @@ async def run_async(args: argparse.Namespace, specs: list[StreamSpec]) -> int: start_time = time.perf_counter() while time.perf_counter() - start_time < args.duration: results = await asyncio.gather( - *(fetch_stream_async(streams[spec.label], args.timeout) for spec in specs) + *(fetch_stream_async(streams[spec.label], args.timeout, args.vision_pipeline) for spec in specs) ) should_quit = False diff --git a/PySpotObserver/pyspotobserver/camera_stream.py b/PySpotObserver/pyspotobserver/camera_stream.py index 729b51b..17e5bc9 100644 --- a/PySpotObserver/pyspotobserver/camera_stream.py +++ b/PySpotObserver/pyspotobserver/camera_stream.py @@ -17,6 +17,8 @@ from .config import SpotConfig, CameraType from .color_correction import _ROBOT_CCMS +from .vision_pipeline import run_vision_pipeline + logger = logging.getLogger(__name__) @@ -217,6 +219,7 @@ def get_current_images( self, timeout: Optional[float] = None, copy: bool = False, + run_pipeline: bool = False, ) -> Tuple[List[np.ndarray], List[np.ndarray]]: """ Get the most recent frame of images. @@ -249,12 +252,16 @@ def get_current_images( frame = self._peek_latest_frame() if frame is not None: - if not copy: - return frame.rgb_images, frame.depth_images - return ( - [img.copy() for img in frame.rgb_images], - [img.copy() for img in frame.depth_images], - ) + rgb, depth = frame.rgb_images, frame.depth_images + + if copy: + rgb = [img.copy() for img in rgb] + depth = [img.copy() for img in depth] + + if run_pipeline and self._vision_pipeline: + return run_vision_pipeline(rgb, depth) + + return rgb, depth wait_timeout = poll_interval if deadline is not None: @@ -270,17 +277,22 @@ def get_current_images( except Empty: continue - if not copy: - return frame.rgb_images, frame.depth_images - return ( - [img.copy() for img in frame.rgb_images], - [img.copy() for img in frame.depth_images], - ) + rgb, depth = frame.rgb_images, frame.depth_images + + if copy: + rgb = [img.copy() for img in rgb] + depth = [img.copy() for img in depth] + + if run_pipeline and self._vision_pipeline: + return run_vision_pipeline(rgb, depth) + + return rgb, depth async def async_get_current_images( self, timeout: Optional[float] = None, copy: bool = False, + run_pipeline: bool = False, ) -> Tuple[List[np.ndarray], List[np.ndarray]]: """ Async version of get_current_images(). @@ -300,9 +312,18 @@ async def async_get_current_images( SpotCamStreamError: If not streaming or timeout occurs """ loop = asyncio.get_event_loop() - return await loop.run_in_executor( - None, self.get_current_images, timeout, copy - ) + + if not run_pipeline: + return await loop.run_in_executor( + None, self.get_current_images, timeout, copy, False + ) + + # Run full pipeline in executor + def _get_and_process(): + rgb, depth = self.get_current_images(timeout, copy, False) + return run_vision_pipeline(rgb, depth) + + return await loop.run_in_executor(None, _get_and_process) def get_camera_order(self) -> List[CameraType]: """ diff --git a/PySpotObserver/pyspotobserver/vision_pipeline.py b/PySpotObserver/pyspotobserver/vision_pipeline.py new file mode 100644 index 0000000..b242ce0 --- /dev/null +++ b/PySpotObserver/pyspotobserver/vision_pipeline.py @@ -0,0 +1,122 @@ +from typing import List, Tuple +import numpy as np +import threading +import onnxruntime as ort +import cv2 + +# ------------------------- +# Global singleton state +# ------------------------- +_session = None +_input_names = None +_lock = threading.Lock() + +# Preallocated buffers (resized only if needed) +_rgb_buffer = None +_depth_buffer = None + +def _init_session(): + global _session, _input_names + + if _session is None: + ort.preload_dlls() + + _session = ort.InferenceSession( + "C:/Users/jtoribio/Documents/aanya/fileexchange/promptda-vitl-rotated_optimized_batch2_fp16.onnx", + providers=["CUDAExecutionProvider"], + ) + + _input_names = [i.name for i in _session.get_inputs()] + + +def _ensure_buffers(batch_size, rgb_shape, depth_shape): + """ + Allocate buffers only if size/shape changed. + """ + global _rgb_buffer, _depth_buffer + + # Expected shapes: + # rgb: (B, 3, H, W) + # depth: (B, 1, H, W) + + if ( + _rgb_buffer is None + or _rgb_buffer.shape != (batch_size, 3, rgb_shape[0], rgb_shape[1]) + ): + _rgb_buffer = np.empty( + (batch_size, 3, rgb_shape[0], rgb_shape[1]), + dtype=np.float32, + ) + + if ( + _depth_buffer is None + or _depth_buffer.shape != (batch_size, 1, depth_shape[0], depth_shape[1]) + ): + _depth_buffer = np.empty( + (batch_size, 1, depth_shape[0], depth_shape[1]), + dtype=np.float32, + ) + + +def run_vision_pipeline( + rgb_images: List[np.ndarray], + depth_images: List[np.ndarray], +): + """ + High-performance pipeline: + - No per-call allocations (after warmup) + - No np.stack + - Thread-safe ONNX execution + """ + + _init_session() + + batch_size = len(rgb_images) + + # Assume all images same shape + h_rgb, w_rgb, _ = rgb_images[0].shape + h_d, w_d = depth_images[0].shape[:2] + + MODEL_DEPTH_SHAPE = (120, 160) + + _ensure_buffers(batch_size, (h_rgb, w_rgb), MODEL_DEPTH_SHAPE) + + # ------------------------- + # Fill buffers (no stack) + # ------------------------- + for i in range(batch_size): + # RGB: HWC -> CHW + _rgb_buffer[i, 0] = rgb_images[i][:, :, 0] + _rgb_buffer[i, 1] = rgb_images[i][:, :, 1] + _rgb_buffer[i, 2] = rgb_images[i][:, :, 2] + + # Depth: HW -> 1HW + resized = cv2.resize( + depth_images[i], + (160, 120), # (width, height) for OpenCV + interpolation=cv2.INTER_NEAREST, # important for depth! + ) + + _depth_buffer[i, 0] = resized + + # ------------------------- + # Inference (thread-safe) + # ------------------------- + with _lock: + output = _session.run( + None, + { + _input_names[0]: _rgb_buffer, + _input_names[1]: _depth_buffer, + }, + )[0] + + # # ------------------------- + # # Fast normalization (in-place) + # # ------------------------- + # out_min = output.min() + # out_max = output.max() + # scale = 1.0 / (out_max - out_min + 1e-8) + # output = (output - out_min) * scale + + return rgb_images, output \ No newline at end of file diff --git a/PySpotObserver/requirements.txt b/PySpotObserver/requirements.txt index 4cbfaba..b129369 100644 --- a/PySpotObserver/requirements.txt +++ b/PySpotObserver/requirements.txt @@ -10,3 +10,4 @@ pytest>=7.4.0 # Testing framework pytest-asyncio>=0.21.0 # Async test support black>=23.0.0 # Code formatting mypy>=1.5.0 # Type checking +onnxruntime-gpu>=1.25.0 \ No newline at end of file From 72677fcd51a53c505c7562fd7a9567681fa3a9be Mon Sep 17 00:00:00 2001 From: Faisal Zaghloul Date: Wed, 10 Jun 2026 16:07:19 -0400 Subject: [PATCH 2/3] fixes --- PySpotObserver/QUICKSTART.md | 11 +- PySpotObserver/README.md | 13 +- PySpotObserver/basic_streaming_copy.py | 412 ------------------ PySpotObserver/examples/README.md | 4 +- PySpotObserver/examples/basic_streaming.py | 17 +- PySpotObserver/examples/common_cli.py | 4 + PySpotObserver/examples/config_example.yaml | 6 + PySpotObserver/pyspotobserver/__init__.py | 7 +- .../pyspotobserver/camera_stream.py | 97 ++++- PySpotObserver/pyspotobserver/config.py | 22 +- .../pyspotobserver/vision_pipeline.py | 363 +++++++++++---- PySpotObserver/requirements.txt | 18 +- PySpotObserver/setup.py | 34 +- install | 2 +- 14 files changed, 454 insertions(+), 556 deletions(-) delete mode 100644 PySpotObserver/basic_streaming_copy.py diff --git a/PySpotObserver/QUICKSTART.md b/PySpotObserver/QUICKSTART.md index 7178474..b50765b 100644 --- a/PySpotObserver/QUICKSTART.md +++ b/PySpotObserver/QUICKSTART.md @@ -8,7 +8,14 @@ pip install -r requirements.txt ``` -2. **Install package in development mode:** +2. **Optional vision pipeline support:** + ```bash + pip install -e ".[vision]" + ``` + +The dependency source of truth is `setup.py`; `requirements.txt` installs the development extra without duplicating package lists. + +3. **Minimal install without dev tools (alternative to step 1):** ```bash pip install -e . ``` @@ -215,4 +222,6 @@ logging.basicConfig( **Import errors**: Make sure all dependencies are installed: `pip install -r requirements.txt` +**Vision pipeline errors**: Install `pip install -e ".[vision]"` and set `vision_model_path` in config, pass `--vision-model-path`, or set `PYSPOTOBSERVER_VISION_MODEL` + **No images received**: Check that cameras are not already in use by another client diff --git a/PySpotObserver/README.md b/PySpotObserver/README.md index 4363785..25875db 100644 --- a/PySpotObserver/README.md +++ b/PySpotObserver/README.md @@ -11,6 +11,7 @@ A clean, Pythonic interface for streaming camera data from Boston Dynamics Spot - **YAML Configuration**: Load settings from config files or pass as parameters - **Multi-stream**: Support for multiple concurrent camera streams - **Thread-safe**: Background streaming with thread-safe image buffering +- **Optional Vision Pipeline**: ONNX Runtime inference can be enabled explicitly without making it a base dependency ## Installation @@ -27,6 +28,14 @@ pip install -e . pip install -e ".[dev]" ``` +### With optional vision pipeline support + +```bash +pip install -e ".[vision]" +``` + +`requirements.txt` delegates to `setup.py` for a development install, so dependency declarations stay in one place. + ## Quick Start ### Basic Synchronous Usage @@ -97,6 +106,7 @@ username: "" password: "" image_buffer_size: 5 image_quality_percent: 100.0 +request_timeout_seconds: 10.0 ``` ## Architecture @@ -165,7 +175,7 @@ The Python implementation differs in these ways: - **No CUDA**: Uses CPU-only NumPy arrays instead of GPU memory - **FIFO Queue**: Simpler queue.Queue instead of custom circular buffer - **Threading**: Python threading instead of C++ jthread -- **No ML Pipeline**: Focuses on camera streaming only (no inference) +- **Optional ML Pipeline**: Inference is available through an optional ONNX Runtime extra - **Simplified**: Removes Unity plugin and DLL export complexity ## Requirements @@ -175,6 +185,7 @@ The Python implementation differs in these ways: - NumPy - OpenCV (opencv-python) - PyYAML +- ONNX Runtime GPU (`onnxruntime-gpu`) only when installing the `vision` extra ## Contributing diff --git a/PySpotObserver/basic_streaming_copy.py b/PySpotObserver/basic_streaming_copy.py deleted file mode 100644 index 5e99389..0000000 --- a/PySpotObserver/basic_streaming_copy.py +++ /dev/null @@ -1,412 +0,0 @@ -""" -Unified streaming example for Spot camera feeds. - -This example combines the old basic, async, and multi-stream demos into one CLI. -Use `--async-mode` to switch to async retrieval. Add `--secondary-cameras` for a -second stream, and optionally `--secondary-robot-ip` to run that stream on a -second robot using the same credentials. -""" - -from __future__ import annotations - -import argparse -import asyncio -from contextlib import AsyncExitStack, ExitStack -from dataclasses import dataclass -import logging -import time -from typing import Sequence - - -import cv2 -import numpy as np - -from pyspotobserver import CameraType, SpotConfig, SpotConnection -from examples.common_cli import ( - add_common_connection_arguments, - build_camera_mask, - build_config_from_args, - parse_camera_list, -) - - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", -) -logger = logging.getLogger(__name__) - - -@dataclass -class TimingStats: - frames: int = 0 - fetch_seconds: float = 0.0 - display_seconds: float = 0.0 - loop_seconds: float = 0.0 - - def add(self, *, fetch_seconds: float, display_seconds: float, loop_seconds: float) -> None: - self.frames += 1 - self.fetch_seconds += fetch_seconds - self.display_seconds += display_seconds - self.loop_seconds += loop_seconds - - -@dataclass -class FetchResult: - rgb_images: Sequence[np.ndarray] - depth_images: Sequence[np.ndarray] - fetch_seconds: float - - -@dataclass -class StreamSpec: - label: str - stream_label: str - stream_id: str - cameras: list[CameraType] - robot_label: str - - @property - def display_label(self) -> str: - return f"{self.stream_label} [{self.robot_label}]" - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__) - add_common_connection_arguments(parser) - parser.add_argument( - "--cameras", - default="frontleft,frontright", - help="Comma-separated cameras for the primary stream.", - ) - parser.add_argument( - "--secondary-cameras", - help="Optional comma-separated cameras for a second stream configuration on each connected robot.", - ) - parser.add_argument( - "--secondary-robot-ip", - help="Optional second robot IP. When provided, the configured stream set is started on this robot too.", - ) - parser.add_argument( - "--duration", - type=float, - default=30.0, - help="Maximum streaming duration in seconds.", - ) - parser.add_argument( - "--timeout", - type=float, - default=2.0, - help="Per-frame retrieval timeout in seconds.", - ) - parser.add_argument( - "--stream-id", - default="primary_stream", - help="Stream identifier for the primary stream.", - ) - parser.add_argument( - "--secondary-stream-id", - default="secondary_stream", - help="Stream identifier for the optional second stream.", - ) - parser.add_argument( - "--async-mode", - action="store_true", - help="Use async connection management and async frame retrieval.", - ) - parser.add_argument( - "--no-display", - action="store_true", - help="Disable OpenCV windows and only log frame metadata.", - ) - parser.add_argument( - "--print-timing", - action="store_true", - help="Print per-stream timing summary at the end of the run.", - ) - parser.add_argument( - "--vision-pipeline", - action="store_true", - help="Run the vision pipeline on the Spot outputs" - ) - return parser.parse_args() - - -def build_stream_specs(args: argparse.Namespace) -> list[StreamSpec]: - primary_cameras = parse_camera_list(args.cameras) - stream_templates = [ - ("primary", args.stream_id, primary_cameras), - ] - if args.secondary_cameras: - stream_templates.append( - ("secondary", args.secondary_stream_id, parse_camera_list(args.secondary_cameras)) - ) - - robot_labels = ["primary"] - if args.secondary_robot_ip: - robot_labels.append("secondary") - - specs = [] - for robot_label in robot_labels: - for stream_label, stream_id, cameras in stream_templates: - specs.append( - StreamSpec( - label=f"{robot_label}:{stream_label}", - stream_label=stream_label, - stream_id=stream_id, - cameras=cameras.copy(), - robot_label=robot_label, - ) - ) - return specs - - -def display_images(window_prefix: str, stream, rgb_images: Sequence[np.ndarray], depth_images: Sequence[np.ndarray]) -> bool: - for i, (rgb, depth) in enumerate(zip(rgb_images, depth_images)): - camera_name = stream.get_camera_order()[i].name - rgb_display = (rgb * 255).astype(np.uint8) - rgb_display = cv2.cvtColor(rgb_display, cv2.COLOR_RGB2BGR) - - depth_normalized = cv2.normalize(depth, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) - depth_colored = cv2.applyColorMap(depth_normalized[0], cv2.COLORMAP_JET) - cv2.imshow(f"{window_prefix} - {camera_name} - RGB", rgb_display) - cv2.imshow(f"{window_prefix} - {camera_name} - Depth", depth_colored) - - return bool(cv2.waitKey(1) & 0xFF == ord("q")) - - -def print_timing_summary( - timing_by_label: dict[str, TimingStats], - specs: Sequence[StreamSpec], - connections: dict[str, SpotConnection], - elapsed_seconds: float, -) -> None: - print("Timing summary:") - print(f"Elapsed wall time: {elapsed_seconds:.3f} s") - timing_by_robot: dict[str, TimingStats] = {} - for spec in specs: - stats = timing_by_label[spec.label] - avg_fetch_ms = (stats.fetch_seconds / stats.frames) * 1000.0 if stats.frames else 0.0 - avg_display_ms = (stats.display_seconds / stats.frames) * 1000.0 if stats.frames else 0.0 - avg_loop_ms = (stats.loop_seconds / stats.frames) * 1000.0 if stats.frames else 0.0 - fps = stats.frames / elapsed_seconds if elapsed_seconds > 0 else 0.0 - print( - f"{spec.display_label} ({connections[spec.robot_label].config.robot_ip}): " - f"frames={stats.frames}, avg_fetch={avg_fetch_ms:.3f} ms, " - f"avg_display={avg_display_ms:.3f} ms, avg_loop={avg_loop_ms:.3f} ms, fps={fps:.2f}" - ) - robot_stats = timing_by_robot.setdefault(spec.robot_label, TimingStats()) - robot_stats.frames += stats.frames - robot_stats.fetch_seconds += stats.fetch_seconds - robot_stats.display_seconds += stats.display_seconds - robot_stats.loop_seconds += stats.loop_seconds - - if len(timing_by_robot) > 1: - print("Per-robot aggregate throughput:") - for robot_label, stats in timing_by_robot.items(): - avg_fetch_ms = (stats.fetch_seconds / stats.frames) * 1000.0 if stats.frames else 0.0 - avg_display_ms = (stats.display_seconds / stats.frames) * 1000.0 if stats.frames else 0.0 - avg_loop_ms = (stats.loop_seconds / stats.frames) * 1000.0 if stats.frames else 0.0 - fps = stats.frames / elapsed_seconds if elapsed_seconds > 0 else 0.0 - print( - f"{robot_label} robot ({connections[robot_label].config.robot_ip}): " - f"frames={stats.frames}, avg_fetch={avg_fetch_ms:.3f} ms, " - f"avg_display={avg_display_ms:.3f} ms, avg_loop={avg_loop_ms:.3f} ms, fps={fps:.2f}" - ) - - -def build_connection_configs(args: argparse.Namespace) -> dict[str, SpotConfig]: - primary_config = build_config_from_args(args) - configs: dict[str, SpotConfig] = {"primary": primary_config} - if args.secondary_robot_ip: - secondary_config = type(primary_config)(**vars(primary_config)) - secondary_config.robot_ip = args.secondary_robot_ip - configs["secondary"] = secondary_config - return configs - - -def start_streams(connections: dict[str, SpotConnection], specs: list[StreamSpec]) -> dict[str, object]: - streams = {} - for spec in specs: - conn = connections[spec.robot_label] - stream = conn.create_cam_stream(stream_id=spec.stream_id) - stream.start_streaming(build_camera_mask(spec.cameras)) - streams[spec.label] = stream - logger.info( - "%s stream on %s robot (%s) cameras: %s", - spec.display_label, - spec.robot_label, - conn.config.robot_ip, - stream.get_camera_order(), - ) - return streams - - -def run_sync(args: argparse.Namespace, specs: list[StreamSpec]) -> int: - connection_configs = build_connection_configs(args) - timing_by_label = {spec.label: TimingStats() for spec in specs} - overall_start = time.perf_counter() - - with ExitStack() as stack: - connections = { - label: stack.enter_context(SpotConnection(config)) - for label, config in connection_configs.items() - } - for label, conn in connections.items(): - logger.info("Connected to %s robot: %s", label, conn) - - streams = start_streams(connections, specs) - - try: - start_time = time.perf_counter() - should_quit = False - while time.perf_counter() - start_time < args.duration and not should_quit: - for spec in specs: - loop_start = time.perf_counter() - stream = streams[spec.label] - - fetch_start = time.perf_counter() - rgb_images, depth_images = stream.get_current_images(timeout=args.timeout) - fetch_elapsed = time.perf_counter() - fetch_start - - display_elapsed = 0.0 - if not args.no_display: - display_start = time.perf_counter() - should_quit = display_images(spec.display_label, stream, rgb_images, depth_images) - display_elapsed = time.perf_counter() - display_start - - timing_by_label[spec.label].add( - fetch_seconds=fetch_elapsed, - display_seconds=display_elapsed, - loop_seconds=fetch_elapsed + display_elapsed, - ) - - if should_quit: - logger.info("User requested quit") - break - finally: - finalize_sync(connections, streams, specs, timing_by_label, args.print_timing, overall_start) - return 0 - - -async def fetch_stream_async(stream, timeout: float, vision_pipeline: bool) -> FetchResult: - fetch_start = time.perf_counter() - rgb_images, depth_images = await stream.async_get_current_images(timeout=timeout, run_pipeline=vision_pipeline) - - return FetchResult( - rgb_images=rgb_images, - depth_images=depth_images, - fetch_seconds=time.perf_counter() - fetch_start, - ) - - -async def run_async(args: argparse.Namespace, specs: list[StreamSpec]) -> int: - connection_configs = build_connection_configs(args) - timing_by_label = {spec.label: TimingStats() for spec in specs} - overall_start = time.perf_counter() - - async with AsyncExitStack() as stack: - connections = {} - for label, config in connection_configs.items(): - connections[label] = await stack.enter_async_context(SpotConnection(config)) - for label, conn in connections.items(): - logger.info("Connected to %s robot: %s", label, conn) - - streams = start_streams(connections, specs) - - try: - start_time = time.perf_counter() - while time.perf_counter() - start_time < args.duration: - results = await asyncio.gather( - *(fetch_stream_async(streams[spec.label], args.timeout, args.vision_pipeline) for spec in specs) - ) - - should_quit = False - for spec, result in zip(specs, results): - display_elapsed = 0.0 - if not args.no_display: - display_start = time.perf_counter() - should_quit = display_images( - spec.display_label, - streams[spec.label], - result.rgb_images, - result.depth_images, - ) or should_quit - display_elapsed = time.perf_counter() - display_start - - timing_by_label[spec.label].add( - fetch_seconds=result.fetch_seconds, - display_seconds=display_elapsed, - loop_seconds=result.fetch_seconds + display_elapsed, - ) - - if should_quit: - logger.info("User requested quit") - break - - await asyncio.sleep(0.01) - finally: - finalize_async(connections, streams, specs, timing_by_label, args.print_timing, overall_start) - return 0 - - -def finalize_sync( - connections, - streams, - specs: Sequence[StreamSpec], - timing_by_label: dict[str, TimingStats], - print_timing: bool, - overall_start: float, -) -> int: - for label, stream in streams.items(): - stream.stop_streaming() - logger.info( - "%s stream stats: frames=%s, errors=%s", - label, - stream.frame_count, - stream.error_count, - ) - cv2.destroyAllWindows() - for label, conn in connections.items(): - logger.info("%s robot active streams: %s", label, conn.list_streams()) - logger.info("Disconnected from robot(s)") - if print_timing: - print_timing_summary(timing_by_label, specs, connections, time.perf_counter() - overall_start) - return 0 - - -def finalize_async( - connections, - streams, - specs: Sequence[StreamSpec], - timing_by_label: dict[str, TimingStats], - print_timing: bool, - overall_start: float, -) -> int: - for label, stream in streams.items(): - stream.stop_streaming() - logger.info( - "%s stream stats: frames=%s, errors=%s", - label, - stream.frame_count, - stream.error_count, - ) - cv2.destroyAllWindows() - for label, conn in connections.items(): - logger.info("%s robot active streams: %s", label, conn.list_streams()) - logger.info("Disconnected from robot(s)") - if print_timing: - print_timing_summary(timing_by_label, specs, connections, time.perf_counter() - overall_start) - return 0 - - -def main() -> int: - args = parse_args() - specs = build_stream_specs(args) - if args.async_mode: - return asyncio.run(run_async(args, specs)) - return run_sync(args, specs) - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/PySpotObserver/examples/README.md b/PySpotObserver/examples/README.md index 80663eb..f223391 100644 --- a/PySpotObserver/examples/README.md +++ b/PySpotObserver/examples/README.md @@ -10,9 +10,10 @@ Install the package and dependencies first: ```bash pip install -r requirements.txt -pip install -e . ``` +For `--vision-pipeline`, also install `pip install -e ".[vision]"` and provide a model path with `--vision-model-path`, config, or `PYSPOTOBSERVER_VISION_MODEL`. + ## Configuration The examples load `examples/config_example.yaml` by default. Fill in at least: @@ -37,6 +38,7 @@ python examples/basic_streaming.py --robot-ip 192.168.80.3 --username --p - asynchronous streaming with `--async-mode` - one or two stream configurations, mirrored across one or two robots - optional OpenCV display +- optional ONNX vision pipeline with `--vision-pipeline` - optional timing summaries with `--print-timing` Show the full CLI: diff --git a/PySpotObserver/examples/basic_streaming.py b/PySpotObserver/examples/basic_streaming.py index dbed5c9..f15274c 100644 --- a/PySpotObserver/examples/basic_streaming.py +++ b/PySpotObserver/examples/basic_streaming.py @@ -14,6 +14,7 @@ from contextlib import AsyncExitStack, ExitStack from dataclasses import dataclass import logging +from pathlib import Path import time from typing import Sequence @@ -129,6 +130,17 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Run the vision pipeline on the Spot outputs" ) + parser.add_argument( + "--vision-model-path", + type=Path, + help="ONNX model path for --vision-pipeline. Overrides config and environment.", + ) + parser.add_argument( + "--vision-provider", + action="append", + dest="vision_providers", + help="ONNX Runtime provider to request. Repeat to set provider preference order.", + ) return parser.parse_args() @@ -265,7 +277,10 @@ def run_sync(args: argparse.Namespace, specs: list[StreamSpec]) -> int: stream = streams[spec.label] fetch_start = time.perf_counter() - rgb_images, depth_images = stream.get_current_images(timeout=args.timeout) + rgb_images, depth_images = stream.get_current_images( + timeout=args.timeout, + run_pipeline=args.vision_pipeline, + ) fetch_elapsed = time.perf_counter() - fetch_start display_elapsed = 0.0 diff --git a/PySpotObserver/examples/common_cli.py b/PySpotObserver/examples/common_cli.py index 802146d..509e9c5 100644 --- a/PySpotObserver/examples/common_cli.py +++ b/PySpotObserver/examples/common_cli.py @@ -109,6 +109,10 @@ def build_config_from_args(args: argparse.Namespace) -> SpotConfig: config.password = args.password if hasattr(args, "image_buffer_size") and args.image_buffer_size is not None: config.image_buffer_size = args.image_buffer_size + if hasattr(args, "vision_model_path") and args.vision_model_path is not None: + config.vision_model_path = str(args.vision_model_path) + if hasattr(args, "vision_providers") and args.vision_providers: + config.vision_providers = args.vision_providers if not config.robot_ip: raise ValueError("Robot IP must be set in --config or with --robot-ip.") diff --git a/PySpotObserver/examples/config_example.yaml b/PySpotObserver/examples/config_example.yaml index 501680d..bec26cb 100644 --- a/PySpotObserver/examples/config_example.yaml +++ b/PySpotObserver/examples/config_example.yaml @@ -12,6 +12,12 @@ image_buffer_size: 5 image_quality_percent: 100.0 request_timeout_seconds: 10.0 +# Optional vision pipeline settings +# vision_model_path: "C:/path/to/model.onnx" +# vision_providers: +# - "CUDAExecutionProvider" +# - "CPUExecutionProvider" + # Advanced settings sdk_name: "PySpotObserver" connection_retry_attempts: 3 diff --git a/PySpotObserver/pyspotobserver/__init__.py b/PySpotObserver/pyspotobserver/__init__.py index 8b3f448..1b124ae 100644 --- a/PySpotObserver/pyspotobserver/__init__.py +++ b/PySpotObserver/pyspotobserver/__init__.py @@ -6,13 +6,16 @@ """ from .config import SpotConfig, CameraType -from .connection import SpotConnection -from .camera_stream import SpotCamStream +from .connection import SpotAuthenticationError, SpotConnection, SpotConnectionError +from .camera_stream import SpotCamStream, SpotCamStreamError __version__ = "0.1.0" __all__ = [ "SpotConfig", "CameraType", "SpotConnection", + "SpotConnectionError", + "SpotAuthenticationError", "SpotCamStream", + "SpotCamStreamError", ] diff --git a/PySpotObserver/pyspotobserver/camera_stream.py b/PySpotObserver/pyspotobserver/camera_stream.py index 17e5bc9..bcf4531 100644 --- a/PySpotObserver/pyspotobserver/camera_stream.py +++ b/PySpotObserver/pyspotobserver/camera_stream.py @@ -8,7 +8,7 @@ import time from dataclasses import dataclass from queue import Queue, Empty -from typing import Optional, List, Tuple +from typing import Dict, Optional, List, Tuple import cv2 import numpy as np @@ -17,7 +17,6 @@ from .config import SpotConfig, CameraType from .color_correction import _ROBOT_CCMS -from .vision_pipeline import run_vision_pipeline logger = logging.getLogger(__name__) @@ -94,6 +93,13 @@ def __init__( # Preallocated frame pool (initialized once we know image shapes) self._frame_pool: List[ImageFrame] = [] self._frame_pool_index: int = 0 + self._image_requests: List[image_pb2.ImageRequest] = [] + + # Optional per-stream vision pipeline, imported lazily. + self._vision_pipeline = None + + # Scratch buffers for color correction by image shape. + self._ccm_scratch_by_shape: Dict[Tuple[int, ...], np.ndarray] = {} # Color correction matrices (None if robot IP is not recognized) self._ccms: Optional[dict] = _ROBOT_CCMS.get(config.robot_ip) @@ -156,6 +162,7 @@ def start_streaming(self, camera_mask: int) -> None: # Parse camera mask to get ordered list of cameras self._camera_mask = camera_mask self._camera_order = self._parse_camera_mask(camera_mask) + self._image_requests = self._build_image_requests() logger.info( f"Stream '{self._stream_id}': Starting with cameras: " @@ -166,6 +173,7 @@ def start_streaming(self, camera_mask: int) -> None: self._clear_queue() self._frame_pool = [] self._frame_pool_index = 0 + self._ccm_scratch_by_shape = {} self._frame_count = 0 self._error_count = 0 @@ -258,8 +266,8 @@ def get_current_images( rgb = [img.copy() for img in rgb] depth = [img.copy() for img in depth] - if run_pipeline and self._vision_pipeline: - return run_vision_pipeline(rgb, depth) + if run_pipeline: + return self._run_vision_pipeline(rgb, depth) return rgb, depth @@ -283,8 +291,8 @@ def get_current_images( rgb = [img.copy() for img in rgb] depth = [img.copy() for img in depth] - if run_pipeline and self._vision_pipeline: - return run_vision_pipeline(rgb, depth) + if run_pipeline: + return self._run_vision_pipeline(rgb, depth) return rgb, depth @@ -321,10 +329,34 @@ async def async_get_current_images( # Run full pipeline in executor def _get_and_process(): rgb, depth = self.get_current_images(timeout, copy, False) - return run_vision_pipeline(rgb, depth) + return self._run_vision_pipeline(rgb, depth) return await loop.run_in_executor(None, _get_and_process) + def _run_vision_pipeline( + self, + rgb_images: List[np.ndarray], + depth_images: List[np.ndarray], + ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + try: + from .vision_pipeline import VisionPipeline, VisionPipelineError + except ImportError as exc: + raise SpotCamStreamError( + "Vision pipeline support is unavailable. " + 'Install PySpotObserver with: pip install -e ".[vision]"' + ) from exc + + if self._vision_pipeline is None: + try: + self._vision_pipeline = VisionPipeline.from_config(self._config) + except VisionPipelineError as exc: + raise SpotCamStreamError(str(exc)) from exc + + try: + return self._vision_pipeline.run(rgb_images, depth_images) + except VisionPipelineError as exc: + raise SpotCamStreamError(str(exc)) from exc + def get_camera_order(self) -> List[CameraType]: """ Get the ordered list of cameras being streamed. @@ -363,13 +395,10 @@ def _stream_loop(self) -> None: while not self._stop_event.is_set() and self._streaming: try: - # Build image requests for all cameras (RGB + depth) - image_requests = self._build_image_requests() - # Request images from robot start_time = time.monotonic() image_responses = self._image_client.get_image( - image_requests, + self._image_requests, timeout=self._config.request_timeout_seconds, ) request_time = time.monotonic() - start_time @@ -553,7 +582,7 @@ def _fill_frame_from_responses( ccm = None if self._ccms is not None: - ccm = self._ccms[self._camera_order[i]] + ccm = self._ccms[self._camera_order[i]] self._convert_image_response_inplace( responses[rgb_idx], is_depth=False, @@ -723,8 +752,11 @@ def _convert_image_response_inplace( ) np.multiply(img, 1.0 / 255.0, out=out_array, casting="unsafe") if ccm is not None: - img = img @ ccm.T - np.clip(img, 0.0, 1.0, out=img) + self._apply_ccm_inplace( + out_array, + ccm, + scratch=self._get_ccm_scratch(out_array.shape), + ) return elif image_proto.format == image_pb2.Image.FORMAT_RAW: @@ -764,6 +796,12 @@ def _convert_image_response_inplace( f"Output array shape mismatch: {out_array.shape} vs {img.shape}" ) np.multiply(img, 1.0 / 255.0, out=out_array, casting="unsafe") + if ccm is not None: + self._apply_ccm_inplace( + out_array, + ccm, + scratch=self._get_ccm_scratch(out_array.shape), + ) return elif pixel_format == image_pb2.Image.PIXEL_FORMAT_RGBA_U8: @@ -776,6 +814,12 @@ def _convert_image_response_inplace( f"Output array shape mismatch: {out_array.shape} vs {img.shape}" ) np.multiply(img, 1.0 / 255.0, out=out_array, casting="unsafe") + if ccm is not None: + self._apply_ccm_inplace( + out_array, + ccm, + scratch=self._get_ccm_scratch(out_array.shape), + ) return elif pixel_format == image_pb2.Image.PIXEL_FORMAT_GREYSCALE_U8: @@ -789,6 +833,12 @@ def _convert_image_response_inplace( np.multiply(img, 1.0 / 255.0, out=channel, casting="unsafe") out_array[:, :, 1] = channel out_array[:, :, 2] = channel + if ccm is not None: + self._apply_ccm_inplace( + out_array, + ccm, + scratch=self._get_ccm_scratch(out_array.shape), + ) return else: @@ -801,15 +851,30 @@ def _convert_image_response_inplace( f"Unsupported image format: {image_proto.format}" ) + def _get_ccm_scratch(self, shape: Tuple[int, ...]) -> np.ndarray: + scratch = self._ccm_scratch_by_shape.get(shape) + if scratch is None: + scratch = np.empty(shape, dtype=np.float32) + self._ccm_scratch_by_shape[shape] = scratch + return scratch + @staticmethod - def _apply_ccm_inplace(img: np.ndarray, matrix: np.ndarray) -> None: + def _apply_ccm_inplace( + img: np.ndarray, + matrix: np.ndarray, + scratch: Optional[np.ndarray] = None, + ) -> None: """ Apply a 3x3 color correction matrix to an (H, W, 3) float32 image in-place. Computes corrected = matrix @ pixel for each pixel (column-vector form), equivalent to img @ matrix.T in numpy row-vector form, then clips to [0, 1]. """ - img[:] = img @ matrix.T + if scratch is None: + img[:] = img @ matrix.T + else: + np.einsum("...c,dc->...d", img, matrix, out=scratch) + np.copyto(img, scratch) np.clip(img, 0.0, 1.0, out=img) @staticmethod diff --git a/PySpotObserver/pyspotobserver/config.py b/PySpotObserver/pyspotobserver/config.py index 7c8b22e..84142ee 100644 --- a/PySpotObserver/pyspotobserver/config.py +++ b/PySpotObserver/pyspotobserver/config.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from enum import IntFlag from pathlib import Path -from typing import Dict, Any +from typing import Any, Dict, List, Optional, Union import yaml @@ -67,9 +67,15 @@ class SpotConfig: image_quality_percent: float = 100.0 """JPEG quality for RGB images (0-100)""" - request_timeout_seconds: float = 500.0 + request_timeout_seconds: float = 10.0 """Timeout for image requests""" + vision_model_path: Optional[str] = None + """Optional ONNX model path for run_pipeline=True""" + + vision_providers: Optional[List[str]] = None + """Optional ONNX Runtime provider preference order""" + # Advanced settings sdk_name: str = "PySpotObserver" """Name to identify this SDK client""" @@ -83,8 +89,12 @@ class SpotConfig: extra_params: Dict[str, Any] = field(default_factory=dict) """Additional user-defined parameters""" + def __post_init__(self) -> None: + if self.request_timeout_seconds <= 0: + raise ValueError("request_timeout_seconds must be positive") + @classmethod - def from_yaml(cls, yaml_path: Path | str) -> "SpotConfig": + def from_yaml(cls, yaml_path: Union[Path, str]) -> "SpotConfig": """ Load configuration from a YAML file. @@ -110,7 +120,7 @@ def from_yaml(cls, yaml_path: Path | str) -> "SpotConfig": return cls(**data) - def to_yaml(self, yaml_path: Path | str) -> None: + def to_yaml(self, yaml_path: Union[Path, str]) -> None: """ Save configuration to a YAML file. @@ -128,11 +138,15 @@ def to_yaml(self, yaml_path: Path | str) -> None: 'image_buffer_size': self.image_buffer_size, 'image_quality_percent': self.image_quality_percent, 'request_timeout_seconds': self.request_timeout_seconds, + 'vision_model_path': self.vision_model_path, + 'vision_providers': self.vision_providers, 'sdk_name': self.sdk_name, 'connection_retry_attempts': self.connection_retry_attempts, 'connection_retry_delay_ms': self.connection_retry_delay_ms, } + data = {key: value for key, value in data.items() if value is not None} + if self.extra_params: data['extra_params'] = self.extra_params diff --git a/PySpotObserver/pyspotobserver/vision_pipeline.py b/PySpotObserver/pyspotobserver/vision_pipeline.py index b242ce0..8bbdd26 100644 --- a/PySpotObserver/pyspotobserver/vision_pipeline.py +++ b/PySpotObserver/pyspotobserver/vision_pipeline.py @@ -1,122 +1,299 @@ -from typing import List, Tuple -import numpy as np +""" +Optional ONNX Runtime vision pipeline support. +""" + +from __future__ import annotations + +import os import threading -import onnxruntime as ort +from pathlib import Path +from typing import List, Optional, Sequence, Tuple, Union + import cv2 +import numpy as np + +from .config import SpotConfig + + +DEFAULT_MODEL_ENV_VAR = "PYSPOTOBSERVER_VISION_MODEL" +DEFAULT_PROVIDERS = ("CUDAExecutionProvider", "CPUExecutionProvider") +DEFAULT_DEPTH_SIZE = (120, 160) # (height, width) + + +class VisionPipelineError(Exception): + """Raised when the optional vision pipeline cannot run.""" + + +def _normalize_providers(providers: Optional[Union[Sequence[str], str]]) -> Tuple[str, ...]: + if providers is None: + return DEFAULT_PROVIDERS + if isinstance(providers, str): + return tuple(part.strip() for part in providers.split(",") if part.strip()) + return tuple(provider for provider in providers if provider) -# ------------------------- -# Global singleton state -# ------------------------- -_session = None -_input_names = None -_lock = threading.Lock() -# Preallocated buffers (resized only if needed) -_rgb_buffer = None -_depth_buffer = None +def _dtype_for_onnx_type(type_name: str) -> np.dtype: + normalized = type_name.lower() + if "float16" in normalized: + return np.dtype(np.float16) + if "double" in normalized or "float64" in normalized: + return np.dtype(np.float64) + return np.dtype(np.float32) -def _init_session(): - global _session, _input_names - if _session is None: - ort.preload_dlls() +def _depth_list_from_output(output: np.ndarray, batch_size: int) -> List[np.ndarray]: + output = np.asarray(output) - _session = ort.InferenceSession( - "C:/Users/jtoribio/Documents/aanya/fileexchange/promptda-vitl-rotated_optimized_batch2_fp16.onnx", - providers=["CUDAExecutionProvider"], + if batch_size == 1 and output.ndim == 2: + output = output[np.newaxis, :, :] + elif output.ndim == 4 and output.shape[1] == 1: + output = output[:, 0, :, :] + elif output.ndim == 4 and output.shape[-1] == 1: + output = output[:, :, :, 0] + + if output.ndim != 3 or output.shape[0] != batch_size: + raise VisionPipelineError( + "Vision model output must be shaped as (B, H, W), " + "(B, 1, H, W), or (B, H, W, 1); " + f"got {tuple(output.shape)} for batch size {batch_size}" ) - _input_names = [i.name for i in _session.get_inputs()] + return [np.asarray(output[i]) for i in range(batch_size)] -def _ensure_buffers(batch_size, rgb_shape, depth_shape): - """ - Allocate buffers only if size/shape changed. +class VisionPipeline: """ - global _rgb_buffer, _depth_buffer + Per-stream ONNX vision pipeline with reusable input buffers. - # Expected shapes: - # rgb: (B, 3, H, W) - # depth: (B, 1, H, W) + ONNX Runtime is imported lazily so camera streaming remains available without + installing the optional vision extra. + """ - if ( - _rgb_buffer is None - or _rgb_buffer.shape != (batch_size, 3, rgb_shape[0], rgb_shape[1]) + def __init__( + self, + model_path: Union[str, os.PathLike[str]], + providers: Optional[Union[Sequence[str], str]] = None, + depth_size: Tuple[int, int] = DEFAULT_DEPTH_SIZE, ): - _rgb_buffer = np.empty( - (batch_size, 3, rgb_shape[0], rgb_shape[1]), - dtype=np.float32, + self.model_path = Path(model_path).expanduser() + self.providers = _normalize_providers(providers) + self.depth_size = depth_size + self._lock = threading.Lock() + self._session = None + self._input_names: List[str] = [] + self._rgb_dtype = np.dtype(np.float32) + self._depth_dtype = np.dtype(np.float32) + self._rgb_buffer: Optional[np.ndarray] = None + self._depth_buffer: Optional[np.ndarray] = None + self._depth_resize_buffer: Optional[np.ndarray] = None + + if not self.model_path.exists(): + raise VisionPipelineError(f"Vision model not found: {self.model_path}") + + @classmethod + def from_config(cls, config: SpotConfig) -> "VisionPipeline": + extra_params = config.extra_params or {} + model_path = ( + config.vision_model_path + or extra_params.get("vision_model_path") + or os.environ.get(DEFAULT_MODEL_ENV_VAR) ) + if not model_path: + raise VisionPipelineError( + "Vision model path is required for run_pipeline=True. " + "Set SpotConfig.vision_model_path, extra_params['vision_model_path'], " + f"or {DEFAULT_MODEL_ENV_VAR}." + ) - if ( - _depth_buffer is None - or _depth_buffer.shape != (batch_size, 1, depth_shape[0], depth_shape[1]) - ): - _depth_buffer = np.empty( - (batch_size, 1, depth_shape[0], depth_shape[1]), - dtype=np.float32, + providers = ( + config.vision_providers + or extra_params.get("vision_providers") + or extra_params.get("vision_provider") ) + depth_size = extra_params.get("vision_depth_size", DEFAULT_DEPTH_SIZE) + if len(depth_size) != 2: + raise VisionPipelineError("vision_depth_size must contain height and width") + return cls( + model_path=model_path, + providers=providers, + depth_size=(int(depth_size[0]), int(depth_size[1])), + ) -def run_vision_pipeline( - rgb_images: List[np.ndarray], - depth_images: List[np.ndarray], -): - """ - High-performance pipeline: - - No per-call allocations (after warmup) - - No np.stack - - Thread-safe ONNX execution - """ + def run( + self, + rgb_images: List[np.ndarray], + depth_images: List[np.ndarray], + ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + if not rgb_images or not depth_images: + raise VisionPipelineError("Vision pipeline requires at least one RGB/depth pair") + if len(rgb_images) != len(depth_images): + raise VisionPipelineError( + f"RGB/depth batch mismatch: {len(rgb_images)} RGB, {len(depth_images)} depth" + ) - _init_session() + with self._lock: + self._init_session() + self._ensure_buffers(rgb_images, depth_images) + self._fill_buffers(rgb_images, depth_images) - batch_size = len(rgb_images) + output = self._session.run( + None, + { + self._input_names[0]: self._rgb_buffer, + self._input_names[1]: self._depth_buffer, + }, + )[0] - # Assume all images same shape - h_rgb, w_rgb, _ = rgb_images[0].shape - h_d, w_d = depth_images[0].shape[:2] + return rgb_images, _depth_list_from_output(output, len(rgb_images)) - MODEL_DEPTH_SHAPE = (120, 160) + def _init_session(self) -> None: + if self._session is not None: + return - _ensure_buffers(batch_size, (h_rgb, w_rgb), MODEL_DEPTH_SHAPE) + try: + import onnxruntime as ort + except ImportError as exc: + raise VisionPipelineError( + "Vision pipeline requires ONNX Runtime. " + "Install PySpotObserver with the vision extra: " + 'pip install -e ".[vision]"' + ) from exc - # ------------------------- - # Fill buffers (no stack) - # ------------------------- - for i in range(batch_size): - # RGB: HWC -> CHW - _rgb_buffer[i, 0] = rgb_images[i][:, :, 0] - _rgb_buffer[i, 1] = rgb_images[i][:, :, 1] - _rgb_buffer[i, 2] = rgb_images[i][:, :, 2] + if hasattr(ort, "preload_dlls"): + ort.preload_dlls() - # Depth: HW -> 1HW - resized = cv2.resize( - depth_images[i], - (160, 120), # (width, height) for OpenCV - interpolation=cv2.INTER_NEAREST, # important for depth! + available = set(ort.get_available_providers()) + selected = [provider for provider in self.providers if provider in available] + if not selected: + raise VisionPipelineError( + "None of the requested ONNX Runtime providers are available. " + f"requested={list(self.providers)}, available={sorted(available)}" + ) + + self._session = ort.InferenceSession( + str(self.model_path), + providers=selected, + ) + + inputs = self._session.get_inputs() + if len(inputs) < 2: + raise VisionPipelineError( + f"Vision model must expose at least 2 inputs; got {len(inputs)}" + ) + + self._input_names = [inputs[0].name, inputs[1].name] + self._rgb_dtype = _dtype_for_onnx_type(inputs[0].type) + self._depth_dtype = _dtype_for_onnx_type(inputs[1].type) + + def _ensure_buffers( + self, + rgb_images: Sequence[np.ndarray], + depth_images: Sequence[np.ndarray], + ) -> None: + batch_size = len(rgb_images) + h_rgb, w_rgb, channels = rgb_images[0].shape + if channels != 3: + raise VisionPipelineError(f"Expected RGB images with 3 channels, got {channels}") + + depth_h, depth_w = self.depth_size + rgb_shape = (batch_size, 3, h_rgb, w_rgb) + depth_shape = (batch_size, 1, depth_h, depth_w) + + if self._rgb_buffer is None or self._rgb_buffer.shape != rgb_shape: + self._rgb_buffer = np.empty(rgb_shape, dtype=self._rgb_dtype) + elif self._rgb_buffer.dtype != self._rgb_dtype: + self._rgb_buffer = np.empty(rgb_shape, dtype=self._rgb_dtype) + + if self._depth_buffer is None or self._depth_buffer.shape != depth_shape: + self._depth_buffer = np.empty(depth_shape, dtype=self._depth_dtype) + elif self._depth_buffer.dtype != self._depth_dtype: + self._depth_buffer = np.empty(depth_shape, dtype=self._depth_dtype) + + if self._depth_dtype == np.dtype(np.float16): + if ( + self._depth_resize_buffer is None + or self._depth_resize_buffer.shape != depth_shape + ): + self._depth_resize_buffer = np.empty(depth_shape, dtype=np.float32) + else: + self._depth_resize_buffer = None + + for index, rgb in enumerate(rgb_images): + if rgb.shape != (h_rgb, w_rgb, 3): + raise VisionPipelineError( + "All RGB images in a pipeline batch must have the same shape; " + f"image 0 is {(h_rgb, w_rgb, 3)}, image {index} is {rgb.shape}" + ) + for index, depth in enumerate(depth_images): + if depth.ndim != 2: + raise VisionPipelineError( + f"Depth image {index} must be 2D before pipeline processing; got {depth.shape}" + ) + + def _fill_buffers( + self, + rgb_images: Sequence[np.ndarray], + depth_images: Sequence[np.ndarray], + ) -> None: + depth_h, depth_w = self.depth_size + + for index, rgb in enumerate(rgb_images): + self._rgb_buffer[index, 0] = rgb[:, :, 0] + self._rgb_buffer[index, 1] = rgb[:, :, 1] + self._rgb_buffer[index, 2] = rgb[:, :, 2] + + for index, depth in enumerate(depth_images): + depth_dst = self._depth_buffer[index, 0] + resize_dst = ( + self._depth_resize_buffer[index, 0] + if self._depth_resize_buffer is not None + else depth_dst + ) + cv2.resize( + depth, + (depth_w, depth_h), + dst=resize_dst, + interpolation=cv2.INTER_NEAREST, + ) + if resize_dst is not depth_dst: + depth_dst[...] = resize_dst + + +_default_pipeline: Optional[VisionPipeline] = None +_default_pipeline_key: Optional[Tuple[str, Tuple[str, ...], Tuple[int, int]]] = None +_default_pipeline_lock = threading.Lock() + + +def run_vision_pipeline( + rgb_images: List[np.ndarray], + depth_images: List[np.ndarray], + *, + model_path: Optional[str] = None, + providers: Optional[Union[Sequence[str], str]] = None, + depth_size: Tuple[int, int] = DEFAULT_DEPTH_SIZE, +) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """ + Backwards-compatible functional entry point using a cached default pipeline. + """ + resolved_model_path = model_path or os.environ.get(DEFAULT_MODEL_ENV_VAR) + if not resolved_model_path: + raise VisionPipelineError( + "Vision model path is required. Pass model_path or set " + f"{DEFAULT_MODEL_ENV_VAR}." ) - _depth_buffer[i, 0] = resized - - # ------------------------- - # Inference (thread-safe) - # ------------------------- - with _lock: - output = _session.run( - None, - { - _input_names[0]: _rgb_buffer, - _input_names[1]: _depth_buffer, - }, - )[0] - - # # ------------------------- - # # Fast normalization (in-place) - # # ------------------------- - # out_min = output.min() - # out_max = output.max() - # scale = 1.0 / (out_max - out_min + 1e-8) - # output = (output - out_min) * scale - - return rgb_images, output \ No newline at end of file + provider_tuple = _normalize_providers(providers) + key = (str(Path(resolved_model_path).expanduser()), provider_tuple, depth_size) + + global _default_pipeline, _default_pipeline_key + with _default_pipeline_lock: + if _default_pipeline is None or _default_pipeline_key != key: + _default_pipeline = VisionPipeline( + model_path=resolved_model_path, + providers=provider_tuple, + depth_size=depth_size, + ) + _default_pipeline_key = key + + return _default_pipeline.run(rgb_images, depth_images) diff --git a/PySpotObserver/requirements.txt b/PySpotObserver/requirements.txt index b129369..feef701 100644 --- a/PySpotObserver/requirements.txt +++ b/PySpotObserver/requirements.txt @@ -1,13 +1,7 @@ -# Core dependencies -bosdyn-client>=5.0.0 # Boston Dynamics Spot SDK -numpy>=2.0.0,<2.4.0 # Array operations -opencv-python>=4.12.0 # Image processing and display -pyyaml>=6.0.0 # YAML configuration support +# Dependency source of truth is setup.py. +# Run this from the PySpotObserver directory for a development install. +-e .[dev] -# Optional dependencies -# For development -pytest>=7.4.0 # Testing framework -pytest-asyncio>=0.21.0 # Async test support -black>=23.0.0 # Code formatting -mypy>=1.5.0 # Type checking -onnxruntime-gpu>=1.25.0 \ No newline at end of file +# Vision pipeline support is optional because it requires ONNX Runtime. +# Install it when needed with: +# pip install -e ".[vision]" diff --git a/PySpotObserver/setup.py b/PySpotObserver/setup.py index 969b0bf..511ef1e 100644 --- a/PySpotObserver/setup.py +++ b/PySpotObserver/setup.py @@ -5,6 +5,24 @@ from setuptools import setup, find_packages from pathlib import Path +INSTALL_REQUIRES = [ + "bosdyn-client>=5.0.0", + "numpy>=2.0.0,<2.4.0", + "opencv-python>=4.12.0", + "pyyaml>=6.0.0", +] + +DEV_REQUIRES = [ + "pytest>=7.4.0", + "pytest-asyncio>=0.21.0", + "black>=23.0.0", + "mypy>=1.5.0", +] + +VISION_REQUIRES = [ + "onnxruntime-gpu>=1.25.0", +] + # Read README for long description readme_file = Path(__file__).parent / "README.md" long_description = readme_file.read_text(encoding="utf-8") if readme_file.exists() else "" @@ -32,19 +50,11 @@ "Topic :: Software Development :: Libraries :: Python Modules", ], python_requires=">=3.9", - install_requires=[ - "bosdyn-client>=5.0.0", - "numpy>=2.0.0,<2.4.0", - "opencv-python>=4.12.0", - "pyyaml>=6.0.0", - ], + install_requires=INSTALL_REQUIRES, extras_require={ - "dev": [ - "pytest>=7.4.0", - "pytest-asyncio>=0.21.0", - "black>=23.0.0", - "mypy>=1.5.0", - ], + "dev": DEV_REQUIRES, + "vision": VISION_REQUIRES, + "all": DEV_REQUIRES + VISION_REQUIRES, }, entry_points={ "console_scripts": [ diff --git a/install b/install index 862cd5c..440ce5a 160000 --- a/install +++ b/install @@ -1 +1 @@ -Subproject commit 862cd5cf0f2f3da8987638271fa120a7a958847b +Subproject commit 440ce5acb594f13b3ac758d14e0cd03be5341164 From ea4bf4d225ad070e77ad934b3ac71f348d2b0004 Mon Sep 17 00:00:00 2001 From: Faisal Zaghloul Date: Wed, 10 Jun 2026 16:24:21 -0400 Subject: [PATCH 3/3] Add pytest files --- PySpotObserver/tests/__init__.py | 3 + PySpotObserver/tests/test_config.py | 126 +++++++++++ .../tests/test_streaming_lifecycle.py | 197 ++++++++++++++++++ 3 files changed, 326 insertions(+) create mode 100644 PySpotObserver/tests/__init__.py create mode 100644 PySpotObserver/tests/test_config.py create mode 100644 PySpotObserver/tests/test_streaming_lifecycle.py diff --git a/PySpotObserver/tests/__init__.py b/PySpotObserver/tests/__init__.py new file mode 100644 index 0000000..95a7ede --- /dev/null +++ b/PySpotObserver/tests/__init__.py @@ -0,0 +1,3 @@ +""" +Tests for PySpotObserver package. +""" diff --git a/PySpotObserver/tests/test_config.py b/PySpotObserver/tests/test_config.py new file mode 100644 index 0000000..4842d5f --- /dev/null +++ b/PySpotObserver/tests/test_config.py @@ -0,0 +1,126 @@ +""" +Unit tests for configuration module. +""" + +import os +import pytest +from pathlib import Path +import tempfile + +from pyspotobserver.config import SpotConfig, CameraType + + +class TestCameraType: + """Tests for CameraType enum.""" + + def test_camera_mask_combinations(self): + """Test bitwise operations on camera types.""" + mask = CameraType.FRONTLEFT | CameraType.FRONTRIGHT + assert mask & CameraType.FRONTLEFT + assert mask & CameraType.FRONTRIGHT + assert not (mask & CameraType.BACK) + + def test_get_source_name_rgb(self): + """Test RGB camera source name generation.""" + assert CameraType.get_source_name(CameraType.FRONTLEFT) == "frontleft_fisheye_image" + assert CameraType.get_source_name(CameraType.BACK) == "back_fisheye_image" + assert CameraType.get_source_name(CameraType.HAND) == "hand_color_image" + + def test_get_source_name_depth(self): + """Test depth camera source name generation.""" + assert CameraType.get_source_name(CameraType.FRONTLEFT, depth=True) == \ + "frontleft_depth_in_visual_frame" + assert CameraType.get_source_name(CameraType.RIGHT, depth=True) == \ + "right_depth_in_visual_frame" + assert CameraType.get_source_name(CameraType.HAND, depth=True) == \ + "hand_depth_in_hand_color_frame" + + +class TestSpotConfig: + """Tests for SpotConfig dataclass.""" + + def test_default_values(self): + """Test configuration with default values.""" + config = SpotConfig(robot_ip="192.168.80.3") + assert config.robot_ip == "192.168.80.3" + assert config.username == "" + assert config.password == "" + assert config.image_buffer_size == 5 + assert config.image_quality_percent == 100.0 + assert config.request_timeout_seconds == 10.0 + assert config.vision_model_path is None + assert config.vision_providers is None + + def test_request_timeout_must_be_positive(self): + """Test that non-positive request timeouts are rejected.""" + with pytest.raises(ValueError, match="request_timeout_seconds"): + SpotConfig(robot_ip="192.168.80.3", request_timeout_seconds=0) + + def test_custom_values(self): + """Test configuration with custom values.""" + config = SpotConfig( + robot_ip="10.0.0.1", + username="admin", + password="secret", + image_buffer_size=10, + ) + assert config.robot_ip == "10.0.0.1" + assert config.username == "admin" + assert config.password == "secret" + assert config.image_buffer_size == 10 + + def test_repr_redacts_password(self): + """Test that __repr__ redacts password.""" + config = SpotConfig(robot_ip="192.168.80.3", password="secret") + repr_str = repr(config) + assert "secret" not in repr_str + assert "***" in repr_str + + def test_yaml_roundtrip(self): + """Test saving and loading from YAML.""" + config = SpotConfig( + robot_ip="192.168.80.3", + username="testuser", + password="testpass", + image_buffer_size=7, + ) + + fd, raw_path = tempfile.mkstemp(suffix=".yaml") + yaml_path = Path(raw_path) + try: + os.close(fd) + yaml_path.unlink() + + # Save to YAML + config.to_yaml(yaml_path) + assert yaml_path.exists() + + # Load from YAML + loaded_config = SpotConfig.from_yaml(yaml_path) + assert loaded_config.robot_ip == config.robot_ip + assert loaded_config.username == config.username + assert loaded_config.password == config.password + assert loaded_config.image_buffer_size == config.image_buffer_size + finally: + try: + yaml_path.unlink() + except FileNotFoundError: + pass + + def test_yaml_file_not_found(self): + """Test loading from non-existent YAML file.""" + with pytest.raises(FileNotFoundError): + SpotConfig.from_yaml("nonexistent.yaml") + + def test_extra_params(self): + """Test extra_params field.""" + config = SpotConfig( + robot_ip="192.168.80.3", + extra_params={"location": "lab", "experiment": "test1"} + ) + assert config.extra_params["location"] == "lab" + assert config.extra_params["experiment"] == "test1" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/PySpotObserver/tests/test_streaming_lifecycle.py b/PySpotObserver/tests/test_streaming_lifecycle.py new file mode 100644 index 0000000..6cd9f86 --- /dev/null +++ b/PySpotObserver/tests/test_streaming_lifecycle.py @@ -0,0 +1,197 @@ +import threading +import time + +import numpy as np +import pytest +import cv2 +from bosdyn.api import image_pb2 + +from pyspotobserver.camera_stream import ImageFrame, SpotCamStream, SpotCamStreamError +from pyspotobserver.config import CameraType, SpotConfig +from pyspotobserver.connection import SpotConnection, SpotConnectionError + + +class DummyImageClient: + def __init__(self, responses): + self._responses = list(responses) + + def get_image(self, image_requests, timeout): + response = self._responses.pop(0) + if isinstance(response, Exception): + raise response + return response + + +def make_stream(image_client=None) -> SpotCamStream: + return SpotCamStream( + image_client=image_client or DummyImageClient([]), + config=SpotConfig(robot_ip="192.168.80.3", request_timeout_seconds=0.05), + stream_id="test_stream", + ) + + +def test_get_current_images_returns_latest_frame(): + stream = make_stream() + stream._streaming = True + + older = ImageFrame( + rgb_images=[np.full((1, 1, 3), 1.0, dtype=np.float32)], + depth_images=[np.full((1, 1), 1.0, dtype=np.float32)], + camera_order=[CameraType.FRONTLEFT], + timestamp=1.0, + ) + newer = ImageFrame( + rgb_images=[np.full((1, 1, 3), 2.0, dtype=np.float32)], + depth_images=[np.full((1, 1), 2.0, dtype=np.float32)], + camera_order=[CameraType.FRONTLEFT], + timestamp=2.0, + ) + + stream._image_queue.put(older) + stream._image_queue.put(newer) + + rgb_images, depth_images = stream.get_current_images(timeout=0.01) + + assert float(rgb_images[0][0, 0, 0]) == 2.0 + assert float(depth_images[0][0, 0]) == 2.0 + assert stream._image_queue.qsize() == 2 + + +def test_get_current_images_run_pipeline_requires_model_path(): + stream = make_stream() + stream._streaming = True + frame = ImageFrame( + rgb_images=[np.full((1, 1, 3), 1.0, dtype=np.float32)], + depth_images=[np.full((1, 1), 1.0, dtype=np.float32)], + camera_order=[CameraType.FRONTLEFT], + timestamp=1.0, + ) + stream._image_queue.put(frame) + + with pytest.raises(SpotCamStreamError, match="Vision model path is required"): + stream.get_current_images(timeout=0.01, run_pipeline=True) + + +def test_get_current_images_run_pipeline_uses_lazy_pipeline(monkeypatch): + import pyspotobserver.vision_pipeline as vision_pipeline + + class FakeVisionPipeline: + @classmethod + def from_config(cls, config): + return cls() + + def run(self, rgb_images, depth_images): + return rgb_images, [depth + 10.0 for depth in depth_images] + + monkeypatch.setattr(vision_pipeline, "VisionPipeline", FakeVisionPipeline) + + stream = make_stream() + stream._streaming = True + frame = ImageFrame( + rgb_images=[np.full((1, 1, 3), 1.0, dtype=np.float32)], + depth_images=[np.full((1, 1), 2.0, dtype=np.float32)], + camera_order=[CameraType.FRONTLEFT], + timestamp=1.0, + ) + stream._image_queue.put(frame) + + rgb_images, depth_images = stream.get_current_images( + timeout=0.01, + run_pipeline=True, + ) + + assert float(rgb_images[0][0, 0, 0]) == 1.0 + assert float(depth_images[0][0, 0]) == 12.0 + + +def test_get_current_images_unblocks_when_stream_stops(): + stream = make_stream() + stream._streaming = True + + result = {} + + def reader(): + try: + stream.get_current_images(timeout=None) + except Exception as exc: # pragma: no cover - exercised by assertion + result["exc"] = exc + + thread = threading.Thread(target=reader) + thread.start() + time.sleep(0.15) + + stream._streaming = False + stream._stop_event.set() + + thread.join(timeout=1.0) + + assert not thread.is_alive() + assert isinstance(result["exc"], SpotCamStreamError) + assert "stopped" in str(result["exc"]).lower() + + +def test_stream_loop_recovers_after_initial_request_error(): + stream = make_stream(DummyImageClient([RuntimeError("boom"), ["ok"]])) + stream._streaming = True + stream._camera_order = [CameraType.FRONTLEFT] + + decoded = [ + np.zeros((1, 1, 3), dtype=np.float32), + np.zeros((1, 1), dtype=np.float32), + ] + + stream._build_image_requests = lambda: [] + stream._decode_initial_responses = lambda responses: decoded + + original_enqueue = stream._enqueue_frame + + def enqueue_and_stop(frame): + original_enqueue(frame) + stream._streaming = False + stream._stop_event.set() + + stream._enqueue_frame = enqueue_and_stop + + stream._stream_loop() + + assert stream.frame_count == 1 + assert stream.error_count == 1 + assert len(stream._frame_pool) == stream._config.image_buffer_size + + +def test_jpeg_inplace_color_correction_updates_output_array(): + stream = make_stream() + bgr = np.full((2, 2, 3), 255, dtype=np.uint8) + ok, encoded = cv2.imencode(".jpg", bgr) + assert ok + + response = image_pb2.ImageResponse() + response.shot.image.format = image_pb2.Image.FORMAT_JPEG + response.shot.image.data = encoded.tobytes() + + out = np.empty((2, 2, 3), dtype=np.float32) + zero_ccm = np.zeros((3, 3), dtype=np.float32) + + stream._convert_image_response_inplace( + response, + is_depth=False, + out_array=out, + ccm=zero_ccm, + ) + + assert np.allclose(out, 0.0) + + +def test_disconnect_raises_if_stream_does_not_stop(): + class BrokenStream: + def stop_streaming(self): + raise SpotCamStreamError("still alive") + + conn = SpotConnection(SpotConfig(robot_ip="192.168.80.3")) + conn._connected = True + conn._cam_streams = {"broken": BrokenStream()} + + with pytest.raises(SpotConnectionError, match="Failed to stop all streams cleanly"): + conn.disconnect() + + assert conn.connected is True