Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions docs/source/overview/gym/dataset_functors.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,10 @@ The ``LeRobotRecorder`` functor enables recording robot learning episodes in the

The LeRobotRecorder saves the following data for each frame:

- ``observation.qpos``: Joint positions
- ``observation.qvel``: Joint velocities
- ``observation.qf``: Joint forces/torques
- ``observation.state``: Joint positions (proprioceptive state)
- ``action``: Applied action
- ``{sensor_name}.color``: Camera images (if sensors present)
- ``{sensor_name}.color_right``: Right camera images (for stereo cameras)
- ``observation.images.{sensor_name}``: Camera images (if sensors present)
- ``observation.images.{sensor_name}_right``: Right camera images (for stereo cameras)

## Usage Example

Expand Down
2 changes: 1 addition & 1 deletion docs/source/tutorial/data_generation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Important parameters are:
- **env.control_parts**: Controlled robot parts in the environment.


In the current implementation, ``LeRobotRecorder`` stores robot state and action features such as ``observation.qpos``, ``observation.qvel``, ``observation.qf``, ``action``, and camera images when sensors are present.
In the current implementation, ``LeRobotRecorder`` stores robot state and action features following LeRobot official format: ``observation.state`` for joint positions, ``action`` for applied actions, and ``observation.images.{sensor_name}`` for camera images.

Step 2: Prepare the Action Configuration
----------------------------------------
Expand Down
27 changes: 27 additions & 0 deletions embodichain/data/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,30 @@ class EefType(Enum):
class ActionMode(Enum):
ABSOLUTE = ""
RELATIVE = "delta_" # This indicates the action is relative change with respect to last state.


class LeRobotKey(Enum):
"""LeRobot standard field keys - official LeRobot dataset format."""

OBS_STR = "observation"
OBS_PREFIX = "observation."
OBS_ENV_STATE = "observation.environment_state"
OBS_STATE = "observation.state"
OBS_QVEL = "observation.qvel"
OBS_QF = "observation.qf"
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
OBS_LANGUAGE = "observation.language"
OBS_LANGUAGE_TOKENS = "observation.language.tokens"
OBS_LANGUAGE_ATTENTION_MASK = "observation.language.attention_mask"
OBS_LANGUAGE_SUBTASK = "observation.subtask"
OBS_LANGUAGE_SUBTASK_TOKENS = "observation.subtask.tokens"
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = "observation.subtask.attention_mask"
ACTION = "action"
ACTION_PREFIX = "action."
ACTION_TOKENS = "action.tokens"
ACTION_TOKEN_MASK = "action.token_mask"
REWARD = "next.reward"
TRUNCATED = "next.truncated"
DONE = "next.done"
INFO = "info"
36 changes: 21 additions & 15 deletions embodichain/lab/gym/envs/managers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from embodichain.utils import logger
from embodichain.data.constants import EMBODICHAIN_DEFAULT_DATASET_ROOT
from embodichain.data.enum import LeRobotKey
from embodichain.lab.gym.utils.misc import is_stereocam
from embodichain.lab.sim.sensors import Camera, ContactSensor
from .manager_base import Functor
Expand Down Expand Up @@ -275,25 +276,25 @@ def _build_features(self) -> Dict:
self._env.robot.joint_names[i] for i in self._env.active_joint_ids
]

