Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion isaaclab_arena/relations/object_placer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def place(
objects: list[ObjectBase],
num_envs: int = 1,
result_per_env: bool = True,
*,
seed_offset: int = 0,
) -> PlacementResult | MultiEnvPlacementResult:
"""Place objects according to their spatial relations.

Expand All @@ -85,6 +87,11 @@ def place(
result_per_env: When True (default), each environment gets a distinct
layout. When False, a single best layout is solved and applied
identically to all environments.
seed_offset: Added to ``placement_seed`` for each candidate. Callers
that issue multiple ``place()`` calls (e.g. a pool refill) pass
the cumulative candidate count here so successive batches use a
fresh, non-overlapping seed range. Ignored when
``placement_seed`` is None (no seeding takes place).

Returns:
PlacementResult when a single layout is produced (num_envs=1 or
Expand Down Expand Up @@ -128,7 +135,7 @@ def place(
initial_positions: list[dict[ObjectBase, tuple[float, float, float]]] = []
for candidate_idx in range(num_candidates):
if generator is not None:
generator.manual_seed(self.params.placement_seed + candidate_idx)
generator.manual_seed(self.params.placement_seed + seed_offset + candidate_idx)
initial_positions.append(self._generate_initial_positions(objects, anchor_objects_set, generator))

all_positions = self._solver.solve(objects, initial_positions)
Expand Down
26 changes: 22 additions & 4 deletions isaaclab_arena/relations/pooled_object_placer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class PooledObjectPlacer:
* :meth:`sample_with_replacement` — picks *count* layouts at random
(non-consuming). Used for static initial positions.

Reproducibility is controlled by ``placer_params.placement_seed``: when set,
refills advance a deterministic seed stream (via ``seed_offset``) and
``sample_with_replacement`` draws from a seeded ``random.Random``. When
``placement_seed`` is None, both solver init and sampling fall back to
global RNG state and are non-deterministic.

Args:
objects: All objects (including anchors) participating in relation solving.
placer_params: Parameters forwarded to ``ObjectPlacer`` for the batched solve.
Expand All @@ -41,14 +47,15 @@ def __init__(
placer_params: ObjectPlacerParams,
pool_size: int = 100,
) -> None:
if pool_size < 1:
raise ValueError(f"pool_size must be >= 1, got {pool_size}")
assert pool_size >= 1, f"pool_size must be >= 1, got {pool_size}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Using assert for constructor argument validation is problematic because Python disables all assert statements when run with the -O (optimize) flag. In that mode, a pool_size=0 call will silently pass this check, then max_placement_attempts * 0 = 0 candidates are generated, and the pool initialization will fail later with a generic RuntimeError("Placement pool failed to produce any valid layouts") — losing the precise error message. The original ValueError is the idiomatic Python choice for invalid argument values in a public __init__.

Suggested change
assert pool_size >= 1, f"pool_size must be >= 1, got {pool_size}"
if pool_size < 1:
raise ValueError(f"pool_size must be >= 1, got {pool_size}")


self._objects = objects
self._placer = ObjectPlacer(params=placer_params)
self._pool_size = pool_size
self._layouts: list[PlacementResult] = []
self._next_idx: int = 0
self._rng = random.Random(placer_params.placement_seed)
self._next_seed_offset: int = 0

# Pre-solve the initial batch (runs the gradient solver, no simulation is needed).
self._solve_and_store(pool_size)
Expand All @@ -74,8 +81,19 @@ def _solve_and_store(self, num_layouts: int) -> None:

# place() runs: random init → gradient solve → validate → rank.
# It returns up to num_layouts results; some may fail validation.
# Advancing seed_offset by the candidate count keeps successive batches on
# a fresh, non-overlapping seed range so refills don't replay the initial pool.
seed_offset = self._next_seed_offset
num_candidates = self._placer.params.max_placement_attempts * num_layouts
with torch.inference_mode(False):
result = self._placer.place(self._objects, num_envs=num_layouts, result_per_env=True)
result = self._placer.place(
self._objects, num_envs=num_layouts, result_per_env=True, seed_offset=seed_offset
)
# Keep the offset bookkeeping in step with reality: when placement_seed is None,
# place() ignores seed_offset entirely, so advancing it would suggest a seed stream
# that does not exist.
if self._placer.params.placement_seed is not None:
self._next_seed_offset += num_candidates

# TODO(@zhx06): Simplify once ObjectPlacer.place() always returns MultiEnvPlacementResult.
all_results = result.results if isinstance(result, MultiEnvPlacementResult) else [result]
Expand Down Expand Up @@ -122,7 +140,7 @@ def sample_with_replacement(self, count: int) -> list[PlacementResult]:
Used by ``resolve_on_reset=False`` to assign initial positions
that persist across resets.
"""
return random.choices(self._layouts, k=count)
return self._rng.choices(self._layouts, k=count)

