Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions examples/teleop/franka.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

DATASET_PATH = "test_iris"
INSTRUCTION = "pick up cube"
RECORD_FPS = 30

robot2world = {
"right": rcs.common.Pose(
Expand Down Expand Up @@ -143,7 +144,9 @@ def get_env():

scene = EmptyWorldFR3Duo()
sim_cfg_data = scene.config()
sim_cfg_data.sim_cfg = SimConfig(async_control=True, realtime=True, frequency=30, max_convergence_steps=500)
sim_cfg_data.sim_cfg = SimConfig(
async_control=True, realtime=True, frequency=RECORD_FPS, max_convergence_steps=500
)
sim_cfg_data.relative_to = RelativeTo.CONFIGURED_ORIGIN
if sim_cfg_data.root_frame_objects is None:
sim_cfg_data.root_frame_objects = {}
Expand All @@ -164,7 +167,7 @@ def get_env():
def main():
env_rel, operator = get_env()
env_rel.reset()
tele = TeleopLoop(env_rel, operator)
tele = TeleopLoop(env_rel, operator, env_frequency=RECORD_FPS, robot_platform=ROBOT_INSTANCE)
with env_rel, tele: # type: ignore
tele.environment_step_loop()

Expand Down
52 changes: 31 additions & 21 deletions python/rcs/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,10 +544,11 @@ class RelativeTo(Enum):
NONE = auto()


class RelativeActionSpace(gym.ActionWrapper):
class RelativeActionSpace(ActObsInfoWrapper):
DEFAULT_MAX_CART_MOV = 0.5
DEFAULT_MAX_CART_ROT = np.deg2rad(90)
DEFAULT_MAX_JOINT_MOV = np.deg2rad(5)
ABSOLUTE_ACTION_KEY = "absolute_action"

def __init__(
self,
Expand Down Expand Up @@ -624,6 +625,7 @@ def __init__(
self.initial_obs: dict[str, Any] | None = None
self._origin: common.Pose | VecType | None = None
self._last_action: common.Pose | VecType | None = None
self._absolute_action: common.Pose | VecType | None = None

def set_origin(self, origin: common.Pose | VecType):
if self.get_wrapper_attr("get_control_mode")() == ControlMode.JOINTS:
Expand Down Expand Up @@ -669,7 +671,9 @@ def action(self, action: dict[str, Any]) -> dict[str, Any]:
limited_joints_diff = np.clip(joints_diff, -self.max_mov, self.max_mov)
limited_joints = limited_joints_diff + self._last_action
self._last_action = limited_joints
action.update(JointsDictType(joints=np.clip(self._origin + limited_joints, low, high)))
clipped_joints = np.clip(self._origin + limited_joints, low, high)
action.update(JointsDictType(joints=clipped_joints))
self._absolute_action = clipped_joints

elif self.get_wrapper_attr("get_control_mode")() == ControlMode.CARTESIAN_TRPY and self.trpy_key in action:
assert isinstance(self._origin, common.Pose), "Invalid origin type given the control mode."
Expand Down Expand Up @@ -705,16 +709,15 @@ def action(self, action: dict[str, Any]) -> dict[str, Any]:
translation=self._origin.translation() + clipped_pose_offset.translation(), # type: ignore
rpy_vector=(clipped_pose_offset * self._origin).rotation_rpy().as_vector(),
)
action.update(
TRPYDictType(
xyzrpy=np.concatenate( # type: ignore
[
np.clip(unclipped_pose.translation(), pose_space.low[:3], pose_space.high[:3]),
unclipped_pose.rotation_rpy().as_vector(),
],
)
)
clipped_pose = np.concatenate( # type: ignore
[
np.clip(unclipped_pose.translation(), pose_space.low[:3], pose_space.high[:3]),
unclipped_pose.rotation_rpy().as_vector(),
],
)
action.update(TRPYDictType(xyzrpy=clipped_pose))
self._absolute_action = clipped_pose

elif self.get_wrapper_attr("get_control_mode")() == ControlMode.CARTESIAN_TQuat and self.tquat_key in action:
assert isinstance(self._origin, common.Pose), "Invalid origin type given the control mode."
assert isinstance(self.max_mov, tuple)
Expand Down Expand Up @@ -749,22 +752,29 @@ def action(self, action: dict[str, Any]) -> dict[str, Any]:
translation=self._origin.translation() + clipped_pose_offset.translation(), # type: ignore
quaternion=(clipped_pose_offset * self._origin).rotation_q(),
)

action.update(
TQuatDictType(
tquat=np.concatenate( # type: ignore
[
np.clip(unclipped_pose.translation(), pose_space.low[:3], pose_space.high[:3]),
unclipped_pose.rotation_q(),
],
)
)
clipped_pose = np.concatenate( # type: ignore
[
np.clip(unclipped_pose.translation(), pose_space.low[:3], pose_space.high[:3]),
unclipped_pose.rotation_q(),
],
)
action.update(TQuatDictType(tquat=clipped_pose)) # type: ignore
self._absolute_action = clipped_pose

else:
msg = "Given type is not matching control mode!"
raise RuntimeError(msg)
return action

def observation(self, observation: dict, info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
if self._absolute_action is not None:
info[self.ABSOLUTE_ACTION_KEY] = (
list(self._absolute_action.translation()) + list(self._absolute_action.rotation_q())
if isinstance(self._absolute_action, common.Pose)
else self._absolute_action
)
return observation, info


class CameraSetWrapper(ActObsInfoWrapper):
RGB_KEY = "rgb"
Expand Down
19 changes: 19 additions & 0 deletions python/rcs/envs/storage_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pyarrow.dataset as ds
import simplejpeg
from PIL import Image
from rcs.envs.base import RelativeActionSpace


class StorageWrapper(gym.Wrapper):
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
self.instruction = instruction
self._success = False
self._prev_action = None
self._prev_absolute_action = None

self.thread_pool = ThreadPoolExecutor()
self.queue: Queue[pa.Table | pa.RecordBatch] = Queue(maxsize=2)
Expand Down Expand Up @@ -265,9 +267,25 @@ def step(self, action):
"action": self._prev_action,
"instruction": self.instruction,
"timestamp": datetime.datetime.now().timestamp(),
RelativeActionSpace.ABSOLUTE_ACTION_KEY: self._prev_absolute_action,
}
)
self._prev_action = action
if RelativeActionSpace.ABSOLUTE_ACTION_KEY in obs:
# single env
self._prev_absolute_action = obs[RelativeActionSpace.ABSOLUTE_ACTION_KEY]
else:
# multi env wrapper
act_dict = {}
try:
for key in self.get_wrapper_attr("envs"):
if RelativeActionSpace.ABSOLUTE_ACTION_KEY in obs[key]:
act_dict[key] = obs[key][RelativeActionSpace.ABSOLUTE_ACTION_KEY]
except AttributeError:
pass
if len(act_dict) != 0:
self._prev_absolute_action = act_dict # type: ignore

self.step_cnt += 1
if len(self.buffer) == self.batch_size:
self._flush()
Expand All @@ -291,6 +309,7 @@ def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = Non
self._pause = not self.always_record
self._success = False
self._prev_action = None
self._prev_absolute_action = None
obs, info = self.env.reset()
self.step_cnt = 0
self.uuid = uuid4()
Expand Down
7 changes: 6 additions & 1 deletion python/rcs/operator/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from time import sleep

import gymnasium as gym
from rcs._core.common import RobotPlatform
from rcs.envs.base import ArmWithGripper, ControlMode, RelativeTo
from rcs.sim.sim import Sim
from rcs.utils import SimpleFrameRate
Expand Down Expand Up @@ -71,12 +72,14 @@ def __init__(
env: gym.Env,
operator: BaseOperator,
env_frequency: int = 30,
robot_platform: RobotPlatform = RobotPlatform.HARDWARE,
key_translation: dict[str, str] | None = None,
):
super().__init__()
self.env = env
self.operator = operator
self._exit_requested = False
self.robot_platform = robot_platform
self.env_frequency = env_frequency
if key_translation is None:
# controller to robot translation
Expand Down Expand Up @@ -114,7 +117,9 @@ def _translate_keys(self, actions):
return translated

def environment_step_loop(self):
rate_limiter = SimpleFrameRate(self.env_frequency, "env loop")
rate_limiter = SimpleFrameRate(
self.env_frequency if self.robot_platform == RobotPlatform.HARDWARE else None, "env loop"
)

# 0. Initial Reset to get current positions for untracked robots
self._last_obs, _ = self.env.reset()
Expand Down
4 changes: 3 additions & 1 deletion python/rcs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class SimpleFrameRate:
def __init__(self, frame_rate: float, loop_name: str = "SimpleFrameRate"):
def __init__(self, frame_rate: float | None, loop_name: str = "SimpleFrameRate"):
"""SimpleFrameRate is a utility class to manage frame rates in a simple way.
It allows you to call it in a loop, and it will sleep the necessary time to maintain the desired frame rate.

Expand All @@ -22,6 +22,8 @@ def reset(self):
self.t = None

def __call__(self):
if self.frame_rate is None:
return
if self.t is None:
self.t = perf_counter()
self._last_print = self.t
Expand Down
Loading