Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions isaaclab_arena/assets/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
DeviceRegistry,
EnvironmentRegistry,
HDRImageRegistry,
ObjectRelationLibraryRegistry,
PolicyRegistry,
RetargeterRegistry,
TaskRegistry,
)


Expand Down Expand Up @@ -68,3 +70,25 @@ def register_environment(cls):
else:
registry.register(cls, cls.name)
return cls


# Decorator to register a RelationBase subclass with the ObjectRelationLibraryRegistry.
def register_object_relation(cls):
registry = ObjectRelationLibraryRegistry()
if registry.is_registered(cls.name):
print(f"WARNING: Object relation {cls.name} is already registered. Doing nothing.")
else:
registry.register(cls, cls.name)
return cls


# Decorator to register a TaskBase subclass with the TaskRegistry.
# Keyed by `cls.__name__` so the YAML `type: PascalCase` lookups match without
# requiring a separate `name` attribute on every task class.
def register_task(cls):
registry = TaskRegistry()
if registry.is_registered(cls.__name__):
print(f"WARNING: Task {cls.__name__} is already registered. Doing nothing.")
else:
registry.register(cls, cls.__name__)
return cls
90 changes: 85 additions & 5 deletions isaaclab_arena/assets/registries.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from isaaclab_arena.assets.hdr_image import HDRImage
from isaaclab_arena.assets.teleop_device_base import TeleopDeviceBase
from isaaclab_arena.policy.policy_base import PolicyBase
from isaaclab_arena.relations.relations import RelationBase
from isaaclab_arena.tasks.task_base import TaskBase


# Have to define all classes here in order to avoid circular import.
Expand All @@ -42,7 +44,18 @@ def is_registered(self, key: str) -> bool:
key (str): The name of the component.
"""
# For AssetRegistry and DeviceRegistry, ensure assets are registered before checking
if isinstance(self, (AssetRegistry, DeviceRegistry, RetargeterRegistry, PolicyRegistry, HDRImageRegistry)):
if isinstance(
self,
(
AssetRegistry,
DeviceRegistry,
RetargeterRegistry,
PolicyRegistry,
HDRImageRegistry,
ObjectRelationLibraryRegistry,
TaskRegistry,
),
):
ensure_assets_registered()
return key in self._components

Expand All @@ -56,7 +69,18 @@ def get_component_by_name(self, key: str) -> Any:
Any: The component.
"""
# For AssetRegistry and DeviceRegistry, ensure assets are registered before accessing
if isinstance(self, (AssetRegistry, DeviceRegistry, RetargeterRegistry, PolicyRegistry, HDRImageRegistry)):
if isinstance(
self,
(
AssetRegistry,
DeviceRegistry,
RetargeterRegistry,
PolicyRegistry,
HDRImageRegistry,
ObjectRelationLibraryRegistry,
TaskRegistry,
),
):
ensure_assets_registered()
assert key in self._components, f"component {key} not found, please check if requested component is registered"
return self._components[key]
Expand All @@ -68,7 +92,18 @@ def get_all_keys(self) -> list[str]:
list[str]: The list of keys.
"""
# For AssetRegistry and DeviceRegistry, ensure assets are registered before accessing
if isinstance(self, (AssetRegistry, DeviceRegistry, RetargeterRegistry, PolicyRegistry, HDRImageRegistry)):
if isinstance(
self,
(
AssetRegistry,
DeviceRegistry,
RetargeterRegistry,
PolicyRegistry,
HDRImageRegistry,
ObjectRelationLibraryRegistry,
TaskRegistry,
),
):
ensure_assets_registered()
return list(self._components.keys())

Expand Down Expand Up @@ -217,14 +252,52 @@ def __init__(self):
super().__init__()


class ObjectRelationLibraryRegistry(Registry):
"""Registry for object relation classes."""

def __init__(self):
super().__init__()

def get_object_relation_by_name(self, name: str) -> type["RelationBase"]:
"""Gets an object relation by name.

Args:
name (str): The name of the object relation.
"""
ensure_assets_registered()
return self.get_component_by_name(name)


class TaskRegistry(Registry):
"""Registry for TaskBase subclasses."""

def __init__(self):
super().__init__()

def get_task_by_name(self, name: str) -> type["TaskBase"]:
"""Gets a task class by name.

