-
Notifications
You must be signed in to change notification settings - Fork 52
Deterministic Object Placement #679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,11 +5,13 @@ | |||||||||
|
|
||||||||||
| """Tests for ObjectPlacer and RelationSolver reproducibility.""" | ||||||||||
|
|
||||||||||
| import pytest | ||||||||||
|
|
||||||||||
| from isaaclab_arena.assets.dummy_object import DummyObject | ||||||||||
| from isaaclab_arena.relations.object_placer import ObjectPlacer | ||||||||||
| from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams | ||||||||||
| from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult, PlacementResult | ||||||||||
| from isaaclab_arena.relations.pooled_object_placer import PooledObjectPlacer | ||||||||||
| from isaaclab_arena.relations.relation_solver import RelationSolver | ||||||||||
| from isaaclab_arena.relations.relation_solver_params import RelationSolverParams | ||||||||||
| from isaaclab_arena.relations.relations import IsAnchor, NextTo, On, Side | ||||||||||
|
|
@@ -244,3 +246,184 @@ def test_object_placer_result_per_env_true_applies_pose_per_env(): | |||||||||
| pose = obj.get_initial_pose() | ||||||||||
| assert isinstance(pose, PosePerEnv), f"{obj.name} should have PosePerEnv, got {type(pose).__name__}" | ||||||||||
| assert len(pose.poses) == num_envs | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def test_pooled_placer_same_seed_produces_identical_samples(): | ||||||||||
| """PooledObjectPlacer.sample_with_replacement must be reproducible under placement_seed.""" | ||||||||||
| solver_params = RelationSolverParams(max_iters=50) | ||||||||||
| placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) | ||||||||||
|
|
||||||||||
| pool1 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=8) | ||||||||||
| pool2 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=8) | ||||||||||
|
|
||||||||||
| samples1 = pool1.sample_with_replacement(20) | ||||||||||
| samples2 = pool2.sample_with_replacement(20) | ||||||||||
|
|
||||||||||
| for s1, s2 in zip(samples1, samples2): | ||||||||||
| for obj1, pos1 in s1.positions.items(): | ||||||||||
| pos2 = next(p for o, p in s2.positions.items() if o.name == obj1.name) | ||||||||||
| assert pos1 == pos2, f"Mismatch for {obj1.name}: {pos1} != {pos2}" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def test_pooled_placer_different_seeds_produce_different_samples(): | ||||||||||
| """Different placement_seed values should produce different sample sequences.""" | ||||||||||
| solver_params = RelationSolverParams(max_iters=50) | ||||||||||
|
|
||||||||||
| pool1 = PooledObjectPlacer( | ||||||||||
| objects=list(_create_test_objects()), | ||||||||||
| placer_params=ObjectPlacerParams( | ||||||||||
| placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False | ||||||||||
| ), | ||||||||||
| pool_size=8, | ||||||||||
| ) | ||||||||||
| pool2 = PooledObjectPlacer( | ||||||||||
| objects=list(_create_test_objects()), | ||||||||||
| placer_params=ObjectPlacerParams( | ||||||||||
| placement_seed=123, solver_params=solver_params, apply_positions_to_objects=False | ||||||||||
| ), | ||||||||||
| pool_size=8, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| samples1 = pool1.sample_with_replacement(20) | ||||||||||
| samples2 = pool2.sample_with_replacement(20) | ||||||||||
|
|
||||||||||
| any_different = False | ||||||||||
| for s1, s2 in zip(samples1, samples2): | ||||||||||
| for obj1, pos1 in s1.positions.items(): | ||||||||||
| pos2 = next(p for o, p in s2.positions.items() if o.name == obj1.name) | ||||||||||
| if pos1 != pos2: | ||||||||||
| any_different = True | ||||||||||
| break | ||||||||||
| if any_different: | ||||||||||
| break | ||||||||||
| assert any_different, "Different seeds should produce different samples" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _positions_by_name(result: PlacementResult) -> dict[str, tuple[float, float, float]]: | ||||||||||
| return {obj.name: pos for obj, pos in result.positions.items()} | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def test_pooled_placer_builds_identical_layouts_for_same_seed_and_objects(): | ||||||||||
| """Same objects + same placement_seed must produce bit-identical optimized layouts. | ||||||||||
|
|
||||||||||
| Draining the pool via ``sample_without_replacement`` walks ``_layouts`` in order with | ||||||||||
| no RNG involvement, so this compares the solver output directly across the two pools. | ||||||||||
| """ | ||||||||||
| solver_params = RelationSolverParams(max_iters=50) | ||||||||||
| placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) | ||||||||||
|
|
||||||||||
| pool1 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4) | ||||||||||
| pool2 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4) | ||||||||||
|
|
||||||||||
| layouts1 = pool1.sample_without_replacement(4) | ||||||||||
| layouts2 = pool2.sample_without_replacement(4) | ||||||||||
|
|
||||||||||
| for L1, L2 in zip(layouts1, layouts2): | ||||||||||
| assert _positions_by_name(L1) == _positions_by_name(L2) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def test_pooled_placer_continues_seed_stream_across_refill(): | ||||||||||
| """Pool refill must advance the candidate seed stream so it doesn't replay the initial batch.""" | ||||||||||
| solver_params = RelationSolverParams(max_iters=50) | ||||||||||
| placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) | ||||||||||
|
|
||||||||||
| pool1 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=2) | ||||||||||
| pool2 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=2) | ||||||||||
|
|
||||||||||
| # Four draws: initial batch (draws 0-1) then a forced refill (draws 2-3). | ||||||||||
| draws1 = [pool1.sample_without_replacement(1)[0] for _ in range(4)] | ||||||||||
| draws2 = [pool2.sample_without_replacement(1)[0] for _ in range(4)] | ||||||||||
|
|
||||||||||
| for d1, d2 in zip(draws1, draws2): | ||||||||||
| assert _positions_by_name(d1) == _positions_by_name(d2), "Same seed must give same draws across refill" | ||||||||||
|
|
||||||||||
| # The refill batch (draws 2 and 3) must not be a permutation of the initial batch (draws 0 and 1). | ||||||||||
| # If seed_offset failed to advance, the second batch would re-solve from seeds 0..N-1 and the | ||||||||||
| # cross-batch positions would coincide. | ||||||||||
| initial_batch = {_positions_by_name(draws1[0]).__repr__(), _positions_by_name(draws1[1]).__repr__()} | ||||||||||
| refill_batch = {_positions_by_name(draws1[2]).__repr__(), _positions_by_name(draws1[3]).__repr__()} | ||||||||||
| assert initial_batch.isdisjoint( | ||||||||||
| refill_batch | ||||||||||
| ), "Refill replayed the initial seed range; seed_offset is not advancing" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def test_pooled_placer_unseeded_does_not_crash_and_diverges(): | ||||||||||
| """placement_seed=None must work (no crash) and produce non-reproducible samples.""" | ||||||||||
| solver_params = RelationSolverParams(max_iters=50) | ||||||||||
| placer_params = ObjectPlacerParams( | ||||||||||
| placement_seed=None, solver_params=solver_params, apply_positions_to_objects=False | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| pool1 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4) | ||||||||||
| pool2 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4) | ||||||||||
|
|
||||||||||
| samples1 = pool1.sample_with_replacement(20) | ||||||||||
| samples2 = pool2.sample_with_replacement(20) | ||||||||||
|
|
||||||||||
| any_different = False | ||||||||||
| for s1, s2 in zip(samples1, samples2): | ||||||||||
| if _positions_by_name(s1) != _positions_by_name(s2): | ||||||||||
| any_different = True | ||||||||||
| break | ||||||||||
| assert any_different, "Two unseeded pools should not produce identical samples" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def test_pooled_placer_stored_layouts_have_distinct_positions_dicts(): | ||||||||||
| """Each stored layout must own a distinct positions dict (no aliasing across pool entries).""" | ||||||||||
| solver_params = RelationSolverParams(max_iters=50) | ||||||||||
| placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) | ||||||||||
|
|
||||||||||
|
Comment on lines
+355
to
+375
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||
| pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4) | ||||||||||
| draws = pool.sample_without_replacement(4) | ||||||||||
|
|
||||||||||
| for i in range(len(draws)): | ||||||||||
| for j in range(i + 1, len(draws)): | ||||||||||
| assert ( | ||||||||||
| draws[i].positions is not draws[j].positions | ||||||||||
| ), f"Layouts {i} and {j} share the same positions dict reference" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def test_pooled_placer_sample_without_replacement_count_exceeds_pool_size(): | ||||||||||
| """sample_without_replacement(count) where count > pool_size must solve a larger batch in one shot.""" | ||||||||||
| solver_params = RelationSolverParams(max_iters=50) | ||||||||||
| placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) | ||||||||||
|
|
||||||||||
| pool1 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=2) | ||||||||||
| pool2 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=2) | ||||||||||
|
|
||||||||||
| draws1 = pool1.sample_without_replacement(5) | ||||||||||
| draws2 = pool2.sample_without_replacement(5) | ||||||||||
|
|
||||||||||
| assert len(draws1) == 5 and len(draws2) == 5 | ||||||||||
| for d1, d2 in zip(draws1, draws2): | ||||||||||
| assert _positions_by_name(d1) == _positions_by_name(d2) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def test_pooled_placer_sample_with_replacement_reproducible_across_refill(): | ||||||||||
| """sample_with_replacement must remain reproducible after a refill mutates the pool.""" | ||||||||||
| solver_params = RelationSolverParams(max_iters=50) | ||||||||||
| placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) | ||||||||||
|
|
||||||||||
| def draw_sequence(pool: PooledObjectPlacer) -> list: | ||||||||||
| before = pool.sample_with_replacement(5) | ||||||||||
| # Consume the initial pool to force a refill. | ||||||||||
| pool.sample_without_replacement(2) | ||||||||||
| pool.sample_without_replacement(1) | ||||||||||
| after = pool.sample_with_replacement(5) | ||||||||||
| return before + after | ||||||||||
|
|
||||||||||
| pool1 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=2) | ||||||||||
| pool2 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=2) | ||||||||||
|
|
||||||||||
| seq1 = draw_sequence(pool1) | ||||||||||
| seq2 = draw_sequence(pool2) | ||||||||||
|
|
||||||||||
| for s1, s2 in zip(seq1, seq2): | ||||||||||
| assert _positions_by_name(s1) == _positions_by_name(s2) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def test_pooled_placer_rejects_pool_size_below_one(): | ||||||||||
| """pool_size < 1 is an invariant violation; the assert must fire.""" | ||||||||||
| placer_params = ObjectPlacerParams(placement_seed=42, solver_params=RelationSolverParams(max_iters=10)) | ||||||||||
| with pytest.raises(AssertionError, match="pool_size must be >= 1"): | ||||||||||
| PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=0) | ||||||||||
|
Comment on lines
+428
to
+429
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assertfor constructor argument validation is problematic because Python disables allassertstatements when run with the-O(optimize) flag. In that mode, apool_size=0call will silently pass this check, thenmax_placement_attempts * 0 = 0candidates are generated, and the pool initialization will fail later with a genericRuntimeError("Placement pool failed to produce any valid layouts")— losing the precise error message. The originalValueErroris the idiomatic Python choice for invalid argument values in a public__init__.