diff --git a/python/rcs/_core/sim.pyi b/python/rcs/_core/sim.pyi index fa5a8899..71f79cba 100644 --- a/python/rcs/_core/sim.pyi +++ b/python/rcs/_core/sim.pyi @@ -97,6 +97,7 @@ class Sim: def is_converged(self) -> bool: ... def reset(self) -> None: ... def set_config(self, cfg: SimConfig) -> bool: ... + def set_model_data(self, mjmdl: int, mjdata: int) -> None: ... def step(self, k: int) -> None: ... def step_until_convergence(self) -> None: ... def sync_gui(self) -> None: ... diff --git a/python/rcs/envs/scenes.py b/python/rcs/envs/scenes.py index b661204d..2c8c42b0 100644 --- a/python/rcs/envs/scenes.py +++ b/python/rcs/envs/scenes.py @@ -70,8 +70,9 @@ def add_task_mujoco(cfg: TaskConfig, composer: ModelComposer, env_cfg: "SimEnvCr """Add task-specific elements to the Mujoco scene.""" @staticmethod - def add_task_env(_cfg: TaskConfig, env: gym.Env, _simulation: Sim, _env_cfg: "SimEnvCreatorConfig") -> gym.Env: + def add_task_env(cfg: TaskConfig, env: gym.Env, simulation: Sim, env_cfg: "SimEnvCreatorConfig") -> gym.Env: """Add task-specific wrappers to the environment.""" + _ = (cfg, simulation, env_cfg) return env diff --git a/python/rcs/sim/sim.py b/python/rcs/sim/sim.py index 07963e4d..5bce943e 100644 --- a/python/rcs/sim/sim.py +++ b/python/rcs/sim/sim.py @@ -48,19 +48,7 @@ class Sim(_Sim): STATE_SPEC = mj.mjtState.mjSTATE_INTEGRATION def __init__(self, mjmdl: str | PathLike | ModelComposer, cfg: SimConfig | None = None): - if isinstance(mjmdl, ModelComposer): - self.model = mjmdl.get_model() - else: - mjmdl = Path(mjmdl) - if mjmdl.suffix == ".xml": - self.model = mj.MjModel.from_xml_path(str(mjmdl)) - elif mjmdl.suffix == ".mjb": - self.model = mj.MjModel.from_binary_path(str(mjmdl)) - else: - msg = f"Filetype {mjmdl.suffix} is unknown" - logger.error(msg) - - self.data = mj.MjData(self.model) + self.model, self.data = self.get_model_data(mjmdl) super().__init__(self.model._address, self.data._address) self._mp_context = mp.get_context("spawn") self._gui_uuid: Optional[str] = None @@ -71,6 +59,28 @@ def __init__(self, mjmdl: str | PathLike | ModelComposer, cfg: SimConfig | None if cfg is not None: self.set_config(cfg) + def get_model_data(self, mjmdl: str | PathLike | ModelComposer) -> tuple[mj.MjModel, mj.MjData]: + if isinstance(mjmdl, ModelComposer): + model = mjmdl.get_model() + else: + mjmdl = Path(mjmdl) + if mjmdl.suffix == ".xml": + model = mj.MjModel.from_xml_path(str(mjmdl)) + elif mjmdl.suffix == ".mjb": + model = mj.MjModel.from_binary_path(str(mjmdl)) + else: + msg = f"Filetype {mjmdl.suffix} is unknown" + logger.error(msg) + data = mj.MjData(model) + return model, data + + def set_model_data(self, model: mj.MjModel, data: mj.MjData): + super().set_model_data(model._address, data._address) + + def instantiate_new_sim(self, mjmdl: str | PathLike | ModelComposer): + self.model, self.data = self.get_model_data(mjmdl) + self.set_model_data(self.model, self.data) + def get_state_spec(self) -> int: return int(self.STATE_SPEC) diff --git a/src/pybind/rcs.cpp b/src/pybind/rcs.cpp index ab7772c0..23e42944 100644 --- a/src/pybind/rcs.cpp +++ b/src/pybind/rcs.cpp @@ -735,6 +735,12 @@ PYBIND11_MODULE(_core, m) { return std::make_shared((mjModel*)m, (mjData*)d); }), py::arg("mjmdl"), py::arg("mjdata")) + .def( + "set_model_data", + [](rcs::sim::Sim& self, long model, long data) { + self.set_model_data((mjModel*)model, (mjData*)data); + }, + py::arg("mjmdl"), py::arg("mjdata")) .def("step_until_convergence", &rcs::sim::Sim::step_until_convergence, py::call_guard()) .def("is_converged", &rcs::sim::Sim::is_converged) diff --git a/src/sim/sim.cpp b/src/sim/sim.cpp index 8facd93d..bfecf014 100644 --- a/src/sim/sim.cpp +++ b/src/sim/sim.cpp @@ -25,7 +25,12 @@ bool get_last_return_value(ConditionCallback cb) { return cb.last_return_value; } -Sim::Sim(mjModel* m, mjData* d) : m(m), d(d), renderer(m) {}; +Sim::Sim(mjModel* m, mjData* d) : m(m), d(d), renderer(m) {} + +void Sim::set_model_data(mjModel* model, mjData* data) { + this->m = model; + this->d = data; +} bool Sim::set_config(const SimConfig& cfg) { this->cfg = cfg; diff --git a/src/sim/sim.h b/src/sim/sim.h index 4ed35e60..bcc364c1 100644 --- a/src/sim/sim.h +++ b/src/sim/sim.h @@ -76,6 +76,7 @@ class Sim { mjModel* m; mjData* d; Sim(mjModel* m, mjData* d); + void set_model_data(mjModel* model, mjData* data); bool set_config(const SimConfig& cfg); SimConfig get_config(); bool is_converged();