features["observation.qpos"] = {
features[LeRobotKey.OBS_STATE.value] = {
"dtype": "float32",
"shape": (state_dim,),
"names": joint_names,
}
features["observation.qvel"] = {
features[LeRobotKey.OBS_QVEL.value] = {
"dtype": "float32",
"shape": (state_dim,),
"names": joint_names,
}
features["observation.qf"] = {
features[LeRobotKey.OBS_QF.value] = {
"dtype": "float32",
"shape": (state_dim,),
"names": joint_names,
}

# Use full qpos dimension for action (includes gripper)
action_dim = state_dim
features["action"] = {
features[LeRobotKey.ACTION.value] = {
"dtype": "float32",
"shape": (action_dim,),
"names": joint_names,
Expand All @@ -316,14 +317,16 @@ def _build_features(self) -> Dict:
f"Only support 'color' frame for vision sensors, but got '{frame_name}' in sensor '{sensor_name}'"
)

features[f"{sensor_name}.{frame_name}"] = {
features[f"{LeRobotKey.OBS_IMAGES.value}.{sensor_name}"] = {
"dtype": "video" if self.use_videos else "image",
"shape": (sensor.cfg.height, sensor.cfg.width, 3),
"names": ["height", "width", "channel"],
}

if is_stereo:
features[f"{sensor_name}.{frame_name}_right"] = {
features[
f"{LeRobotKey.OBS_IMAGES.value}.{sensor_name}_right"
] = {
"dtype": "video" if self.use_videos else "image",
"shape": (sensor.cfg.height, sensor.cfg.width, 3),
"names": ["height", "width", "channel"],
Expand Down Expand Up @@ -379,7 +382,7 @@ def _add_nested_features(
# Recursively handle deeper nesting
self._add_nested_features(features, f"{key}.{sub_key}", sub_space)
else:
feature_name = f"observation.{key}.{sub_key}"
feature_name = f"{LeRobotKey.OBS_PREFIX.value}{key}.{sub_key}"
# Handle empty shapes for scalar values (e.g., mass, friction, damping)
# LeRobot requires non-empty shapes, so convert () to (1,)
shape = sub_space.shape if sub_space.shape else (1,)
Expand Down Expand Up @@ -463,12 +466,14 @@ def _convert_frame_to_lerobot(

color_data = obs["sensor"][sensor_name]["color"]
color_img = color_data[:, :, :3].cpu()
frame[f"{sensor_name}.color"] = color_img
frame[f"{LeRobotKey.OBS_IMAGES.value}.{sensor_name}"] = color_img

if is_stereo:
color_right_data = obs["sensor"][sensor_name]["color_right"]
color_right_img = color_right_data[:, :, :3].cpu()
frame[f"{sensor_name}.color_right"] = color_right_img
frame[f"{LeRobotKey.OBS_IMAGES.value}.{sensor_name}_right"] = (
color_right_img
)
elif isinstance(sensor, ContactSensor):
for frame_name in value.keys():
frame[f"{sensor_name}.{frame_name}"] = obs["sensor"][
Expand All @@ -481,10 +486,11 @@ def _convert_frame_to_lerobot(
f"Unsupported sensor type for '{sensor_name}' when converting to LeRobot format. Currently only support Camera and ContactSensor."
)

# Add state
frame["observation.qpos"] = obs["robot"]["qpos"].cpu()
frame["observation.qvel"] = obs["robot"]["qvel"].cpu()
frame["observation.qf"] = obs["robot"]["qf"].cpu()
# Add state (use LeRobot standard key "observation.state")
frame[LeRobotKey.OBS_STATE.value] = obs["robot"]["qpos"].cpu()
# Keep additional proprio data that may be useful even though not in official LeRobot format
frame[LeRobotKey.OBS_QVEL.value] = obs["robot"]["qvel"].cpu()
frame[LeRobotKey.OBS_QF.value] = obs["robot"]["qf"].cpu()

# Add extra observation features if they exist
for key in obs.keys():
Expand Down Expand Up @@ -516,7 +522,7 @@ def _convert_frame_to_lerobot(
if isinstance(action_tensor, torch.Tensor):
action_data = action_tensor.cpu()

frame["action"] = action_data
frame[LeRobotKey.ACTION.value] = action_data

return frame

Expand Down Expand Up @@ -548,7 +554,7 @@ def _add_nested_obs_to_frame(
# Handle 0D tensors (scalars) - convert to 1D for LeRobot compatibility
if isinstance(value, torch.Tensor) and value.ndim == 0:
value = value.unsqueeze(0)
frame[f"observation.{key}.{sub_key}"] = value
frame[f"{LeRobotKey.OBS_PREFIX.value}{key}.{sub_key}"] = value

def _update_dataset_info(self, updates: dict) -> bool:
"""Update dataset metadata."""
Expand Down
24 changes: 12 additions & 12 deletions tests/gym/envs/managers/test_dataset_functors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@
LEROBOT_AVAILABLE,
)

from embodichain.data.enum import LeRobotKey

LEROBOT_AVAILABLE = True
except ImportError:
LEROBOT_AVAILABLE = False
LeRobotRecorder = None

LeRobotKey = None

# Import Camera for mocking (only if available)

try:
from embodichain.lab.sim.sensors import Camera

Expand Down Expand Up @@ -228,15 +231,12 @@ def test_build_features_creates_correct_structure(self, mock_lerobot_dataset):
# Access the private method through the instance
features = recorder._build_features()

# Check expected features exist
assert "observation.qpos" in features
assert "observation.qvel" in features
assert "observation.qf" in features
assert "action" in features
assert LeRobotKey.OBS_STATE.value in features
assert LeRobotKey.ACTION.value in features

# Check shapes
assert features["observation.qpos"]["shape"] == (6,)
assert features["action"]["shape"] == (6,)
assert features[LeRobotKey.OBS_STATE.value]["shape"] == (6,)
assert features[LeRobotKey.ACTION.value]["shape"] == (6,)

@patch("embodichain.lab.gym.envs.managers.datasets.LeRobotDataset")
def test_build_features_with_sensor(self, mock_lerobot_dataset):
Expand Down Expand Up @@ -276,8 +276,8 @@ def mock_isinstance(obj, class_or_tuple):
recorder = LeRobotRecorder(cfg, env)
features = recorder._build_features()

# Check camera feature exists
assert "camera.color" in features
# Check camera feature exists (use LeRobot standard key format)
assert f"{LeRobotKey.OBS_IMAGES.value}.camera" in features


@pytest.mark.skipif(not LEROBOT_AVAILABLE, reason="LeRobot not installed")
Expand Down Expand Up @@ -328,8 +328,8 @@ def test_convert_frame_with_tensor_action(self, mock_lerobot_dataset):

assert "task" in frame
assert frame["task"] == "test_task"
assert "observation.qpos" in frame
assert "action" in frame
assert LeRobotKey.OBS_STATE.value in frame
assert LeRobotKey.ACTION.value in frame


class TestDatasetFunctorCfg:
Expand Down
Loading