Args:
name (str): The name of the task class (typically the class __name__).
"""
ensure_assets_registered()
return self.get_component_by_name(name)


# Lazy registration to avoid circular imports
_assets_registered = False
# Blocks re-entry: registration decorators call is_registered() -> ensure_assets_registered()
# mid-import, which would re-import a partial module and raise a circular ImportError.
_registration_in_progress = False


def ensure_assets_registered():
"""Ensure all assets are registered. Call this before accessing the registry."""
global _assets_registered
if not _assets_registered:
global _assets_registered, _registration_in_progress
if _assets_registered or _registration_in_progress:
return
_registration_in_progress = True
try:
# Import modules to trigger asset registration via decorators
import isaaclab_arena.assets.background_library # noqa: F401
import isaaclab_arena.assets.device_library # noqa: F401
Expand All @@ -233,5 +306,12 @@ def ensure_assets_registered():
import isaaclab_arena.assets.retargeter_library # noqa: F401
import isaaclab_arena.embodiments # noqa: F401
import isaaclab_arena.policy # noqa: F401
import isaaclab_arena.relations.relations # noqa: F401
import isaaclab_arena.tasks # noqa: F401
except Exception:
_assets_registered = False # roll back so a later call can retry
raise

_assets_registered = True
finally:
_registration_in_progress = False
150 changes: 150 additions & 0 deletions isaaclab_arena/environments/arena_env_graph_conversion_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from isaaclab_arena.assets.asset import Asset
from isaaclab_arena.assets.object_reference import ObjectReference
from isaaclab_arena.assets.registries import AssetRegistry
from isaaclab_arena.environments.arena_env_graph_task_conversion_utils import build_task_from_specs
from isaaclab_arena.environments.arena_env_graph_types import (
ArenaEnvGraphNodeSpec,
ArenaEnvGraphNodeType,
ArenaEnvGraphObjectReferenceNodeSpec,
ArenaEnvGraphSpatialConstraintType,
ArenaEnvGraphStateSpec,
)
from isaaclab_arena.environments.graph_spec_utils import relation_class_for_spatial_constraint_type
from isaaclab_arena.environments.isaaclab_arena_environment import IsaacLabArenaEnvironment
from isaaclab_arena.scene.scene import Scene
from isaaclab_arena.utils.pose import Pose

if TYPE_CHECKING:
from isaaclab_arena.environments.arena_env_graph_spec import ArenaEnvGraphSpec


def build_arena_env_from_graph_spec(graph_spec: ArenaEnvGraphSpec) -> Any:
"""Build an IsaacLabArenaEnvironment from a validated ``ArenaEnvGraphSpec``.

Precondition: ``graph_spec`` is already validated (node refs exist, ids unique, etc.).
"""
# TODO(xinjieyao, 2026-05-26): aggregate every state_spec into a single combined initial state instead of
# picking one. For now we just take the first state_spec, which is the initial state
# for the first task — this matches the previous default behavior.
initial_state_spec = graph_spec.state_specs[0] if graph_spec.state_specs else None

# 1. Materialize every graph node into a live asset, keyed by node id so spatial
# constraints and task args can reference each node by its graph-local id.
assets_by_node_id = _instantiate_assets_from_nodes(graph_spec.nodes, AssetRegistry())

# 2. Wire the initial state's spatial relations / fixed poses into those assets.
if initial_state_spec is not None:
_attach_spatial_constraints_to_assets(initial_state_spec, assets_by_node_id)

# 3. Partition nodes into the env's embodiment (exactly one) and its scene assets.
embodiment, scene_assets = _partition_nodes_into_embodiment_and_scene(graph_spec.nodes, assets_by_node_id)

# 4. Resolve task specs against the same assets_by_node_id so task args bind to the
# actual asset instances created in step 1 (not duplicates).
return IsaacLabArenaEnvironment(
name=graph_spec.env_name,
scene=Scene(assets=scene_assets),
embodiment=embodiment,
task=build_task_from_specs(graph_spec.tasks, assets_by_node_id),
)
Comment on lines +30 to +58
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Super clean!



def _partition_nodes_into_embodiment_and_scene(
node_specs: list[ArenaEnvGraphNodeSpec], assets_by_node_id: dict[str, Any]
) -> tuple[Any, list[Asset]]:
"""Split materialized nodes into the single embodiment asset and the list of scene assets.