@property
def remaining(self) -> int:
Expand Down
183 changes: 183 additions & 0 deletions isaaclab_arena/tests/test_object_placer_reproducibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

"""Tests for ObjectPlacer and RelationSolver reproducibility."""

import pytest

from isaaclab_arena.assets.dummy_object import DummyObject
from isaaclab_arena.relations.object_placer import ObjectPlacer
from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams
from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult, PlacementResult
from isaaclab_arena.relations.pooled_object_placer import PooledObjectPlacer
from isaaclab_arena.relations.relation_solver import RelationSolver
from isaaclab_arena.relations.relation_solver_params import RelationSolverParams
from isaaclab_arena.relations.relations import IsAnchor, NextTo, On, Side
Expand Down Expand Up @@ -244,3 +246,184 @@ def test_object_placer_result_per_env_true_applies_pose_per_env():
pose = obj.get_initial_pose()
assert isinstance(pose, PosePerEnv), f"{obj.name} should have PosePerEnv, got {type(pose).__name__}"
assert len(pose.poses) == num_envs


def test_pooled_placer_same_seed_produces_identical_samples():
"""PooledObjectPlacer.sample_with_replacement must be reproducible under placement_seed."""
solver_params = RelationSolverParams(max_iters=50)
placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False)

pool1 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=8)
pool2 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=8)

samples1 = pool1.sample_with_replacement(20)
samples2 = pool2.sample_with_replacement(20)

for s1, s2 in zip(samples1, samples2):
for obj1, pos1 in s1.positions.items():
pos2 = next(p for o, p in s2.positions.items() if o.name == obj1.name)
assert pos1 == pos2, f"Mismatch for {obj1.name}: {pos1} != {pos2}"


def test_pooled_placer_different_seeds_produce_different_samples():
"""Different placement_seed values should produce different sample sequences."""
solver_params = RelationSolverParams(max_iters=50)

pool1 = PooledObjectPlacer(
objects=list(_create_test_objects()),
placer_params=ObjectPlacerParams(
placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False
),
pool_size=8,
)
pool2 = PooledObjectPlacer(
objects=list(_create_test_objects()),
placer_params=ObjectPlacerParams(
placement_seed=123, solver_params=solver_params, apply_positions_to_objects=False
),
pool_size=8,
)

samples1 = pool1.sample_with_replacement(20)
samples2 = pool2.sample_with_replacement(20)

any_different = False
for s1, s2 in zip(samples1, samples2):
for obj1, pos1 in s1.positions.items():
pos2 = next(p for o, p in s2.positions.items() if o.name == obj1.name)
if pos1 != pos2:
any_different = True
break
if any_different:
break
assert any_different, "Different seeds should produce different samples"


def _positions_by_name(result: PlacementResult) -> dict[str, tuple[float, float, float]]:
return {obj.name: pos for obj, pos in result.positions.items()}


def test_pooled_placer_builds_identical_layouts_for_same_seed_and_objects():
"""Same objects + same placement_seed must produce bit-identical optimized layouts.

Draining the pool via ``sample_without_replacement`` walks ``_layouts`` in order with
no RNG involvement, so this compares the solver output directly across the two pools.
"""
solver_params = RelationSolverParams(max_iters=50)
placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False)

pool1 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4)
pool2 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4)

layouts1 = pool1.sample_without_replacement(4)
layouts2 = pool2.sample_without_replacement(4)

for L1, L2 in zip(layouts1, layouts2):
assert _positions_by_name(L1) == _positions_by_name(L2)


def test_pooled_placer_continues_seed_stream_across_refill():
"""Pool refill must advance the candidate seed stream so it doesn't replay the initial batch."""
solver_params = RelationSolverParams(max_iters=50)
placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False)

pool1 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=2)
pool2 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=2)

# Four draws: initial batch (draws 0-1) then a forced refill (draws 2-3).
draws1 = [pool1.sample_without_replacement(1)[0] for _ in range(4)]
draws2 = [pool2.sample_without_replacement(1)[0] for _ in range(4)]

