diff --git a/.gitignore b/.gitignore index f18390bf..ace8e6a5 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,5 @@ docs/_build .codex CLAUDE.md uv.lock +*.parquet diff --git a/examples/teleop/franka.py b/examples/teleop/franka.py index 1645ccef..ef073761 100644 --- a/examples/teleop/franka.py +++ b/examples/teleop/franka.py @@ -10,8 +10,6 @@ from rcs.operator.gello import GelloConfig, GelloOperator from rcs.operator.interface import TeleopLoop from rcs.operator.quest import QuestConfig, QuestOperator -from rcs_fr3.configs import DefaultFR3MultiHardwareEnv -from rcs_fr3.creators import HardwareCameraCreatorConfig from simpub.sim.mj_publisher import MujocoPublisher import rcs @@ -67,7 +65,9 @@ } config: QuestConfig | GelloConfig -config = QuestConfig(mq3_addr=MQ3_ADDR, simulation=ROBOT_INSTANCE == RobotPlatform.SIMULATION, switched_left_right=True) +config = QuestConfig( + mq3_addr=MQ3_ADDR, simulation=ROBOT_INSTANCE == RobotPlatform.SIMULATION, switched_left_right=False +) # config = GelloConfig( # arms={ # "right": GelloArmConfig(com_port="/dev/serial/by-id/usb-ROBOTIS_OpenRB-150_E505008B503059384C2E3120FF07332D-if00"), @@ -79,6 +79,9 @@ def get_env(): if ROBOT_INSTANCE == RobotPlatform.HARDWARE: + from rcs_fr3.configs import DefaultFR3MultiHardwareEnv + from rcs_fr3.creators import HardwareCameraCreatorConfig + env_creator = DefaultFR3MultiHardwareEnv() env_creator.left_ip = ROBOT2IP["left"] env_creator.right_ip = ROBOT2IP["right"] @@ -151,7 +154,7 @@ def get_env(): if sim_cfg_data.root_frame_objects is None: sim_cfg_data.root_frame_objects = {} # cfg.root_frame_objects["green_cube"] = (rcs.OBJECT_PATHS["green_cube"], Pose(translation=[0.5, 0, 0.5], quaternion=[0, 0, 0, 1])) - sim_cfg_data.task_cfg = PickTaskConfig(robot_name="left") + sim_cfg_data.task_cfg = PickTaskConfig(robot_name="right") env_rel = scene.create_env(sim_cfg_data) env_rel = StorageWrapper( diff --git a/python/rcs/__main__.py b/python/rcs/__main__.py index a481aab1..743c68a0 100644 --- a/python/rcs/__main__.py +++ b/python/rcs/__main__.py @@ -3,10 +3,9 @@ import typer from rcs.envs.storage_wrapper import StorageWrapper -from rcs.sim_state_replay import replay as replay_command +from rcs.sim.replayer import replay as replay_dataset app = typer.Typer() -app.command()(replay_command) @app.command() @@ -34,5 +33,39 @@ def consolidate( typer.echo("Done.") +@app.command("replay") +def replay( + dataset: Annotated[ + Path, + typer.Argument( + exists=True, + help="Parquet dataset directory to replay.", + ), + ], + headless: Annotated[bool, typer.Option(help="Whether to run without GUI.")] = True, + frequency: Annotated[int, typer.Option(help="Simulation frequency to use during replay.")] = 30, + relative_to: Annotated[ + str, + typer.Option(help="RelativeTo enum name: CONFIGURED_ORIGIN, LAST_STEP, or NONE."), + ] = "CONFIGURED_ORIGIN", + scene: Annotated[ + str, + typer.Option(help="Python expression that evaluates to a scene instance."), + ] = "env_configs.EmptyWorldFR3Duo()", + task_cfg: Annotated[ + str, + typer.Option(help="Python expression that evaluates to a task config."), + ] = 'env_tasks.PickTaskConfig(robot_name="right")', +): + replay_dataset( + dataset=dataset, + headless=headless, + frequency=frequency, + relative_to=relative_to, + scene=scene, + task_cfg=task_cfg, + ) + + if __name__ == "__main__": app() diff --git a/python/rcs/envs/base.py b/python/rcs/envs/base.py index b3c21c05..49c85b24 100644 --- a/python/rcs/envs/base.py +++ b/python/rcs/envs/base.py @@ -200,23 +200,36 @@ class HardwareEnv(BaseEnv): class SimEnv(BaseEnv): PLATFORM = RobotPlatform.SIMULATION + STATE_KEY = "sim_state" + STATE_SPEC_KEY = "sim_state_spec" - def __init__(self, sim: simulation.Sim) -> None: + def __init__(self, sim: simulation.Sim, return_state=True) -> None: self.sim = sim cfg = self.sim.get_config() self.frame_rate = SimpleFrameRate(cfg.frequency, "MoJoCo Simulation Loop") self.main_greenlet: greenlet | None = None + self.return_state = return_state + self._replay_state: tuple[np.ndarray, int | None] | None = None + + def set_replay_state(self, state: np.ndarray, spec: int | None = None): + self._replay_state = (state, spec) def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], float, bool, bool, dict]: if self.main_greenlet is not None: self.main_greenlet.switch() else: self.step_sim() - return super().step(action) + obs, reward, terminated, truncated, info = super().step(action) + if self.return_state: + obs, info = self.observation(obs, info) + return obs, reward, terminated, truncated, info def step_sim(self): cfg = self.sim.get_config() - if cfg.async_control: + if self._replay_state is not None: + self.sim.set_state(self._replay_state[0], self._replay_state[1]) + self._replay_state = None + elif cfg.async_control: self.sim.step(round(1 / cfg.frequency / self.sim.model.opt.timestep)) if cfg.realtime: self.frame_rate.frame_rate = cfg.frequency @@ -236,6 +249,12 @@ def reset( self.apply_sim_state() return super().reset(seed=seed, options=options) + def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + sim_state = self.sim.get_state() + info[self.STATE_KEY] = sim_state + info[self.STATE_SPEC_KEY] = self.sim.get_state_spec() + return observation, info + class CoverWrapper(gym.Wrapper): """The CoverWrapper must be the last wrapper on the stack diff --git a/python/rcs/envs/sim.py b/python/rcs/envs/sim.py index 6ea2bf3d..77354065 100644 --- a/python/rcs/envs/sim.py +++ b/python/rcs/envs/sim.py @@ -40,23 +40,6 @@ def reset( return super().reset(seed=seed, options=options) -class SimStateObservationWrapper(ActObsInfoWrapper): - STATE_KEY = "sim_state" - STATE_SPEC_KEY = "sim_state_spec" - - def __init__(self, env): - super().__init__(env) - assert self.env.get_wrapper_attr("PLATFORM") == RobotPlatform.SIMULATION, "Base environment must be simulation." - self.sim = cast(sim.Sim, self.get_wrapper_attr("sim")) - - def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: - observation = dict(observation) - sim_state = self.sim.get_state() - observation[self.STATE_KEY] = sim_state - observation[self.STATE_SPEC_KEY] = self.sim.get_state_spec() - return observation, info - - class GripperWrapperSim(ActObsInfoWrapper): def __init__(self, env): super().__init__(env) diff --git a/python/rcs/envs/storage_wrapper.py b/python/rcs/envs/storage_wrapper.py index 1073cd04..6df7c3fa 100644 --- a/python/rcs/envs/storage_wrapper.py +++ b/python/rcs/envs/storage_wrapper.py @@ -14,7 +14,6 @@ import pyarrow.dataset as ds import simplejpeg from PIL import Image -from rcs.envs.base import RelativeActionSpace class StorageWrapper(gym.Wrapper): @@ -31,6 +30,7 @@ def __init__( basename_template: Optional[str] = None, max_rows_per_group: Optional[int] = None, max_rows_per_file: Optional[int] = None, + success_from_env: bool = False, ): """ Asynchronously log environment transitions to a Parquet dataset on disk. @@ -91,6 +91,7 @@ def __init__( self._success = False self._prev_action = None self._prev_absolute_action = None + self.success_from_env = success_from_env self.thread_pool = ThreadPoolExecutor() self.queue: Queue[pa.Table | pa.RecordBatch] = Queue(maxsize=2) @@ -249,42 +250,31 @@ def step(self, action): if not self._pause: assert isinstance(obs, dict) + if "frames" in obs and not obs["frames"]: + del obs["frames"] if "frames" in obs: self._encode_images(obs) self._flatten_arrays(obs) - if info.get("success"): + if info.get("success") and self.success_from_env: self.success() - self.buffer.append( - { - "obs": obs, - "info": info, - "reward": reward, - "step": self.step_cnt, - "uuid": self.uuid.hex, - "date": datetime.date.today().isoformat(), - "success": self._success, - "action": self._prev_action, - "instruction": self.instruction, - "timestamp": datetime.datetime.now().timestamp(), - RelativeActionSpace.ABSOLUTE_ACTION_KEY: self._prev_absolute_action, - } - ) + frame = { + "obs": obs, + "info": info, + "reward": reward, + "step": self.step_cnt, + "uuid": self.uuid.hex, + "date": datetime.date.today().isoformat(), + "success": self._success, + "action": self._prev_action, + "env_action": action, + "instruction": self.instruction, + "timestamp": datetime.datetime.now().timestamp(), + } + self._prev_action = action - if RelativeActionSpace.ABSOLUTE_ACTION_KEY in obs: - # single env - self._prev_absolute_action = obs[RelativeActionSpace.ABSOLUTE_ACTION_KEY] - else: - # multi env wrapper - act_dict = {} - try: - for key in self.get_wrapper_attr("envs"): - if RelativeActionSpace.ABSOLUTE_ACTION_KEY in obs[key]: - act_dict[key] = obs[key][RelativeActionSpace.ABSOLUTE_ACTION_KEY] - except AttributeError: - pass - if len(act_dict) != 0: - self._prev_absolute_action = act_dict # type: ignore + + self.buffer.append(frame) self.step_cnt += 1 if len(self.buffer) == self.batch_size: @@ -292,6 +282,9 @@ def step(self, action): return obs, reward, terminated, truncated, info + def set_instruction(self, instruction: str): + self.instruction = instruction + def success(self): self._success = True diff --git a/python/rcs/sim/replayer.py b/python/rcs/sim/replayer.py new file mode 100644 index 00000000..dc6c6be1 --- /dev/null +++ b/python/rcs/sim/replayer.py @@ -0,0 +1,166 @@ +import typing +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import duckdb +import gymnasium as gym +import numpy as np +import rcs.envs.configs as env_configs +import rcs.envs.tasks as env_tasks +from rcs._core.sim import SimConfig +from rcs.envs.base import RelativeTo, SimEnv +from rcs.envs.scenes import SimEnvCreator +from rcs.envs.storage_wrapper import StorageWrapper + +DATASET_PATH = "recorded_iris" + + +@dataclass(frozen=True) +class RecordedSimStep: + step: int + uuid: str + timestamp: float | None + observation: dict[str, Any] + info: dict[str, Any] + action: Any + instruction: str + success: bool + + @property + def sim_state(self) -> np.ndarray: + if SimEnv.STATE_KEY in self.info: + return np.asarray(self.info[SimEnv.STATE_KEY], dtype=np.float64) + + for value in self.info.values(): + if isinstance(value, dict) and SimEnv.STATE_KEY in value: + return np.asarray(value[SimEnv.STATE_KEY], dtype=np.float64) + + msg = f"Recorded step {self.step} does not contain a sim state in info." + raise KeyError(msg) + + @property + def sim_state_spec(self) -> int | None: + if SimEnv.STATE_SPEC_KEY in self.info: + return int(self.info[SimEnv.STATE_SPEC_KEY]) + + for value in self.info.values(): + if isinstance(value, dict) and SimEnv.STATE_SPEC_KEY in value: + return int(value[SimEnv.STATE_SPEC_KEY]) + + return None + + +def load_distinct_uuids(dataset_path: Path | str) -> list[str]: + connection = duckdb.connect() + try: + rows = connection.execute( + "SELECT DISTINCT uuid FROM read_parquet(?) ORDER BY uuid", + [str(dataset_path)], + ).fetchall() + finally: + connection.close() + return [str(row[0]) for row in rows] + + +def load_trajectory(dataset_path: Path | str, trajectory_uuid: str) -> list[RecordedSimStep]: + connection = duckdb.connect() + try: + rows = connection.execute( + "SELECT uuid, step, timestamp, obs, info, env_action, instruction, success " + "FROM read_parquet(?) WHERE uuid = ? ORDER BY step", + [str(dataset_path), trajectory_uuid], + ).fetchall() + finally: + connection.close() + + return [ + RecordedSimStep( + step=int(row[1]), + uuid=str(row[0]), + timestamp=float(row[2]) if row[2] is not None else None, + observation=row[3], + info=row[4], + action=row[5], + instruction=str(row[6]), + success=bool(row[7]), + ) + for row in rows + ] + + +def restore_sim_step(env: gym.Env, recorded_step: RecordedSimStep): + try: + lead_env = env.get_wrapper_attr("lead_env") + except AttributeError: + lead_env = None + + if lead_env is not None: + lead_env.set_replay_state(recorded_step.sim_state, spec=recorded_step.sim_state_spec) + else: + env.get_wrapper_attr("set_replay_state")(recorded_step.sim_state, spec=recorded_step.sim_state_spec) + + +def replay_trajectory(env: gym.Env, recorded_steps: list[RecordedSimStep], headless: bool): + if not recorded_steps: + msg = "No recorded sim states found in the requested trajectory." + raise ValueError(msg) + + env.reset() + for recorded_step in recorded_steps: + restore_sim_step(env, recorded_step) + env.step(recorded_step.action) + if not headless: + env.get_wrapper_attr("sim").sync_gui() + if recorded_step.success: + env.get_wrapper_attr("success")() + + +def replay( + dataset: Path | str, + headless: bool = True, + frequency: int = 30, + relative_to: str = RelativeTo.CONFIGURED_ORIGIN.name, + scene: str = "env_configs.EmptyWorldFR3Duo()", + task_cfg: str = 'env_tasks.PickTaskConfig(robot_name="right")', +): + exec_scope = {**globals(), "__builtins__": __builtins__, "env_configs": env_configs, "env_tasks": env_tasks} + scene_locals: dict[str, Any] = {} + exec(f"_result = {scene}", exec_scope, scene_locals) + sc = typing.cast(SimEnvCreator, scene_locals["_result"]) + sim_cfg_data = sc.config() + sim_cfg_data.sim_cfg = SimConfig( + async_control=True, + realtime=not headless, + frequency=frequency, + max_convergence_steps=500, + ) + sim_cfg_data.headless = headless + sim_cfg_data.relative_to = RelativeTo[relative_to.upper()] + if sim_cfg_data.root_frame_objects is None: + sim_cfg_data.root_frame_objects = {} + task_cfg_locals: dict[str, Any] = {} + exec(f"_result = {task_cfg}", exec_scope, task_cfg_locals) + sim_cfg_data.task_cfg = task_cfg_locals["_result"] + + uuids = load_distinct_uuids(dataset) + + env_rel = sc.create_env(sim_cfg_data) + env_rel = StorageWrapper( + env_rel, + DATASET_PATH, + "", + batch_size=32, + max_rows_per_group=100, + max_rows_per_file=1000, + always_record=True, + ) + try: + for uuid in uuids: + recorded_steps = load_trajectory(dataset, uuid) + if not recorded_steps: + continue + env_rel.get_wrapper_attr("set_instruction")(recorded_steps[0].instruction) + replay_trajectory(env_rel, recorded_steps, headless) + finally: + env_rel.close() diff --git a/python/rcs/sim_state_replay.py b/python/rcs/sim_state_replay.py deleted file mode 100644 index 7f5e6a32..00000000 --- a/python/rcs/sim_state_replay.py +++ /dev/null @@ -1,242 +0,0 @@ -from __future__ import annotations - -import time -from dataclasses import dataclass -from pathlib import Path -from typing import Annotated, Any - -import gymnasium as gym -import numpy as np -import pyarrow.compute as pc -import pyarrow.dataset as ds -import typer -from PIL import Image -from rcs.envs.base import ControlMode -from rcs.envs.sim import SimStateObservationWrapper - -import rcs # noqa: F401 - -app = typer.Typer(help="Replay recorded MuJoCo trajectories from a parquet dataset.") - -DATASET_ARGUMENT = typer.Argument(..., exists=True, file_okay=False, dir_okay=True) - -ENV_ID_OPTION = typer.Option(help="Gymnasium env id used for replay.") -TRAJECTORY_UUID_OPTION = typer.Option(help="UUID of the recorded trajectory to replay.") -CAMERA_OPTION = typer.Option("--camera", help="Camera names to enable on the replay env.") -RESOLUTION_OPTION = typer.Option(help="Replay camera resolution as WIDTH HEIGHT.") -FRAME_RATE_OPTION = typer.Option(help="Replay camera frame rate.") -RENDER_MODE_OPTION = typer.Option(help="Gym render mode for the replay env.") -CONTROL_MODE_OPTION = typer.Option(help="Control mode name for env creation.") -SLEEP_OPTION = typer.Option(help="Optional delay between restored states.") -OUTPUT_DIR_OPTION = typer.Option(help="Optional directory for re-rendered RGB frames.") -PREFER_DUCKDB_OPTION = typer.Option(help="Use duckdb for parquet loading when it is available.") - - -@dataclass(frozen=True) -class RecordedSimStep: - step: int - uuid: str - timestamp: float | None - observation: dict[str, Any] - - @property - def sim_state(self) -> np.ndarray: - return np.asarray(self.observation[SimStateObservationWrapper.STATE_KEY], dtype=np.float64) - - @property - def sim_state_spec(self) -> int: - return int(self.observation.get(SimStateObservationWrapper.STATE_SPEC_KEY, 0)) - - -class DuckDBUnavailableError(RuntimeError): - pass - - -def _get_duckdb_module(): - try: - import duckdb - except ModuleNotFoundError as exc: - msg = ( - "duckdb is required for the preferred parquet read path but is not installed. " - "Install the 'duckdb' Python package or rely on the pyarrow fallback in library calls." - ) - raise DuckDBUnavailableError(msg) from exc - return duckdb - - -def _load_distinct_uuids_with_duckdb(dataset_path: Path) -> list[str]: - duckdb = _get_duckdb_module() - connection = duckdb.connect() - try: - rows = connection.execute( - "SELECT DISTINCT uuid FROM read_parquet(?) ORDER BY uuid", - [str(dataset_path)], - ).fetchall() - finally: - connection.close() - return [row[0] for row in rows] - - -def _load_distinct_uuids_with_pyarrow(dataset_path: Path) -> list[str]: - dataset = ds.dataset(str(dataset_path), format="parquet") - uuids = dataset.to_table(columns=["uuid"])["uuid"] - return sorted(str(uuid) for uuid in pc.unique(uuids).to_pylist()) # type: ignore - - -def list_trajectory_ids(dataset_path: Path, prefer_duckdb: bool = True) -> list[str]: - if prefer_duckdb: - try: - return _load_distinct_uuids_with_duckdb(dataset_path) - except DuckDBUnavailableError: - pass - return _load_distinct_uuids_with_pyarrow(dataset_path) - - -def _load_trajectory_with_duckdb(dataset_path: Path, trajectory_uuid: str) -> list[RecordedSimStep]: - duckdb = _get_duckdb_module() - connection = duckdb.connect() - try: - table = connection.execute( - "SELECT uuid, step, timestamp, obs FROM read_parquet(?) WHERE uuid = ? ORDER BY step", - [str(dataset_path), trajectory_uuid], - ).to_arrow_table() - finally: - connection.close() - return [ - RecordedSimStep( - step=int(row["step"]), - uuid=str(row["uuid"]), - timestamp=float(row["timestamp"]) if row["timestamp"] is not None else None, - observation=row["obs"], - ) - for row in table.to_pylist() - ] - - -def _load_trajectory_with_pyarrow(dataset_path: Path, trajectory_uuid: str) -> list[RecordedSimStep]: - dataset = ds.dataset(str(dataset_path), format="parquet") - table = dataset.to_table(filter=pc.field("uuid") == trajectory_uuid, columns=["uuid", "step", "timestamp", "obs"]) - rows = table.sort_by([("step", "ascending")]).to_pylist() - return [ - RecordedSimStep( - step=int(row["step"]), - uuid=str(row["uuid"]), - timestamp=float(row["timestamp"]) if row["timestamp"] is not None else None, - observation=row["obs"], - ) - for row in rows - ] - - -def load_trajectory(dataset_path: Path, trajectory_uuid: str, prefer_duckdb: bool = True) -> list[RecordedSimStep]: - if prefer_duckdb: - try: - return _load_trajectory_with_duckdb(dataset_path, trajectory_uuid) - except DuckDBUnavailableError: - pass - return _load_trajectory_with_pyarrow(dataset_path, trajectory_uuid) - - -def resolve_trajectory_uuid(dataset_path: Path, trajectory_uuid: str | None, prefer_duckdb: bool = True) -> str: - if trajectory_uuid is not None: - return trajectory_uuid - available_uuids = list_trajectory_ids(dataset_path, prefer_duckdb=prefer_duckdb) - if len(available_uuids) == 1: - return available_uuids[0] - msg = ( - f"Dataset {dataset_path} contains {len(available_uuids)} trajectories. " - f"Pass --trajectory-uuid and choose one of: {available_uuids}" - ) - raise ValueError(msg) - - -def restore_sim_step(env: gym.Env, recorded_step: RecordedSimStep): - sim = env.get_wrapper_attr("sim") - sim.set_state(recorded_step.sim_state, spec=recorded_step.sim_state_spec) - - -def collect_rgb_frames(env: gym.Env) -> dict[str, np.ndarray]: - try: - camera_set = env.get_wrapper_attr("camera_set") - except AttributeError: - return {} - - frameset = camera_set.get_latest_frames() - if frameset is None: - return {} - - rgb_frames: dict[str, np.ndarray] = {} - for camera_name, frame in frameset.frames.items(): - lower_name = camera_name.lower() - if "digit" in lower_name or "tactile" in lower_name: - continue - rgb_frames[camera_name] = np.asarray(frame.camera.color.data) - return rgb_frames - - -def save_rgb_frames(output_dir: Path, recorded_step: RecordedSimStep, rgb_frames: dict[str, np.ndarray]): - output_dir.mkdir(parents=True, exist_ok=True) - for camera_name, rgb_frame in rgb_frames.items(): - Image.fromarray(rgb_frame).save(output_dir / f"step-{recorded_step.step:06d}-{camera_name}.png") - - -def replay_trajectory( - env: gym.Env, - recorded_steps: list[RecordedSimStep], - *, - sleep_s: float = 0.0, - output_dir: Path | None = None, -): - if not recorded_steps: - msg = "No recorded sim states found in the requested trajectory." - raise ValueError(msg) - - env.reset() - sim = env.get_wrapper_attr("sim") - for recorded_step in recorded_steps: - restore_sim_step(env, recorded_step) - if output_dir is not None: - save_rgb_frames(output_dir, recorded_step, collect_rgb_frames(env)) - sim.sync_gui() - if sleep_s > 0: - time.sleep(sleep_s) - - -@app.command() -def replay( - dataset: Annotated[Path, DATASET_ARGUMENT], - env_id: Annotated[str, ENV_ID_OPTION] = "rcs/FR3SimplePickUpSim-v0", - trajectory_uuid: Annotated[str | None, TRAJECTORY_UUID_OPTION] = None, - camera: Annotated[list[str] | None, CAMERA_OPTION] = None, - resolution: Annotated[tuple[int, int], RESOLUTION_OPTION] = (256, 256), - frame_rate: Annotated[int, FRAME_RATE_OPTION] = 0, - render_mode: Annotated[str, RENDER_MODE_OPTION] = "human", - control_mode: Annotated[str, CONTROL_MODE_OPTION] = ControlMode.CARTESIAN_TRPY.name, - sleep_s: Annotated[float, SLEEP_OPTION] = 0.0, - output_dir: Annotated[Path | None, OUTPUT_DIR_OPTION] = None, - prefer_duckdb: Annotated[bool, PREFER_DUCKDB_OPTION] = True, -): - if camera is None: - camera = [] - resolved_uuid = resolve_trajectory_uuid(dataset, trajectory_uuid, prefer_duckdb=prefer_duckdb) - env = gym.make( - env_id, - render_mode=render_mode, - control_mode=ControlMode[control_mode], - resolution=resolution, - frame_rate=frame_rate, - cam_list=camera, - ) - try: - recorded_steps = load_trajectory(dataset, resolved_uuid, prefer_duckdb=prefer_duckdb) - replay_trajectory(env, recorded_steps, sleep_s=sleep_s, output_dir=output_dir) - finally: - env.close() - - typer.echo(f"Replayed {len(recorded_steps)} steps from trajectory {resolved_uuid}.") - if output_dir is not None: - typer.echo(f"Saved re-rendered RGB frames to {output_dir}.") - - -if __name__ == "__main__": - app() diff --git a/python/tests/test_replayer.py b/python/tests/test_replayer.py new file mode 100644 index 00000000..c819c6b2 --- /dev/null +++ b/python/tests/test_replayer.py @@ -0,0 +1,248 @@ +from pathlib import Path +from typing import Any + +import duckdb +import numpy as np +from rcs._core.sim import SimConfig +from rcs.envs.base import RelativeTo +from rcs.envs.configs import EmptyWorldFR3Duo +from rcs.envs.storage_wrapper import StorageWrapper +from rcs.envs.tasks import PickTaskConfig +from rcs.sim.replayer import load_distinct_uuids, load_trajectory, replay_trajectory + + +def _build_env(output_dir: Path, *, with_cameras: bool, instruction: str = "") -> StorageWrapper: + scene = EmptyWorldFR3Duo() + cfg = scene.config() + cfg.sim_cfg = SimConfig(async_control=True, realtime=False, frequency=30, max_convergence_steps=500) + cfg.headless = True + cfg.relative_to = RelativeTo.CONFIGURED_ORIGIN + if cfg.root_frame_objects is None: + cfg.root_frame_objects = {} + cfg.task_cfg = PickTaskConfig(robot_name="right") + if not with_cameras: + cfg.camera_cfgs = {} + else: + assert cfg.camera_cfgs is not None + for camera_cfg in cfg.camera_cfgs.values(): + camera_cfg.resolution_width = 64 + camera_cfg.resolution_height = 48 + camera_cfg.frame_rate = 1 + + env = scene.create_env(cfg) + return StorageWrapper( + env, + str(output_dir), + instruction, + batch_size=2, + max_rows_per_group=10, + max_rows_per_file=10, + always_record=True, + ) + + +def _record_source_dataset(dataset_dir: Path, *, limit: int, instruction: str) -> None: + env = _build_env(dataset_dir, with_cameras=False, instruction=instruction) + try: + env.reset() + action = { + "left": { + "tquat": np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]), + "gripper": np.array([1.0], dtype=np.float32), + }, + "right": { + "tquat": np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]), + "gripper": np.array([1.0], dtype=np.float32), + }, + } + for _ in range(limit): + env.step(action) + finally: + env.close() + + +def _source_rows(dataset_dir: Path, limit: int): + connection = duckdb.connect() + try: + uuid = load_distinct_uuids(dataset_dir)[0] + return connection.execute( + """ + SELECT step, obs, info, reward, success, action, env_action, instruction + FROM read_parquet(?) + WHERE uuid = ? + ORDER BY step + LIMIT ? + """, + [str(dataset_dir), uuid, limit], + ).fetchall() + finally: + connection.close() + + +def _replay_rows(dataset_dir: Path): + connection = duckdb.connect() + try: + return connection.execute( + """ + SELECT step, obs, info, reward, success, action, env_action, instruction + FROM read_parquet(?) + ORDER BY step + """, + [str(dataset_dir)], + ).fetchall() + finally: + connection.close() + + +def _replay_prefix(output_dir: Path, *, with_cameras: bool, limit: int) -> None: + source_dir = output_dir.parent / "source" + env = _build_env(output_dir, with_cameras=with_cameras) + try: + uuid = load_distinct_uuids(source_dir)[0] + recorded_steps = load_trajectory(source_dir, uuid)[:limit] + env.get_wrapper_attr("set_instruction")(recorded_steps[0].instruction) + replay_trajectory(env, recorded_steps, True) + finally: + env.close() + + +def _assert_nested_close(actual: Any, expected: Any, *, atol: float = 1e-6): + if isinstance(expected, dict): + assert isinstance(actual, dict) + assert actual.keys() == expected.keys() + for key in expected: + _assert_nested_close(actual[key], expected[key], atol=atol) + return + if isinstance(expected, list): + assert isinstance(actual, list) + assert len(actual) == len(expected) + for actual_item, expected_item in zip(actual, expected, strict=True): + _assert_nested_close(actual_item, expected_item, atol=atol) + return + if expected is None: + assert actual is None + return + if isinstance(expected, bool): + assert actual is expected + return + if isinstance(expected, (int, float)): + assert np.isclose(actual, expected, rtol=0.0, atol=atol) + return + assert actual == expected + + +def _strip_unstable_info(info: dict[str, Any]) -> dict[str, Any]: + cleaned = {} + for key, value in info.items(): + if key in {"camera_available", "frame_timestamp"}: + continue + if isinstance(value, dict): + cleaned[key] = { + nested_key: nested_value + for nested_key, nested_value in value.items() + if nested_key not in {"sim_state", "is_sim_converged", "absolute_action"} + } + else: + cleaned[key] = value + return cleaned + + +def _strip_frames(obs: dict[str, Any]) -> dict[str, Any]: + return {key: value for key, value in obs.items() if key != "frames"} + + +def test_replayer_reproduces_existing_parquet_prefix_without_cameras(tmp_path: Path): + source_dir = tmp_path / "source" + replay_dir = tmp_path / "replayed" + limit = 3 + instruction = "pick up cube" + + _record_source_dataset(source_dir, limit=limit, instruction=instruction) + _replay_prefix(replay_dir, with_cameras=False, limit=limit) + + source_rows = _source_rows(source_dir, limit) + replay_rows = _replay_rows(replay_dir) + + assert len(source_rows) == len(replay_rows) == limit + + for replay_row, source_row in zip(replay_rows, source_rows, strict=True): + ( + replay_step, + replay_obs, + replay_info, + replay_reward, + replay_success, + replay_action, + replay_env_action, + replay_instruction, + ) = replay_row + ( + source_step, + source_obs, + source_info, + source_reward, + source_success, + source_action, + source_env_action, + source_instruction, + ) = source_row + + _assert_nested_close(replay_step, source_step) + _assert_nested_close(replay_obs, source_obs, atol=1e-5) + assert replay_info["camera_available"] is source_info["camera_available"] + _assert_nested_close(_strip_unstable_info(replay_info), _strip_unstable_info(source_info), atol=1e-5) + _assert_nested_close(replay_reward, source_reward, atol=1e-8) + _assert_nested_close(replay_success, source_success) + _assert_nested_close(replay_action, source_action, atol=1e-8) + _assert_nested_close(replay_env_action, source_env_action, atol=1e-8) + _assert_nested_close(replay_instruction, source_instruction) + + +def test_replayer_adds_cameras_to_existing_episode_without_cameras(tmp_path: Path): + source_dir = tmp_path / "source" + replay_dir = tmp_path / "replayed_with_cameras" + limit = 3 + instruction = "pick up cube" + + _record_source_dataset(source_dir, limit=limit, instruction=instruction) + _replay_prefix(replay_dir, with_cameras=True, limit=limit) + + source_rows = _source_rows(source_dir, limit) + replay_rows = _replay_rows(replay_dir) + + assert len(source_rows) == len(replay_rows) == limit + + for replay_row, source_row in zip(replay_rows, source_rows, strict=True): + ( + replay_step, + replay_obs, + replay_info, + replay_reward, + replay_success, + replay_action, + replay_env_action, + replay_instruction, + ) = replay_row + ( + source_step, + source_obs, + source_info, + source_reward, + source_success, + source_action, + source_env_action, + source_instruction, + ) = source_row + + assert "frames" in replay_obs + assert set(replay_obs["frames"]) == {"head", "left_wrist", "right_wrist"} + assert replay_info["camera_available"] is True + assert "frame_timestamp" in replay_info + _assert_nested_close(replay_step, source_step) + _assert_nested_close(_strip_frames(replay_obs), source_obs, atol=1e-5) + _assert_nested_close(_strip_unstable_info(replay_info), _strip_unstable_info(source_info), atol=1e-5) + _assert_nested_close(replay_reward, source_reward, atol=1e-8) + _assert_nested_close(replay_success, source_success) + _assert_nested_close(replay_action, source_action, atol=1e-8) + _assert_nested_close(replay_env_action, source_env_action, atol=1e-8) + _assert_nested_close(replay_instruction, source_instruction) diff --git a/python/tests/test_sim_state_record_replay.py b/python/tests/test_sim_state_record_replay.py deleted file mode 100644 index 78abf1f3..00000000 --- a/python/tests/test_sim_state_record_replay.py +++ /dev/null @@ -1,195 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from pathlib import Path - -import gymnasium as gym -import mujoco as mj -import numpy as np -import pyarrow.dataset as ds -from rcs._core.common import RobotPlatform -from rcs.camera.interface import CameraFrame, DataFrame, Frame, FrameSet -from rcs.envs.sim import SimStateObservationWrapper -from rcs.envs.storage_wrapper import StorageWrapper -from rcs.sim.sim import Sim -from rcs.sim_state_replay import ( - RecordedSimStep, - load_trajectory, - replay_trajectory, - restore_sim_step, -) - -XML = """ - - - - - - - - - -""" - - -@dataclass -class DummyCameraSet: - sim: Sim - - def get_latest_frames(self) -> FrameSet: - color_value = int(np.clip(round((self.sim.data.qpos[0] + 1.0) * 80.0), 0, 255)) - rgb = np.full((8, 8, 3), color_value, dtype=np.uint8) - return FrameSet( - frames={ - "main": Frame( - camera=CameraFrame( - color=DataFrame(data=rgb), - depth=None, - ), - ) - }, - avg_timestamp=None, - ) - - -class DummySimEnv(gym.Env): - PLATFORM = RobotPlatform.SIMULATION - - def __init__(self, sim: Sim, camera_set: DummyCameraSet | None = None): - super().__init__() - self.sim = sim - self.camera_set = camera_set - self.action_space = gym.spaces.Dict( - { - "delta": gym.spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float64), - } - ) - self.observation_space = gym.spaces.Dict( - { - "qpos": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.sim.model.nq,), dtype=np.float64), - "qvel": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.sim.model.nv,), dtype=np.float64), - } - ) - - def _obs(self) -> dict[str, np.ndarray]: - return { - "qpos": self.sim.data.qpos.copy(), - "qvel": self.sim.data.qvel.copy(), - } - - def get_wrapper_attr(self, name: str): - return getattr(self, name) - - def reset(self, *, seed: int | None = None, options: dict | None = None): - super().reset(seed=seed) - mj.mj_resetData(self.sim.model, self.sim.data) - mj.mj_forward(self.sim.model, self.sim.data) - return self._obs(), {"dummy": True} - - def step(self, action: dict[str, np.ndarray]): - self.sim.data.qpos[0] += float(action["delta"][0]) - self.sim.data.qvel[:] = 0.0 - mj.mj_forward(self.sim.model, self.sim.data) - return self._obs(), 0.0, False, False, {"dummy": True} - - def close(self): - return None - - -class SpySim: - def __init__(self): - self.states: list[tuple[np.ndarray, int | None]] = [] - self.sync_calls = 0 - - def set_state(self, state: np.ndarray, spec: int | None = None): - self.states.append((np.asarray(state, dtype=np.float64), spec)) - - def sync_gui(self): - self.sync_calls += 1 - - -class SpyReplayEnv(gym.Env): - PLATFORM = RobotPlatform.SIMULATION - - def __init__(self, sim: SpySim): - super().__init__() - self.sim = sim - self.reset_calls = 0 - - def get_wrapper_attr(self, name: str): - return getattr(self, name) - - def reset(self, *, seed: int | None = None, options: dict | None = None): - self.reset_calls += 1 - return {}, {} - - -def test_replay_trajectory_syncs_gui_without_stepping(): - spy_sim = SpySim() - env = SpyReplayEnv(spy_sim) - recorded_steps = [ - RecordedSimStep( - step=3, - uuid="traj-1", - timestamp=1.23, - observation={ - SimStateObservationWrapper.STATE_KEY: np.array([1.0, 2.0, 3.0], dtype=np.float64), - SimStateObservationWrapper.STATE_SPEC_KEY: 7, - }, - ) - ] - - replay_trajectory(env, recorded_steps) - - assert env.reset_calls == 1 - assert len(spy_sim.states) == 1 - np.testing.assert_allclose(spy_sim.states[0][0], np.array([1.0, 2.0, 3.0], dtype=np.float64)) - assert spy_sim.states[0][1] == 7 - assert spy_sim.sync_calls == 1 - - -def test_record_and_replay_sim_state(tmp_path: Path): - model_path = tmp_path / "dummy.xml" - model_path.write_text(XML) - - dataset_path = tmp_path / "dataset" - record_env: gym.Env = DummySimEnv(Sim(model_path)) - record_env = SimStateObservationWrapper(record_env) - record_env = StorageWrapper(record_env, str(dataset_path), "test sim replay") - obs, _ = record_env.reset() - record_env.start_record() - assert SimStateObservationWrapper.STATE_KEY in obs - record_env.step({"delta": np.array([0.125], dtype=np.float64)}) - record_env.stop_record() - record_env.close() - - table = ds.dataset(str(dataset_path), format="parquet").to_table().sort_by([("step", "ascending")]) - rows = table.to_pylist() - assert len(rows) == 1 - - recorded_obs = rows[0]["obs"] - assert SimStateObservationWrapper.STATE_KEY in recorded_obs - assert SimStateObservationWrapper.STATE_SPEC_KEY in recorded_obs - - recorded_steps = load_trajectory(dataset_path, rows[0]["uuid"], prefer_duckdb=True) - assert len(recorded_steps) == 1 - assert np.allclose(recorded_steps[0].sim_state, np.asarray(recorded_obs[SimStateObservationWrapper.STATE_KEY])) - - replay_sim = Sim(model_path) - replay_env: gym.Env = DummySimEnv(replay_sim, camera_set=DummyCameraSet(replay_sim)) - replay_env = SimStateObservationWrapper(replay_env) - render_dir = tmp_path / "rendered" - - replay_env.reset() - restore_sim_step(replay_env, recorded_steps[0]) - assert np.allclose( - replay_env.get_wrapper_attr("sim").data.qpos, np.asarray(recorded_obs["qpos"]), atol=1e-9, rtol=0 - ) - assert np.allclose( - replay_env.get_wrapper_attr("sim").data.qvel, np.asarray(recorded_obs["qvel"]), atol=1e-9, rtol=0 - ) - - replay_trajectory(replay_env, recorded_steps, output_dir=render_dir) - - rendered_files = sorted(path.name for path in render_dir.glob("*.png")) - assert rendered_files == ["step-000000-main.png"]