Asserts exactly one EMBODIMENT node. OBJECT / OBJECT_REFERENCE nodes become scene assets;
any other node type raises. Lighting is not handled yet.
"""
embodiment = None
scene_assets: list[Asset] = []
# TODO(xinjieyao, 2026-05-26): include lighting later
for node_spec in node_specs:
if node_spec.type == ArenaEnvGraphNodeType.EMBODIMENT:
assert embodiment is None, "Only one embodiment node can be converted to an IsaacLabArenaEnvironment"
embodiment = assets_by_node_id[node_spec.id]
elif node_spec.type in (ArenaEnvGraphNodeType.OBJECT, ArenaEnvGraphNodeType.OBJECT_REFERENCE):
scene_assets.append(assets_by_node_id[node_spec.id])
else:
raise ValueError(f"Unsupported node type: {node_spec.type}")
return embodiment, scene_assets


def _instantiate_assets_from_nodes(node_specs: list[ArenaEnvGraphNodeSpec], asset_registry: Any) -> dict[str, Any]:
"""Return ``{node.id: live_asset}`` after a single pass over ``node_specs``.

Each ``node_spec.params`` is forwarded verbatim to the asset constructor. Assumes parent
nodes precede their OBJECT_REFERENCE children — guaranteed by ``assert_references_exist``.
"""
assets_by_node_id: dict[str, Any] = {}
for node_spec in node_specs:
# OBJECT_REFERENCE wraps a USD prim inside an already-instantiated parent asset
# (e.g. a table inside a kitchen background). Validation guarantees the parent
# precedes the reference, so it is already in assets_by_node_id here.
if node_spec.type == ArenaEnvGraphNodeType.OBJECT_REFERENCE:
assert isinstance(node_spec, ArenaEnvGraphObjectReferenceNodeSpec)
assets_by_node_id[node_spec.id] = ObjectReference(
name=node_spec.name,
prim_path=node_spec.prim_path,
parent_asset=assets_by_node_id[node_spec.parent],
object_type=node_spec.object_type,
**node_spec.params,
)
else:
# Standard nodes (object / background / embodiment): look up the registered class
# by name and instantiate with the spec's verbatim kwargs.
asset_class = asset_registry.get_asset_by_name(node_spec.name)
assets_by_node_id[node_spec.id] = asset_class(**node_spec.params)
return assets_by_node_id


def _attach_spatial_constraints_to_assets(
state_spec: ArenaEnvGraphStateSpec, assets_by_node_id: dict[str, Any]
) -> None:
"""Attach one Relation per spatial constraint to the asset(s) it targets, in place.

AT_POSE is special-cased to ``set_initial_pose`` since it has no Relation class.
Raises AssertionError on unsupported constraint types or malformed AT_POSE params.
"""
for spatial_constraint in state_spec.spatial_constraints:
parent_asset = assets_by_node_id[spatial_constraint.parent]

# AT_POSE has no Relation class — it pins the parent's initial pose directly,
# bypassing the solver. Pop the two known fields and reject any extras so a typo
# in the spec ('postion_xyz') raises here instead of silently being ignored.
# TODO(xinjieyao, 2026-05-26): move at_pose handling into the placer module.
if spatial_constraint.type == ArenaEnvGraphSpatialConstraintType.AT_POSE:
at_pose_params = dict(spatial_constraint.params)
position_xyz = at_pose_params.pop("position_xyz", None)
rotation_xyzw = at_pose_params.pop("rotation_xyzw", (0.0, 0.0, 0.0, 1.0))
assert (
position_xyz is not None
), f"at_pose constraint '{spatial_constraint.id}' requires params.position_xyz"
assert (
not at_pose_params
), f"Unsupported at_pose params for constraint '{spatial_constraint.id}': {sorted(at_pose_params)}"
parent_asset.set_initial_pose(Pose(position_xyz=position_xyz, rotation_xyzw=rotation_xyzw))
else:
# All other constraint types resolve through ObjectRelationLibraryRegistry. The
# registry returning None signals an enum value with no registered class — would
# be a programming error, not a YAML error.
relation_class = relation_class_for_spatial_constraint_type(spatial_constraint.type)
assert relation_class is not None, f"Unsupported spatial constraint type '{spatial_constraint.type.value}'"
# Unary relations (IS_ANCHOR, POSITION_LIMITS, ...) attach to the parent asset.
# Binary relations (ON, NEXT_TO, ...) attach to the *child* asset, with parent
# as the relation's first constructor arg — matches how add_relation is wired.
if relation_class.is_unary():
parent_asset.add_relation(relation_class(**spatial_constraint.params))
else:
child_asset = assets_by_node_id[spatial_constraint.child]
child_asset.add_relation(relation_class(parent_asset, **spatial_constraint.params))
Loading
Loading