-
Notifications
You must be signed in to change notification settings - Fork 52
Heterogenous Object Placement #676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a568855
a04f088
c8c58e2
96278de
9599087
39058fc
a498313
4ff5a59
31105a8
4b0c725
54c229a
bde2403
3a44e1a
b17166a
983309f
2461fc7
93ec979
88ad2e6
0361523
37bd79b
90783a1
9c91263
4307ac2
5e7e9cb
5a2bf9f
318b7d7
4b4da3d
70cacf1
dd1a554
72665c6
c9446c5
64c7d9d
e85612e
2d784c4
16dd3ef
a047add
793664c
e113a98
7969e98
c779958
29f4076
de48ebb
87f0dd6
c2fc41d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,8 @@ | |
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import torch | ||
|
|
||
| import isaaclab.sim as sim_utils | ||
| from isaaclab.assets import RigidObjectCfg | ||
| from isaaclab.sensors.contact_sensor.contact_sensor_cfg import ContactSensorCfg | ||
|
|
@@ -17,9 +19,7 @@ | |
|
|
||
|
|
||
| class RigidObjectSet(Object): | ||
| """ | ||
| A set of rigid objects. | ||
| """ | ||
| """A set of rigid objects with one member selected per environment.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
|
|
@@ -40,9 +40,12 @@ def __init__( | |
| scale: The scale of the object set. Note all objects can only have the same scale, if | ||
| different scales are needed, considering scaling the object USD file. | ||
| random_choice: Whether to randomly choose an object from the object set to spawn in | ||
| each environment. If False, object is spawned based on the order of objects in the list. | ||
| each environment. If False, variants are assigned by repeating | ||
| the member order across environments. | ||
| initial_pose: The initial pose of the object from this object set. | ||
| """ | ||
| if len(objects) < 1: | ||
| raise ValueError(f"Object set {name} must contain at least 1 object.") | ||
| if not self._are_all_objects_type_rigid(objects): | ||
| raise ValueError(f"Object set {name} must contain only rigid objects.") | ||
|
|
||
|
|
@@ -60,15 +63,18 @@ def __init__( | |
| "rigid body at the same depth so paths match after rename. " | ||
| f"Rigid body depths by asset: {per_asset}." | ||
| ) | ||
| self.object_usd_paths = self._modify_assets(objects) | ||
| print(f"Modified object USD paths: {self.object_usd_paths}") | ||
| self.member_usd_paths: list[str] = self._modify_assets(objects) | ||
| print(f"Modified object USD paths: {self.member_usd_paths}") | ||
| else: | ||
| self.object_usd_paths = [object.usd_path for object in objects] | ||
| self.member_usd_paths = [] | ||
| for obj in objects: | ||
| assert obj.usd_path is not None | ||
| self.member_usd_paths.append(obj.usd_path) | ||
|
|
||
| self.objects: list[Object] = objects | ||
| self.random_choice = random_choice | ||
| self.variant_indices_by_env: list[int] | None = None | ||
|
|
||
| # Set default prim_path if not provided | ||
| if prim_path is None: | ||
| prim_path = f"{{ENV_REGEX_NS}}/{name}" | ||
|
|
||
|
|
@@ -77,36 +83,137 @@ def __init__( | |
| object_type=ObjectType.RIGID, | ||
| usd_path="", | ||
| prim_path=prim_path, | ||
| scale=(1.0, 1.0, 1.0), # We rewrite the USDs to handle scaling | ||
| scale=(1.0, 1.0, 1.0), # We rewrite the USDs to handle scaling. | ||
| initial_pose=initial_pose, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| @property | ||
| def object_usd_paths(self) -> list[str]: | ||
| """USD paths passed to MultiUsdFileCfg. | ||
|
|
||
| Before assignment this is the member USD list. After assignment this | ||
| returns one USD path per environment based on variant_indices_by_env. | ||
| """ | ||
| if self.variant_indices_by_env is not None: | ||
| return [self.member_usd_paths[idx] for idx in self.variant_indices_by_env] | ||
| return self.member_usd_paths | ||
|
|
||
| def get_bounding_box(self) -> AxisAlignedBoundingBox: | ||
| """Get the bounding box of the object set. | ||
| """Return one local bbox for callers that cannot vary by env. | ||
|
|
||
| Returns the bounding box with the greatest z-extent among all objects in the set. | ||
| This is a heuristic to avoid objects spawning inside their support surfaces. | ||
| The returned bbox has shape (1, 3) and uses the member with the | ||
| greatest z-extent. Heterogeneous placement uses | ||
| get_bounding_box_per_env() after assign_variants() so each env | ||
| uses its actual variant geometry. | ||
| """ | ||
| return max(self.objects, key=lambda obj: obj.get_bounding_box().size[0, 2].item()).get_bounding_box() | ||
|
|
||
| def assign_variants(self, num_envs: int, variant_seed: int | None = None) -> None: | ||
| """Fix one member-variant index per environment. | ||
|
|
||
| The assignment is fixed for the lifetime of the object set so spawned | ||
| USDs and per-env bboxes stay aligned across placement refills. | ||
| Subsequent calls with the same num_envs are no-ops. A call with a | ||
| different num_envs regenerates with a warning. When random_choice is True, each env | ||
| independently samples one variant; otherwise assignments repeat the | ||
| member order across environments. | ||
| Regeneration is safe before the scene is spawned; afterwards, per-env | ||
| bboxes can desync from the spawned USDs. | ||
|
|
||
| Callers invoke this once num_envs is known, before reading | ||
| variant_indices_by_env or get_bounding_box_per_env. | ||
|
|
||
| Args: | ||
| num_envs: Number of environments to assign variants for. | ||
| variant_seed: Optional seed used when random_choice=True. | ||
| """ | ||
| if self.variant_indices_by_env is not None: | ||
| if len(self.variant_indices_by_env) == num_envs: | ||
| return | ||
| print(f"Warning: RigidObjectSet '{self.name}' regenerating variant assignments for {num_envs} envs.") | ||
| self._set_variant_indices_by_env(self._generate_variant_indices(num_envs, variant_seed=variant_seed)) | ||
|
|
||
| def get_bounding_box_per_env(self, num_envs: int) -> AxisAlignedBoundingBox: | ||
| """Return the local bbox for each env's assigned variant. | ||
|
|
||
| Unlike the single-bbox compatibility fallback, this returns the real | ||
| local bbox of the variant assigned to each env, enabling correct | ||
| collision-free placement for heterogeneous scenes. | ||
|
|
||
| Requires assign_variants(num_envs) to have been called first. The | ||
| returned bbox has shape (num_envs, 3). | ||
|
|
||
| Args: | ||
| num_envs: Number of environments. Must match the assignment. | ||
|
|
||
| Returns: | ||
| AxisAlignedBoundingBox with min_point / max_point of | ||
| shape (num_envs, 3). | ||
| """ | ||
| assert self.variant_indices_by_env is not None, ( | ||
| f"RigidObjectSet '{self.name}' has no variant assignment; " | ||
| "call assign_variants(num_envs) before get_bounding_box_per_env()." | ||
| ) | ||
| assert len(self.variant_indices_by_env) == num_envs, ( | ||
| f"RigidObjectSet '{self.name}' got request for {num_envs} envs, " | ||
| f"but is assigned for {len(self.variant_indices_by_env)} envs." | ||
| ) | ||
| bounding_boxes = [obj.get_bounding_box() for obj in self.objects] | ||
|
|
||
| min_pts = torch.stack([bounding_boxes[idx].min_point[0] for idx in self.variant_indices_by_env], dim=0) | ||
| max_pts = torch.stack([bounding_boxes[idx].max_point[0] for idx in self.variant_indices_by_env], dim=0) | ||
| return AxisAlignedBoundingBox(min_point=min_pts, max_point=max_pts) | ||
|
|
||
| def get_contact_sensor_cfg(self, contact_against_object: ObjectBase | None = None) -> ContactSensorCfg: | ||
| # We assume that by here, our USDs have been modified to be compatible with each other | ||
| # and we can use the first USD path to find the shallowest rigid body. | ||
| return super().get_contact_sensor_cfg(contact_against_object, usd_path=self.object_usd_paths[0]) | ||
| # and we can use the canonical first member USD to find the shallowest rigid body. | ||
| return super().get_contact_sensor_cfg(contact_against_object, usd_path=self.member_usd_paths[0]) | ||
|
|
||
| def _generate_variant_indices(self, num_envs: int, variant_seed: int | None = None) -> list[int]: | ||
| """Return one member index per env. | ||
|
|
||
| Ordered sets repeat member order. Random sets sample independently per | ||
| env, using a local generator when variant_seed is set. | ||
| """ | ||
| n = len(self.objects) | ||
| if not self.random_choice: | ||
| return [env_idx % n for env_idx in range(num_envs)] | ||
| if variant_seed is None: | ||
| return torch.randint(low=0, high=n, size=(num_envs,)).tolist() | ||
| generator = torch.Generator() | ||
| generator.manual_seed(variant_seed) | ||
| return torch.randint(low=0, high=n, size=(num_envs,), generator=generator).tolist() | ||
|
|
||
| def _set_variant_indices_by_env(self, variant_indices_by_env: list[int]) -> None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion to add a short docstring. Hard to know what this function is doing.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doctring added |
||
| """Validate and store variant indices, then sync spawn config when it exists.""" | ||
| n = len(self.objects) | ||
| assert all( | ||
| 0 <= idx < n for idx in variant_indices_by_env | ||
| ), f"RigidObjectSet '{self.name}' variant indices must be in [0, {n}); got {variant_indices_by_env}." | ||
| self.variant_indices_by_env = variant_indices_by_env | ||
| # During __init__, Object.object_cfg has not been built yet; _generate_rigid_cfg() | ||
| # reads object_usd_paths after this assignment. | ||
| spawn_cfg = self.object_cfg.spawn if getattr(self, "object_cfg", None) is not None else None | ||
| if isinstance(spawn_cfg, sim_utils.MultiUsdFileCfg): | ||
| spawn_cfg.usd_path = self.object_usd_paths | ||
|
|
||
| def _are_all_objects_type_rigid(self, objects: list[ObjectBase]) -> bool: | ||
| if objects is None or len(objects) == 0: | ||
| raise ValueError(f"Object set {self.name} must contain at least 1 object.") | ||
| return all(detect_object_type(usd_path=object.usd_path) == ObjectType.RIGID for object in objects) | ||
| def _are_all_objects_type_rigid(self, objects: list[Object]) -> bool: | ||
| for obj in objects: | ||
| assert obj.usd_path is not None | ||
| if detect_object_type(usd_path=obj.usd_path) != ObjectType.RIGID: | ||
| return False | ||
| return True | ||
|
|
||
| def _generate_rigid_cfg(self) -> RigidObjectCfg: | ||
| assert self.object_type == ObjectType.RIGID | ||
| object_cfg = RigidObjectCfg( | ||
| prim_path=self.prim_path, | ||
| spawn=sim_utils.MultiUsdFileCfg( | ||
| usd_path=self.object_usd_paths, | ||
| random_choice=self.random_choice, | ||
| # Arena owns per-env variant assignment so bbox selection and | ||
| # spawned USDs stay aligned. | ||
| random_choice=False, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a small inline comment here that this is hardcoded false because the randomization is handled at the Arena level in this class.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed. inline comment added. |
||
| activate_contact_sensors=True, | ||
| ), | ||
| ) | ||
|
|
@@ -143,6 +250,7 @@ def _asset_modification_possible(self, objects: list[Object]) -> bool: | |
| def _get_all_rigid_body_depths(self, objects: list[Object]) -> list[int]: | ||
| depths = [] | ||
| for asset in objects: | ||
| assert asset.usd_path is not None | ||
| shallowest_rigid_body = find_shallowest_rigid_body(asset.usd_path) | ||
| depth = shallowest_rigid_body.count("/") - 1 if shallowest_rigid_body else -1 | ||
| depths.append(depth) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,12 +60,12 @@ def _solve_relations(self) -> None: | |
| 3. Applies solved positions to objects via a :class:`PooledObjectPlacer` | ||
|
|
||
| Behaviour on reset depends on :attr:`ObjectPlacerParams.resolve_on_reset` | ||
| (overridable from CLI with ``--resolve_on_reset`` / ``--no-resolve_on_reset``): | ||
| (overridable from CLI with --resolve_on_reset / --no-resolve_on_reset): | ||
|
|
||
| * **True** (default) — registers a reset event that draws a fresh layout | ||
| from the pool for each resetting environment. | ||
| * **False** — applies one layout per environment via ``set_initial_pose`` | ||
| so per-object reset events restore the same layout every time. | ||
| * **False** — applies one layout per environment so per-object reset | ||
| events restore the same layout every time. | ||
| """ | ||
| objects_with_relations = self.arena_env.scene.get_objects_with_relations() | ||
|
|
||
|
|
@@ -92,7 +92,13 @@ def _solve_relations(self) -> None: | |
| objects=objects_with_relations, | ||
| placer_params=placer_params, | ||
| pool_size=pool_size, | ||
| num_envs=num_envs, | ||
| ) | ||
| if placement_pool.had_fallbacks: | ||
| print( | ||
| "Warning: Relation placement pool accepted best-loss fallback layouts " | ||
| "that failed strict placement validation." | ||
| ) | ||
|
|
||
| if placer_params.resolve_on_reset: | ||
| anchor_objects_set = set(get_anchor_objects(objects_with_relations)) | ||
|
|
@@ -123,18 +129,32 @@ def _set_init_state_from_pool( | |
| pool: PooledObjectPlacer, | ||
| anchor_objects_set: set, | ||
| ) -> None: | ||
| """Set ``object_cfg.init_state`` from a pool layout so objects spawn at valid positions. | ||
| """Set object_cfg.init_state from pool layouts so objects spawn at valid positions. | ||
|
|
||
| Only touches ``init_state.pos`` / ``init_state.rot`` — does NOT create | ||
| per-object reset events (the placement event handles resets). | ||
| This does not create per-object reset events; the placement event | ||
| handles reset-time layout writes. | ||
| """ | ||
| if pool.requires_env_indexed_layouts: | ||
| layout = pool.sample_with_replacement(1)[0] | ||
| for obj in objects: | ||
| if obj in anchor_objects_set: | ||
| continue | ||
| rotation_xyzw = get_rotation_xyzw(obj) | ||
| # Isaac Lab broadcasts init_state across envs; reset events write env-specific layouts. | ||
| pos = layout.positions.get(obj) | ||
| if pos is None: | ||
| raise RuntimeError(f"Pool layout for env 0 is missing object '{obj.name}'.") | ||
| obj.object_cfg.init_state.pos = pos | ||
| obj.object_cfg.init_state.rot = rotation_xyzw | ||
| return | ||
|
Comment on lines
+137
to
+149
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For env-indexed (heterogeneous) layouts, |
||
|
|
||
| layout = pool.sample_with_replacement(1)[0] | ||
| for obj in objects: | ||
| if obj in anchor_objects_set: | ||
| continue | ||
| pos = layout.positions.get(obj) | ||
| if pos is None: | ||
| continue | ||
| raise RuntimeError(f"Pool layout is missing object '{obj.name}'.") | ||
| rotation_xyzw = get_rotation_xyzw(obj) | ||
| obj.object_cfg.init_state.pos = pos | ||
| obj.object_cfg.init_state.rot = rotation_xyzw | ||
|
|
@@ -145,7 +165,7 @@ def _apply_pool_layouts_to_objects( | |
| pool: PooledObjectPlacer, | ||
| num_envs: int, | ||
| ) -> None: | ||
| """Draw layouts from the pool and apply them to objects via ``set_initial_pose``. | ||
| """Draw layouts from the pool and apply them as per-env initial poses. | ||
|
|
||
| Each non-anchor object gets a :class:`~isaaclab_arena.utils.pose.PosePerEnv` | ||
| so that per-object reset events restore these positions. | ||
|
|
@@ -161,10 +181,9 @@ def _apply_pool_layouts_to_objects( | |
| for env_idx in range(num_envs): | ||
| pos = layouts[env_idx].positions.get(obj) | ||
| if pos is None: | ||
| break | ||
| raise RuntimeError(f"Placement layout for env {env_idx} is missing object '{obj.name}'.") | ||
| poses.append(Pose(position_xyz=pos, rotation_xyzw=rotation_xyzw)) | ||
| else: | ||
| obj.set_initial_pose(PosePerEnv(poses=poses)) | ||
| obj.set_initial_pose(PosePerEnv(poses=poses)) | ||
|
|
||
| def _modify_recorder_cfg_dataset_filename(self, recorder_cfg: RecorderManagerBaseCfg) -> RecorderManagerBaseCfg: | ||
| """Modify the recorder dataset filename to include the timestamp and rank.""" | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.