Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ jobs:
image: nvcr.io/nvstaging/isaac-amr/isaaclab_arena:latest
credentials:
username: $oauthtoken
password: ${{ env.NGC_API_KEY }}
password: ${{ secrets.ARENA_NGC_API_KEY }}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by accident?


steps:
# nvidia-smi
Expand Down
111 changes: 8 additions & 103 deletions isaaclab_arena/environments/arena_env_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from isaaclab_tasks.utils import parse_env_cfg
from isaaclab_teleop import IsaacTeleopCfg

from isaaclab_arena.assets.object import Object
from isaaclab_arena.assets.object_reference import ObjectReference
from isaaclab_arena.assets.registries import DeviceRegistry
from isaaclab_arena.embodiments.no_embodiment import NoEmbodiment
from isaaclab_arena.environments.isaaclab_arena_environment import IsaacLabArenaEnvironment
Expand All @@ -28,16 +26,11 @@
IsaacLabArenaManagerBasedRLEnvCfg,
)
from isaaclab_arena.metrics.recorder_manager_utils import metrics_to_recorder_manager_cfg
from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams
from isaaclab_arena.relations.placement_events import get_rotation_xyzw, solve_and_place_objects
from isaaclab_arena.relations.pooled_object_placer import PooledObjectPlacer
from isaaclab_arena.relations.relation_solver_params import RelationSolverParams
from isaaclab_arena.relations.relations import get_anchor_objects
from isaaclab_arena.relations.relation_placement import prepare_relation_placement
from isaaclab_arena.tasks.no_task import NoTask
from isaaclab_arena.utils.configclass import combine_configclass_instances, make_configclass
from isaaclab_arena.utils.isaaclab_utils.simulation_app import reapply_viewer_cfg
from isaaclab_arena.utils.multiprocess import get_local_rank
from isaaclab_arena.utils.pose import Pose, PosePerEnv


class ArenaEnvBuilder:
Expand Down Expand Up @@ -68,103 +61,15 @@ def _solve_relations(self) -> None:
so per-object reset events restore the same layout every time.
"""
objects_with_relations = self.arena_env.scene.get_objects_with_relations()

if not objects_with_relations:
print("No objects with relations found in scene. Skipping relation solving.")
return

num_envs = self.args.num_envs
cli_resolve = self.args.resolve_on_reset

# The pool applies positions itself, so disable ObjectPlacer's built-in apply.
# Position history and verbose logging are unnecessary for batch-solving a pool.
placer_params = ObjectPlacerParams(
placement_seed=self.args.placement_seed,
apply_positions_to_objects=False,
solver_params=RelationSolverParams(save_position_history=False, verbose=False),
)
if cli_resolve is not None:
placer_params.resolve_on_reset = cli_resolve

pool_size = num_envs * placer_params.min_unique_layouts_per_env

placement_pool = PooledObjectPlacer(
placement_plan = prepare_relation_placement(
objects=objects_with_relations,
placer_params=placer_params,
pool_size=pool_size,
num_envs=self.args.num_envs,
placement_seed=self.args.placement_seed,
resolve_on_reset=self.args.resolve_on_reset,
embodiment=self.arena_env.embodiment,
)

if placer_params.resolve_on_reset:
anchor_objects_set = set(get_anchor_objects(objects_with_relations))
for obj in objects_with_relations:
if obj not in anchor_objects_set and obj.event_cfg is not None:
raise RuntimeError(
f"Non-anchor object '{obj.name}' has an explicit pose-reset event. "
"Relational solving should not be combined with explicit setting of "
"poses on non-anchor objects."
)
# Set init_state so objects spawn at valid positions (not origin).
# The placement event will override these on every reset.
self._set_init_state_from_pool(objects_with_relations, placement_pool, anchor_objects_set)
self._placement_event_cfg = EventTermCfg(
func=solve_and_place_objects,
mode="reset",
params={
"objects": objects_with_relations,
"placement_pool": placement_pool,
},
)
else:
self._apply_pool_layouts_to_objects(objects_with_relations, placement_pool, num_envs)

def _set_init_state_from_pool(
self,
objects: list[Object | ObjectReference],
pool: PooledObjectPlacer,
anchor_objects_set: set,
) -> None:
"""Set ``object_cfg.init_state`` from a pool layout so objects spawn at valid positions.

