diff --git a/examples/teleop/franka.py b/examples/teleop/franka.py index 979b55fe..1645ccef 100644 --- a/examples/teleop/franka.py +++ b/examples/teleop/franka.py @@ -55,6 +55,7 @@ DATASET_PATH = "test_iris" INSTRUCTION = "pick up cube" +RECORD_FPS = 30 robot2world = { "right": rcs.common.Pose( @@ -143,7 +144,9 @@ def get_env(): scene = EmptyWorldFR3Duo() sim_cfg_data = scene.config() - sim_cfg_data.sim_cfg = SimConfig(async_control=True, realtime=True, frequency=30, max_convergence_steps=500) + sim_cfg_data.sim_cfg = SimConfig( + async_control=True, realtime=True, frequency=RECORD_FPS, max_convergence_steps=500 + ) sim_cfg_data.relative_to = RelativeTo.CONFIGURED_ORIGIN if sim_cfg_data.root_frame_objects is None: sim_cfg_data.root_frame_objects = {} @@ -164,7 +167,7 @@ def get_env(): def main(): env_rel, operator = get_env() env_rel.reset() - tele = TeleopLoop(env_rel, operator) + tele = TeleopLoop(env_rel, operator, env_frequency=RECORD_FPS, robot_platform=ROBOT_INSTANCE) with env_rel, tele: # type: ignore tele.environment_step_loop() diff --git a/python/rcs/envs/base.py b/python/rcs/envs/base.py index a62628a9..b3c21c05 100644 --- a/python/rcs/envs/base.py +++ b/python/rcs/envs/base.py @@ -544,10 +544,11 @@ class RelativeTo(Enum): NONE = auto() -class RelativeActionSpace(gym.ActionWrapper): +class RelativeActionSpace(ActObsInfoWrapper): DEFAULT_MAX_CART_MOV = 0.5 DEFAULT_MAX_CART_ROT = np.deg2rad(90) DEFAULT_MAX_JOINT_MOV = np.deg2rad(5) + ABSOLUTE_ACTION_KEY = "absolute_action" def __init__( self, @@ -624,6 +625,7 @@ def __init__( self.initial_obs: dict[str, Any] | None = None self._origin: common.Pose | VecType | None = None self._last_action: common.Pose | VecType | None = None + self._absolute_action: common.Pose | VecType | None = None def set_origin(self, origin: common.Pose | VecType): if self.get_wrapper_attr("get_control_mode")() == ControlMode.JOINTS: @@ -669,7 +671,9 @@ def action(self, action: dict[str, Any]) -> dict[str, Any]: limited_joints_diff = np.clip(joints_diff, -self.max_mov, self.max_mov) limited_joints = limited_joints_diff + self._last_action self._last_action = limited_joints - action.update(JointsDictType(joints=np.clip(self._origin + limited_joints, low, high))) + clipped_joints = np.clip(self._origin + limited_joints, low, high) + action.update(JointsDictType(joints=clipped_joints)) + self._absolute_action = clipped_joints elif self.get_wrapper_attr("get_control_mode")() == ControlMode.CARTESIAN_TRPY and self.trpy_key in action: assert isinstance(self._origin, common.Pose), "Invalid origin type given the control mode." @@ -705,16 +709,15 @@ def action(self, action: dict[str, Any]) -> dict[str, Any]: translation=self._origin.translation() + clipped_pose_offset.translation(), # type: ignore rpy_vector=(clipped_pose_offset * self._origin).rotation_rpy().as_vector(), ) - action.update( - TRPYDictType( - xyzrpy=np.concatenate( # type: ignore - [ - np.clip(unclipped_pose.translation(), pose_space.low[:3], pose_space.high[:3]), - unclipped_pose.rotation_rpy().as_vector(), - ], - ) - ) + clipped_pose = np.concatenate( # type: ignore + [ + np.clip(unclipped_pose.translation(), pose_space.low[:3], pose_space.high[:3]), + unclipped_pose.rotation_rpy().as_vector(), + ], ) + action.update(TRPYDictType(xyzrpy=clipped_pose)) + self._absolute_action = clipped_pose + elif self.get_wrapper_attr("get_control_mode")() == ControlMode.CARTESIAN_TQuat and self.tquat_key in action: assert isinstance(self._origin, common.Pose), "Invalid origin type given the control mode." assert isinstance(self.max_mov, tuple) @@ -749,22 +752,29 @@ def action(self, action: dict[str, Any]) -> dict[str, Any]: translation=self._origin.translation() + clipped_pose_offset.translation(), # type: ignore quaternion=(clipped_pose_offset * self._origin).rotation_q(), ) - - action.update( - TQuatDictType( - tquat=np.concatenate( # type: ignore - [ - np.clip(unclipped_pose.translation(), pose_space.low[:3], pose_space.high[:3]), - unclipped_pose.rotation_q(), - ], - ) - ) + clipped_pose = np.concatenate( # type: ignore + [ + np.clip(unclipped_pose.translation(), pose_space.low[:3], pose_space.high[:3]), + unclipped_pose.rotation_q(), + ], ) + action.update(TQuatDictType(tquat=clipped_pose)) # type: ignore + self._absolute_action = clipped_pose + else: msg = "Given type is not matching control mode!" raise RuntimeError(msg) return action + def observation(self, observation: dict, info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + if self._absolute_action is not None: + info[self.ABSOLUTE_ACTION_KEY] = ( + list(self._absolute_action.translation()) + list(self._absolute_action.rotation_q()) + if isinstance(self._absolute_action, common.Pose) + else self._absolute_action + ) + return observation, info + class CameraSetWrapper(ActObsInfoWrapper): RGB_KEY = "rgb" diff --git a/python/rcs/envs/storage_wrapper.py b/python/rcs/envs/storage_wrapper.py index 01253bc9..1073cd04 100644 --- a/python/rcs/envs/storage_wrapper.py +++ b/python/rcs/envs/storage_wrapper.py @@ -14,6 +14,7 @@ import pyarrow.dataset as ds import simplejpeg from PIL import Image +from rcs.envs.base import RelativeActionSpace class StorageWrapper(gym.Wrapper): @@ -89,6 +90,7 @@ def __init__( self.instruction = instruction self._success = False self._prev_action = None + self._prev_absolute_action = None self.thread_pool = ThreadPoolExecutor() self.queue: Queue[pa.Table | pa.RecordBatch] = Queue(maxsize=2) @@ -265,9 +267,25 @@ def step(self, action): "action": self._prev_action, "instruction": self.instruction, "timestamp": datetime.datetime.now().timestamp(), + RelativeActionSpace.ABSOLUTE_ACTION_KEY: self._prev_absolute_action, } ) self._prev_action = action + if RelativeActionSpace.ABSOLUTE_ACTION_KEY in obs: + # single env + self._prev_absolute_action = obs[RelativeActionSpace.ABSOLUTE_ACTION_KEY] + else: + # multi env wrapper + act_dict = {} + try: + for key in self.get_wrapper_attr("envs"): + if RelativeActionSpace.ABSOLUTE_ACTION_KEY in obs[key]: + act_dict[key] = obs[key][RelativeActionSpace.ABSOLUTE_ACTION_KEY] + except AttributeError: + pass + if len(act_dict) != 0: + self._prev_absolute_action = act_dict # type: ignore + self.step_cnt += 1 if len(self.buffer) == self.batch_size: self._flush() @@ -291,6 +309,7 @@ def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = Non self._pause = not self.always_record self._success = False self._prev_action = None + self._prev_absolute_action = None obs, info = self.env.reset() self.step_cnt = 0 self.uuid = uuid4() diff --git a/python/rcs/operator/interface.py b/python/rcs/operator/interface.py index e7e6b3c0..37debc92 100644 --- a/python/rcs/operator/interface.py +++ b/python/rcs/operator/interface.py @@ -7,6 +7,7 @@ from time import sleep import gymnasium as gym +from rcs._core.common import RobotPlatform from rcs.envs.base import ArmWithGripper, ControlMode, RelativeTo from rcs.sim.sim import Sim from rcs.utils import SimpleFrameRate @@ -71,12 +72,14 @@ def __init__( env: gym.Env, operator: BaseOperator, env_frequency: int = 30, + robot_platform: RobotPlatform = RobotPlatform.HARDWARE, key_translation: dict[str, str] | None = None, ): super().__init__() self.env = env self.operator = operator self._exit_requested = False + self.robot_platform = robot_platform self.env_frequency = env_frequency if key_translation is None: # controller to robot translation @@ -114,7 +117,9 @@ def _translate_keys(self, actions): return translated def environment_step_loop(self): - rate_limiter = SimpleFrameRate(self.env_frequency, "env loop") + rate_limiter = SimpleFrameRate( + self.env_frequency if self.robot_platform == RobotPlatform.HARDWARE else None, "env loop" + ) # 0. Initial Reset to get current positions for untracked robots self._last_obs, _ = self.env.reset() diff --git a/python/rcs/utils.py b/python/rcs/utils.py index f2d3bebc..9f9b7e7f 100644 --- a/python/rcs/utils.py +++ b/python/rcs/utils.py @@ -6,7 +6,7 @@ class SimpleFrameRate: - def __init__(self, frame_rate: float, loop_name: str = "SimpleFrameRate"): + def __init__(self, frame_rate: float | None, loop_name: str = "SimpleFrameRate"): """SimpleFrameRate is a utility class to manage frame rates in a simple way. It allows you to call it in a loop, and it will sleep the necessary time to maintain the desired frame rate. @@ -22,6 +22,8 @@ def reset(self): self.t = None def __call__(self): + if self.frame_rate is None: + return if self.t is None: self.t = perf_counter() self._last_print = self.t