diff --git a/isaaclab_arena/assets/register.py b/isaaclab_arena/assets/register.py index 6071dd9da..fe1718839 100644 --- a/isaaclab_arena/assets/register.py +++ b/isaaclab_arena/assets/register.py @@ -8,8 +8,10 @@ DeviceRegistry, EnvironmentRegistry, HDRImageRegistry, + ObjectRelationLibraryRegistry, PolicyRegistry, RetargeterRegistry, + TaskRegistry, ) @@ -68,3 +70,25 @@ def register_environment(cls): else: registry.register(cls, cls.name) return cls + + +# Decorator to register a RelationBase subclass with the ObjectRelationLibraryRegistry. +def register_object_relation(cls): + registry = ObjectRelationLibraryRegistry() + if registry.is_registered(cls.name): + print(f"WARNING: Object relation {cls.name} is already registered. Doing nothing.") + else: + registry.register(cls, cls.name) + return cls + + +# Decorator to register a TaskBase subclass with the TaskRegistry. +# Keyed by `cls.__name__` so the YAML `type: PascalCase` lookups match without +# requiring a separate `name` attribute on every task class. +def register_task(cls): + registry = TaskRegistry() + if registry.is_registered(cls.__name__): + print(f"WARNING: Task {cls.__name__} is already registered. Doing nothing.") + else: + registry.register(cls, cls.__name__) + return cls diff --git a/isaaclab_arena/assets/registries.py b/isaaclab_arena/assets/registries.py index 5abbc731f..10f9c43bc 100644 --- a/isaaclab_arena/assets/registries.py +++ b/isaaclab_arena/assets/registries.py @@ -16,6 +16,8 @@ from isaaclab_arena.assets.hdr_image import HDRImage from isaaclab_arena.assets.teleop_device_base import TeleopDeviceBase from isaaclab_arena.policy.policy_base import PolicyBase + from isaaclab_arena.relations.relations import RelationBase + from isaaclab_arena.tasks.task_base import TaskBase # Have to define all classes here in order to avoid circular import. @@ -42,7 +44,18 @@ def is_registered(self, key: str) -> bool: key (str): The name of the component. """ # For AssetRegistry and DeviceRegistry, ensure assets are registered before checking - if isinstance(self, (AssetRegistry, DeviceRegistry, RetargeterRegistry, PolicyRegistry, HDRImageRegistry)): + if isinstance( + self, + ( + AssetRegistry, + DeviceRegistry, + RetargeterRegistry, + PolicyRegistry, + HDRImageRegistry, + ObjectRelationLibraryRegistry, + TaskRegistry, + ), + ): ensure_assets_registered() return key in self._components @@ -56,7 +69,18 @@ def get_component_by_name(self, key: str) -> Any: Any: The component. """ # For AssetRegistry and DeviceRegistry, ensure assets are registered before accessing - if isinstance(self, (AssetRegistry, DeviceRegistry, RetargeterRegistry, PolicyRegistry, HDRImageRegistry)): + if isinstance( + self, + ( + AssetRegistry, + DeviceRegistry, + RetargeterRegistry, + PolicyRegistry, + HDRImageRegistry, + ObjectRelationLibraryRegistry, + TaskRegistry, + ), + ): ensure_assets_registered() assert key in self._components, f"component {key} not found, please check if requested component is registered" return self._components[key] @@ -68,7 +92,18 @@ def get_all_keys(self) -> list[str]: list[str]: The list of keys. """ # For AssetRegistry and DeviceRegistry, ensure assets are registered before accessing - if isinstance(self, (AssetRegistry, DeviceRegistry, RetargeterRegistry, PolicyRegistry, HDRImageRegistry)): + if isinstance( + self, + ( + AssetRegistry, + DeviceRegistry, + RetargeterRegistry, + PolicyRegistry, + HDRImageRegistry, + ObjectRelationLibraryRegistry, + TaskRegistry, + ), + ): ensure_assets_registered() return list(self._components.keys()) @@ -217,14 +252,52 @@ def __init__(self): super().__init__() +class ObjectRelationLibraryRegistry(Registry): + """Registry for object relation classes.""" + + def __init__(self): + super().__init__() + + def get_object_relation_by_name(self, name: str) -> type["RelationBase"]: + """Gets an object relation by name. + + Args: + name (str): The name of the object relation. + """ + ensure_assets_registered() + return self.get_component_by_name(name) + + +class TaskRegistry(Registry): + """Registry for TaskBase subclasses.""" + + def __init__(self): + super().__init__() + + def get_task_by_name(self, name: str) -> type["TaskBase"]: + """Gets a task class by name. + + Args: + name (str): The name of the task class (typically the class __name__). + """ + ensure_assets_registered() + return self.get_component_by_name(name) + + # Lazy registration to avoid circular imports _assets_registered = False +# Blocks re-entry: registration decorators call is_registered() -> ensure_assets_registered() +# mid-import, which would re-import a partial module and raise a circular ImportError. +_registration_in_progress = False def ensure_assets_registered(): """Ensure all assets are registered. Call this before accessing the registry.""" - global _assets_registered - if not _assets_registered: + global _assets_registered, _registration_in_progress + if _assets_registered or _registration_in_progress: + return + _registration_in_progress = True + try: # Import modules to trigger asset registration via decorators import isaaclab_arena.assets.background_library # noqa: F401 import isaaclab_arena.assets.device_library # noqa: F401 @@ -233,5 +306,12 @@ def ensure_assets_registered(): import isaaclab_arena.assets.retargeter_library # noqa: F401 import isaaclab_arena.embodiments # noqa: F401 import isaaclab_arena.policy # noqa: F401 + import isaaclab_arena.relations.relations # noqa: F401 + import isaaclab_arena.tasks # noqa: F401 + except Exception: + _assets_registered = False # roll back so a later call can retry + raise _assets_registered = True + finally: + _registration_in_progress = False diff --git a/isaaclab_arena/environments/arena_env_graph_conversion_utils.py b/isaaclab_arena/environments/arena_env_graph_conversion_utils.py new file mode 100644 index 000000000..fb2290aa5 --- /dev/null +++ b/isaaclab_arena/environments/arena_env_graph_conversion_utils.py @@ -0,0 +1,150 @@ +# Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from isaaclab_arena.assets.asset import Asset +from isaaclab_arena.assets.object_reference import ObjectReference +from isaaclab_arena.assets.registries import AssetRegistry +from isaaclab_arena.environments.arena_env_graph_task_conversion_utils import build_task_from_specs +from isaaclab_arena.environments.arena_env_graph_types import ( + ArenaEnvGraphNodeSpec, + ArenaEnvGraphNodeType, + ArenaEnvGraphObjectReferenceNodeSpec, + ArenaEnvGraphSpatialConstraintType, + ArenaEnvGraphStateSpec, +) +from isaaclab_arena.environments.graph_spec_utils import relation_class_for_spatial_constraint_type +from isaaclab_arena.environments.isaaclab_arena_environment import IsaacLabArenaEnvironment +from isaaclab_arena.scene.scene import Scene +from isaaclab_arena.utils.pose import Pose + +if TYPE_CHECKING: + from isaaclab_arena.environments.arena_env_graph_spec import ArenaEnvGraphSpec + + +def build_arena_env_from_graph_spec(graph_spec: ArenaEnvGraphSpec) -> Any: + """Build an IsaacLabArenaEnvironment from a validated ``ArenaEnvGraphSpec``. + + Precondition: ``graph_spec`` is already validated (node refs exist, ids unique, etc.). + """ + # TODO(xinjieyao, 2026-05-26): aggregate every state_spec into a single combined initial state instead of + # picking one. For now we just take the first state_spec, which is the initial state + # for the first task — this matches the previous default behavior. + initial_state_spec = graph_spec.state_specs[0] if graph_spec.state_specs else None + + # 1. Materialize every graph node into a live asset, keyed by node id so spatial + # constraints and task args can reference each node by its graph-local id. + assets_by_node_id = _instantiate_assets_from_nodes(graph_spec.nodes, AssetRegistry()) + + # 2. Wire the initial state's spatial relations / fixed poses into those assets. + if initial_state_spec is not None: + _attach_spatial_constraints_to_assets(initial_state_spec, assets_by_node_id) + + # 3. Partition nodes into the env's embodiment (exactly one) and its scene assets. + embodiment, scene_assets = _partition_nodes_into_embodiment_and_scene(graph_spec.nodes, assets_by_node_id) + + # 4. Resolve task specs against the same assets_by_node_id so task args bind to the + # actual asset instances created in step 1 (not duplicates). + return IsaacLabArenaEnvironment( + name=graph_spec.env_name, + scene=Scene(assets=scene_assets), + embodiment=embodiment, + task=build_task_from_specs(graph_spec.tasks, assets_by_node_id), + ) + + +def _partition_nodes_into_embodiment_and_scene( + node_specs: list[ArenaEnvGraphNodeSpec], assets_by_node_id: dict[str, Any] +) -> tuple[Any, list[Asset]]: + """Split materialized nodes into the single embodiment asset and the list of scene assets. + + Asserts exactly one EMBODIMENT node. OBJECT / OBJECT_REFERENCE nodes become scene assets; + any other node type raises. Lighting is not handled yet. + """ + embodiment = None + scene_assets: list[Asset] = [] + # TODO(xinjieyao, 2026-05-26): include lighting later + for node_spec in node_specs: + if node_spec.type == ArenaEnvGraphNodeType.EMBODIMENT: + assert embodiment is None, "Only one embodiment node can be converted to an IsaacLabArenaEnvironment" + embodiment = assets_by_node_id[node_spec.id] + elif node_spec.type in (ArenaEnvGraphNodeType.OBJECT, ArenaEnvGraphNodeType.OBJECT_REFERENCE): + scene_assets.append(assets_by_node_id[node_spec.id]) + else: + raise ValueError(f"Unsupported node type: {node_spec.type}") + return embodiment, scene_assets + + +def _instantiate_assets_from_nodes(node_specs: list[ArenaEnvGraphNodeSpec], asset_registry: Any) -> dict[str, Any]: + """Return ``{node.id: live_asset}`` after a single pass over ``node_specs``. + + Each ``node_spec.params`` is forwarded verbatim to the asset constructor. Assumes parent + nodes precede their OBJECT_REFERENCE children — guaranteed by ``assert_references_exist``. + """ + assets_by_node_id: dict[str, Any] = {} + for node_spec in node_specs: + # OBJECT_REFERENCE wraps a USD prim inside an already-instantiated parent asset + # (e.g. a table inside a kitchen background). Validation guarantees the parent + # precedes the reference, so it is already in assets_by_node_id here. + if node_spec.type == ArenaEnvGraphNodeType.OBJECT_REFERENCE: + assert isinstance(node_spec, ArenaEnvGraphObjectReferenceNodeSpec) + assets_by_node_id[node_spec.id] = ObjectReference( + name=node_spec.name, + prim_path=node_spec.prim_path, + parent_asset=assets_by_node_id[node_spec.parent], + object_type=node_spec.object_type, + **node_spec.params, + ) + else: + # Standard nodes (object / background / embodiment): look up the registered class + # by name and instantiate with the spec's verbatim kwargs. + asset_class = asset_registry.get_asset_by_name(node_spec.name) + assets_by_node_id[node_spec.id] = asset_class(**node_spec.params) + return assets_by_node_id + + +def _attach_spatial_constraints_to_assets( + state_spec: ArenaEnvGraphStateSpec, assets_by_node_id: dict[str, Any] +) -> None: + """Attach one Relation per spatial constraint to the asset(s) it targets, in place. + + AT_POSE is special-cased to ``set_initial_pose`` since it has no Relation class. + Raises AssertionError on unsupported constraint types or malformed AT_POSE params. + """ + for spatial_constraint in state_spec.spatial_constraints: + parent_asset = assets_by_node_id[spatial_constraint.parent] + + # AT_POSE has no Relation class — it pins the parent's initial pose directly, + # bypassing the solver. Pop the two known fields and reject any extras so a typo + # in the spec ('postion_xyz') raises here instead of silently being ignored. + # TODO(xinjieyao, 2026-05-26): move at_pose handling into the placer module. + if spatial_constraint.type == ArenaEnvGraphSpatialConstraintType.AT_POSE: + at_pose_params = dict(spatial_constraint.params) + position_xyz = at_pose_params.pop("position_xyz", None) + rotation_xyzw = at_pose_params.pop("rotation_xyzw", (0.0, 0.0, 0.0, 1.0)) + assert ( + position_xyz is not None + ), f"at_pose constraint '{spatial_constraint.id}' requires params.position_xyz" + assert ( + not at_pose_params + ), f"Unsupported at_pose params for constraint '{spatial_constraint.id}': {sorted(at_pose_params)}" + parent_asset.set_initial_pose(Pose(position_xyz=position_xyz, rotation_xyzw=rotation_xyzw)) + else: + # All other constraint types resolve through ObjectRelationLibraryRegistry. The + # registry returning None signals an enum value with no registered class — would + # be a programming error, not a YAML error. + relation_class = relation_class_for_spatial_constraint_type(spatial_constraint.type) + assert relation_class is not None, f"Unsupported spatial constraint type '{spatial_constraint.type.value}'" + # Unary relations (IS_ANCHOR, POSITION_LIMITS, ...) attach to the parent asset. + # Binary relations (ON, NEXT_TO, ...) attach to the *child* asset, with parent + # as the relation's first constructor arg — matches how add_relation is wired. + if relation_class.is_unary(): + parent_asset.add_relation(relation_class(**spatial_constraint.params)) + else: + child_asset = assets_by_node_id[spatial_constraint.child] + child_asset.add_relation(relation_class(parent_asset, **spatial_constraint.params)) diff --git a/isaaclab_arena/environments/arena_env_graph_spec.py b/isaaclab_arena/environments/arena_env_graph_spec.py index 2f69f97d6..e394c67ad 100644 --- a/isaaclab_arena/environments/arena_env_graph_spec.py +++ b/isaaclab_arena/environments/arena_env_graph_spec.py @@ -5,14 +5,25 @@ import yaml from dataclasses import dataclass, field -from enum import Enum from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any from isaaclab_arena.assets.object_type import ObjectType +from isaaclab_arena.environments.arena_env_graph_types import ( + ArenaEnvGraphNodeSpec, + ArenaEnvGraphNodeType, + ArenaEnvGraphObjectReferenceNodeSpec, + ArenaEnvGraphSpatialConstraintSpec, + ArenaEnvGraphSpatialConstraintType, + ArenaEnvGraphStateSpec, + ArenaEnvGraphTaskConstraintSpec, + ArenaEnvGraphTaskConstraintType, + ArenaEnvGraphTaskSpec, +) from isaaclab_arena.environments.graph_spec_utils import ( as_dict, assert_references_exist, + assert_spatial_constraint_shapes, assert_unique_ids, optional_dict, optional_str, @@ -22,120 +33,23 @@ required_str, ) +if TYPE_CHECKING: + from isaaclab_arena.environments.isaaclab_arena_environment import IsaacLabArenaEnvironment -class ArenaEnvGraphNodeType(Enum): - EMBODIMENT = "embodiment" - BACKGROUND = "background" - OBJECT = "object" - OBJECT_REFERENCE = "object_reference" - LIGHTING = "lighting" - - -class ArenaEnvGraphSpatialConstraintType(Enum): - IS_ANCHOR = "is_anchor" - NEXT_TO = "next_to" - ON = "on" - AT_POSE = "at_pose" # through set_initial_pose() - AT_POSITION = "at_position" # through object relation solver: AtPosition - POSITION_LIMITS = "position_limits" - RANDOM_AROUND_SOLUTION = "random_around_solution" - ROTATE_AROUND_SOLUTION = "rotate_around_solution" - # TODO(xinjieyao, 2026-05-21): Support "in" in solver - IN = "in" - - -class ArenaEnvGraphTaskConstraintType(Enum): - REACH = "reach" - - -@dataclass -class ArenaEnvGraphNodeSpec: - """Node in an environment graph. - - Could be an object, an embodiment, a background, etc. Object references — USD prims - inside a parent background asset — are represented by the - :class:`ArenaEnvGraphObjectReferenceNodeSpec` subclass, which adds the extra fields - needed to locate and type the referenced prim. - """ - - id: str - name: str # Name registered in the asset registry - type: ArenaEnvGraphNodeType - # Asset-type specific optional kwargs (e.g. scale, spawn_cfg_addon) — distinct from - # the typed graph metadata above. The Arena environment builder forwards these when - # instantiating the asset class. - params: dict[str, Any] = field(default_factory=dict) - - -# kw_only=True forces the three new fields to be keyword-only in __init__. Required because -# the base class ends with a defaulted field (`params`) and Python forbids non-default args -# from following default ones — placing the new required fields after `*` sidesteps that rule -# and lets us declare them as required (no default) instead of Optional with runtime checks. -@dataclass(kw_only=True) -class ArenaEnvGraphObjectReferenceNodeSpec(ArenaEnvGraphNodeSpec): - """Object-reference node: a USD prim inside a parent background asset. - - All three extra fields are required for this node type — without them the - builder cannot bind to the referenced prim or know how to wrap it. - """ - - parent: str # id of the parent (typically background) node that owns the prim - prim_path: str # USD prim path of the referenced prim (may contain {ENV_REGEX_NS}) - object_type: ObjectType # how to wrap the prim (rigid, articulation, etc.) - - -@dataclass -class ArenaEnvGraphSpatialConstraintSpec: - """Spatial constraint edge in an environment graph state spec. - - It defines a relation between two nodes. - """ - - id: str - type: ArenaEnvGraphSpatialConstraintType - parent: str - child: str | None = None # Optional, e.g. is_anchor constraint does not have a child - # Type-specific optional kwargs for the underlying RelationBase subclass selected by `type` - # (e.g. {x_min, x_max, y_min, y_max} for position_limits; {side, distance} for next_to etc.). - # The Arena environment builder forwards these when constructing the Relation instance. - params: dict[str, Any] = field(default_factory=dict) - - -@dataclass -class ArenaEnvGraphTaskConstraintSpec: - """Task-dependent constraint edge in an environment graph state spec.""" - - id: str - type: ArenaEnvGraphTaskConstraintType - parent: str - child: str | None = None # Optional, could be a robot keeps gripper open or closed, or a single object - # Type-specific optional kwargs for the underlying TaskConstraintBase subclass selected by `type` - # (e.g. grasp pose offset the reach constraint.). - # The Arena environment builder forwards these when constructing the TaskConstraint instance. - params: dict[str, Any] = field(default_factory=dict) - - -@dataclass -class ArenaEnvGraphStateSpec: - """Snapshot of the environment state in the graph. - - Could be an initial, intermediate, or final state. - """ - - id: str - spatial_constraints: list[ArenaEnvGraphSpatialConstraintSpec] = field(default_factory=list) - task_constraints: list[ArenaEnvGraphTaskConstraintSpec] = field(default_factory=list) - - -@dataclass -class ArenaEnvGraphTaskSpec: - """Task entry in an environment graph.""" - id: str - type: str # Task class name, could be a custom task class or a built-in task class - initial_state_spec_id: str - success_state_spec_id: str - task_args: dict[str, Any] = field(default_factory=dict) +# Re-exported for callers that already import these names from this module. +__all__ = [ + "ArenaEnvGraphNodeSpec", + "ArenaEnvGraphNodeType", + "ArenaEnvGraphObjectReferenceNodeSpec", + "ArenaEnvGraphSpatialConstraintSpec", + "ArenaEnvGraphSpatialConstraintType", + "ArenaEnvGraphSpec", + "ArenaEnvGraphStateSpec", + "ArenaEnvGraphTaskConstraintSpec", + "ArenaEnvGraphTaskConstraintType", + "ArenaEnvGraphTaskSpec", +] @dataclass @@ -161,15 +75,20 @@ def from_dict(cls, data: dict[str, Any]) -> "ArenaEnvGraphSpec": tasks = parse_list(data, "tasks", _parse_task) state_specs = parse_list(data, "state_specs", _parse_state_spec) - assert_unique_ids(nodes, tasks, state_specs) - assert_references_exist(nodes, tasks, state_specs) - - return cls( + spec = cls( env_name=required_str(data, "env_name"), nodes=nodes, tasks=tasks, state_specs=state_specs, ) + spec.validate() + return spec + + def validate(self) -> None: + """Validate graph-level ids, references, and relationship shapes.""" + assert_unique_ids(self.nodes, self.tasks, self.state_specs) + assert_references_exist(self.nodes, self.tasks, self.state_specs) + assert_spatial_constraint_shapes(self.state_specs) @property def nodes_by_id(self) -> dict[str, ArenaEnvGraphNodeSpec]: @@ -183,6 +102,21 @@ def tasks_by_id(self) -> dict[str, ArenaEnvGraphTaskSpec]: def state_specs_by_id(self) -> dict[str, ArenaEnvGraphStateSpec]: return {state_spec.id: state_spec for state_spec in self.state_specs} + def to_arena_env(self) -> "IsaacLabArenaEnvironment": + """Convert this graph spec into an `IsaacLabArenaEnvironment`. + + The first ``state_spec`` is used as the scene's initial state. + """ + # Lazy import: build_arena_env_from_graph_spec pulls in Scene -> phyx_utils -> + # pxr.PhysxSchema, which requires SimulationApp. Keeping the import here lets + # data-only consumers of the spec (parsers, tests) import this module before + # SimulationApp is started. + # TODO(xinjieyao, 2026-05-26): once `build_arena_env_from_graph_spec` aggregates across all state_specs, + # this wrapper stays single-arg — no caller-side selection is needed. + from isaaclab_arena.environments.arena_env_graph_conversion_utils import build_arena_env_from_graph_spec + + return build_arena_env_from_graph_spec(self) + def _parse_node(data: Any) -> ArenaEnvGraphNodeSpec: data = as_dict(data, "Node spec") diff --git a/isaaclab_arena/environments/arena_env_graph_task_conversion_utils.py b/isaaclab_arena/environments/arena_env_graph_task_conversion_utils.py new file mode 100644 index 000000000..25d720972 --- /dev/null +++ b/isaaclab_arena/environments/arena_env_graph_task_conversion_utils.py @@ -0,0 +1,131 @@ +# Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import types +import typing +from typing import TYPE_CHECKING, Any + +from isaaclab_arena.affordances.affordance_base import AffordanceBase +from isaaclab_arena.assets.asset import Asset +from isaaclab_arena.assets.registries import TaskRegistry +from isaaclab_arena.tasks.sequential_task_base import SequentialTaskBase + +if TYPE_CHECKING: + from isaaclab_arena.environments.arena_env_graph_types import ArenaEnvGraphTaskSpec + + +# Annotation bases that mark a task __init__ kwarg as a graph-node reference. +# * Asset — direct ("background_scene: Asset") +# * AffordanceBase — task interface enforces an affordance contract on the kwarg +# ("placeable_object: Placeable"). +NODE_REF_BASES: tuple[type, ...] = (Asset, AffordanceBase) + + +def build_task_from_specs(task_specs: list[ArenaEnvGraphTaskSpec], assets_by_node_id: dict[str, Any]) -> Any | None: + """Return None for no specs, the single task for one spec, or a SequentialTaskBase for many.""" + if not task_specs: + return None + task_instances = [_build_task_from_spec(task_spec, assets_by_node_id) for task_spec in task_specs] + if len(task_instances) == 1: + return task_instances[0] + return SequentialTaskBase( + subtasks=task_instances, + desired_subtask_success_state=[True] * len(task_instances), + ) + + +def _build_task_from_spec(task_spec: ArenaEnvGraphTaskSpec, assets_by_node_id: dict[str, Any]) -> Any: + """Look up the task class by name, resolve any Asset-typed kwargs, instantiate.""" + task_class = TaskRegistry().get_task_by_name(task_spec.type) + task_init_kwargs = _resolve_node_refs_in_task_args(task_class, task_spec.task_args, assets_by_node_id) + return task_class(**task_init_kwargs) + + +def _resolve_node_refs_in_task_args( + task_class: type, raw_task_args: dict[str, Any], assets_by_node_id: dict[str, Any] +) -> dict[str, Any]: + """Swap node-id strings for live assets on Asset / list[Asset] params; pass others through. + + Example — for ``PickAndPlaceTask(pick_up_object: Asset, ..., episode_length_s: float)``:: + + raw_task_args = {"pick_up_object": "cube", ..., "episode_length_s": 5.0} + assets_by_node_id = {"cube": , ...} + # -> {"pick_up_object": , ..., "episode_length_s": 5.0} + + Misspelled / non-string node ids raise AssertionError instead of silently passing through. + """ + # Introspect __init__ once — the task class is the single source of truth for which + # params come from graph nodes. `None` from .get() below means "not a node-ref param". + is_collection_by_param_name = find_node_ref_params_in_signature(task_class) + + resolved_task_kwargs: dict[str, Any] = {} + for param_name, raw_param_value in raw_task_args.items(): + is_collection = is_collection_by_param_name.get(param_name) + if is_collection is None: + # Not annotated as a node ref (e.g. floats, strings, tuples) — forward unchanged. + resolved_task_kwargs[param_name] = raw_param_value + elif is_collection: + # list[Asset]-typed param: resolve each element to its live asset. + resolved_task_kwargs[param_name] = [ + _lookup_asset_by_node_id(raw_node_id, assets_by_node_id, task_class, param_name) + for raw_node_id in raw_param_value + ] + else: + # Asset-typed param: resolve the single node id to its live asset. + resolved_task_kwargs[param_name] = _lookup_asset_by_node_id( + raw_param_value, assets_by_node_id, task_class, param_name + ) + return resolved_task_kwargs + + +def _lookup_asset_by_node_id(node_id: Any, assets_by_node_id: dict[str, Any], task_class: type, param_name: str) -> Any: + """Return the live asset for ``node_id``; raise AssertionError naming the task/param on miss.""" + assert ( + isinstance(node_id, str) and node_id in assets_by_node_id + ), f"{task_class.__name__}.{param_name}: unknown node id {node_id!r}" + return assets_by_node_id[node_id] + + +def find_node_ref_params_in_signature(task_class: type) -> dict[str, bool]: + """Return ``{param_name: is_collection}`` for ``__init__`` params annotated as a NODE_REF_BASES subclass. + + Optional / ``X | None`` counts as a node ref. ``is_collection=True`` for ``list[X]`` + (``tuple[X, ...]`` is intentionally unsupported — no task uses it). Single source of + truth for what's a graph-node ref; also consumable by validators / YAML generators. + """ + is_collection_by_param_name: dict[str, bool] = {} + # get_type_hints resolves stringified / forward-ref annotations into real classes so issubclass works. + # e.g. `pick_up_object: "Asset"` (a str under `from __future__ import annotations`) becomes the Asset class. + for param_name, param_annotation in typing.get_type_hints(task_class.__init__).items(): + # Skip the implicit `self` slot and any `return` annotation — neither is a kwarg. + if param_name not in ("self", "return"): + # Walk Union members so `Asset | None` is recognized via its Asset branch. + # First matching branch wins; later branches in the same param are ignored. + for annotation_branch in _strip_none(param_annotation): + # Scalar node ref: annotation is itself an Asset / AffordanceBase subclass. + if isinstance(annotation_branch, type) and issubclass(annotation_branch, NODE_REF_BASES): + is_collection_by_param_name[param_name] = False + break + # Collection node ref: list[X] where X is an Asset / AffordanceBase subclass. + # The isinstance(args[0], type) guard rejects parametrized generics like list[list[Asset]]. + if typing.get_origin(annotation_branch) is list: + list_element_args = typing.get_args(annotation_branch) + if ( + list_element_args + and isinstance(list_element_args[0], type) + and issubclass(list_element_args[0], NODE_REF_BASES) + ): + is_collection_by_param_name[param_name] = True + break + return is_collection_by_param_name + + +def _strip_none(annotation: Any) -> tuple[Any, ...]: + """Non-None members of a ``X | None`` / ``Optional[X]``; ``(annotation,)`` otherwise.""" + if typing.get_origin(annotation) in (typing.Union, types.UnionType): + return tuple(member for member in typing.get_args(annotation) if member is not type(None)) + return (annotation,) diff --git a/isaaclab_arena/environments/arena_env_graph_types.py b/isaaclab_arena/environments/arena_env_graph_types.py new file mode 100644 index 000000000..1ee00038d --- /dev/null +++ b/isaaclab_arena/environments/arena_env_graph_types.py @@ -0,0 +1,137 @@ +# Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Lightweight schema for the env-graph spec. + +Pure data: enums and dataclasses with no parsing or conversion behavior. Behavior (YAML +loading, validation, conversion entry point) lives in ``arena_env_graph_spec``. + +Lives in its own module so that conversion-utilities can import these names at module +level without creating a circular import back into ``arena_env_graph_spec`` (which itself +depends on the conversion module). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from isaaclab_arena.assets.object_type import ObjectType + + +class ArenaEnvGraphNodeType(Enum): + EMBODIMENT = "embodiment" + BACKGROUND = "background" + OBJECT = "object" + OBJECT_REFERENCE = "object_reference" + LIGHTING = "lighting" + + +class ArenaEnvGraphSpatialConstraintType(Enum): + IS_ANCHOR = "is_anchor" + NEXT_TO = "next_to" + ON = "on" + AT_POSE = "at_pose" # through set_initial_pose() + AT_POSITION = "at_position" # through object relation solver: AtPosition + POSITION_LIMITS = "position_limits" + RANDOM_AROUND_SOLUTION = "random_around_solution" + ROTATE_AROUND_SOLUTION = "rotate_around_solution" + # TODO(xinjieyao, 2026-05-21): Support "in" in solver + IN = "in" + + +class ArenaEnvGraphTaskConstraintType(Enum): + REACH = "reach" + + +@dataclass +class ArenaEnvGraphNodeSpec: + """Node in an environment graph. + + Could be an object, an embodiment, a background, etc. Object references — USD prims + inside a parent background asset — are represented by the + :class:`ArenaEnvGraphObjectReferenceNodeSpec` subclass, which adds the extra fields + needed to locate and type the referenced prim. + """ + + id: str + name: str # Name registered in the asset registry + type: ArenaEnvGraphNodeType + # Asset-type specific optional kwargs (e.g. scale, spawn_cfg_addon) — distinct from + # the typed graph metadata above. The Arena environment builder forwards these when + # instantiating the asset class. + params: dict[str, Any] = field(default_factory=dict) + + +# kw_only=True forces the three new fields to be keyword-only in __init__. Required because +# the base class ends with a defaulted field (`params`) and Python forbids non-default args +# from following default ones — placing the new required fields after `*` sidesteps that rule +# and lets us declare them as required (no default) instead of Optional with runtime checks. +@dataclass(kw_only=True) +class ArenaEnvGraphObjectReferenceNodeSpec(ArenaEnvGraphNodeSpec): + """Object-reference node: a USD prim inside a parent background asset. + + All three extra fields are required for this node type — without them the + builder cannot bind to the referenced prim or know how to wrap it. + """ + + parent: str # id of the parent (typically background) node that owns the prim + prim_path: str # USD prim path of the referenced prim (may contain {ENV_REGEX_NS}) + object_type: ObjectType # how to wrap the prim (rigid, articulation, etc.) + + +@dataclass +class ArenaEnvGraphSpatialConstraintSpec: + """Spatial constraint edge in an environment graph state spec. + + It defines a relation between two nodes. + """ + + id: str + type: ArenaEnvGraphSpatialConstraintType + parent: str + child: str | None = None # Optional, e.g. is_anchor constraint does not have a child + # Type-specific optional kwargs for the underlying RelationBase subclass selected by `type` + # (e.g. {x_min, x_max, y_min, y_max} for position_limits; {side, distance} for next_to etc.). + # The Arena environment builder forwards these when constructing the Relation instance. + params: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ArenaEnvGraphTaskConstraintSpec: + """Task-dependent constraint edge in an environment graph state spec.""" + + id: str + type: ArenaEnvGraphTaskConstraintType + parent: str + child: str | None = None # Optional, could be a robot keeps gripper open or closed, or a single object + # Type-specific optional kwargs for the underlying TaskConstraintBase subclass selected by `type` + # (e.g. grasp pose offset the reach constraint.). + # The Arena environment builder forwards these when constructing the TaskConstraint instance. + params: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ArenaEnvGraphStateSpec: + """Snapshot of the environment state in the graph. + + Could be an initial, intermediate, or final state. + """ + + id: str + spatial_constraints: list[ArenaEnvGraphSpatialConstraintSpec] = field(default_factory=list) + task_constraints: list[ArenaEnvGraphTaskConstraintSpec] = field(default_factory=list) + + +@dataclass +class ArenaEnvGraphTaskSpec: + """Task entry in an environment graph.""" + + id: str + type: str # Task class name, could be a custom task class or a built-in task class + initial_state_spec_id: str + success_state_spec_id: str + task_args: dict[str, Any] = field(default_factory=dict) diff --git a/isaaclab_arena/environments/graph_spec_utils.py b/isaaclab_arena/environments/graph_spec_utils.py index 0769a561a..f9d11fb62 100644 --- a/isaaclab_arena/environments/graph_spec_utils.py +++ b/isaaclab_arena/environments/graph_spec_utils.py @@ -3,42 +3,54 @@ # # SPDX-License-Identifier: Apache-2.0 -from collections.abc import Callable +from collections.abc import Callable, Iterator from enum import Enum from numbers import Real -from typing import Any +from typing import TYPE_CHECKING, Any + +from isaaclab_arena.assets.registries import ObjectRelationLibraryRegistry + +if TYPE_CHECKING: + from isaaclab_arena.environments.arena_env_graph_spec import ArenaEnvGraphSpatialConstraintType + from isaaclab_arena.relations.relations import RelationBase def as_dict(data: Any, spec_name: str) -> dict[str, Any]: + """Require a YAML section to be a mapping before parsing it.""" assert isinstance(data, dict), f"{spec_name} must be a dict, got {type(data).__name__}" return data def parse_list(data: dict[str, Any], key: str, parser: Callable[[Any], Any]) -> list[Any]: + """Parse a list field, treating a missing field as an empty list.""" values = data.get(key, []) assert isinstance(values, list), f"Field '{key}' must be a list" return [parser(value) for value in values] def required_str(data: dict[str, Any], key: str) -> str: + """Read a required non-empty string field.""" value = data.get(key) assert isinstance(value, str) and value, f"Missing required string field '{key}'" return value def optional_str(data: dict[str, Any], key: str) -> str | None: + """Read an optional string field without inventing a default value.""" value = data.get(key) assert value is None or isinstance(value, str), f"Optional field '{key}' must be a string when set" return value def optional_dict(data: dict[str, Any], key: str) -> dict[str, Any]: + """Read an optional mapping field and return a mutable copy.""" value = data.get(key, {}) assert value is None or isinstance(value, dict), f"Optional field '{key}' must be a dict when set" return dict(value or {}) def required_number_sequence(data: dict[str, Any], key: str, length: int) -> tuple[float, ...]: + """Read a fixed-length numeric list such as a position or quaternion.""" value = data.get(key) assert isinstance(value, (list, tuple)), f"Missing required numeric sequence field '{key}'" assert len(value) == length, f"Field '{key}' must contain {length} numbers" @@ -49,6 +61,7 @@ def required_number_sequence(data: dict[str, Any], key: str, length: int) -> tup def required_enum(data: dict[str, Any], key: str, enum_type: type[Enum]) -> Enum: + """Read a required enum field from its YAML string value.""" value = data.get(key) assert value is not None, f"Missing required field '{key}'" parsed = parse_enum(value, key, enum_type) @@ -57,10 +70,12 @@ def required_enum(data: dict[str, Any], key: str, enum_type: type[Enum]) -> Enum def optional_enum(data: dict[str, Any], key: str, enum_type: type[Enum]) -> Enum | None: + """Read an optional enum field from its YAML string value.""" return parse_enum(data.get(key), key, enum_type) def parse_enum(value: Any, key: str, enum_type: type[Enum]) -> Enum | None: + """Convert a YAML string to an enum value and show valid options on failure.""" if value is None or isinstance(value, enum_type): return value assert isinstance(value, str), f"Field '{key}' must be a string when set" @@ -72,6 +87,7 @@ def parse_enum(value: Any, key: str, enum_type: type[Enum]) -> Enum | None: def assert_unique_ids(nodes: list[Any], tasks: list[Any], state_specs: list[Any]) -> None: + """Ensure every graph id is unique, including constraint ids inside states.""" id_locations: dict[str, list[str]] = {} for node in nodes: _add_id_location(id_locations, node.id, f"node '{node.id}'") @@ -89,15 +105,24 @@ def assert_unique_ids(nodes: list[Any], tasks: list[Any], state_specs: list[Any] def assert_references_exist(nodes: list[Any], tasks: list[Any], state_specs: list[Any]) -> None: + """Ensure every graph reference points to a node or state spec that exists.""" node_ids = {node.id for node in nodes} state_spec_ids = {state_spec.id for state_spec in state_specs} + # Track ids seen so far so a node's parent must be defined *earlier* in the list. The + # conversion process (_instantiate_assets_from_nodes) materializes nodes in order and looks + # up the parent, so a parent listed after its reference would otherwise only fail + # there with a raw KeyError. + seen_node_ids: set[str] = set() for node in nodes: - # `parent` only exists on ArenaEnvGraphObjectReferenceNodeSpec; getattr keeps this - # helper generic so it doesn't need to import the subclass. parent = getattr(node, "parent", None) if parent is not None: assert parent in node_ids, f"Node '{node.id}' references unknown parent '{parent}'" + assert parent in seen_node_ids, ( + f"Node '{node.id}' references parent '{parent}' defined later in the node list; " + "a parent must appear before any node that references it" + ) + seen_node_ids.add(node.id) for task in tasks: for label, state_spec_id in ( @@ -129,5 +154,128 @@ def assert_references_exist(nodes: list[Any], tasks: list[Any], state_specs: lis ), f"Constraint '{constraint.id}' references unknown child node '{constraint.child}'" +def assert_spatial_constraint_shapes(state_specs: list[Any]) -> None: + """Check each spatial constraint has the parent/child shape its relation expects.""" + for state_spec in state_specs: + for constraint in state_spec.spatial_constraints: + constraint_type = _enum_value(constraint.type) + if constraint_type == "at_pose": + # Special: no relation class; pose is supplied directly via params. + assert ( + "position_xyz" in constraint.params + ), f"Spatial constraint '{constraint.id}' of type 'at_pose' requires params.position_xyz" + is_unary = True + elif constraint_type == "in": + # Special: no relation class; semantically a binary parent/child constraint. + # TODO(xinjieyao, 2026-05-27): add an `In` relation class so this can resolve through the registry. + is_unary = False + else: + relation_cls = relation_class_for_spatial_constraint_type(constraint.type) + assert ( + relation_cls is not None + ), f"Spatial constraint type '{constraint_type}' is not mapped to a relation class" + is_unary = relation_cls.is_unary() + + if is_unary: + assert ( + constraint.child is None + ), f"Spatial constraint '{constraint.id}' of type '{constraint_type}' must not define a child node" + else: + assert ( + constraint.child is not None + ), f"Spatial constraint '{constraint.id}' of type '{constraint_type}' requires a child node" + + def _add_id_location(id_locations: dict[str, list[str]], spec_id: str, location: str) -> None: id_locations.setdefault(spec_id, []).append(location) + + +def _enum_value(value: Any) -> Any: + return getattr(value, "value", value) + + +def relation_class_for_spatial_constraint_type( + constraint_type: "ArenaEnvGraphSpatialConstraintType", +) -> "type[RelationBase] | None": + """Resolve a spatial-constraint enum member to its RelationBase subclass. + + Returns None for enum members that have no registered class yet (e.g. AT_POSE, + handled via set_initial_pose; IN, not yet supported by the solver). + # TODO(xinjieyao, 2026-05-28): add support for AT_POSE and IN. + """ + registry = ObjectRelationLibraryRegistry() + if registry.is_registered(constraint_type.value): + return registry.get_object_relation_by_name(constraint_type.value) + return None + + +def iter_nested_leaf_values(value: Any, key_path: str = "") -> Iterator[tuple[str, Any]]: + """Walk nested task-arg values while keeping a readable path for errors. + + Example: + >>> list(iter_nested_leaf_values({"object": "mug", "destination": ["table", "shelf"]})) + [('object', 'mug'), ('destination[0]', 'table'), ('destination[1]', 'shelf')] + """ + if isinstance(value, dict): + for key, item in value.items(): + nested_key_path = f"{key_path}.{key}" if key_path else str(key) + yield from iter_nested_leaf_values(item, nested_key_path) + elif isinstance(value, (list, tuple)): + for index, item in enumerate(value): + nested_key_path = f"{key_path}[{index}]" if key_path else f"[{index}]" + yield from iter_nested_leaf_values(item, nested_key_path) + else: + yield key_path, value + + +def map_nested_leaf_values(value: Any, transform: Callable[[Any], Any]) -> Any: + """Apply a transform to nested task-arg leaves while preserving container shape. + + Example: + >>> map_nested_leaf_values({"a": [1, 2], "b": (3, 4)}, lambda x: x * 10) + {'a': [10, 20], 'b': (30, 40)} + """ + if isinstance(value, dict): + return {key: map_nested_leaf_values(item, transform) for key, item in value.items()} + if isinstance(value, list): + return [map_nested_leaf_values(item, transform) for item in value] + if isinstance(value, tuple): + return tuple(map_nested_leaf_values(item, transform) for item in value) + return transform(value) + + +def normalize_identifier(identifier: str) -> str: + """Normalize names so YAML keys can be matched across casing and separators. + + Example: + >>> normalize_identifier("Pickup_Object") + 'pickupobject' + """ + return "".join(char for char in identifier.lower() if char.isalnum()) + + +def camel_to_snake(identifier: str) -> str: + """Turn a class-like name into the module-style name we try during discovery. + + Example: + >>> camel_to_snake("AtPosition") + 'at_position' + """ + chars: list[str] = [] + for index, char in enumerate(identifier): + if char.isupper() and index > 0 and (identifier[index - 1].islower() or identifier[index - 1].isdigit()): + chars.append("_") + chars.append(char.lower()) + return "".join(chars) + + +def strip_suffix(value: str, suffix: str) -> str: + """Remove a suffix only when the value actually has it. + + Example: + >>> strip_suffix("AtPositionSpec", "Spec") + 'AtPosition' + >>> strip_suffix("AtPosition", "Spec") + 'AtPosition' + """ + return value[: -len(suffix)] if value.endswith(suffix) else value diff --git a/isaaclab_arena/relations/relations.py b/isaaclab_arena/relations/relations.py index c204c34ad..752e9f631 100644 --- a/isaaclab_arena/relations/relations.py +++ b/isaaclab_arena/relations/relations.py @@ -11,10 +11,11 @@ from isaaclab.utils.math import euler_xyz_from_quat -from isaaclab_arena.utils.pose import PoseRange +from isaaclab_arena.assets.register import register_object_relation if TYPE_CHECKING: from isaaclab_arena.assets.object_base import ObjectBase + from isaaclab_arena.utils.pose import PoseRange class Side(Enum): @@ -44,12 +45,20 @@ class UnaryRelation(RelationBase): without referencing another object (e.g., AtPosition, PositionLimits). """ - pass + @staticmethod + def is_unary() -> bool: + """Return whether the relation constrains a single object.""" + return True class Relation(RelationBase): """Base class for binary spatial relationships between objects.""" + @staticmethod + def is_unary() -> bool: + """Return whether the relation constrains a single object.""" + return False + def __init__(self, parent: ObjectBase, relation_loss_weight: float = 1.0): """ Args: @@ -60,6 +69,7 @@ def __init__(self, parent: ObjectBase, relation_loss_weight: float = 1.0): self.relation_loss_weight = relation_loss_weight +@register_object_relation class NextTo(Relation): """Represents a 'next to' relationship between objects. @@ -69,6 +79,8 @@ class NextTo(Relation): Note: Loss computation is handled by NextToLossStrategy in relation_loss_strategies.py. """ + name = "next_to" + def __init__( self, parent: ObjectBase, @@ -99,6 +111,7 @@ def __init__( self.cross_position_ratio = cross_position_ratio +@register_object_relation class On(Relation): """Represents an 'on top of' relationship between objects. @@ -109,6 +122,8 @@ class On(Relation): Note: Loss computation is handled by OnLossStrategy in relation_loss_strategies.py. """ + name = "on" + def __init__( self, parent: ObjectBase, @@ -126,6 +141,7 @@ def __init__( self.clearance_m = clearance_m +@register_object_relation class IsAnchor(RelationBase): """Marker indicating this object is an anchor for relation solving. @@ -144,9 +160,15 @@ class IsAnchor(RelationBase): bin.add_relation(NextTo(chair)) """ - pass + name = "is_anchor" + + @staticmethod + def is_unary() -> bool: + """Return whether the relation constrains a single object.""" + return True +@register_object_relation class RandomAroundSolution(RelationBase): """Marker indicating the solver solution should be used as center of a PoseRange. @@ -165,6 +187,13 @@ class RandomAroundSolution(RelationBase): # -> ObjectPlacer sets a PoseRange spanning ±0.1m in X and Y around solved position """ + name = "random_around_solution" + + @staticmethod + def is_unary() -> bool: + """Return whether the relation constrains a single object.""" + return True + def __init__( self, x_half_m: float = 0.0, @@ -236,6 +265,7 @@ def to_pose_range_centered_at( ) +@register_object_relation class RotateAroundSolution(RelationBase): """Marker specifying an explicit rotation to apply on top of the solver solution. @@ -252,6 +282,13 @@ class RotateAroundSolution(RelationBase): # -> ObjectPlacer sets a Pose with solved position and 45° yaw rotation """ + name = "rotate_around_solution" + + @staticmethod + def is_unary() -> bool: + """Return whether the relation constrains a single object.""" + return True + def __init__( self, roll_rad: float = 0.0, @@ -274,6 +311,8 @@ def get_rotation_xyzw(self) -> tuple[float, float, float, float]: Returns: Quaternion rotation converted from roll/pitch/yaw. """ + import torch + from isaaclab.utils.math import quat_from_euler_xyz roll = torch.tensor(self.roll_rad) @@ -283,6 +322,7 @@ def get_rotation_xyzw(self) -> tuple[float, float, float, float]: return tuple(quat.tolist()) +@register_object_relation class AtPosition(UnaryRelation): """Constrains object to specific world coordinates. @@ -299,6 +339,8 @@ class AtPosition(UnaryRelation): mug.add_relation(AtPosition(x=0.5, y=1.0)) """ + name = "at_position" + def __init__( self, x: float | None = None, @@ -322,6 +364,7 @@ def __init__( self.relation_loss_weight = relation_loss_weight +@register_object_relation class PositionLimits(UnaryRelation): """Constrains object position to a world-coordinate axis-aligned box. @@ -332,6 +375,8 @@ class PositionLimits(UnaryRelation): mug.add_relation(PositionLimits(z_min=0.8)) # only constrain Z """ + name = "position_limits" + def __init__( self, x_min: float | None = None, diff --git a/isaaclab_arena/tasks/__init__.py b/isaaclab_arena/tasks/__init__.py index fee3a6a9f..913a466be 100644 --- a/isaaclab_arena/tasks/__init__.py +++ b/isaaclab_arena/tasks/__init__.py @@ -2,3 +2,21 @@ # All rights reserved. # # SPDX-License-Identifier: Apache-2.0 + +# Import concrete task modules so their @register_task decorators fire when the +# tasks package is loaded (e.g. via ensure_assets_registered()). +from isaaclab_arena.tasks import ( # noqa: F401 + assembly_task, + close_door_task, + goal_pose_task, + lift_object_task, + no_task, + open_door_task, + pick_and_place_task, + place_upright_task, + press_button_task, + rotate_revolute_joint_task, + sequential_task_base, + sorting_task, + turn_knob_task, +) diff --git a/isaaclab_arena/tasks/assembly_task.py b/isaaclab_arena/tasks/assembly_task.py index ab3231f6a..56a1dc27e 100644 --- a/isaaclab_arena/tasks/assembly_task.py +++ b/isaaclab_arena/tasks/assembly_task.py @@ -16,6 +16,7 @@ import isaaclab_arena_environments.mdp as mdp from isaaclab_arena.assets.asset import Asset +from isaaclab_arena.assets.register import register_task from isaaclab_arena.metrics.metric_base import MetricBase from isaaclab_arena.metrics.object_moved import ObjectMovedRateMetric from isaaclab_arena.metrics.success_rate import SuccessRateMetric @@ -25,6 +26,7 @@ from isaaclab_arena.utils.cameras import get_viewer_cfg_look_at_object +@register_task class AssemblyTask(TaskBase): """ Assembly task where an object needs to be assembled with a base object, like peg insert, gear mesh, etc. diff --git a/isaaclab_arena/tasks/close_door_task.py b/isaaclab_arena/tasks/close_door_task.py index 08fb1a6b1..114a5773c 100644 --- a/isaaclab_arena/tasks/close_door_task.py +++ b/isaaclab_arena/tasks/close_door_task.py @@ -10,11 +10,13 @@ from isaaclab.utils import configclass from isaaclab_arena.affordances.openable import Openable +from isaaclab_arena.assets.register import register_task from isaaclab_arena.embodiments.common.arm_mode import ArmMode from isaaclab_arena.tasks.common.open_close_door_mimic import RotateDoorMimicEnvCfg from isaaclab_arena.tasks.rotate_revolute_joint_task import RotateRevoluteJointTask +@register_task class CloseDoorTask(RotateRevoluteJointTask): def __init__( self, diff --git a/isaaclab_arena/tasks/goal_pose_task.py b/isaaclab_arena/tasks/goal_pose_task.py index 93585ef2a..3162c9f09 100644 --- a/isaaclab_arena/tasks/goal_pose_task.py +++ b/isaaclab_arena/tasks/goal_pose_task.py @@ -13,6 +13,7 @@ from isaaclab.utils import configclass from isaaclab_arena.assets.asset import Asset +from isaaclab_arena.assets.register import register_task from isaaclab_arena.metrics.metric_base import MetricBase from isaaclab_arena.metrics.object_moved import ObjectMovedRateMetric from isaaclab_arena.metrics.success_rate import SuccessRateMetric @@ -21,6 +22,7 @@ from isaaclab_arena.utils.cameras import get_viewer_cfg_look_at_object +@register_task class GoalPoseTask(TaskBase): def __init__( self, diff --git a/isaaclab_arena/tasks/lift_object_task.py b/isaaclab_arena/tasks/lift_object_task.py index 605bfccb9..4c4461fd7 100644 --- a/isaaclab_arena/tasks/lift_object_task.py +++ b/isaaclab_arena/tasks/lift_object_task.py @@ -17,6 +17,7 @@ from isaaclab_tasks.manager_based.manipulation.dexsuite import dexsuite_env_cfg as dexsuite from isaaclab_arena.assets.asset import Asset +from isaaclab_arena.assets.register import register_task from isaaclab_arena.embodiments.embodiment_base import EmbodimentBase from isaaclab_arena.metrics.metric_base import MetricBase from isaaclab_arena.metrics.success_rate import SuccessRateMetric @@ -28,6 +29,7 @@ from isaaclab_arena.utils.pose import PoseRange +@register_task class LiftObjectTask(TaskBase): def __init__( self, @@ -130,6 +132,7 @@ class LiftObjectTerminationsCfg: success: TerminationTermCfg = MISSING +@register_task class LiftObjectTaskRL(LiftObjectTask): def __init__( self, @@ -377,6 +380,7 @@ class DexsuiteLiftTerminationsCfg(dexsuite.TerminationsCfg): ) +@register_task class DexsuiteLiftTask(LiftObjectTask): """Dexsuite lift task for Arena evaluation. diff --git a/isaaclab_arena/tasks/no_task.py b/isaaclab_arena/tasks/no_task.py index 62a54ee66..b25d6d61a 100644 --- a/isaaclab_arena/tasks/no_task.py +++ b/isaaclab_arena/tasks/no_task.py @@ -5,10 +5,12 @@ from isaaclab.envs.common import ViewerCfg +from isaaclab_arena.assets.register import register_task from isaaclab_arena.embodiments.common.arm_mode import ArmMode from isaaclab_arena.tasks.task_base import TaskBase +@register_task class NoTask(TaskBase): """Null object for environments without a task.""" diff --git a/isaaclab_arena/tasks/open_door_task.py b/isaaclab_arena/tasks/open_door_task.py index e0dd5eae9..4e5287226 100644 --- a/isaaclab_arena/tasks/open_door_task.py +++ b/isaaclab_arena/tasks/open_door_task.py @@ -10,11 +10,13 @@ from isaaclab.utils import configclass from isaaclab_arena.affordances.openable import Openable +from isaaclab_arena.assets.register import register_task from isaaclab_arena.embodiments.common.arm_mode import ArmMode from isaaclab_arena.tasks.common.open_close_door_mimic import RotateDoorMimicEnvCfg from isaaclab_arena.tasks.rotate_revolute_joint_task import RotateRevoluteJointTask +@register_task class OpenDoorTask(RotateRevoluteJointTask): def __init__( self, diff --git a/isaaclab_arena/tasks/pick_and_place_task.py b/isaaclab_arena/tasks/pick_and_place_task.py index 364decffc..bc029ba40 100644 --- a/isaaclab_arena/tasks/pick_and_place_task.py +++ b/isaaclab_arena/tasks/pick_and_place_task.py @@ -15,6 +15,7 @@ from isaaclab.utils import configclass from isaaclab_arena.assets.asset import Asset +from isaaclab_arena.assets.register import register_task from isaaclab_arena.embodiments.common.arm_mode import ArmMode from isaaclab_arena.metrics.metric_base import MetricBase from isaaclab_arena.metrics.object_moved import ObjectMovedRateMetric @@ -25,6 +26,7 @@ from isaaclab_arena.utils.cameras import get_viewer_cfg_look_at_object +@register_task class PickAndPlaceTask(TaskBase): """Pick-and-place task. Success fires when the pick-up object contacts the destination with low velocity. Failure (object_dropped) fires when the object falls below the diff --git a/isaaclab_arena/tasks/place_upright_task.py b/isaaclab_arena/tasks/place_upright_task.py index d041bf529..1a82b00b4 100644 --- a/isaaclab_arena/tasks/place_upright_task.py +++ b/isaaclab_arena/tasks/place_upright_task.py @@ -14,6 +14,7 @@ from isaaclab.utils import configclass from isaaclab_arena.affordances.placeable import Placeable +from isaaclab_arena.assets.register import register_task from isaaclab_arena.embodiments.common.arm_mode import ArmMode from isaaclab_arena.metrics.metric_base import MetricBase from isaaclab_arena.metrics.object_moved import ObjectMovedRateMetric @@ -23,6 +24,7 @@ from isaaclab_arena.utils.cameras import get_viewer_cfg_look_at_object +@register_task class PlaceUprightTask(TaskBase): def __init__( diff --git a/isaaclab_arena/tasks/press_button_task.py b/isaaclab_arena/tasks/press_button_task.py index 36dffe31a..d3f7141d8 100644 --- a/isaaclab_arena/tasks/press_button_task.py +++ b/isaaclab_arena/tasks/press_button_task.py @@ -12,6 +12,7 @@ from isaaclab.utils import configclass from isaaclab_arena.affordances.pressable import Pressable +from isaaclab_arena.assets.register import register_task from isaaclab_arena.embodiments.common.arm_mode import ArmMode from isaaclab_arena.metrics.metric_base import MetricBase from isaaclab_arena.metrics.success_rate import SuccessRateMetric @@ -19,6 +20,7 @@ from isaaclab_arena.utils.cameras import get_viewer_cfg_look_at_object +@register_task class PressButtonTask(TaskBase): def __init__( self, diff --git a/isaaclab_arena/tasks/rotate_revolute_joint_task.py b/isaaclab_arena/tasks/rotate_revolute_joint_task.py index 8cb44af84..4d94b6041 100644 --- a/isaaclab_arena/tasks/rotate_revolute_joint_task.py +++ b/isaaclab_arena/tasks/rotate_revolute_joint_task.py @@ -11,6 +11,7 @@ from isaaclab.utils import configclass from isaaclab_arena.affordances.openable import Openable +from isaaclab_arena.assets.register import register_task from isaaclab_arena.embodiments.common.arm_mode import ArmMode from isaaclab_arena.metrics.metric_base import MetricBase from isaaclab_arena.metrics.revolute_joint_moved_rate import RevoluteJointMovedRateMetric @@ -19,6 +20,7 @@ from isaaclab_arena.utils.cameras import get_viewer_cfg_look_at_object +@register_task class RotateRevoluteJointTask(TaskBase): def __init__( self, diff --git a/isaaclab_arena/tasks/sorting_task.py b/isaaclab_arena/tasks/sorting_task.py index 1470a0428..528acd579 100644 --- a/isaaclab_arena/tasks/sorting_task.py +++ b/isaaclab_arena/tasks/sorting_task.py @@ -13,6 +13,7 @@ from isaaclab.utils import configclass from isaaclab_arena.assets.asset import Asset +from isaaclab_arena.assets.register import register_task from isaaclab_arena.metrics.metric_base import MetricBase from isaaclab_arena.metrics.success_rate import SuccessRateMetric from isaaclab_arena.tasks.task_base import TaskBase @@ -21,6 +22,7 @@ from isaaclab_arena.utils.configclass import make_configclass +@register_task class SortMultiObjectTask(TaskBase): def __init__( diff --git a/isaaclab_arena/tasks/turn_knob_task.py b/isaaclab_arena/tasks/turn_knob_task.py index 5ec557184..6eda3dd4c 100644 --- a/isaaclab_arena/tasks/turn_knob_task.py +++ b/isaaclab_arena/tasks/turn_knob_task.py @@ -12,6 +12,7 @@ from isaaclab.utils import configclass from isaaclab_arena.affordances.turnable import Turnable +from isaaclab_arena.assets.register import register_task from isaaclab_arena.embodiments.common.arm_mode import ArmMode from isaaclab_arena.metrics.metric_base import MetricBase from isaaclab_arena.metrics.success_rate import SuccessRateMetric @@ -19,6 +20,7 @@ from isaaclab_arena.utils.cameras import get_viewer_cfg_look_at_object +@register_task class TurnKnobTask(TaskBase): def __init__( self, diff --git a/isaaclab_arena/tests/test_arena_env_graph_spec.py b/isaaclab_arena/tests/test_arena_env_graph_spec.py index e1630b0d2..19623dc21 100644 --- a/isaaclab_arena/tests/test_arena_env_graph_spec.py +++ b/isaaclab_arena/tests/test_arena_env_graph_spec.py @@ -5,6 +5,8 @@ from pathlib import Path +import pytest + from isaaclab_arena.assets.object_type import ObjectType from isaaclab_arena.environments.arena_env_graph_spec import ( ArenaEnvGraphNodeType, @@ -13,6 +15,8 @@ ArenaEnvGraphSpec, ArenaEnvGraphStateSpec, ) +from isaaclab_arena.environments.graph_spec_utils import relation_class_for_spatial_constraint_type +from isaaclab_arena.relations.relations import IsAnchor, PositionLimits TEST_DATA_DIR = Path(__file__).parent / "test_data" @@ -22,8 +26,8 @@ def test_arena_env_graph_spec_loads_pick_and_place_yaml(): assert spec.env_name == "pick_and_place_maple_table_default" assert len(spec.nodes) == 6 - assert len(spec.tasks) == 1 - assert len(spec.state_specs) == 2 + assert len(spec.tasks) == 2 + assert len(spec.state_specs) == 3 table = spec.nodes_by_id["maple_table_robolab_table"] assert isinstance(table, ArenaEnvGraphObjectReferenceNodeSpec) @@ -38,8 +42,13 @@ def test_arena_env_graph_spec_loads_pick_and_place_yaml(): task = spec.tasks_by_id["pick_and_place_0"] assert task.initial_state_spec_id == "state_spec_0" assert task.success_state_spec_id == "state_spec_1" - assert task.task_args["object"] == "rubiks_cube_hot3d_robolab" - assert task.task_args["destination"] == "bowl_ycb_robolab" + assert task.task_args["pick_up_object"] == "rubiks_cube_hot3d_robolab" + assert task.task_args["destination_location"] == "bowl_ycb_robolab" + + second_task = spec.tasks_by_id["pick_and_place_1"] + assert second_task.initial_state_spec_id == "state_spec_1" + assert second_task.success_state_spec_id == "state_spec_2" + assert second_task.task_args["pick_up_object"] == "mug_ycb_robolab" initial_state = spec.state_specs_by_id["state_spec_0"] assert isinstance(initial_state, ArenaEnvGraphStateSpec) @@ -70,6 +79,15 @@ def test_arena_env_graph_spec_loads_pick_and_place_yaml(): assert final_mug_pose.params["position_xyz"] == (0.65, 0.25, 0.85) assert final_mug_pose.params["rotation_xyzw"] == (0.0, 0.0, 0.0, 1.0) + table_anchor = initial_state.spatial_constraints[0] + assert table_anchor.type == ArenaEnvGraphSpatialConstraintType.IS_ANCHOR + assert relation_class_for_spatial_constraint_type(table_anchor.type) is IsAnchor + assert relation_class_for_spatial_constraint_type(cube_limits.type) is PositionLimits + assert ( + relation_class_for_spatial_constraint_type(initial_mug_pose.type) is None + ) # at_pose: handled via set_initial_pose + assert relation_class_for_spatial_constraint_type(in_constraint.type) is None # in: not yet supported by solver + def test_arena_env_graph_spec_parses_optional_task_constraints_and_at_pose(): data = _minimal_env_graph_data() @@ -87,6 +105,52 @@ def test_arena_env_graph_spec_parses_optional_task_constraints_and_at_pose(): assert fixed_pose.params["rotation_xyzw"] == (0.0, 0.0, 0.0, 1.0) +def test_arena_env_graph_spec_validate_rejects_mutated_missing_reference(): + spec = ArenaEnvGraphSpec.from_dict(_minimal_env_graph_data()) + spec.state_specs[0].spatial_constraints[0].parent = "missing_table" + + with pytest.raises(AssertionError, match="unknown parent node 'missing_table'"): + spec.validate() + + +def test_arena_env_graph_spec_validate_rejects_mutated_invalid_relationship_shape(): + spec = ArenaEnvGraphSpec.from_dict(_minimal_env_graph_data()) + constraint = spec.state_specs[0].spatial_constraints[0] + constraint.type = ArenaEnvGraphSpatialConstraintType.ON + + with pytest.raises(AssertionError, match="requires a child node"): + spec.validate() + + +def _test_arena_env_graph_conversion_builds_sequential_pick_and_place_task(simulation_app): + from isaaclab_arena.tasks.pick_and_place_task import PickAndPlaceTask + from isaaclab_arena.tasks.sequential_task_base import SequentialTaskBase + + spec = ArenaEnvGraphSpec.from_yaml(TEST_DATA_DIR / "pick_and_place_maple_table_env_graph.yaml") + arena_env = spec.to_arena_env() + + assert arena_env.name == "pick_and_place_maple_table_default" + assert isinstance(arena_env.task, SequentialTaskBase) + assert arena_env.task.desired_subtask_success_state == [True, True] + assert len(arena_env.task.subtasks) == 2 + assert all(isinstance(subtask, PickAndPlaceTask) for subtask in arena_env.task.subtasks) + assert arena_env.task.subtasks[0].pick_up_object.name == "rubiks_cube_hot3d_robolab" + assert arena_env.task.subtasks[1].pick_up_object.name == "mug_ycb_robolab" + assert all(subtask.destination_location.name == "bowl_ycb_robolab" for subtask in arena_env.task.subtasks) + assert all(subtask.background_scene.name == "maple_table_robolab" for subtask in arena_env.task.subtasks) + + return True + + +def test_arena_env_graph_conversion_builds_sequential_pick_and_place_task(): + pytest.importorskip("isaaclab.app") + + from isaaclab_arena.tests.utils.subprocess import run_simulation_app_function + + result = run_simulation_app_function(_test_arena_env_graph_conversion_builds_sequential_pick_and_place_task) + assert result + + def test_arena_env_graph_spec_rejects_invalid_data(): cases = [ ( @@ -134,6 +198,16 @@ def test_arena_env_graph_spec_rejects_invalid_data(): lambda data: data["state_specs"][0]["spatial_constraints"][0].pop("parent"), "Missing required string field 'parent'", ), + ( + "relationship missing child", + lambda data: data["state_specs"][0]["spatial_constraints"][0].__setitem__("type", "on"), + "requires a child node", + ), + ( + "unary relationship with child", + lambda data: data["state_specs"][0]["spatial_constraints"][0].__setitem__("child", "cube"), + "must not define a child node", + ), ( "old state edges wrapper", _move_state_constraints_under_edges, @@ -141,27 +215,27 @@ def test_arena_env_graph_spec_rejects_invalid_data(): ), ( "missing node parent reference", - lambda data: data["nodes"][1].__setitem__("parent", "missing_background"), + lambda data: data["nodes"][2].__setitem__("parent", "missing_background"), "unknown parent 'missing_background'", ), ( "unknown object type", - lambda data: data["nodes"][1].__setitem__("object_type", "unknown"), + lambda data: data["nodes"][2].__setitem__("object_type", "unknown"), "Unknown object_type 'unknown'", ), ( "object_reference missing parent", - lambda data: data["nodes"][1].pop("parent"), + lambda data: data["nodes"][2].pop("parent"), "Missing required string field 'parent'", ), ( "object_reference missing prim_path", - lambda data: data["nodes"][1].pop("prim_path"), + lambda data: data["nodes"][2].pop("prim_path"), "Missing required string field 'prim_path'", ), ( "object_reference missing object_type", - lambda data: data["nodes"][1].pop("object_type"), + lambda data: data["nodes"][2].pop("object_type"), "Missing required field 'object_type'", ), ( @@ -205,7 +279,9 @@ def _minimal_env_graph_data(): "env_name": "minimal_env_graph", "nodes": [ {"id": "robot", "name": "robot", "type": "embodiment"}, - # Kept at index 1 so the bad-data mutation lambdas below can address it. + {"id": "background", "name": "background", "type": "background"}, + # Kept at index 2 (after its parent at index 1) so the bad-data mutation lambdas + # below can address it, and so the order satisfies the upstream ordering contract. { "id": "table", "name": "table", @@ -214,12 +290,11 @@ def _minimal_env_graph_data(): "prim_path": "{ENV_REGEX_NS}/background/table", "object_type": "rigid", }, - {"id": "background", "name": "background", "type": "background"}, {"id": "cube", "name": "cube", "type": "object"}, ], "tasks": [{ "id": "task_0", - "type": "pick_and_place", + "type": "PickAndPlaceTask", "initial_state_spec_id": "state_0", "success_state_spec_id": "state_0", }], diff --git a/isaaclab_arena/tests/test_data/pick_and_place_maple_table_env_graph.yaml b/isaaclab_arena/tests/test_data/pick_and_place_maple_table_env_graph.yaml index 3e4eab4b0..ac71e6435 100644 --- a/isaaclab_arena/tests/test_data/pick_and_place_maple_table_env_graph.yaml +++ b/isaaclab_arena/tests/test_data/pick_and_place_maple_table_env_graph.yaml @@ -35,13 +35,23 @@ nodes: tasks: - id: pick_and_place_0 - type: pick_and_place + type: PickAndPlaceTask initial_state_spec_id: state_spec_0 success_state_spec_id: state_spec_1 task_args: - object: rubiks_cube_hot3d_robolab - destination: bowl_ycb_robolab - background: maple_table_robolab + pick_up_object: rubiks_cube_hot3d_robolab + destination_location: bowl_ycb_robolab + background_scene: maple_table_robolab + episode_length_s: 20.0 + + - id: pick_and_place_1 + type: PickAndPlaceTask + initial_state_spec_id: state_spec_1 + success_state_spec_id: state_spec_2 + task_args: + pick_up_object: mug_ycb_robolab + destination_location: bowl_ycb_robolab + background_scene: maple_table_robolab episode_length_s: 20.0 state_specs: @@ -125,7 +135,34 @@ state_specs: rotation_xyzw: [0.0, 0.0, 0.0, 1.0] task_constraints: - - id: state_spec_1_droid_reach_bowl_ycb_robolab + - id: state_spec_1_droid_reach_mug_ycb_robolab + type: "reach" + parent: droid_abs_joint_pos + child: mug_ycb_robolab + + - id: state_spec_2 + spatial_constraints: + - id: state_spec_2_maple_table_robolab_table_is_anchor + type: is_anchor + parent: maple_table_robolab_table + + - id: state_spec_2_bowl_ycb_robolab_on_maple_table_robolab_table + type: "on" + parent: maple_table_robolab_table + child: bowl_ycb_robolab + + - id: state_spec_2_rubiks_cube_hot3d_robolab_in_bowl_ycb_robolab + type: "in" + parent: bowl_ycb_robolab + child: rubiks_cube_hot3d_robolab + + - id: state_spec_2_mug_ycb_robolab_in_bowl_ycb_robolab + type: "in" + parent: bowl_ycb_robolab + child: mug_ycb_robolab + + task_constraints: + - id: state_spec_2_droid_reach_bowl_ycb_robolab type: "reach" parent: droid_abs_joint_pos child: bowl_ycb_robolab diff --git a/isaaclab_arena/tests/test_task_registry.py b/isaaclab_arena/tests/test_task_registry.py new file mode 100644 index 000000000..393e35c09 --- /dev/null +++ b/isaaclab_arena/tests/test_task_registry.py @@ -0,0 +1,37 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from isaaclab_arena.tests.utils.subprocess import run_simulation_app_function + + +def _test_task_registry_resolves_concrete_tasks(simulation_app): + from isaaclab_arena.assets.registries import TaskRegistry + + expected = [ + "NoTask", + "PickAndPlaceTask", + "PlaceUprightTask", + "LiftObjectTask", + "LiftObjectTaskRL", + "DexsuiteLiftTask", + "GoalPoseTask", + "PressButtonTask", + "RotateRevoluteJointTask", + "CloseDoorTask", + "OpenDoorTask", + "TurnKnobTask", + "AssemblyTask", + "SortMultiObjectTask", + ] + registry = TaskRegistry() + for name in expected: + cls = registry.get_task_by_name(name) + assert cls.__name__ == name, f"{name} -> {cls.__name__}" + return True + + +def test_task_registry_resolves_concrete_tasks(): + result = run_simulation_app_function(_test_task_registry_resolves_concrete_tasks) + assert result