Only touches ``init_state.pos`` / ``init_state.rot`` — does NOT create
per-object reset events (the placement event handles resets).
"""
layout = pool.sample_with_replacement(1)[0]
for obj in objects:
if obj in anchor_objects_set:
continue
pos = layout.positions.get(obj)
if pos is None:
continue
rotation_xyzw = get_rotation_xyzw(obj)
obj.object_cfg.init_state.pos = pos
obj.object_cfg.init_state.rot = rotation_xyzw

def _apply_pool_layouts_to_objects(
self,
objects: list[Object | ObjectReference],
pool: PooledObjectPlacer,
num_envs: int,
) -> None:
"""Draw layouts from the pool and apply them to objects via ``set_initial_pose``.

Each non-anchor object gets a :class:`~isaaclab_arena.utils.pose.PosePerEnv`
so that per-object reset events restore these positions.
"""
layouts = pool.sample_with_replacement(num_envs)
anchor_objects_set = set(get_anchor_objects(objects))

for obj in objects:
if obj in anchor_objects_set:
continue
rotation_xyzw = get_rotation_xyzw(obj)
poses = []
for env_idx in range(num_envs):
pos = layouts[env_idx].positions.get(obj)
if pos is None:
break
poses.append(Pose(position_xyz=pos, rotation_xyzw=rotation_xyzw))
else:
obj.set_initial_pose(PosePerEnv(poses=poses))
if placement_plan is not None:
self._placement_event_cfg = placement_plan.placement_event_cfg

def _modify_recorder_cfg_dataset_filename(self, recorder_cfg: RecorderManagerBaseCfg) -> RecorderManagerBaseCfg:
"""Modify the recorder dataset filename to include the timestamp and rank."""
Expand Down
12 changes: 6 additions & 6 deletions isaaclab_arena/relations/placement_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import torch
from typing import TYPE_CHECKING

from isaaclab.envs import ManagerBasedEnv

from isaaclab_arena.relations.pooled_object_placer import PooledObjectPlacer
from isaaclab_arena.relations.relations import RotateAroundSolution, get_anchor_objects
from isaaclab_arena.utils.pose import Pose

if TYPE_CHECKING:
from isaaclab.envs import ManagerBasedEnv

from isaaclab_arena.assets.object_base import ObjectBase

IDENTITY_ROTATION_XYZW = (0.0, 0.0, 0.0, 1.0)
Expand Down Expand Up @@ -49,18 +49,18 @@ def solve_and_place_objects(
num_reset_envs = len(env_ids)
results_per_env = placement_pool.sample_without_replacement(num_reset_envs)

anchor_objects_set = set(get_anchor_objects(objects))
rotations = {obj: get_rotation_xyzw(obj) for obj in objects if obj not in anchor_objects_set}
anchor_object_names = {obj.name for obj in get_anchor_objects(objects)}
rotations = {obj.name: get_rotation_xyzw(obj) for obj in objects if obj.name not in anchor_object_names}

zero_velocity = torch.zeros(1, 6, device=env.device)
for local_idx, cur_env in enumerate(env_ids.tolist()):
env_id_tensor = torch.tensor([cur_env], device=env.device)
positions = results_per_env[local_idx].positions
for obj, pos in positions.items():
if obj in anchor_objects_set:
if obj.name in anchor_object_names:
continue
asset = env.scene[obj.name]
pose = Pose(position_xyz=pos, rotation_xyzw=rotations[obj])
pose = Pose(position_xyz=pos, rotation_xyzw=rotations[obj.name])
pose_t_xyz_q_xyzw = pose.to_tensor(device=env.device).unsqueeze(0)
pose_t_xyz_q_xyzw[0, :3] += env.scene.env_origins[cur_env, :]
asset.write_root_pose_to_sim(pose_t_xyz_q_xyzw, env_ids=env_id_tensor)
Expand Down
5 changes: 5 additions & 0 deletions isaaclab_arena/relations/pooled_object_placer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,8 @@ def sample_with_replacement(self, count: int) -> list[PlacementResult]:
def remaining(self) -> int:
"""Number of layouts not yet consumed by :meth:`sample_without_replacement`."""
return len(self._layouts) - self._next_idx

@property
def pool_size(self) -> int:
"""Number of layouts requested when refilling the pool."""
return self._pool_size
Loading
Loading