diff --git a/isaaclab_arena/tasks/composite_task_base.py b/isaaclab_arena/tasks/composite_task_base.py new file mode 100644 index 000000000..0a1206f66 --- /dev/null +++ b/isaaclab_arena/tasks/composite_task_base.py @@ -0,0 +1,400 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +import copy +import numpy as np +import torch +import warnings +from dataclasses import MISSING +from functools import partial +from typing import Any + +from isaaclab.envs.mimic_env_cfg import MimicEnvCfg, SubTaskConfig +from isaaclab.managers import EventTermCfg, TerminationTermCfg +from isaaclab.managers.recorder_manager import RecorderTerm, RecorderTermCfg +from isaaclab.utils import configclass + +from isaaclab_arena.embodiments.common.arm_mode import ArmMode +from isaaclab_arena.metrics.metric_base import MetricBase +from isaaclab_arena.tasks.common.mimic_default_params import MIMIC_DATAGEN_CONFIG_DEFAULTS +from isaaclab_arena.tasks.task_base import TaskBase +from isaaclab_arena.utils.configclass import ( + check_configclass_field_duplicates, + combine_configclass_instances, + transform_configclass_instance, +) + + +@configclass +class CompositeTaskEventsCfg: + reset_subtask_success_state: EventTermCfg = MISSING + + +@configclass +class TerminationsCfg: + success: TerminationTermCfg = MISSING + + +class SubtaskSuccessStateRecorder(RecorderTerm): + """Records the subtask success state just before the environment is reset.""" + + def __init__(self, cfg, env): + super().__init__(cfg, env) + self.name = cfg.name + + def record_post_step(self): + # Return subtask success state as a torch tensor + subtask_ever_succeeded = torch.tensor(self._env._subtask_ever_succeeded, device=self._env.device) + return self.name, subtask_ever_succeeded.clone() + + +@configclass +class SubtaskSuccessStateRecorderCfg(RecorderTermCfg): + class_type: type[RecorderTerm] = SubtaskSuccessStateRecorder + name: str = "subtask_success_rate" + + +class SubtaskSuccessRateMetric(MetricBase): + """Computes the per-subtask success rates. + + Returns a dict with success rate for each subtask. + """ + + name = "subtask_success_rate" + recorder_term_name = "subtask_success_rate" + + def __init__(self): + super().__init__() + + def get_recorder_term_cfg(self) -> RecorderTermCfg: + """Return the recorder term configuration for the subtask success state metric.""" + return SubtaskSuccessStateRecorderCfg(name=self.recorder_term_name) + + def compute_metric_from_recording(self, recorded_metric_data: list[np.ndarray]) -> list: + """Computes per-subtask success rates. + + Args: + recorded_metric_data: List of arrays, each shape (num_subtasks,) with bool values. + + Returns: + List of success rates for each subtask. + """ + num_demos = len(recorded_metric_data) + if num_demos == 0: + return [0.0] + + num_subtasks = recorded_metric_data[0].shape[1] + subtask_successes = np.zeros(num_subtasks, dtype=float) + + for ep in range(num_demos): + ep_subtask_success_result = np.any(recorded_metric_data[ep], axis=0).astype(float) + subtask_successes += ep_subtask_success_result + subtask_success_rates = subtask_successes / num_demos + + return subtask_success_rates.tolist() + + +class CompositeTaskBase(TaskBase): + """ + A base class for composite tasks composed of multiple subtasks. + Completion ordering of subtasks does not matter. + + + Args: + subtasks: List of TaskBase instances representing the subtasks that compose this composite task. + episode_length_s: Maximum duration of a single episode in seconds. If None, no time limit is enforced. + desired_subtask_success_state: (Optional) Precise success state for each subtask during the final time step. + Can be used to enforce a specific current state for each subtask at the end of the episode. + """ + + def __init__( + self, + subtasks: list[TaskBase], + episode_length_s: float | None = None, + desired_subtask_success_state: list[bool | None] | None = None, + ): + super().__init__(episode_length_s) + assert len(subtasks) > 0, "Composite task requires at least one subtask" + self.subtasks = subtasks + + if desired_subtask_success_state is not None: + assert len(desired_subtask_success_state) == len( + subtasks + ), "Desired subtask success state must be the same length as the number of subtasks" + assert all( + s is None or isinstance(s, bool) for s in desired_subtask_success_state + ), "Desired subtask success state entries must each be True, False, or None" + self.desired_subtask_success_state = desired_subtask_success_state + + @staticmethod + def _add_suffix_configclass_transform(fields: list[tuple], suffix: str) -> list[tuple]: + "Config transformation to add a suffix to all field names." + return [(f"{name}{suffix}", ftype, value) for name, ftype, value in fields] + + @staticmethod + def _remove_configclass_transform(fields: list[tuple], exclude_fields: set[str]) -> list[tuple]: + "Config transformation to remove all fields in an exclude set." + return [(name, ftype, value) for name, ftype, value in fields if name not in exclude_fields] + + @staticmethod + def _evaluate_subtask_successes( + env, + subtasks: list[TaskBase], + subtask_indices, + ) -> list[list[bool]]: + """Evaluate the success function of selected subtasks across all envs. + + Args: + env: The environment instance. + subtasks: Full list of subtasks for this composite task. + subtask_indices: Iterable of subtask indices to evaluate. Indices not in this + iterable are left as False in the returned matrix. + + Returns: + A (num_envs x len(subtasks)) list of bools, where entry [env_idx][subtask_idx] + is True if that subtask's success function returned True this step. + """ + subtask_currently_succeeding = [[False for _ in subtasks] for _ in range(env.num_envs)] + for subtask_idx in subtask_indices: + subtask_success_func = subtasks[subtask_idx].get_termination_cfg().success.func + subtask_success_params = subtasks[subtask_idx].get_termination_cfg().success.params + results = subtask_success_func(env, **subtask_success_params) + for env_idx in range(env.num_envs): + if results[env_idx]: + subtask_currently_succeeding[env_idx][subtask_idx] = True + return subtask_currently_succeeding + + @staticmethod + def composite_task_success_func( + env, + subtasks: list[TaskBase], + desired_subtask_success_state: list[bool | None] | None, + ) -> torch.Tensor: + """Composite task composite success function. + + Args: + env: The environment instance. + subtasks: List of subtasks that compose this composite task. + desired_subtask_success_state: (Optional) Precise success state for each subtask during the final time step. + Can be used to enforce a specific current state for each subtask at the end of the episode. + + Returns: + A bool tensor of shape (num_envs,) indicating composite success per env. + """ + # Initialize each env's subtask success state to False if not already initialized + if not hasattr(env, "_subtask_ever_succeeded"): + env._subtask_ever_succeeded = [[False for _ in subtasks] for _ in range(env.num_envs)] + + # Evaluate every subtask's success function (composite tasks have no ordering constraint). + subtask_currently_succeeding = CompositeTaskBase._evaluate_subtask_successes( + env, subtasks, range(len(subtasks)) + ) + for env_idx in range(env.num_envs): + for subtask_idx in range(len(subtasks)): + if subtask_currently_succeeding[env_idx][subtask_idx]: + env._subtask_ever_succeeded[env_idx][subtask_idx] = True + + # Compute composite task success state for each env. + # Entries in `desired_subtask_success_state` set to None are "don't cares" and + # may be any state. For each subtask it must (a) have been evaluated as True + # at some point and (b) currently match the desired value. + if desired_subtask_success_state is not None: + per_env_success = [] + for env_idx in range(env.num_envs): + env_success = True + for i, desired in enumerate(desired_subtask_success_state): + if desired is None: + continue + # Check that both the subtask has ever succeeded and currently matches the desired success state. + ever_succeeded = env._subtask_ever_succeeded[env_idx][i] + currently_matches = subtask_currently_succeeding[env_idx][i] == desired + if not (ever_succeeded and currently_matches): + env_success = False + break + per_env_success.append(env_success) + else: + per_env_success = [all(env_successes) for env_successes in env._subtask_ever_succeeded] + + success_tensor = torch.tensor(per_env_success, dtype=torch.bool, device=env.device) + + env.extras["subtask_success_state"] = copy.copy(env._subtask_ever_succeeded) + + return success_tensor + + @staticmethod + def reset_subtask_success_state( + env, + env_ids, + subtasks: list[TaskBase], + ) -> None: + "Reset subtask success vector for each environment." + # Initialize each env's subtask success state to False + if not hasattr(env, "_subtask_ever_succeeded"): + env._subtask_ever_succeeded = [[False for _ in subtasks] for _ in range(env.num_envs)] + else: + for env_id in env_ids: + env._subtask_ever_succeeded[env_id] = [False for _ in subtasks] + + def get_scene_cfg(self) -> Any: + "Make combined scene cfg from all subtasks." + # Check for duplicate fields across subtask scene configs and warn if found + duplicates = check_configclass_field_duplicates(*(subtask.get_scene_cfg() for subtask in self.subtasks)) + if duplicates: + warnings.warn( + f"\n[WARNING] Duplicate scene config fields found across subtasks: {duplicates}. " + "Duplicates will be ignored.\n", + UserWarning, + ) + + scene_cfg = combine_configclass_instances("SceneCfg", *(subtask.get_scene_cfg() for subtask in self.subtasks)) + return scene_cfg + + def _make_composite_task_events_cfg(self) -> Any: + "Make event to reset subtask success state." + reset_subtask_success_state = EventTermCfg( + func=self.reset_subtask_success_state, + mode="reset", + params={ + "subtasks": self.subtasks, + }, + ) + + return CompositeTaskEventsCfg( + reset_subtask_success_state=reset_subtask_success_state, + ) + + def get_events_cfg(self) -> Any: + "Make combined events cfg from all subtasks." + # Collect events_cfgs from subtasks with renamed fields to avoid collisions + renamed_events_cfgs = [] + for i, subtask in enumerate(self.subtasks): + subtask_events_cfg = subtask.get_events_cfg() + if subtask_events_cfg is None: + continue + renamed_cfg = transform_configclass_instance( + subtask_events_cfg, partial(self._add_suffix_configclass_transform, suffix=f"_subtask_{i}") + ) + assert renamed_cfg is not None, f"Renaming dropped subtask {i}'s events cfg" + renamed_events_cfgs.append(renamed_cfg) + + # Add reset subtask success state event to the combined events cfgs + events_cfg = combine_configclass_instances( + "EventsCfg", *renamed_events_cfgs, self._make_composite_task_events_cfg() + ) + + return events_cfg + + def _make_composite_task_termination_cfg(self) -> Any: + "Make composite success check termination term." + success = TerminationTermCfg( + func=self.composite_task_success_func, + params={ + "subtasks": self.subtasks, + "desired_subtask_success_state": self.desired_subtask_success_state, + }, + ) + + return TerminationsCfg( + success=success, + ) + + def get_termination_cfg(self) -> Any: + "Make combined termination cfg from all subtasks." + # Collect termination cfgs from subtasks with 'success' field removed + subtask_termination_cfgs = [] + for subtask in self.subtasks: + termination_cfg = subtask.get_termination_cfg() + cleaned_cfg = transform_configclass_instance( + termination_cfg, partial(self._remove_configclass_transform, exclude_fields={"success"}) + ) + # cleaned_cfg is None when the subtask's only termination field was 'success' + if cleaned_cfg is not None: + subtask_termination_cfgs.append(cleaned_cfg) + + # Combine subtask terminations with the composite sequential task success + combined_termination_cfg = combine_configclass_instances( + "TerminationsCfg", *subtask_termination_cfgs, self._make_composite_task_termination_cfg() + ) + + return combined_termination_cfg + + def _combine_subtask_metrics(self, subtask_idxs: list[int]) -> list[MetricBase]: + """Combine metrics from subtasks with the given ids. + + Per-subtask "success_rate" metrics are intentionally collapsed into a single shared entry as + the composite task should only have one success rate metric. + Individual per-subtask success is reported separately via SubtaskSuccessRateMetric (added in get_metrics). + """ + combined_metrics = [] + + for subtask_idx in subtask_idxs: + subtask_metrics = self.subtasks[subtask_idx].get_metrics() + for metric in subtask_metrics: + if metric.name != "success_rate": + metric.name = f"{metric.name}_subtask_{subtask_idx}" + metric.recorder_term_name = f"{metric.recorder_term_name}_subtask_{subtask_idx}" + combined_metrics.append(copy.copy(metric)) + else: + if not any(m.name == "success_rate" for m in combined_metrics): + combined_metrics.append(copy.copy(metric)) + + return combined_metrics + + def get_metrics(self) -> list[MetricBase]: + "Get metrics for the composite task." + subtask_metrics = self._combine_subtask_metrics([i for i in range(len(self.subtasks))]) + # Add the composite task's own metric for per-subtask success rates + subtask_metrics.append(SubtaskSuccessRateMetric()) + + return subtask_metrics + + def _validate_consistent_mimic_eef_names(self, arm_mode: ArmMode) -> set[str]: + "Check that all subtasks have the same Mimic eef_names." + mimic_eef_names = set(self.subtasks[0].get_mimic_env_cfg(arm_mode).subtask_configs.keys()) + for i, subtask in enumerate(self.subtasks[1:], start=1): + subtask_eef_names_set = set(subtask.get_mimic_env_cfg(arm_mode).subtask_configs.keys()) + if subtask_eef_names_set != mimic_eef_names: + raise ValueError( + f"All subtasks must have the same Mimic eef_names.\nSubtask 0 has eef_names: {mimic_eef_names}, but" + f" subtask {i} has eef_names: {subtask_eef_names_set}." + ) + return mimic_eef_names + + def combine_mimic_subtask_configs(self, arm_mode: ArmMode) -> dict[str, list[SubTaskConfig]]: + "Combine the Mimic subtask configs for all subtasks." + mimic_eef_names = self._validate_consistent_mimic_eef_names(arm_mode) + + combined_mimic_subtask_configs = {eef_name: [] for eef_name in mimic_eef_names} + + # Combine the "Mimic subtask" cfgs from all subtasks + for i, subtask in enumerate(self.subtasks): + # Get the Mimic env cfg for the subtask + mimic_env_cfg = subtask.get_mimic_env_cfg(arm_mode) + for eef_name in mimic_eef_names: + # For each eef, get the "Mimic subtask" cfgs for the subtask, update the term signal name, + # and add it to the combined "Mimic subtask" list + for mimic_subtask in mimic_env_cfg.subtask_configs[eef_name]: + if not mimic_subtask.subtask_term_signal: + # The last Mimic subtasks may not have an explicit term signal name + # so give it a default name if it doesn't already have one. + mimic_subtask.subtask_term_signal = f"subtask_{i}_{eef_name}_last_mimic_subtask" + else: + mimic_subtask.subtask_term_signal = ( + f"subtask_{i}_{eef_name}_{mimic_subtask.subtask_term_signal}" + ) + combined_mimic_subtask_configs[eef_name].append(mimic_subtask) + + return combined_mimic_subtask_configs + + def get_mimic_env_cfg(self, arm_mode: ArmMode) -> MimicEnvCfg: + "Get the Mimic environment configuration for the sequential task." + mimic_env_cfg = MimicEnvCfg() + + # Assign all default config values to mimic_env_cfg.datagen_config + for key, value in MIMIC_DATAGEN_CONFIG_DEFAULTS.items(): + setattr(mimic_env_cfg.datagen_config, key, value) + + mimic_env_cfg.subtask_configs = self.combine_mimic_subtask_configs(arm_mode) + return mimic_env_cfg diff --git a/isaaclab_arena/tasks/sequential_task_base.py b/isaaclab_arena/tasks/sequential_task_base.py index f097aeff5..36b743722 100644 --- a/isaaclab_arena/tasks/sequential_task_base.py +++ b/isaaclab_arena/tasks/sequential_task_base.py @@ -4,99 +4,15 @@ # SPDX-License-Identifier: Apache-2.0 import copy -import numpy as np import torch -from dataclasses import MISSING -from functools import partial -from isaaclab.envs.mimic_env_cfg import MimicEnvCfg, SubTaskConfig -from isaaclab.managers import EventTermCfg, TerminationTermCfg -from isaaclab.managers.recorder_manager import RecorderTerm, RecorderTermCfg -from isaaclab.utils import configclass - -from isaaclab_arena.embodiments.common.arm_mode import ArmMode -from isaaclab_arena.metrics.metric_base import MetricBase -from isaaclab_arena.tasks.common.mimic_default_params import MIMIC_DATAGEN_CONFIG_DEFAULTS +from isaaclab_arena.tasks.composite_task_base import CompositeTaskBase from isaaclab_arena.tasks.task_base import TaskBase -from isaaclab_arena.utils.configclass import ( - check_configclass_field_duplicates, - combine_configclass_instances, - transform_configclass_instance, -) - - -@configclass -class SequentialTaskEventsCfg: - reset_subtask_success_state: EventTermCfg = MISSING - - -@configclass -class TerminationsCfg: - success: TerminationTermCfg = MISSING - - -class SubtaskSuccessStateRecorder(RecorderTerm): - """Records the subtask success state just before the environment is reset.""" - - def __init__(self, cfg, env): - super().__init__(cfg, env) - self.name = cfg.name - - def record_post_step(self): - # Return subtask success state as a torch tensor - subtask_success_state = torch.tensor(self._env._subtask_success_state, device=self._env.device) - return self.name, subtask_success_state.clone() - - -@configclass -class SubtaskSuccessStateRecorderCfg(RecorderTermCfg): - class_type: type[RecorderTerm] = SubtaskSuccessStateRecorder - name: str = "subtask_success_rate" - - -class SubtaskSuccessRateMetric(MetricBase): - """Computes the per-subtask success rates. - - Returns a dict with success rate for each subtask. - """ - - name = "subtask_success_rate" - recorder_term_name = "subtask_success_rate" - - def __init__(self): - super().__init__() - - def get_recorder_term_cfg(self) -> RecorderTermCfg: - """Return the recorder term configuration for the subtask success state metric.""" - return SubtaskSuccessStateRecorderCfg(name=self.recorder_term_name) - - def compute_metric_from_recording(self, recorded_metric_data: list[np.ndarray]) -> list: - """Computes per-subtask success rates. - - Args: - recorded_metric_data: List of arrays, each shape (num_subtasks,) with bool values. - - Returns: - List of success rates for each subtask. - """ - num_demos = len(recorded_metric_data) - if num_demos == 0: - return [0.0] - - num_subtasks = recorded_metric_data[0].shape[1] - subtask_successes = np.zeros(num_subtasks, dtype=float) - for ep in range(num_demos): - ep_subtask_success_result = np.any(recorded_metric_data[ep], axis=0).astype(float) - subtask_successes += ep_subtask_success_result - subtask_success_rates = subtask_successes / num_demos - return subtask_success_rates.tolist() - - -class SequentialTaskBase(TaskBase): +class SequentialTaskBase(CompositeTaskBase): """ - A base class for composite tasks composed sequentially from multiple subtasks. + A class for composite tasks composed sequentially from multiple subtasks. The sequential task takes a list of TaskBase instances (subtasks), and automatically collects configs to form a composite task. @@ -106,83 +22,72 @@ class SequentialTaskBase(TaskBase): without affecting the completeness of the overall sequential task. """ - def __init__( - self, - subtasks: list[TaskBase], - episode_length_s: float | None = None, - desired_subtask_success_state: list[bool] | None = None, - ): - super().__init__(episode_length_s) - assert len(subtasks) > 0, "SequentialTaskBase requires at least one subtask" - self.subtasks = subtasks - - if desired_subtask_success_state is not None: - assert len(desired_subtask_success_state) == len( - subtasks - ), "Desired subtask success state must be the same length as the number of subtasks" - self.desired_subtask_success_state = desired_subtask_success_state - - @staticmethod - def add_suffix_configclass_transform(fields: list[tuple], suffix: str) -> list[tuple]: - "Config transformation to add a suffix to all field names." - return [(f"{name}{suffix}", ftype, value) for name, ftype, value in fields] - - @staticmethod - def remove_configclass_transform(fields: list[tuple], exclude_fields: set[str]) -> list[tuple]: - "Config transformation to remove all fields in an exclude set." - return [(name, ftype, value) for name, ftype, value in fields if name not in exclude_fields] - @staticmethod - def sequential_task_success_func( + def composite_task_success_func( env, subtasks: list[TaskBase], - desired_subtask_success_state: list[bool] | None, + desired_subtask_success_state: list[bool | None] | None, ) -> torch.Tensor: - "Sequential task composite success function." + """Sequential task composite success function. + + Args: + env: The environment instance. + subtasks: List of subtasks that compose this sequential task. + desired_subtask_success_state: (Optional) Precise success state for each subtask during the final time step. + Can be used to enforce a specific current state for each subtask at the end of the episode. + + Returns: + A bool tensor of shape (num_envs,) indicating composite success per env. + """ # Initialize each env's subtask success state to False if not already initialized - if not hasattr(env, "_subtask_success_state"): - env._subtask_success_state = [[False for _ in subtasks] for _ in range(env.num_envs)] + if not hasattr(env, "_subtask_ever_succeeded"): + env._subtask_ever_succeeded = [[False for _ in subtasks] for _ in range(env.num_envs)] # Initialize each env's current subtask index (state machine) to 0 if not already initialized if not hasattr(env, "_current_subtask_idx"): env._current_subtask_idx = [0 for _ in range(env.num_envs)] - current_subtask_success_state = [[False for _ in subtasks] for _ in range(env.num_envs)] + # Determine which subtasks need their success function evaluated. + if desired_subtask_success_state is not None: + subtasks_to_evaluate = range(len(subtasks)) + else: + subtasks_to_evaluate = sorted(set(env._current_subtask_idx)) - # Check success of subtask for each env - for env_idx in range(env.num_envs): - if desired_subtask_success_state: - # Compute the success state for all subtasks - for subtask_idx in range(len(subtasks)): - subtask_success_func = subtasks[subtask_idx].get_termination_cfg().success.func - subtask_success_params = subtasks[subtask_idx].get_termination_cfg().success.params - result = subtask_success_func(env, **subtask_success_params)[env_idx] - if result: - current_subtask_success_state[env_idx][subtask_idx] = True + subtask_currently_succeeding = CompositeTaskBase._evaluate_subtask_successes( + env, subtasks, subtasks_to_evaluate + ) - # Compute the success state for the current subtask + # Advance the state machine per env using the precomputed active-subtask result. + for env_idx in range(env.num_envs): current_subtask_idx = env._current_subtask_idx[env_idx] - current_subtask_success_func = subtasks[current_subtask_idx].get_termination_cfg().success.func - current_subtask_success_params = subtasks[current_subtask_idx].get_termination_cfg().success.params - result = current_subtask_success_func(env, **current_subtask_success_params)[env_idx] - - if result: - env._subtask_success_state[env_idx][current_subtask_idx] = True + if subtask_currently_succeeding[env_idx][current_subtask_idx]: + env._subtask_ever_succeeded[env_idx][current_subtask_idx] = True if current_subtask_idx < len(subtasks) - 1: env._current_subtask_idx[env_idx] += 1 - # Compute composite task success state for each env - if desired_subtask_success_state: - per_env_success = [ - all(env._subtask_success_state[env_idx]) - and current_subtask_success_state[env_idx] == desired_subtask_success_state - for env_idx in range(env.num_envs) - ] + # Compute composite task success state for each env. + # Entries in `desired_subtask_success_state` set to None are "don't cares" and + # may be any state. For each subtask it must (a) have been evaluated as True + # at some point and (b) currently match the desired value. + if desired_subtask_success_state is not None: + per_env_success = [] + for env_idx in range(env.num_envs): + env_success = True + for i, desired in enumerate(desired_subtask_success_state): + if desired is None: + continue + # Check that both the subtask has ever succeeded and currently matches the desired success state. + ever_succeeded = env._subtask_ever_succeeded[env_idx][i] + currently_matches = subtask_currently_succeeding[env_idx][i] == desired + if not (ever_succeeded and currently_matches): + env_success = False + break + per_env_success.append(env_success) else: - per_env_success = [all(env_successes) for env_successes in env._subtask_success_state] + per_env_success = [all(env_successes) for env_successes in env._subtask_ever_succeeded] success_tensor = torch.tensor(per_env_success, dtype=torch.bool, device=env.device) - env.extras["subtask_success_state"] = copy.copy(env._subtask_success_state) + env.extras["subtask_success_state"] = copy.copy(env._subtask_ever_succeeded) return success_tensor @@ -194,11 +99,11 @@ def reset_subtask_success_state( ) -> None: "Reset subtask success vector and state machine for each environment." # Initialize each env's subtask success state to False - if not hasattr(env, "_subtask_success_state"): - env._subtask_success_state = [[False for _ in subtasks] for _ in range(env.num_envs)] + if not hasattr(env, "_subtask_ever_succeeded"): + env._subtask_ever_succeeded = [[False for _ in subtasks] for _ in range(env.num_envs)] else: for env_id in env_ids: - env._subtask_success_state[env_id] = [False for _ in subtasks] + env._subtask_ever_succeeded[env_id] = [False for _ in subtasks] # Initialize each env's current subtask index (state machine) to 0 if not hasattr(env, "_current_subtask_idx"): @@ -206,156 +111,3 @@ def reset_subtask_success_state( else: for env_id in env_ids: env._current_subtask_idx[env_id] = 0 - - def get_scene_cfg(self) -> configclass: - "Make combined scene cfg from all subtasks." - # Check for duplicate fields across subtask scene configs and warn if found - duplicates = check_configclass_field_duplicates(*(subtask.get_scene_cfg() for subtask in self.subtasks)) - if duplicates: - import warnings - - warnings.warn( - f"\n[WARNING] Duplicate scene config fields found across subtasks: {duplicates}. " - "Duplicates will be ignored.\n", - UserWarning, - ) - - scene_cfg = combine_configclass_instances("SceneCfg", *(subtask.get_scene_cfg() for subtask in self.subtasks)) - return scene_cfg - - def make_sequential_task_events_cfg(self) -> configclass: - "Make event to reset subtask success state." - reset_subtask_success_state = EventTermCfg( - func=self.reset_subtask_success_state, - mode="reset", - params={ - "subtasks": self.subtasks, - }, - ) - - return SequentialTaskEventsCfg( - reset_subtask_success_state=reset_subtask_success_state, - ) - - def get_events_cfg(self) -> configclass: - "Make combined events cfg from all subtasks." - # Collect events_cfgs from subtasks with renamed fields to avoid collisions - renamed_events_cfgs = [] - for i, subtask in enumerate(self.subtasks): - subtask_events_cfg = subtask.get_events_cfg() - renamed_cfg = transform_configclass_instance( - subtask_events_cfg, partial(self.add_suffix_configclass_transform, suffix=f"_subtask_{i}") - ) - if renamed_cfg is not None: - renamed_events_cfgs.append(renamed_cfg) - - # Add reset subtask success state event to the combined events cfgs - events_cfg = combine_configclass_instances( - "EventsCfg", *renamed_events_cfgs, self.make_sequential_task_events_cfg() - ) - - return events_cfg - - def make_sequential_task_termination_cfg(self) -> configclass: - "Make composite success check termination term." - success = TerminationTermCfg( - func=self.sequential_task_success_func, - params={ - "subtasks": self.subtasks, - "desired_subtask_success_state": self.desired_subtask_success_state, - }, - ) - - return TerminationsCfg( - success=success, - ) - - def get_termination_cfg(self) -> configclass: - "Make combined termination cfg from all subtasks." - # Collect termination cfgs from subtasks with 'success' field removed - subtask_termination_cfgs = [] - for subtask in self.subtasks: - termination_cfg = subtask.get_termination_cfg() - cleaned_cfg = transform_configclass_instance( - termination_cfg, partial(self.remove_configclass_transform, exclude_fields={"success"}) - ) - if cleaned_cfg is not None: - subtask_termination_cfgs.append(cleaned_cfg) - - # Combine subtask terminations with the composite sequential task success - combined_termination_cfg = combine_configclass_instances( - "TerminationsCfg", *subtask_termination_cfgs, self.make_sequential_task_termination_cfg() - ) - - return combined_termination_cfg - - def combine_subtask_metrics(self, subtask_idxs: list[int]) -> list[MetricBase]: - "Combine metrics from subtasks with the given ids." - combined_metrics = [] - - for subtask_idx in subtask_idxs: - subtask_metrics = self.subtasks[subtask_idx].get_metrics() - for metric in subtask_metrics: - if metric.name != "success_rate": - metric.name = f"{metric.name}_subtask_{subtask_idx}" - metric.recorder_term_name = f"{metric.recorder_term_name}_subtask_{subtask_idx}" - combined_metrics.append(copy.copy(metric)) - else: - if not any(m.name == "success_rate" for m in combined_metrics): - combined_metrics.append(copy.copy(metric)) - - return combined_metrics - - def get_metrics(self) -> list[MetricBase]: - "Get metrics for the sequential task." - subtask_metrics = self.combine_subtask_metrics([i for i in range(len(self.subtasks))]) - # Add the sequential task's own metric for per-subtask success rates - subtask_metrics.append(SubtaskSuccessRateMetric()) - - return subtask_metrics - - def combine_mimic_subtask_configs(self, arm_mode: ArmMode) -> dict[str, list[SubTaskConfig]]: - "Combine the Mimic subtask configs for all subtasks." - # Check that all subtasks have the same Mimic eef_names - mimic_eef_names = set(self.subtasks[0].get_mimic_env_cfg(arm_mode).subtask_configs.keys()) - - for subtask in self.subtasks[1:]: - subtask_eef_names_set = set(subtask.get_mimic_env_cfg(arm_mode).subtask_configs.keys()) - if subtask_eef_names_set != mimic_eef_names: - raise ValueError( - f"All subtasks must have the same Mimic eef_names.\nSubtask 0 has eef_names: {mimic_eef_names}, but" - f" subtask {self.subtasks.index(subtask)} has eef_names: {subtask_eef_names_set}." - ) - - combined_mimic_subtask_configs = {eef_name: [] for eef_name in mimic_eef_names} - - # Combine the "Mimic subtask" cfgs from all subtasks - for i, subtask in enumerate(self.subtasks): - # Get the Mimic env cfg for the subtask - mimic_env_cfg = subtask.get_mimic_env_cfg(arm_mode) - for eef_name in mimic_eef_names: - # For each eef, get the "Mimic subtask" cfgs for the subtask, update the term signal name, - # and add it to the combined "Mimic subtask" list - for mimic_subtask in mimic_env_cfg.subtask_configs[eef_name]: - if not mimic_subtask.subtask_term_signal: - # The last Mimic subtasks may not have an explicit term signal name - # so give it a default name if it doesn't already have one. - mimic_subtask.subtask_term_signal = f"subtask_{i}_{eef_name}_last_mimic_subtask" - else: - mimic_subtask.subtask_term_signal = ( - f"subtask_{i}_{eef_name}_{mimic_subtask.subtask_term_signal}" - ) - combined_mimic_subtask_configs[eef_name].append(mimic_subtask) - - return combined_mimic_subtask_configs - - def get_mimic_env_cfg(self, arm_mode: ArmMode) -> MimicEnvCfg: - "Get the Mimic environment configuration for the sequential task." - mimic_env_cfg = MimicEnvCfg() - - # Assign all default config values to mimic_env_cfg.datagen_config - for key, value in MIMIC_DATAGEN_CONFIG_DEFAULTS.items(): - setattr(mimic_env_cfg.datagen_config, key, value) - - mimic_env_cfg.subtask_configs = self.combine_mimic_subtask_configs(arm_mode) - return mimic_env_cfg diff --git a/isaaclab_arena/tests/test_composite_open_door.py b/isaaclab_arena/tests/test_composite_open_door.py new file mode 100644 index 000000000..2ffb75ff3 --- /dev/null +++ b/isaaclab_arena/tests/test_composite_open_door.py @@ -0,0 +1,362 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +import gymnasium as gym +import torch +import traceback + +from isaaclab_arena.tests.utils.subprocess import run_simulation_app_function + +NUM_STEPS = 10 +HEADLESS = True + + +def get_test_environment(remove_reset_door_state_event: bool, num_envs: int): + """Returns a scene which we use for these tests.""" + + from isaaclab_arena.assets.registries import AssetRegistry + from isaaclab_arena.cli.isaaclab_arena_cli import get_isaaclab_arena_cli_parser + from isaaclab_arena.embodiments.franka.franka import FrankaIKEmbodiment + from isaaclab_arena.environments.arena_env_builder import ArenaEnvBuilder + from isaaclab_arena.environments.isaaclab_arena_environment import IsaacLabArenaEnvironment + from isaaclab_arena.scene.scene import Scene + from isaaclab_arena.tasks.composite_task_base import CompositeTaskBase + from isaaclab_arena.tasks.open_door_task import OpenDoorTask + from isaaclab_arena.utils.pose import Pose + + args_parser = get_isaaclab_arena_cli_parser() + args_cli = args_parser.parse_args(["--num_envs", str(num_envs)]) + + asset_registry = AssetRegistry() + background = asset_registry.get_asset_by_name("packing_table")() + microwave_0 = asset_registry.get_asset_by_name("microwave")(prim_path="{ENV_REGEX_NS}/microwave_0") + microwave_1 = asset_registry.get_asset_by_name("microwave")(prim_path="{ENV_REGEX_NS}/microwave_1") + + microwave_0.name = "microwave_0" + microwave_1.name = "microwave_1" + + # Put the microwave on the packing table. + microwave_0.set_initial_pose( + Pose( + position_xyz=(0.6, -0.00586, 0.22773), + rotation_xyzw=(0, 0, -0.7071068, 0.7071068), + ) + ) + microwave_1.set_initial_pose( + Pose( + position_xyz=(0.6, 0.70586, 0.22773), + rotation_xyzw=(0, 0, -0.7071068, 0.7071068), + ) + ) + + subtask_1 = OpenDoorTask(microwave_0) + subtask_2 = OpenDoorTask(microwave_1) + + scene = Scene(assets=[background, microwave_0, microwave_1]) + + isaaclab_arena_environment = IsaacLabArenaEnvironment( + name="composite_open_door", + embodiment=FrankaIKEmbodiment(), + scene=scene, + task=CompositeTaskBase([subtask_1, subtask_2]), + ) + + env_builder = ArenaEnvBuilder(isaaclab_arena_environment, args_cli) + name, cfg = env_builder.build_registered() + if remove_reset_door_state_event: + # Remove the reset door and subtask state events to allow us to inspect the scene without having it reset. + cfg.events.reset_door_state_subtask_0 = None + cfg.events.reset_door_state_subtask_1 = None + cfg.events.reset_subtask_success_state = None + env = gym.make(name, cfg=cfg).unwrapped + env.reset() + + return env, microwave_0, microwave_1 + + +def _test_composite_open_door_microwave(simulation_app) -> bool: + from isaaclab.envs.manager_based_env import ManagerBasedEnv + + from isaaclab_arena.tests.utils.simulation import step_zeros_and_call + + # Get the scene + env, microwave_0, microwave_1 = get_test_environment(remove_reset_door_state_event=True, num_envs=1) + + def assert_composite_task_incomplete(env: ManagerBasedEnv, terminated: torch.Tensor): + assert terminated.shape == torch.Size([1]) + assert not terminated.item() + if not terminated.item(): + print("Composite task is not completed") + + def assert_composite_task_complete(env: ManagerBasedEnv, terminated: torch.Tensor): + assert terminated.shape == torch.Size([1]) + assert terminated.item() + if terminated.item(): + print("Composite task is completed") + + try: + print("Closing both microwaves") + microwave_0.close(env, env_ids=None) + microwave_1.close(env, env_ids=None) + step_zeros_and_call(env, NUM_STEPS, assert_composite_task_incomplete) + + print("Opening microwave 0 (completing subtask 0)") + microwave_0.open(env, env_ids=None) + step_zeros_and_call(env, NUM_STEPS, assert_composite_task_incomplete) + + print("Opening microwave 1 (completing subtask 1, composite task should be complete)") + microwave_1.open(env, env_ids=None) + step_zeros_and_call(env, NUM_STEPS, assert_composite_task_complete) + + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + + finally: + env.close() + + return True + + +def _test_reverse_order_composite_open_door_microwave(simulation_app) -> bool: + """Composite tasks complete regardless of subtask completion order — open microwave 1 + first, then microwave 0, and the composite task should still complete.""" + from isaaclab.envs.manager_based_env import ManagerBasedEnv + + from isaaclab_arena.tests.utils.simulation import step_zeros_and_call + + # Get the scene + env, microwave_0, microwave_1 = get_test_environment(remove_reset_door_state_event=True, num_envs=1) + + def assert_composite_task_incomplete(env: ManagerBasedEnv, terminated: torch.Tensor): + assert terminated.shape == torch.Size([1]) + assert not terminated.item() + if not terminated.item(): + print("Composite task is not completed") + + def assert_composite_task_complete(env: ManagerBasedEnv, terminated: torch.Tensor): + assert terminated.shape == torch.Size([1]) + assert terminated.item() + if terminated.item(): + print("Composite task is completed") + + try: + print("Closing both microwaves") + microwave_0.close(env, env_ids=None) + microwave_1.close(env, env_ids=None) + step_zeros_and_call(env, NUM_STEPS, assert_composite_task_incomplete) + + print("Opening microwave 1 (completing subtask 1)") + microwave_1.open(env, env_ids=None) + step_zeros_and_call(env, NUM_STEPS, assert_composite_task_incomplete) + + print("Closing microwave 1") + microwave_1.close(env, env_ids=None) + step_zeros_and_call(env, NUM_STEPS, assert_composite_task_incomplete) + + print("Opening microwave 0 (completing subtask 0, composite task should be complete)") + microwave_0.open(env, env_ids=None) + step_zeros_and_call(env, NUM_STEPS, assert_composite_task_complete) + + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + + finally: + env.close() + + return True + + +def _test_composite_open_door_microwave_multiple_envs(simulation_app) -> bool: + from isaaclab.envs.manager_based_env import ManagerBasedEnv + + from isaaclab_arena.tests.utils.simulation import step_zeros_and_call + + # Get the scene + env, microwave_0, microwave_1 = get_test_environment(remove_reset_door_state_event=True, num_envs=2) + + def assert_composite_task_incomplete(env: ManagerBasedEnv, terminated: torch.Tensor): + assert terminated.shape == torch.Size([2]) + assert not torch.any(terminated) + if not torch.any(terminated): + print("Composite task is not completed") + + def assert_composite_task_complete(env: ManagerBasedEnv, terminated: torch.Tensor): + assert terminated.shape == torch.Size([2]) + assert torch.all(terminated) + if torch.all(terminated): + print("Composite task is completed") + + try: + print("Closing both microwaves") + microwave_0.close(env, env_ids=None) + microwave_1.close(env, env_ids=None) + step_zeros_and_call(env, NUM_STEPS, assert_composite_task_incomplete) + + print("Opening microwave 0 (completing subtask 0)") + microwave_0.open(env, env_ids=None) + step_zeros_and_call(env, NUM_STEPS, assert_composite_task_incomplete) + + print("Opening microwave 1 (completing subtask 1, composite task should be complete)") + microwave_1.open(env, env_ids=None) + step_zeros_and_call(env, NUM_STEPS, assert_composite_task_complete) + + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + + finally: + env.close() + + return True + + +def _test_reverse_order_composite_open_door_microwave_multiple_envs(simulation_app) -> bool: + """Multi-env version of the reverse-order test: order independence must hold for every env.""" + from isaaclab.envs.manager_based_env import ManagerBasedEnv + + from isaaclab_arena.tests.utils.simulation import step_zeros_and_call + + # Get the scene + env, microwave_0, microwave_1 = get_test_environment(remove_reset_door_state_event=True, num_envs=2) + + def assert_composite_task_incomplete(env: ManagerBasedEnv, terminated: torch.Tensor): + assert terminated.shape == torch.Size([2]) + assert not torch.any(terminated) + if not torch.any(terminated): + print("Composite task is not completed") + + def assert_composite_task_complete(env: ManagerBasedEnv, terminated: torch.Tensor): + assert terminated.shape == torch.Size([2]) + assert torch.all(terminated) + if torch.all(terminated): + print("Composite task is completed") + + try: + print("Closing both microwaves") + microwave_0.close(env, env_ids=None) + microwave_1.close(env, env_ids=None) + step_zeros_and_call(env, NUM_STEPS, assert_composite_task_incomplete) + + print("Opening microwave 1 (completing subtask 1)") + microwave_1.open(env, env_ids=None) + step_zeros_and_call(env, NUM_STEPS, assert_composite_task_incomplete) + + print("Closing microwave 1") + microwave_1.close(env, env_ids=None) + step_zeros_and_call(env, NUM_STEPS, assert_composite_task_incomplete) + + print("Opening microwave 0 (completing subtask 0, composite task should be complete)") + microwave_0.open(env, env_ids=None) + step_zeros_and_call(env, NUM_STEPS, assert_composite_task_complete) + + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + + finally: + env.close() + + return True + + +def _test_composite_open_door_microwave_reset_condition(simulation_app) -> bool: + from isaaclab_arena.tests.utils.simulation import step_zeros_and_call + + # Get the scene + env, microwave_0, microwave_1 = get_test_environment(remove_reset_door_state_event=False, num_envs=2) + + try: + print("Closing both microwaves") + microwave_0.close(env, env_ids=None) + microwave_1.close(env, env_ids=None) + step_zeros_and_call(env, NUM_STEPS) + is_open_0 = microwave_0.is_open(env) + is_open_1 = microwave_1.is_open(env) + print(f"expected: [False, False], [False, False]: got: {is_open_0}, {is_open_1}") + assert torch.all(is_open_0 == torch.tensor([False], device=env.device)) + assert torch.all(is_open_1 == torch.tensor([False], device=env.device)) + + print("Opening microwave 0 (completing subtask 0)") + microwave_0.open(env, None) + step_zeros_and_call(env, NUM_STEPS) + is_open_0 = microwave_0.is_open(env) + is_open_1 = microwave_1.is_open(env) + print(f"expected: [True, True], [False, False]: got: {is_open_0}, {is_open_1}") + assert torch.all(is_open_0 == torch.tensor([True], device=env.device)) + assert torch.all(is_open_1 == torch.tensor([False], device=env.device)) + + # Check that envs automatically reset to closed. + print("Opening microwave 1 (completing subtask 1, composite task completes and env resets)") + microwave_1.open(env, None) + step_zeros_and_call(env, NUM_STEPS) + is_open_0 = microwave_0.is_open(env) + is_open_1 = microwave_1.is_open(env) + print(f"expected: [False, False], [False, False]: got: {is_open_0}, {is_open_1}") + assert torch.all(is_open_0 == torch.tensor([False], device=env.device)) + assert torch.all(is_open_1 == torch.tensor([False], device=env.device)) + + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + + finally: + env.close() + + return True + + +def test_composite_open_door_microwave(): + result = run_simulation_app_function( + _test_composite_open_door_microwave, + headless=HEADLESS, + ) + assert result, f"Test {_test_composite_open_door_microwave.__name__} failed" + + +def test_reverse_order_composite_open_door_microwave(): + result = run_simulation_app_function( + _test_reverse_order_composite_open_door_microwave, + headless=HEADLESS, + ) + assert result, f"Test {_test_reverse_order_composite_open_door_microwave.__name__} failed" + + +def test_composite_open_door_microwave_multiple_envs(): + result = run_simulation_app_function( + _test_composite_open_door_microwave_multiple_envs, + headless=HEADLESS, + ) + assert result, f"Test {_test_composite_open_door_microwave_multiple_envs.__name__} failed" + + +def test_reverse_order_composite_open_door_microwave_multiple_envs(): + result = run_simulation_app_function( + _test_reverse_order_composite_open_door_microwave_multiple_envs, + headless=HEADLESS, + ) + assert result, f"Test {_test_reverse_order_composite_open_door_microwave_multiple_envs.__name__} failed" + + +def test_composite_open_door_microwave_reset_condition(): + result = run_simulation_app_function( + _test_composite_open_door_microwave_reset_condition, + headless=HEADLESS, + ) + assert result, f"Test {_test_composite_open_door_microwave_reset_condition.__name__} failed" + + +if __name__ == "__main__": + test_composite_open_door_microwave() + test_reverse_order_composite_open_door_microwave() + test_composite_open_door_microwave_multiple_envs() + test_reverse_order_composite_open_door_microwave_multiple_envs() + test_composite_open_door_microwave_reset_condition() diff --git a/isaaclab_arena/tests/test_composite_task_base.py b/isaaclab_arena/tests/test_composite_task_base.py new file mode 100644 index 000000000..a951c7416 --- /dev/null +++ b/isaaclab_arena/tests/test_composite_task_base.py @@ -0,0 +1,272 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +import traceback + +from isaaclab_arena.tests.utils.subprocess import run_simulation_app_function + +HEADLESS = True + + +def _test_add_suffix_configclass_transform(simulation_app) -> bool: + """Test that _add_suffix_configclass_transform correctly renames fields with suffix.""" + + from functools import partial + + from isaaclab.utils import configclass + + from isaaclab_arena.tasks.composite_task_base import CompositeTaskBase + from isaaclab_arena.utils.configclass import transform_configclass_instance + + @configclass + class FooCfg: + int_field: int = 123 + str_field: str = "123" + float_field: float = 1.23 + bool_field: bool = True + + try: + original_cfg = FooCfg() + edited_cfg = transform_configclass_instance( + original_cfg, + partial(CompositeTaskBase._add_suffix_configclass_transform, suffix="_suffix"), + ) + + # Check that new fields exist with suffix + assert hasattr(edited_cfg, "int_field_suffix") + assert hasattr(edited_cfg, "str_field_suffix") + assert hasattr(edited_cfg, "float_field_suffix") + assert hasattr(edited_cfg, "bool_field_suffix") + + # Check that values are preserved + assert edited_cfg.int_field_suffix == 123 + assert edited_cfg.str_field_suffix == "123" + assert edited_cfg.float_field_suffix == 1.23 + assert edited_cfg.bool_field_suffix is True + + # Check types are preserved + assert isinstance(edited_cfg.int_field_suffix, int) + assert isinstance(edited_cfg.str_field_suffix, str) + assert isinstance(edited_cfg.float_field_suffix, float) + assert isinstance(edited_cfg.bool_field_suffix, bool) + + # Check that old field names don't exist + assert not hasattr(edited_cfg, "int_field") + assert not hasattr(edited_cfg, "str_field") + assert not hasattr(edited_cfg, "float_field") + assert not hasattr(edited_cfg, "bool_field") + + # Test None input + edited_cfg = transform_configclass_instance( + None, + partial(CompositeTaskBase._add_suffix_configclass_transform, suffix="_suffix"), + ) + assert edited_cfg is None + + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + + return True + + +def _test_remove_configclass_transform(simulation_app) -> bool: + """Test that _remove_configclass_transform correctly removes specified fields.""" + + from functools import partial + + from isaaclab.utils import configclass + + from isaaclab_arena.tasks.composite_task_base import CompositeTaskBase + from isaaclab_arena.utils.configclass import transform_configclass_instance + + @configclass + class FooCfg: + field_a: int = 123 + field_b: str = "123" + field_c: float = 1.23 + + try: + original_cfg = FooCfg() + edited_cfg = transform_configclass_instance( + original_cfg, + partial(CompositeTaskBase._remove_configclass_transform, exclude_fields={"field_b"}), + ) + + # Check that remaining fields exist + assert hasattr(edited_cfg, "field_a") + assert hasattr(edited_cfg, "field_c") + + # Check that values are preserved + assert edited_cfg.field_a == 123 + assert edited_cfg.field_c == 1.23 + + # Check that removed field doesn't exist + assert not hasattr(edited_cfg, "field_b") + + # Test removing multiple fields + original_cfg = FooCfg() + edited_cfg = transform_configclass_instance( + original_cfg, + partial(CompositeTaskBase._remove_configclass_transform, exclude_fields={"field_a", "field_c"}), + ) + + # Check that only field_b remains + assert hasattr(edited_cfg, "field_b") + assert edited_cfg.field_b == "123" + assert not hasattr(edited_cfg, "field_a") + assert not hasattr(edited_cfg, "field_c") + + # Test None input + edited_cfg = transform_configclass_instance( + None, + partial(CompositeTaskBase._remove_configclass_transform, exclude_fields=set()), + ) + assert edited_cfg is None + + # Test removing all fields returns None + original_cfg = FooCfg() + edited_cfg = transform_configclass_instance( + original_cfg, + partial(CompositeTaskBase._remove_configclass_transform, exclude_fields={"field_a", "field_b", "field_c"}), + ) + assert edited_cfg is None + + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + + return True + + +class _MockSuccessFunc: + """Callable that returns a controlled per-env boolean tensor.""" + + def __init__(self, num_envs: int): + import torch + + self.num_envs = num_envs + self.return_value = torch.tensor([False] * num_envs) + + def set(self, values: list[bool]): + import torch + + assert len(values) == self.num_envs + self.return_value = torch.tensor(values) + + def __call__(self, env, **kwargs): + return self.return_value + + +class _MockSubtask: + """Minimal stand-in for a TaskBase with a controllable success function.""" + + def __init__(self, num_envs: int): + self.func = _MockSuccessFunc(num_envs) + + class _SuccessCfg: + pass + + class _TerminationCfg: + pass + + self._termination_cfg = _TerminationCfg() + self._termination_cfg.success = _SuccessCfg() + self._termination_cfg.success.func = self.func + self._termination_cfg.success.params = {} + + def get_termination_cfg(self): + return self._termination_cfg + + def set_success(self, values: list[bool]): + self.func.set(values) + + +class _MockEnv: + """Minimal stand-in for the env used by composite_task_success_func.""" + + def __init__(self, num_envs: int = 1, device: str = "cpu"): + self.num_envs = num_envs + self.device = device + self.extras = {} + + +def _test_composite_desired_subtask_success_state_with_none(simulation_app) -> bool: + """When ``desired_subtask_success_state`` contains None entries, those positions are + ignored and only positions with True/False are checked. Verifies the composite-task + matching logic (ordering does not matter for composite tasks).""" + + from isaaclab_arena.tasks.composite_task_base import CompositeTaskBase + + try: + env = _MockEnv(num_envs=1) + subtasks = [_MockSubtask(num_envs=1) for _ in range(3)] + + # Latch all three subtasks True simultaneously (composite doesn't require order). + subtasks[0].set_success([True]) + subtasks[1].set_success([True]) + subtasks[2].set_success([True]) + result = CompositeTaskBase.composite_task_success_func(env, subtasks, [None, True, True]) + assert env._subtask_ever_succeeded == [[True, True, True]] + assert result.tolist() == [True] + + # Subtask 0 currently False (don't-care) -> still success. + subtasks[0].set_success([False]) + result = CompositeTaskBase.composite_task_success_func(env, subtasks, [None, True, True]) + assert result.tolist() == [True] + + # Subtask 2 currently False breaks the [None, True, True] pattern -> failure. + subtasks[2].set_success([False]) + result = CompositeTaskBase.composite_task_success_func(env, subtasks, [None, True, True]) + assert result.tolist() == [False] + + # [None, False, None]: subtask 1 must be currently False AND latched True at + # some point. Drive subtask 1 False; it was latched True earlier -> success. + subtasks[1].set_success([False]) + result = CompositeTaskBase.composite_task_success_func(env, subtasks, [None, False, None]) + assert result.tolist() == [True] + + # All-None desired state matches trivially. + result = CompositeTaskBase.composite_task_success_func(env, subtasks, [None, None, None]) + assert result.tolist() == [True] + + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + + return True + + +def test_composite_desired_subtask_success_state_with_none(): + result = run_simulation_app_function( + _test_composite_desired_subtask_success_state_with_none, + headless=HEADLESS, + ) + assert result, f"Test {_test_composite_desired_subtask_success_state_with_none.__name__} failed" + + +def test_add_suffix_configclass_transform(): + result = run_simulation_app_function( + _test_add_suffix_configclass_transform, + headless=HEADLESS, + ) + assert result, f"Test {_test_add_suffix_configclass_transform.__name__} failed" + + +def test_remove_configclass_transform(): + result = run_simulation_app_function( + _test_remove_configclass_transform, + headless=HEADLESS, + ) + assert result, f"Test {_test_remove_configclass_transform.__name__} failed" + + +if __name__ == "__main__": + test_add_suffix_configclass_transform() + test_remove_configclass_transform() + test_composite_desired_subtask_success_state_with_none() diff --git a/isaaclab_arena/tests/test_sequential_open_door.py b/isaaclab_arena/tests/test_sequential_open_door.py index 037237b3e..f8cf245fb 100644 --- a/isaaclab_arena/tests/test_sequential_open_door.py +++ b/isaaclab_arena/tests/test_sequential_open_door.py @@ -26,23 +26,6 @@ def get_test_environment(remove_reset_door_state_event: bool, num_envs: int): from isaaclab_arena.tasks.sequential_task_base import SequentialTaskBase from isaaclab_arena.utils.pose import Pose - class SequentialOpenDoorTask(SequentialTaskBase): - def __init__( - self, - subtasks, - episode_length_s=None, - ): - super().__init__(subtasks=subtasks, episode_length_s=episode_length_s) - - def get_metrics(self): - return [] - - def get_prompt(self): - return "" - - def get_mimic_env_cfg(self, embodiment_name: str): - return None - args_parser = get_isaaclab_arena_cli_parser() args_cli = args_parser.parse_args(["--num_envs", str(num_envs)]) @@ -77,7 +60,7 @@ def get_mimic_env_cfg(self, embodiment_name: str): name="sequential_open_door", embodiment=FrankaIKEmbodiment(), scene=scene, - task=SequentialOpenDoorTask([subtask_1, subtask_2]), + task=SequentialTaskBase([subtask_1, subtask_2]), ) env_builder = ArenaEnvBuilder(isaaclab_arena_environment, args_cli) diff --git a/isaaclab_arena/tests/test_sequential_task_base.py b/isaaclab_arena/tests/test_sequential_task_base.py index 082eeeb37..376abbe22 100644 --- a/isaaclab_arena/tests/test_sequential_task_base.py +++ b/isaaclab_arena/tests/test_sequential_task_base.py @@ -10,60 +10,138 @@ HEADLESS = True -def _test_add_suffix_configclass_transform(simulation_app) -> bool: - """Test that add_suffix_configclass_transform correctly renames fields with suffix.""" +class _MockSuccessFunc: + """Callable that can set and return a per-env boolean success state.""" - from functools import partial + def __init__(self, num_envs: int): + import torch - from isaaclab.utils import configclass + self.num_envs = num_envs + self.return_value = torch.tensor([False] * num_envs) + + def set(self, values: list[bool]): + import torch + + assert len(values) == self.num_envs + self.return_value = torch.tensor(values) + + def __call__(self, env, **kwargs): + return self.return_value + + +class _MockSubtask: + """Minimal stand-in for a TaskBase with a controllable success function.""" + + def __init__(self, num_envs: int): + self.func = _MockSuccessFunc(num_envs) + + class _SuccessCfg: + pass + + class _TerminationCfg: + pass + + self._termination_cfg = _TerminationCfg() + self._termination_cfg.success = _SuccessCfg() + self._termination_cfg.success.func = self.func + self._termination_cfg.success.params = {} + + def get_termination_cfg(self): + return self._termination_cfg + + def set_success(self, values: list[bool]): + self.func.set(values) + + +class _MockEnv: + """Minimal stand-in for the env used by composite_task_success_func.""" + + def __init__(self, num_envs: int = 1, device: str = "cpu"): + self.num_envs = num_envs + self.device = device + self.extras = {} + + +def _test_sequential_success_advances_in_order(simulation_app) -> bool: + """Subtask N+1 success must not count until subtask N has succeeded.""" from isaaclab_arena.tasks.sequential_task_base import SequentialTaskBase - from isaaclab_arena.utils.configclass import transform_configclass_instance - @configclass - class FooCfg: - int_field: int = 123 - str_field: str = "123" - float_field: float = 1.23 - bool_field: bool = True + try: + env = _MockEnv(num_envs=1) + subtasks = [_MockSubtask(num_envs=1) for _ in range(3)] + + # Subtask 0 fails, subtasks 1 and 2 "succeed" out of order. Sequential gating + # must ignore subtasks 1/2 because the state machine is still at index 0. + subtasks[0].set_success([False]) + subtasks[1].set_success([True]) + subtasks[2].set_success([True]) + + result = SequentialTaskBase.composite_task_success_func(env, subtasks, None) + + assert result.tolist() == [False] + assert env._current_subtask_idx == [0] + assert env._subtask_ever_succeeded == [[False, False, False]] + + # Subtask 0 succeeds, index advances to 1, state[0] becomes True. + subtasks[0].set_success([True]) + subtasks[1].set_success([False]) + result = SequentialTaskBase.composite_task_success_func(env, subtasks, None) + + assert result.tolist() == [False] + assert env._current_subtask_idx == [1] + assert env._subtask_ever_succeeded == [[True, False, False]] + + # Subtask 1 succeeds, index advances to 2, state[1] becomes True. + subtasks[1].set_success([True]) + result = SequentialTaskBase.composite_task_success_func(env, subtasks, None) + + assert result.tolist() == [False] + assert env._current_subtask_idx == [2] + assert env._subtask_ever_succeeded == [[True, True, False]] + + # Subtask 2 succeeds, state[2] becomes True, overall success is True. Index does + # not advance past the last subtask (it caps at len-1). + subtasks[2].set_success([True]) + result = SequentialTaskBase.composite_task_success_func(env, subtasks, None) + + assert result.tolist() == [True] + assert env._current_subtask_idx == [2] + assert env._subtask_ever_succeeded == [[True, True, True]] + + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + + return True + + +def _test_sequential_success_latches(simulation_app) -> bool: + """Once a subtask succeeds, ``_subtask_ever_succeeded`` must not un-set when the + underlying success function later returns False.""" + + from isaaclab_arena.tasks.sequential_task_base import SequentialTaskBase try: - original_cfg = FooCfg() - edited_cfg = transform_configclass_instance( - original_cfg, - partial(SequentialTaskBase.add_suffix_configclass_transform, suffix="_suffix"), - ) - - # Check that new fields exist with suffix - assert hasattr(edited_cfg, "int_field_suffix") - assert hasattr(edited_cfg, "str_field_suffix") - assert hasattr(edited_cfg, "float_field_suffix") - assert hasattr(edited_cfg, "bool_field_suffix") - - # Check that values are preserved - assert edited_cfg.int_field_suffix == 123 - assert edited_cfg.str_field_suffix == "123" - assert edited_cfg.float_field_suffix == 1.23 - assert edited_cfg.bool_field_suffix is True - - # Check types are preserved - assert isinstance(edited_cfg.int_field_suffix, int) - assert isinstance(edited_cfg.str_field_suffix, str) - assert isinstance(edited_cfg.float_field_suffix, float) - assert isinstance(edited_cfg.bool_field_suffix, bool) - - # Check that old field names don't exist - assert not hasattr(edited_cfg, "int_field") - assert not hasattr(edited_cfg, "str_field") - assert not hasattr(edited_cfg, "float_field") - assert not hasattr(edited_cfg, "bool_field") - - # Test None input - edited_cfg = transform_configclass_instance( - None, - partial(SequentialTaskBase.add_suffix_configclass_transform, suffix="_suffix"), - ) - assert edited_cfg is None + env = _MockEnv(num_envs=1) + subtasks = [_MockSubtask(num_envs=1) for _ in range(2)] + + # Drive both subtasks to success in order. + subtasks[0].set_success([True]) + SequentialTaskBase.composite_task_success_func(env, subtasks, None) + subtasks[1].set_success([True]) + result = SequentialTaskBase.composite_task_success_func(env, subtasks, None) + assert result.tolist() == [True] + assert env._subtask_ever_succeeded == [[True, True]] + + # Even when the underlying success goes False, latched state stays True and + # overall success stays True. + subtasks[0].set_success([False]) + subtasks[1].set_success([False]) + result = SequentialTaskBase.composite_task_success_func(env, subtasks, None) + assert env._subtask_ever_succeeded == [[True, True]] + assert result.tolist() == [True] except Exception as e: print(f"Error: {e}") @@ -73,67 +151,125 @@ class FooCfg: return True -def _test_remove_configclass_transform(simulation_app) -> bool: - """Test that remove_configclass_transform correctly removes specified fields.""" +def _test_sequential_desired_subtask_success_state(simulation_app) -> bool: + """When ``desired_subtask_success_state`` is provided, overall success requires + both (a) all subtasks latched True and (b) the current success state equals the desired pattern. + """ - from functools import partial + from isaaclab_arena.tasks.sequential_task_base import SequentialTaskBase - from isaaclab.utils import configclass + try: + env = _MockEnv(num_envs=1) + subtasks = [_MockSubtask(num_envs=1) for _ in range(2)] + + # Set both subtasks to True. + subtasks[0].set_success([True]) + SequentialTaskBase.composite_task_success_func(env, subtasks, [True, True]) + subtasks[1].set_success([True]) + + # Current pattern matches desired -> success. + result = SequentialTaskBase.composite_task_success_func(env, subtasks, [True, True]) + assert result.tolist() == [True] + + # Success state still True, but the current pattern no longer matches the desired + # pattern -> overall success is False even though success state is True. + subtasks[0].set_success([False]) + subtasks[1].set_success([True]) + result = SequentialTaskBase.composite_task_success_func(env, subtasks, [True, True]) + assert env._subtask_ever_succeeded == [[True, True]] + assert result.tolist() == [False] + + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + + return True + + +def _test_sequential_desired_subtask_success_state_with_none(simulation_app) -> bool: + """When ``desired_subtask_success_state`` contains None entries, those positions are + ignored and only positions with True/False are checked.""" from isaaclab_arena.tasks.sequential_task_base import SequentialTaskBase - from isaaclab_arena.utils.configclass import transform_configclass_instance - @configclass - class FooCfg: - field_a: int = 123 - field_b: str = "123" - field_c: float = 1.23 + try: + env = _MockEnv(num_envs=1) + subtasks = [_MockSubtask(num_envs=1) for _ in range(3)] + + # Latch subtasks 0, 1, 2 to True in order so all three are "ever succeeded". + subtasks[0].set_success([True]) + SequentialTaskBase.composite_task_success_func(env, subtasks, [None, True, True]) + subtasks[1].set_success([True]) + SequentialTaskBase.composite_task_success_func(env, subtasks, [None, True, True]) + subtasks[2].set_success([True]) + result = SequentialTaskBase.composite_task_success_func(env, subtasks, [None, True, True]) + assert env._subtask_ever_succeeded == [[True, True, True]] + assert result.tolist() == [True] + + # Subtask 0 currently False (don't-care), 1 and 2 currently True -> success. + subtasks[0].set_success([False]) + result = SequentialTaskBase.composite_task_success_func(env, subtasks, [None, True, True]) + assert result.tolist() == [True] + + # Subtask 1 currently False breaks the [None, True, True] pattern -> failure. + subtasks[1].set_success([False]) + result = SequentialTaskBase.composite_task_success_func(env, subtasks, [None, True, True]) + assert result.tolist() == [False] + + # [None, False, None]: subtask 1 must be currently False AND latched True at some + # point. Subtask 1 is latched True and currently False -> success regardless of + # subtasks 0 and 2. + result = SequentialTaskBase.composite_task_success_func(env, subtasks, [None, False, None]) + assert result.tolist() == [True] + + # All-None desired state matches trivially. + result = SequentialTaskBase.composite_task_success_func(env, subtasks, [None, None, None]) + assert result.tolist() == [True] + + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + + return True + + +def _test_sequential_reset_clears_state_and_index(simulation_app) -> bool: + """``reset_subtask_success_state`` must clear both the success state vector + and the state-machine index for the given env_ids while leaving other envs alone.""" + + from isaaclab_arena.tasks.sequential_task_base import SequentialTaskBase try: - original_cfg = FooCfg() - edited_cfg = transform_configclass_instance( - original_cfg, - partial(SequentialTaskBase.remove_configclass_transform, exclude_fields={"field_b"}), - ) - - # Check that remaining fields exist - assert hasattr(edited_cfg, "field_a") - assert hasattr(edited_cfg, "field_c") - - # Check that values are preserved - assert edited_cfg.field_a == 123 - assert edited_cfg.field_c == 1.23 - - # Check that removed field doesn't exist - assert not hasattr(edited_cfg, "field_b") - - # Test removing multiple fields - original_cfg = FooCfg() - edited_cfg = transform_configclass_instance( - original_cfg, - partial(SequentialTaskBase.remove_configclass_transform, exclude_fields={"field_a", "field_c"}), - ) - - # Check that only field_b remains - assert hasattr(edited_cfg, "field_b") - assert edited_cfg.field_b == "123" - assert not hasattr(edited_cfg, "field_a") - assert not hasattr(edited_cfg, "field_c") - - # Test None input - edited_cfg = transform_configclass_instance( - None, - partial(SequentialTaskBase.remove_configclass_transform, exclude_fields=set()), - ) - assert edited_cfg is None - - # Test removing all fields returns None - original_cfg = FooCfg() - edited_cfg = transform_configclass_instance( - original_cfg, - partial(SequentialTaskBase.remove_configclass_transform, exclude_fields={"field_a", "field_b", "field_c"}), - ) - assert edited_cfg is None + env = _MockEnv(num_envs=2) + subtasks = [_MockSubtask(num_envs=2) for _ in range(2)] + + # Set env 0 to "subtask 0 True + index at 1", env 1 fully complete. + subtasks[0].set_success([True, True]) + SequentialTaskBase.composite_task_success_func(env, subtasks, None) + subtasks[0].set_success([False, True]) + subtasks[1].set_success([False, True]) + SequentialTaskBase.composite_task_success_func(env, subtasks, None) + + assert env._subtask_ever_succeeded == [[True, False], [True, True]] + assert env._current_subtask_idx == [1, 1] + + # Reset only env 0. + SequentialTaskBase.reset_subtask_success_state(env, env_ids=[0], subtasks=subtasks) + assert env._subtask_ever_succeeded == [[False, False], [True, True]] + assert env._current_subtask_idx == [0, 1] + + # Reset env 1 too. + SequentialTaskBase.reset_subtask_success_state(env, env_ids=[1], subtasks=subtasks) + assert env._subtask_ever_succeeded == [[False, False], [False, False]] + assert env._current_subtask_idx == [0, 0] + + del env._subtask_ever_succeeded + del env._current_subtask_idx + SequentialTaskBase.reset_subtask_success_state(env, env_ids=[], subtasks=subtasks) + assert env._subtask_ever_succeeded == [[False, False], [False, False]] + assert env._current_subtask_idx == [0, 0] except Exception as e: print(f"Error: {e}") @@ -143,22 +279,49 @@ class FooCfg: return True -def test_add_suffix_configclass_transform(): +def test_sequential_success_advances_in_order(): + result = run_simulation_app_function( + _test_sequential_success_advances_in_order, + headless=HEADLESS, + ) + assert result, f"Test {_test_sequential_success_advances_in_order.__name__} failed" + + +def test_sequential_success_latches(): + result = run_simulation_app_function( + _test_sequential_success_latches, + headless=HEADLESS, + ) + assert result, f"Test {_test_sequential_success_latches.__name__} failed" + + +def test_sequential_desired_subtask_success_state(): + result = run_simulation_app_function( + _test_sequential_desired_subtask_success_state, + headless=HEADLESS, + ) + assert result, f"Test {_test_sequential_desired_subtask_success_state.__name__} failed" + + +def test_sequential_desired_subtask_success_state_with_none(): result = run_simulation_app_function( - _test_add_suffix_configclass_transform, + _test_sequential_desired_subtask_success_state_with_none, headless=HEADLESS, ) - assert result, f"Test {_test_add_suffix_configclass_transform.__name__} failed" + assert result, f"Test {_test_sequential_desired_subtask_success_state_with_none.__name__} failed" -def test_remove_configclass_transform(): +def test_sequential_reset_clears_state_and_index(): result = run_simulation_app_function( - _test_remove_configclass_transform, + _test_sequential_reset_clears_state_and_index, headless=HEADLESS, ) - assert result, f"Test {_test_remove_configclass_transform.__name__} failed" + assert result, f"Test {_test_sequential_reset_clears_state_and_index.__name__} failed" if __name__ == "__main__": - test_add_suffix_configclass_transform() - test_remove_configclass_transform() + test_sequential_success_advances_in_order() + test_sequential_success_latches() + test_sequential_desired_subtask_success_state() + test_sequential_desired_subtask_success_state_with_none() + test_sequential_reset_clears_state_and_index()