for d1, d2 in zip(draws1, draws2):
assert _positions_by_name(d1) == _positions_by_name(d2), "Same seed must give same draws across refill"

# The refill batch (draws 2 and 3) must not be a permutation of the initial batch (draws 0 and 1).
# If seed_offset failed to advance, the second batch would re-solve from seeds 0..N-1 and the
# cross-batch positions would coincide.
initial_batch = {_positions_by_name(draws1[0]).__repr__(), _positions_by_name(draws1[1]).__repr__()}
refill_batch = {_positions_by_name(draws1[2]).__repr__(), _positions_by_name(draws1[3]).__repr__()}
assert initial_batch.isdisjoint(
refill_batch
), "Refill replayed the initial seed range; seed_offset is not advancing"


def test_pooled_placer_unseeded_does_not_crash_and_diverges():
"""placement_seed=None must work (no crash) and produce non-reproducible samples."""
solver_params = RelationSolverParams(max_iters=50)
placer_params = ObjectPlacerParams(
placement_seed=None, solver_params=solver_params, apply_positions_to_objects=False
)

pool1 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4)
pool2 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4)

samples1 = pool1.sample_with_replacement(20)
samples2 = pool2.sample_with_replacement(20)

any_different = False
for s1, s2 in zip(samples1, samples2):
if _positions_by_name(s1) != _positions_by_name(s2):
any_different = True
break
assert any_different, "Two unseeded pools should not produce identical samples"


def test_pooled_placer_stored_layouts_have_distinct_positions_dicts():
"""Each stored layout must own a distinct positions dict (no aliasing across pool entries)."""
solver_params = RelationSolverParams(max_iters=50)
placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False)

Comment on lines +355 to +375
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Non-deterministic test asserting non-determinism

test_pooled_placer_unseeded_does_not_crash_and_diverges asserts that two independently unseeded pools produce at least one differing sample across 20 draws. While this is overwhelmingly likely in practice, it is not guaranteed: if the gradient solver converges to the same global minimum from two different random inits, all positions could coincide and the assertion fires incorrectly. Consider replacing the divergence assertion with a no-crash smoke test.

pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4)
draws = pool.sample_without_replacement(4)

for i in range(len(draws)):
for j in range(i + 1, len(draws)):
assert (
draws[i].positions is not draws[j].positions
), f"Layouts {i} and {j} share the same positions dict reference"


def test_pooled_placer_sample_without_replacement_count_exceeds_pool_size():
"""sample_without_replacement(count) where count > pool_size must solve a larger batch in one shot."""
solver_params = RelationSolverParams(max_iters=50)
placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False)

pool1 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=2)
pool2 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=2)

draws1 = pool1.sample_without_replacement(5)
draws2 = pool2.sample_without_replacement(5)

assert len(draws1) == 5 and len(draws2) == 5
for d1, d2 in zip(draws1, draws2):
assert _positions_by_name(d1) == _positions_by_name(d2)


def test_pooled_placer_sample_with_replacement_reproducible_across_refill():
"""sample_with_replacement must remain reproducible after a refill mutates the pool."""
solver_params = RelationSolverParams(max_iters=50)
placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False)

def draw_sequence(pool: PooledObjectPlacer) -> list:
before = pool.sample_with_replacement(5)
# Consume the initial pool to force a refill.
pool.sample_without_replacement(2)
pool.sample_without_replacement(1)
after = pool.sample_with_replacement(5)
return before + after

pool1 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=2)
pool2 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=2)

seq1 = draw_sequence(pool1)
seq2 = draw_sequence(pool2)

for s1, s2 in zip(seq1, seq2):
assert _positions_by_name(s1) == _positions_by_name(s2)


def test_pooled_placer_rejects_pool_size_below_one():
"""pool_size < 1 is an invariant violation; the assert must fire."""
placer_params = ObjectPlacerParams(placement_seed=42, solver_params=RelationSolverParams(max_iters=10))
with pytest.raises(AssertionError, match="pool_size must be >= 1"):
PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=0)
Comment on lines +428 to +429
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 This test is coupled to the implementation mechanism (AssertionError) rather than the semantic contract (ValueError). If the validation is restored to a ValueError (which is idiomatic for invalid constructor arguments), this test would incorrectly pass the assert but then fail with the wrong exception type. It should catch ValueError to verify the public API contract.

Suggested change
with pytest.raises(AssertionError, match="pool_size must be >= 1"):
PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=0)
with pytest.raises(ValueError, match="pool_size must be >= 1"):
PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=0)

Loading