-
Notifications
You must be signed in to change notification settings - Fork 52
Add NotOn, NotNextTo loss strategy and solver #732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,7 @@ | |
| from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox | ||
|
|
||
| if TYPE_CHECKING: | ||
| from isaaclab_arena.relations.relations import AtPosition, NextTo, On, PositionLimits, Relation | ||
| from isaaclab_arena.relations.relations import AtPosition, NextTo, NotNextTo, NotOn, On, PositionLimits, Relation | ||
|
|
||
| from isaaclab_arena.relations.relations import Side | ||
|
|
||
|
|
@@ -319,6 +319,184 @@ def compute_loss( | |
| return result.squeeze(0) if single_input else result | ||
|
|
||
|
|
||
| class NotOnLossStrategy(RelationLossStrategy): | ||
| """Loss strategy for ``NotOn`` — push child out of parent's XY footprint. | ||
|
|
||
| The inverse of ``OnLossStrategy``. ``NotOn`` is satisfied as soon | ||
| as either the X or Y footprint has *escaped* the parent — Z is | ||
| incidental, since with no XY overlap there is nothing to stack on. | ||
|
|
||
| Loss = ``slope * min(inside_x, inside_y)``, where ``inside_axis`` | ||
| is the child center's "depth inside" the placement band along that | ||
| axis (0 when outside, > 0 when inside). The ``min`` means a | ||
| single-axis escape is sufficient. The gradient on the smaller | ||
| inside-axis points toward the nearer edge — the optimizer pops | ||
| the child out along the path of least travel. | ||
|
|
||
| If ``margin_m > 0``, the parent's effective footprint is widened | ||
| by ``margin_m`` along both XY axes, so the loss does not drop to | ||
| zero until the child has cleared the parent by that margin. | ||
| """ | ||
|
|
||
| def __init__(self, slope: float = 100.0, margin_m: float = 0.0, debug: bool = False): | ||
| """ | ||
| Args: | ||
| slope: Loss magnitude per meter of inside-penetration. | ||
| Matches the default of ``OnLossStrategy``. | ||
| margin_m: Optional safety margin (meters) added to each XY | ||
| extent of the parent's footprint. The loss falls to | ||
| zero only when the child has cleared the widened | ||
| footprint. Default 0.0. | ||
| debug: If True, print the per-axis inside-penetration. | ||
| """ | ||
| assert slope >= 0.0, f"slope must be non-negative, got {slope}" | ||
| assert margin_m >= 0.0, f"margin_m must be non-negative, got {margin_m}" | ||
| self.slope = slope | ||
| self.margin_m = margin_m | ||
| self.debug = debug | ||
|
|
||
| def compute_loss( | ||
| self, | ||
| relation: "NotOn", | ||
| child_pos: torch.Tensor, | ||
| child_bbox: AxisAlignedBoundingBox, | ||
| parent_world_bbox: AxisAlignedBoundingBox, | ||
| ) -> torch.Tensor: | ||
| """Compute loss for ``NotOn``.""" | ||
| single_input = child_pos.dim() == 1 | ||
| if single_input: | ||
| child_pos = child_pos.unsqueeze(0) | ||
|
|
||
| # The On-valid band: child center positions for which the | ||
| # child's footprint is entirely inside the parent's footprint. | ||
| # Inflated by margin_m so Not(On) keeps pushing past the rim. | ||
| m = self.margin_m | ||
| valid_x_min = parent_world_bbox.min_point[:, 0] - child_bbox.min_point[:, 0] - m | ||
| valid_x_max = parent_world_bbox.max_point[:, 0] - child_bbox.max_point[:, 0] + m | ||
| valid_y_min = parent_world_bbox.min_point[:, 1] - child_bbox.min_point[:, 1] - m | ||
| valid_y_max = parent_world_bbox.max_point[:, 1] - child_bbox.max_point[:, 1] + m | ||
|
|
||
| # Inside-band penetration: distance from child center to the | ||
| # nearer edge, clamped to >= 0. Zero when the child has | ||
| # escaped the band along that axis. | ||
| zero = torch.zeros((), dtype=child_pos.dtype, device=child_pos.device) | ||
| inside_x = torch.maximum(zero, torch.minimum(child_pos[:, 0] - valid_x_min, valid_x_max - child_pos[:, 0])) | ||
| inside_y = torch.maximum(zero, torch.minimum(child_pos[:, 1] - valid_y_min, valid_y_max - child_pos[:, 1])) | ||
|
|
||
| # min(): a single-axis escape is enough to satisfy Not(On). | ||
| loss = self.slope * torch.minimum(inside_x, inside_y) | ||
|
|
||
| if self.debug and child_pos.shape[0] == 1: | ||
| print( | ||
| f" [NotOn] inside_x={inside_x[0].item():.6f} inside_y={inside_y[0].item():.6f} " | ||
| f"-> loss={loss[0].item():.6f}" | ||
| ) | ||
|
|
||
| result = relation.relation_loss_weight * loss | ||
| return result.squeeze(0) if single_input else result | ||
|
|
||
|
|
||
| class NotNextToLossStrategy(RelationLossStrategy): | ||
| """Loss strategy for ``NotNextTo`` — push child out of the NextTo zone on the given side. | ||
|
|
||
| The inverse of ``NextToLossStrategy``. The NextTo "satisfied | ||
| region" is the conjunction of three conditions: | ||
| 1. **Half-plane:** child on the correct side of the parent edge. | ||
| 2. **Cross band:** child's perpendicular position inside the | ||
| parent's perpendicular extent. | ||
| 3. **Target distance:** primary-axis position equals | ||
| ``parent_edge + direction * distance_m``. | ||
|
|
||
| ``NotNextTo`` is satisfied as soon as *any one* condition is | ||
| violated by at least ``margin_m`` meters. | ||
|
|
||
| Loss = ``slope * min(relu(margin_m - escape_side), | ||
| relu(margin_m - escape_cross), relu(margin_m - escape_dist))``. | ||
| The gradient points along whichever escape is cheapest, so the | ||
| optimizer naturally escapes by the easiest of the three directions. | ||
| """ | ||
|
|
||
| def __init__(self, slope: float = 10.0, margin_m: float = 0.05, debug: bool = False): | ||
| """ | ||
| Args: | ||
| slope: Loss magnitude per meter inside the safety margin. | ||
| Matches the default of ``NextToLossStrategy``. | ||
| margin_m: Meters of clearance required along whichever | ||
| escape axis the optimizer takes. Default 5 cm. | ||
| debug: If True, print the per-condition escape distances. | ||
| """ | ||
| assert slope >= 0.0, f"slope must be non-negative, got {slope}" | ||
| assert margin_m > 0.0, f"margin_m must be positive, got {margin_m}" | ||
| self.slope = slope | ||
| self.margin_m = margin_m | ||
| self.debug = debug | ||
|
|
||
| def compute_loss( | ||
| self, | ||
| relation: "NotNextTo", | ||
| child_pos: torch.Tensor, | ||
| child_bbox: AxisAlignedBoundingBox, | ||
| parent_world_bbox: AxisAlignedBoundingBox, | ||
| ) -> torch.Tensor: | ||
| """Compute loss for ``NotNextTo``.""" | ||
| single_input = child_pos.dim() == 1 | ||
| if single_input: | ||
| child_pos = child_pos.unsqueeze(0) | ||
|
|
||
| cfg = SIDE_CONFIGS[relation.side] | ||
| distance = relation.distance_m | ||
|
|
||
| # Mirror NextToLossStrategy's target-position derivation. | ||
| if cfg.direction == Direction.POSITIVE: | ||
| parent_edge = parent_world_bbox.max_point[:, cfg.primary_axis] | ||
| child_offset = child_bbox.min_point[:, cfg.primary_axis] | ||
| else: | ||
| parent_edge = parent_world_bbox.min_point[:, cfg.primary_axis] | ||
| child_offset = child_bbox.max_point[:, cfg.primary_axis] | ||
| target_pos = parent_edge + cfg.direction * distance - child_offset | ||
|
|
||
| # Cross band: child placed at target position within parent's perpendicular extent. | ||
| parent_band_min = parent_world_bbox.min_point[:, cfg.band_axis] | ||
| parent_band_max = parent_world_bbox.max_point[:, cfg.band_axis] | ||
| valid_band_min = parent_band_min - child_bbox.min_point[:, cfg.band_axis] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we could reuse some existing losses like the |
||
| valid_band_max = parent_band_max - child_bbox.max_point[:, cfg.band_axis] | ||
|
|
||
| primary = child_pos[:, cfg.primary_axis] | ||
| cross = child_pos[:, cfg.band_axis] | ||
| zero = torch.zeros((), dtype=child_pos.dtype, device=child_pos.device) | ||
|
|
||
| # escape_side: how far on the WRONG side of parent's edge (0 if on correct side). | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like the half-plane logic that |
||
| # direction = +1 means child should be > parent_edge; wrong-side amount = parent_edge - child. | ||
| escape_side = torch.maximum(zero, (parent_edge - primary) * cfg.direction) | ||
|
|
||
| # escape_cross: how far OUTSIDE the perpendicular band (0 inside). | ||
| escape_cross = torch.maximum(zero, valid_band_min - cross) + torch.maximum(zero, cross - valid_band_max) | ||
|
|
||
| # escape_dist: how far the primary-axis position is from the target distance (always >= 0). | ||
| escape_dist = torch.abs(primary - target_pos) | ||
|
|
||
| # Per-condition "how much of the margin is unfilled". Zero once that escape passes margin. | ||
| margin = self.margin_m | ||
| gap_side = torch.maximum(zero, margin - escape_side) | ||
| gap_cross = torch.maximum(zero, margin - escape_cross) | ||
| gap_dist = torch.maximum(zero, margin - escape_dist) | ||
|
|
||
| # min(): a single escape past the margin is enough to satisfy Not(NextTo). | ||
| loss = self.slope * torch.minimum(torch.minimum(gap_side, gap_cross), gap_dist) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, suggestion to move the loss computation into |
||
|
|
||
| if self.debug and child_pos.shape[0] == 1: | ||
| print( | ||
| f" [NotNextTo] {relation.side.value}: " | ||
| f"escape_side={escape_side[0].item():.4f} " | ||
| f"escape_cross={escape_cross[0].item():.4f} " | ||
| f"escape_dist={escape_dist[0].item():.4f} " | ||
| f"-> loss={loss[0].item():.6f}" | ||
| ) | ||
|
|
||
| result = relation.relation_loss_weight * loss | ||
| return result.squeeze(0) if single_input else result | ||
|
|
||
|
|
||
| class NoCollisionLossStrategy: | ||
| """Loss strategy for no-overlap constraints between objects. | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion to move the all loss computation into
loss_primitives.pyinto respective functions next to the others.