diff --git a/docs/source/api/lab/isaaclab.utils.rst b/docs/source/api/lab/isaaclab.utils.rst index 5b352152e0b5..f236ebcb6a15 100644 --- a/docs/source/api/lab/isaaclab.utils.rst +++ b/docs/source/api/lab/isaaclab.utils.rst @@ -188,3 +188,16 @@ Warp operations :members: :imported-members: :show-inheritance: + +Warp Fabric kernels +^^^^^^^^^^^^^^^^^^^ + +Warp kernels for reading and writing Fabric ``Matrix4d`` attributes +(``omni:fabric:worldMatrix`` / ``omni:fabric:localMatrix``) via +:class:`wp.fabricarray` and :class:`wp.indexedfabricarray`. Will be used by +:class:`~isaaclab_physx.sim.views.FabricFrameView` to keep child world and +local matrices consistent without round-tripping through USD. + +.. automodule:: isaaclab.utils.warp.fabric + :members: + :show-inheritance: diff --git a/scripts/benchmarks/benchmark_view_comparison.py b/scripts/benchmarks/benchmark_view_comparison.py index aa5927e10b6a..2f19554d5ffc 100644 --- a/scripts/benchmarks/benchmark_view_comparison.py +++ b/scripts/benchmarks/benchmark_view_comparison.py @@ -284,7 +284,8 @@ def _run_pose_benchmarks( start_time = time.perf_counter() for _ in range(num_iterations): - view.set_world_poses(new_positions, orientations) + with view.xform_world_space_writer() as w: + w.set_poses(new_positions, orientations) timing_results["set_world_poses"] = (time.perf_counter() - start_time) / num_iterations ret_pos, ret_quat = view.get_world_poses() diff --git a/scripts/benchmarks/benchmark_xform_prim_view.py b/scripts/benchmarks/benchmark_xform_prim_view.py index fee3b9642c79..ae6656e3d8d0 100644 --- a/scripts/benchmarks/benchmark_xform_prim_view.py +++ b/scripts/benchmarks/benchmark_xform_prim_view.py @@ -139,11 +139,16 @@ def benchmark_frame_view( # noqa: C901 is_newton = api == "isaaclab-newton-site" def to_torch(a): - return wp.to_torch(a) if isinstance(a, wp.array) else a + if isinstance(a, wp.array): + return wp.to_torch(a) + if hasattr(a, "torch"): + return a.torch + return a try: # -- Warmup -------------------------------------------------------- xform_view.get_world_poses() + xform_view.get_world_scales() # -- get_world_poses ----------------------------------------------- if is_newton: @@ -162,7 +167,7 @@ def to_torch(a): # -- set_world_poses ----------------------------------------------- if is_newton: - new_positions = wp.clone(positions) + new_positions = wp.clone(positions.warp) wp.to_torch(new_positions)[:, 2] += 0.1 else: new_positions = positions_t.clone() @@ -172,7 +177,8 @@ def to_torch(a): torch.cuda.synchronize() start_time = time.perf_counter() for _ in range(num_iterations): - xform_view.set_world_poses(new_positions, orientations) + with xform_view.xform_world_space_writer() as w: + w.set_poses(new_positions, orientations) if is_newton: torch.cuda.synchronize() timing_results["set_world_poses"] = (time.perf_counter() - start_time) / num_iterations @@ -198,7 +204,7 @@ def to_torch(a): # -- set_local_poses ----------------------------------------------- if is_newton: - new_translations = wp.clone(translations) + new_translations = wp.clone(translations.warp) wp.to_torch(new_translations)[:, 2] += 0.1 else: new_translations = translations_t.clone() @@ -208,7 +214,8 @@ def to_torch(a): torch.cuda.synchronize() start_time = time.perf_counter() for _ in range(num_iterations): - xform_view.set_local_poses(new_translations, orientations_local) + with xform_view.xform_local_space_writer() as w: + w.set_poses(new_translations, orientations_local) if is_newton: torch.cuda.synchronize() timing_results["set_local_poses"] = (time.perf_counter() - start_time) / num_iterations @@ -217,6 +224,72 @@ def to_torch(a): computed_results["local_translations_after_set"] = to_torch(ta).clone() computed_results["local_orientations_after_set"] = to_torch(ola).clone() + # -- get_world_scales ---------------------------------------------- + if is_newton: + torch.cuda.synchronize() + start_time = time.perf_counter() + for _ in range(num_iterations): + world_scales = xform_view.get_world_scales() + if is_newton: + torch.cuda.synchronize() + timing_results["get_world_scales"] = (time.perf_counter() - start_time) / num_iterations + + world_scales_t = to_torch(world_scales) + computed_results["initial_world_scales"] = world_scales_t.clone() + + # -- set_world_scales ---------------------------------------------- + if is_newton: + new_world_scales = wp.clone(world_scales.warp) + wp.to_torch(new_world_scales)[:] = 1.1 + else: + new_world_scales = world_scales_t.clone() + new_world_scales[:] = 1.1 + + if is_newton: + torch.cuda.synchronize() + start_time = time.perf_counter() + for _ in range(num_iterations): + with xform_view.xform_world_space_writer() as w: + w.set_scales(new_world_scales) + if is_newton: + torch.cuda.synchronize() + timing_results["set_world_scales"] = (time.perf_counter() - start_time) / num_iterations + + computed_results["world_scales_after_set"] = to_torch(xform_view.get_world_scales()).clone() + + # -- get_local_scales ---------------------------------------------- + if is_newton: + torch.cuda.synchronize() + start_time = time.perf_counter() + for _ in range(num_iterations): + local_scales = xform_view.get_local_scales() + if is_newton: + torch.cuda.synchronize() + timing_results["get_local_scales"] = (time.perf_counter() - start_time) / num_iterations + + local_scales_t = to_torch(local_scales) + computed_results["initial_local_scales"] = local_scales_t.clone() + + # -- set_local_scales ---------------------------------------------- + if is_newton: + new_local_scales = wp.clone(local_scales.warp) + wp.to_torch(new_local_scales)[:] = 0.9 + else: + new_local_scales = local_scales_t.clone() + new_local_scales[:] = 0.9 + + if is_newton: + torch.cuda.synchronize() + start_time = time.perf_counter() + for _ in range(num_iterations): + with xform_view.xform_local_space_writer() as w: + w.set_scales(new_local_scales) + if is_newton: + torch.cuda.synchronize() + timing_results["set_local_scales"] = (time.perf_counter() - start_time) / num_iterations + + computed_results["local_scales_after_set"] = to_torch(xform_view.get_local_scales()).clone() + # -- get_both (world + local) -------------------------------------- if is_newton: torch.cuda.synchronize() @@ -233,7 +306,8 @@ def to_torch(a): torch.cuda.synchronize() start_time = time.perf_counter() for _ in range(num_iterations): - xform_view.set_world_poses(new_positions, orientations) + with xform_view.xform_world_space_writer() as w: + w.set_poses(new_positions, orientations) xform_view.get_world_poses() if is_newton: torch.cuda.synchronize() @@ -273,6 +347,10 @@ def print_results(results_dict: dict[str, dict[str, float]], num_prims: int, num ("Set World Poses", "set_world_poses"), ("Get Local Poses", "get_local_poses"), ("Set Local Poses", "set_local_poses"), + ("Get World Scales", "get_world_scales"), + ("Set World Scales", "set_world_scales"), + ("Get Local Scales", "get_local_scales"), + ("Set Local Scales", "set_local_scales"), ("Get Both (World+Local)", "get_both"), ("Interleaved World Set->Get", "interleaved_world_set_get"), ] diff --git a/source/isaaclab/changelog.d/fabric-local-poses.rst b/source/isaaclab/changelog.d/fabric-local-poses.rst new file mode 100644 index 000000000000..1ca61cfda97b --- /dev/null +++ b/source/isaaclab/changelog.d/fabric-local-poses.rst @@ -0,0 +1,28 @@ +Added +^^^^^ + +* Added explicit local/world scale methods + :meth:`~isaaclab.sim.views.BaseFrameView.get_local_scales`, + :meth:`~isaaclab.sim.views.BaseFrameView.set_local_scales`, + :meth:`~isaaclab.sim.views.BaseFrameView.get_world_scales`, and + :meth:`~isaaclab.sim.views.BaseFrameView.set_world_scales` to the FrameView + API, implemented for :class:`~isaaclab.sim.views.UsdFrameView`. + +* Added :func:`~isaaclab.utils.warp.fabric.decompose_indexed_fabric_transforms`, + :func:`~isaaclab.utils.warp.fabric.compose_indexed_fabric_transforms`, + :func:`~isaaclab.utils.warp.fabric.update_indexed_local_matrix_from_world`, and + :func:`~isaaclab.utils.warp.fabric.update_indexed_world_matrix_from_local` + Warp kernels operating on :class:`wp.indexedfabricarray` for reading and + writing Fabric ``Matrix4d`` attributes (``omni:fabric:worldMatrix`` / + ``omni:fabric:localMatrix``). + +Deprecated +^^^^^^^^^^ + +* Deprecated :meth:`~isaaclab.sim.views.BaseFrameView.get_scales` and + :meth:`~isaaclab.sim.views.BaseFrameView.set_scales` in favor of the explicit + ``get_local_scales`` / ``set_local_scales`` (operates on ``xformOp:scale``) or + ``get_world_scales`` / ``set_world_scales`` (operates on composed world-space + scale). The deprecated methods still work but emit a ``DeprecationWarning``; + :class:`~isaaclab.sim.views.UsdFrameView` preserves prior behavior by + defaulting to local scales. diff --git a/source/isaaclab/changelog.d/xform-space-writer.rst b/source/isaaclab/changelog.d/xform-space-writer.rst new file mode 100644 index 000000000000..41c9c97a55cc --- /dev/null +++ b/source/isaaclab/changelog.d/xform-space-writer.rst @@ -0,0 +1,41 @@ +Added +^^^^^ + +* Added :class:`~isaaclab.sim.views.FrameViewSpaceWriterBase`, the new context-managed + write API for ``FrameView``-managed prim transforms. Open with + ``view.xform_world_space_writer()`` or ``view.xform_local_space_writer()`` and call + :meth:`~isaaclab.sim.views.FrameViewSpaceWriterBase.set_poses` / + :meth:`~isaaclab.sim.views.FrameViewSpaceWriterBase.set_scales` inside the scope; + the writer's ``__exit__`` derives the opposite-space matrices once and + synchronizes once. Only one writer scope may be active per view at a + time. View-level getters + (:meth:`~isaaclab.sim.views.BaseFrameView.get_world_poses` etc.) raise + :class:`RuntimeError` while a writer scope is active. + +* Added the two concrete tag classes + :class:`~isaaclab.sim.views.FrameViewWorldSpaceWriter` and + :class:`~isaaclab.sim.views.FrameViewLocalSpaceWriter` returned by + :meth:`~isaaclab.sim.views.BaseFrameView.xform_world_space_writer` / + :meth:`~isaaclab.sim.views.BaseFrameView.xform_local_space_writer`. + +Deprecated +^^^^^^^^^^ + +* Deprecated :meth:`~isaaclab.sim.views.BaseFrameView.set_world_poses` and + :meth:`~isaaclab.sim.views.BaseFrameView.set_local_poses`. Use + ``with view.xform_world_space_writer() as w: w.set_poses(...)`` (or + :meth:`~isaaclab.sim.views.BaseFrameView.xform_local_space_writer`) + instead. The deprecated methods still work but emit a one-time + ``DeprecationWarning`` per class and open a single-statement writer scope + internally. + +Removed +^^^^^^^ + +* **Breaking:** Removed ``set_world_scales`` and ``set_local_scales`` + from :class:`~isaaclab.sim.views.BaseFrameView` (and all subclasses). + These were introduced in this release cycle without a stable downstream + user, so they are removed outright (no deprecation cycle). Use + ``with view.xform_world_space_writer() as w: w.set_scales(...)`` (or + :meth:`~isaaclab.sim.views.BaseFrameView.xform_local_space_writer`) + instead. diff --git a/source/isaaclab/isaaclab/sensors/camera/camera.py b/source/isaaclab/isaaclab/sensors/camera/camera.py index 5ed97a3825f5..58cf4423eaf5 100644 --- a/source/isaaclab/isaaclab/sensors/camera/camera.py +++ b/source/isaaclab/isaaclab/sensors/camera/camera.py @@ -376,7 +376,8 @@ def set_world_poses( orientations = convert_camera_frame_orientation_convention(orientations, origin=convention, target="opengl") ori_wp = wp.from_torch(orientations.contiguous(), dtype=wp.vec4f) idx_wp = self._resolve_env_ids_wp(env_ids) - self._view.set_world_poses(pos_wp, ori_wp, idx_wp) + with self._view.xform_world_space_writer() as writer: + writer.set_poses(pos_wp, ori_wp, idx_wp) def set_world_poses_from_view( self, eyes: torch.Tensor, targets: torch.Tensor, env_ids: Sequence[int] | None = None @@ -434,11 +435,12 @@ def set_world_poses_from_view( env_ids_torch = env_ids_torch.index_select(0, valid_indices) orientations = quat_from_matrix(rotation_matrix) idx_wp = wp.from_torch(env_ids_torch.contiguous(), dtype=wp.int32) - self._view.set_world_poses( - wp.from_torch(eyes.contiguous(), dtype=wp.vec3f), - wp.from_torch(orientations.contiguous(), dtype=wp.vec4f), - idx_wp, - ) + with self._view.xform_world_space_writer() as writer: + writer.set_poses( + wp.from_torch(eyes.contiguous(), dtype=wp.vec3f), + wp.from_torch(orientations.contiguous(), dtype=wp.vec4f), + idx_wp, + ) """ Operations diff --git a/source/isaaclab/isaaclab/sim/views/__init__.pyi b/source/isaaclab/isaaclab/sim/views/__init__.pyi index d578f85d6ada..734e925c19bb 100644 --- a/source/isaaclab/isaaclab/sim/views/__init__.pyi +++ b/source/isaaclab/isaaclab/sim/views/__init__.pyi @@ -7,6 +7,9 @@ __all__ = [ "BaseFrameView", "UsdFrameView", "FrameView", + "FrameViewSpaceWriterBase", + "FrameViewWorldSpaceWriter", + "FrameViewLocalSpaceWriter", # Deprecated alias "XformPrimView", ] @@ -14,5 +17,6 @@ __all__ = [ from .base_frame_view import BaseFrameView from .usd_frame_view import UsdFrameView from .frame_view import FrameView +from .xform_space_writer import FrameViewSpaceWriterBase, FrameViewWorldSpaceWriter, FrameViewLocalSpaceWriter # Deprecated alias from .xform_prim_view import XformPrimView diff --git a/source/isaaclab/isaaclab/sim/views/base_frame_view.py b/source/isaaclab/isaaclab/sim/views/base_frame_view.py index 656108f24d2c..dc7cc324f53a 100644 --- a/source/isaaclab/isaaclab/sim/views/base_frame_view.py +++ b/source/isaaclab/isaaclab/sim/views/base_frame_view.py @@ -8,23 +8,49 @@ from __future__ import annotations import abc +import warnings +from typing import TYPE_CHECKING import warp as wp from isaaclab.utils.warp import ProxyArray +if TYPE_CHECKING: + from .xform_space_writer import FrameViewLocalSpaceWriter, FrameViewSpaceWriterBase, FrameViewWorldSpaceWriter + class BaseFrameView(abc.ABC): - """Abstract interface for reading and writing world-space transforms of multiple prims. + """Abstract interface for reading and writing transforms of multiple prims. Backend-specific implementations (USD/Fabric, Newton GPU state, etc.) subclass this to provide efficient batched pose queries. The factory :class:`~isaaclab.sim.views.FrameView` selects the correct implementation at runtime based on the active physics backend. - All getters return :class:`~isaaclab.utils.warp.ProxyArray`. Setters accept ``wp.array``. + All getters return :class:`~isaaclab.utils.warp.ProxyArray`. All writes go + through the writer-scope API -- :meth:`xform_world_space_writer` or + :meth:`xform_local_space_writer`: + + .. code-block:: python + + with view.xform_world_space_writer() as writer: + writer.set_poses(positions=p, orientations=o) + writer.set_scales(scales=s) + # Derived-space matrices are recomputed and the writer scope is closed. + + Only one writer scope may be active per view at a time. While a writer + scope is active, the view-level getters + (:meth:`get_world_poses`, :meth:`get_local_poses`, + :meth:`get_world_scales`, :meth:`get_local_scales`) raise + :class:`RuntimeError` -- use the writer's :meth:`~FrameViewSpaceWriterBase.get_poses` + or :meth:`~FrameViewSpaceWriterBase.get_scales` inside the scope, or exit the + scope first. """ + # Class-level default; instance-level value is set by the writer's + # __enter__ / __exit__ to track the active scope on this view. + _active_writer: FrameViewSpaceWriterBase | None = None + @property @abc.abstractmethod def count(self) -> int: @@ -37,7 +63,77 @@ def device(self) -> str: """Device where arrays are allocated (``"cpu"`` or ``"cuda:0"``).""" ... + # ------------------------------------------------------------------ + # Write scope -- recommended API for all transform writes. + # ------------------------------------------------------------------ + + def xform_world_space_writer(self) -> FrameViewWorldSpaceWriter: + """Open a world-space write scope on this view (recommended write API). + + Inside the scope, :meth:`~FrameViewSpaceWriterBase.set_poses` / + :meth:`~FrameViewSpaceWriterBase.set_scales` write world-space values. + + Returns: + A :class:`~isaaclab.sim.views.FrameViewWorldSpaceWriter` context manager. + + Raises: + RuntimeError: On ``__enter__``, if another writer is already active + on this view. + + Example: + .. code-block:: python + + with view.xform_world_space_writer() as w: + w.set_poses(positions=p, orientations=o) + w.set_scales(scales=s) + """ + return self._make_world_space_writer() + + def xform_local_space_writer(self) -> FrameViewLocalSpaceWriter: + """Open a local-space write scope on this view (recommended write API). + + Inside the scope, :meth:`~FrameViewSpaceWriterBase.set_poses` / + :meth:`~FrameViewSpaceWriterBase.set_scales` write local-space values. + + Returns: + A :class:`~isaaclab.sim.views.FrameViewLocalSpaceWriter` context manager. + + Raises: + RuntimeError: On ``__enter__``, if another writer is already active + on this view. + + Example: + .. code-block:: python + + with view.xform_local_space_writer() as w: + w.set_poses(translations=t, orientations=o) + w.set_scales(scales=s) + """ + return self._make_local_space_writer() + + @abc.abstractmethod + def _make_world_space_writer(self) -> FrameViewWorldSpaceWriter: + """Backend hook: return a fresh :class:`FrameViewWorldSpaceWriter` for this view.""" + ... + @abc.abstractmethod + def _make_local_space_writer(self) -> FrameViewLocalSpaceWriter: + """Backend hook: return a fresh :class:`FrameViewLocalSpaceWriter` for this view.""" + ... + + def _assert_no_active_writer(self, method_name: str) -> None: + """Raise :class:`RuntimeError` if a writer scope is currently active on this view.""" + if self._active_writer is not None: + raise RuntimeError( + f"{type(self).__name__}.{method_name}() is not allowed while a writer " + f"scope is active ({type(self._active_writer).__name__}). Use the writer's " + f"get_poses / get_scales inside the scope, or exit the scope first." + ) + + # ------------------------------------------------------------------ + # Public getters -- guarded; delegate to backend ``_*_impl`` hooks. + # ------------------------------------------------------------------ + def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: """Get world-space positions and orientations for prims in the view. @@ -46,12 +142,101 @@ def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, Returns: A tuple ``(positions, orientations)`` of :class:`~isaaclab.utils.warp.ProxyArray` - wrappers. Use ``.warp`` for the underlying ``wp.array`` or ``.torch`` for a - cached zero-copy ``torch.Tensor`` view. + wrappers. + + Raises: + RuntimeError: If a writer scope is active on this view. + """ + self._assert_no_active_writer("get_world_poses") + return self._get_world_poses_impl(indices) + + def get_local_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: + """Get local-space translations and orientations for prims in the view. + + Args: + indices: Subset of prims to query. ``None`` means all prims. + + Returns: + A tuple ``(translations, orientations)`` of :class:`~isaaclab.utils.warp.ProxyArray` + wrappers. + + Raises: + RuntimeError: If a writer scope is active on this view. + """ + self._assert_no_active_writer("get_local_poses") + return self._get_local_poses_impl(indices) + + def get_local_scales(self, indices: wp.array | None = None) -> ProxyArray: + """Get local-space scales for prims in the view. + + Args: + indices: Subset of prims to query. ``None`` means all prims. + + Returns: + A :class:`~isaaclab.utils.warp.ProxyArray` of shape ``(M, 3)``. + + Raises: + RuntimeError: If a writer scope is active on this view. + """ + self._assert_no_active_writer("get_local_scales") + return self._get_local_scales_impl(indices) + + def get_world_scales(self, indices: wp.array | None = None) -> ProxyArray: + """Get world-space (composed) scales for prims in the view. + + Returns the effective scale in world space (``parent_scale * local_scale``). + + .. note:: + Scale extraction uses TRS (Translation-Rotation-Scale) decomposition, + which assumes no shear/skew in the transform matrix. If a prim's + world transform contains shear, the extracted scale values will be + approximate. + + Args: + indices: Subset of prims to query. ``None`` means all prims. + + Returns: + A :class:`~isaaclab.utils.warp.ProxyArray` of shape ``(M, 3)``. + + Raises: + RuntimeError: If a writer scope is active on this view. """ + self._assert_no_active_writer("get_world_scales") + return self._get_world_scales_impl(indices) + + # ------------------------------------------------------------------ + # Backend hooks for the public getters above. + # ------------------------------------------------------------------ + + @abc.abstractmethod + def _get_world_poses_impl(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: + """Backend implementation of :meth:`get_world_poses`.""" + ... + + @abc.abstractmethod + def _get_local_poses_impl(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: + """Backend implementation of :meth:`get_local_poses`.""" ... @abc.abstractmethod + def _get_local_scales_impl(self, indices: wp.array | None = None) -> ProxyArray: + """Backend implementation of :meth:`get_local_scales`.""" + ... + + @abc.abstractmethod + def _get_world_scales_impl(self, indices: wp.array | None = None) -> ProxyArray: + """Backend implementation of :meth:`get_world_scales`.""" + ... + + # ------------------------------------------------------------------ + # Deprecated pose setters -- route through the writer scope. + # ------------------------------------------------------------------ + + _set_world_poses_deprecated_warned: bool = False + _set_local_poses_deprecated_warned: bool = False + _set_scales_deprecated_warned: bool = False + _get_scales_deprecated_warned: bool = False + def set_world_poses( self, positions: wp.array | None = None, @@ -60,28 +245,26 @@ def set_world_poses( ) -> None: """Set world-space positions and/or orientations for prims in the view. + .. deprecated:: + Use ``with view.xform_world_space_writer() as w: w.set_poses(...)`` instead. + This method opens a single-statement writer scope internally. + Args: positions: World-space positions ``(M, 3)``. ``None`` leaves positions unchanged. orientations: World-space quaternions ``(M, 4)``. ``None`` leaves orientations unchanged. indices: Subset of prims to update. ``None`` means all prims. """ - ... + if not BaseFrameView._set_world_poses_deprecated_warned: + BaseFrameView._set_world_poses_deprecated_warned = True + warnings.warn( + "set_world_poses() is deprecated. Use 'with view.xform_world_space_writer() as w:" + " w.set_poses(...)' instead.", + DeprecationWarning, + stacklevel=2, + ) + with self.xform_world_space_writer() as writer: + writer.set_poses(positions, orientations, indices) - @abc.abstractmethod - def get_local_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: - """Get local-space positions and orientations for prims in the view. - - Args: - indices: Subset of prims to query. ``None`` means all prims. - - Returns: - A tuple ``(translations, orientations)`` of :class:`~isaaclab.utils.warp.ProxyArray` - wrappers. Use ``.warp`` for the underlying ``wp.array`` or ``.torch`` for a - cached zero-copy ``torch.Tensor`` view. - """ - ... - - @abc.abstractmethod def set_local_poses( self, translations: wp.array | None = None, @@ -90,31 +273,86 @@ def set_local_poses( ) -> None: """Set local-space translations and/or orientations for prims in the view. + .. deprecated:: + Use ``with view.xform_local_space_writer() as w: w.set_poses(...)`` instead. + This method opens a single-statement writer scope internally. + Args: translations: Local-space translations ``(M, 3)``. ``None`` leaves translations unchanged. orientations: Local-space quaternions ``(M, 4)``. ``None`` leaves orientations unchanged. indices: Subset of prims to update. ``None`` means all prims. """ - ... + if not BaseFrameView._set_local_poses_deprecated_warned: + BaseFrameView._set_local_poses_deprecated_warned = True + warnings.warn( + "set_local_poses() is deprecated. Use 'with view.xform_local_space_writer() as w:" + " w.set_poses(...)' instead.", + DeprecationWarning, + stacklevel=2, + ) + with self.xform_local_space_writer() as writer: + writer.set_poses(translations, orientations, indices) + + # ------------------------------------------------------------------ + # Deprecated -- use writer scope or get_local_scales / get_world_scales. + # ------------------------------------------------------------------ - @abc.abstractmethod def get_scales(self, indices: wp.array | None = None) -> ProxyArray: """Get scales for prims in the view. + .. deprecated:: + Use :meth:`get_local_scales` or :meth:`get_world_scales` instead. + This method delegates to :meth:`_get_scales_impl` which preserves + each backend's legacy behavior. + Args: indices: Subset of prims to query. ``None`` means all prims. Returns: - A :class:`~isaaclab.utils.warp.ProxyArray` of shape ``(M, 3)``. + A ``ProxyArray`` of shape ``(M, 3)``. + + Raises: + RuntimeError: If a writer scope is active on this view. """ - ... + if not BaseFrameView._get_scales_deprecated_warned: + BaseFrameView._get_scales_deprecated_warned = True + warnings.warn( + "get_scales() is deprecated. Use get_local_scales() or get_world_scales() instead.", + DeprecationWarning, + stacklevel=2, + ) + self._assert_no_active_writer("get_scales") + return self._get_scales_impl(indices) - @abc.abstractmethod def set_scales(self, scales: wp.array, indices: wp.array | None = None) -> None: """Set scales for prims in the view. + .. deprecated:: + Use ``with view.xform_world_space_writer() as w: w.set_scales(...)`` (or + :meth:`xform_local_space_writer`) instead. This method delegates to + :meth:`_set_scales_impl` which opens the backend's legacy space + (world for Fabric, local for USD) and calls ``writer.set_scales``. + Args: scales: Scales ``(M, 3)`` as ``wp.array``. indices: Subset of prims to update. ``None`` means all prims. """ + if not BaseFrameView._set_scales_deprecated_warned: + BaseFrameView._set_scales_deprecated_warned = True + warnings.warn( + "set_scales() is deprecated. Use 'with view.xform_world_space_writer() as w:" + " w.set_scales(...)' (or xform_local_space_writer()) instead.", + DeprecationWarning, + stacklevel=2, + ) + self._set_scales_impl(scales, indices) + + @abc.abstractmethod + def _get_scales_impl(self, indices: wp.array | None = None) -> ProxyArray: + """Backend-specific implementation for deprecated :meth:`get_scales`.""" + ... + + @abc.abstractmethod + def _set_scales_impl(self, scales: wp.array, indices: wp.array | None = None) -> None: + """Backend-specific implementation for deprecated :meth:`set_scales`.""" ... diff --git a/source/isaaclab/isaaclab/sim/views/usd_frame_view.py b/source/isaaclab/isaaclab/sim/views/usd_frame_view.py index 31331221e46f..7145d4e07b00 100644 --- a/source/isaaclab/isaaclab/sim/views/usd_frame_view.py +++ b/source/isaaclab/isaaclab/sim/views/usd_frame_view.py @@ -17,6 +17,7 @@ from isaaclab.utils.warp import ProxyArray from .base_frame_view import BaseFrameView +from .xform_space_writer import FrameViewLocalSpaceWriter, FrameViewWorldSpaceWriter logger = logging.getLogger(__name__) @@ -35,7 +36,14 @@ class UsdFrameView(BaseFrameView): For GPU-accelerated Fabric operations, use the PhysX backend variant obtained via :class:`~isaaclab.sim.views.FrameView`. - Getters return :class:`~isaaclab.utils.warp.ProxyArray`. Setters accept ``wp.array``. + All writes go through the writer-scope API (:meth:`xform_world_space_writer` + / :meth:`xform_local_space_writer`). The + USD backend's writers are pass-throughs: each :meth:`set_poses` / + :meth:`set_scales` call directly modifies the prim's USD ``xformOp:*`` + attributes (no batching, no derivation step on exit) -- USD has no + separate world-matrix storage to keep in sync. + + Getters return :class:`~isaaclab.utils.warp.ProxyArray`. .. note:: **Transform Requirements:** @@ -126,24 +134,70 @@ def prim_paths(self) -> list[str]: return self._prim_paths # ------------------------------------------------------------------ - # Setters + # Writer factory hooks (pass-through writers; USD has no derived state) + # ------------------------------------------------------------------ + + def _make_world_space_writer(self) -> FrameViewWorldSpaceWriter: + return _UsdWorldSpaceWriter(self) + + def _make_local_space_writer(self) -> FrameViewLocalSpaceWriter: + return _UsdLocalSpaceWriter(self) + + # ------------------------------------------------------------------ + # Visibility (USD-only, no writer scope) # ------------------------------------------------------------------ - def set_world_poses( + def set_visibility(self, visibility: torch.Tensor, indices: wp.array | None = None): + """Set visibility for prims in the view. + + Args: + visibility: Visibility as a boolean tensor of shape ``(M,)``. + indices: Indices of prims to set visibility for. Defaults to None (all prims). + """ + indices_list = self._resolve_indices(indices) + + if visibility.shape != (len(indices_list),): + raise ValueError(f"Expected visibility shape ({len(indices_list)},), got {visibility.shape}.") + + with Sdf.ChangeBlock(): + for idx, prim_idx in enumerate(indices_list): + imageable = UsdGeom.Imageable(self._prims[prim_idx]) + if visibility[idx]: + imageable.MakeVisible() + else: + imageable.MakeInvisible() + + def get_visibility(self, indices: wp.array | None = None) -> torch.Tensor: + """Get visibility for prims in the view. + + Args: + indices: Indices of prims to get visibility for. Defaults to None (all prims). + + Returns: + A tensor of shape ``(M,)`` containing the visibility of each prim (bool). + """ + indices_list = self._resolve_indices(indices) + + visibility = torch.zeros(len(indices_list), dtype=torch.bool, device=self._device) + for idx, prim_idx in enumerate(indices_list): + imageable = UsdGeom.Imageable(self._prims[prim_idx]) + visibility[idx] = imageable.ComputeVisibility() != UsdGeom.Tokens.invisible + return visibility + + # ------------------------------------------------------------------ + # Backend hooks: pose / scale writes (called by writers). + # ------------------------------------------------------------------ + + def _apply_world_pose_write( self, positions: wp.array | None = None, orientations: wp.array | None = None, indices: wp.array | None = None, - ): - """Set world-space poses for prims in the view. + ) -> None: + """Apply a world-space pose write directly to USD xform ops. Converts the desired world pose to local-space relative to each prim's - parent before writing to USD xform ops. - - Args: - positions: World-space positions of shape ``(M, 3)``. - orientations: World-space quaternions ``(w, x, y, z)`` of shape ``(M, 4)``. - indices: Indices of prims to set poses for. Defaults to None (all prims). + parent before writing. """ indices_list = self._resolve_indices(indices) @@ -187,19 +241,13 @@ def set_world_poses( if local_quat is not None: prim.GetAttribute("xformOp:orient").Set(local_quat) - def set_local_poses( + def _apply_local_pose_write( self, translations: wp.array | None = None, orientations: wp.array | None = None, indices: wp.array | None = None, - ): - """Set local-space poses for prims in the view. - - Args: - translations: Local-space translations of shape ``(M, 3)``. - orientations: Local-space quaternions ``(w, x, y, z)`` of shape ``(M, 4)``. - indices: Indices of prims to set poses for. Defaults to None (all prims). - """ + ) -> None: + """Apply a local-space pose write directly to USD xform ops.""" indices_list = self._resolve_indices(indices) translations_array = Vt.Vec3dArray.FromNumpy(self._to_numpy(translations)) if translations is not None else None @@ -213,13 +261,8 @@ def set_local_poses( if orientations_array is not None: prim.GetAttribute("xformOp:orient").Set(orientations_array[idx]) - def set_scales(self, scales: wp.array, indices: wp.array | None = None): - """Set scales for prims in the view. - - Args: - scales: Scales of shape ``(M, 3)``. - indices: Indices of prims to set scales for. Defaults to None (all prims). - """ + def _apply_local_scale_write(self, scales: wp.array, indices: wp.array | None = None) -> None: + """Apply a local-space scale write (``xformOp:scale``).""" indices_list = self._resolve_indices(indices) scales_array = Vt.Vec3dArray.FromNumpy(self._to_numpy(scales)) @@ -228,41 +271,41 @@ def set_scales(self, scales: wp.array, indices: wp.array | None = None): prim = self._prims[prim_idx] prim.GetAttribute("xformOp:scale").Set(scales_array[idx]) - def set_visibility(self, visibility: torch.Tensor, indices: wp.array | None = None): - """Set visibility for prims in the view. + def _apply_world_scale_write(self, scales: wp.array, indices: wp.array | None = None) -> None: + """Apply a world-space scale write. - Args: - visibility: Visibility as a boolean tensor of shape ``(M,)``. - indices: Indices of prims to set visibility for. Defaults to None (all prims). + Computes ``local_scale = world_scale / parent_world_scale`` and writes + to ``xformOp:scale``. """ indices_list = self._resolve_indices(indices) - - if visibility.shape != (len(indices_list),): - raise ValueError(f"Expected visibility shape ({len(indices_list)},), got {visibility.shape}.") + scales_np = self._to_numpy(scales) + xf_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) with Sdf.ChangeBlock(): for idx, prim_idx in enumerate(indices_list): - imageable = UsdGeom.Imageable(self._prims[prim_idx]) - if visibility[idx]: - imageable.MakeVisible() + prim = self._prims[prim_idx] + parent = prim.GetParent() + if parent and parent.IsValid() and parent.GetPath() != Sdf.Path.absoluteRootPath: + parent_world = xf_cache.GetLocalToWorldTransform(parent) + parent_scale = Gf.Vec3d( + Gf.Vec3d(parent_world[0][0], parent_world[0][1], parent_world[0][2]).GetLength(), + Gf.Vec3d(parent_world[1][0], parent_world[1][1], parent_world[1][2]).GetLength(), + Gf.Vec3d(parent_world[2][0], parent_world[2][1], parent_world[2][2]).GetLength(), + ) else: - imageable.MakeInvisible() + parent_scale = Gf.Vec3d(1.0, 1.0, 1.0) + local_scale = Gf.Vec3d( + float(scales_np[idx][0] / parent_scale[0]), + float(scales_np[idx][1] / parent_scale[1]), + float(scales_np[idx][2] / parent_scale[2]), + ) + prim.GetAttribute("xformOp:scale").Set(local_scale) # ------------------------------------------------------------------ - # Getters + # Backend hooks: pose / scale reads. # ------------------------------------------------------------------ - def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: - """Get world-space poses for prims in the view. - - Args: - indices: Indices of prims to get poses for. Defaults to None (all prims). - - Returns: - A tuple ``(positions, orientations)`` of :class:`~isaaclab.utils.warp.ProxyArray` - wrappers. Use ``.warp`` for the underlying ``wp.array`` or ``.torch`` for a - cached zero-copy ``torch.Tensor`` view. - """ + def _get_world_poses_impl(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: indices_list = self._resolve_indices(indices) positions = Vt.Vec3dArray(len(indices_list)) @@ -280,17 +323,7 @@ def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, quat_wp = wp.array(np.array(orientations, dtype=np.float32), dtype=wp.float32, device=self._device) return ProxyArray(pos_wp), ProxyArray(quat_wp) - def get_local_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: - """Get local-space poses for prims in the view. - - Args: - indices: Indices of prims to get poses for. Defaults to None (all prims). - - Returns: - A tuple ``(translations, orientations)`` of :class:`~isaaclab.utils.warp.ProxyArray` - wrappers. Use ``.warp`` for the underlying ``wp.array`` or ``.torch`` for a - cached zero-copy ``torch.Tensor`` view. - """ + def _get_local_poses_impl(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: indices_list = self._resolve_indices(indices) translations = Vt.Vec3dArray(len(indices_list)) @@ -308,15 +341,7 @@ def get_local_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, quat_wp = wp.array(np.array(orientations, dtype=np.float32), dtype=wp.float32, device=self._device) return ProxyArray(pos_wp), ProxyArray(quat_wp) - def get_scales(self, indices: wp.array | None = None) -> ProxyArray: - """Get scales for prims in the view. - - Args: - indices: Indices of prims to get scales for. Defaults to None (all prims). - - Returns: - A :class:`~isaaclab.utils.warp.ProxyArray` of shape ``(M, 3)``. - """ + def _get_local_scales_impl(self, indices: wp.array | None = None) -> ProxyArray: indices_list = self._resolve_indices(indices) scales = Vt.Vec3dArray(len(indices_list)) @@ -324,25 +349,34 @@ def get_scales(self, indices: wp.array | None = None) -> ProxyArray: prim = self._prims[prim_idx] scales[idx] = prim.GetAttribute("xformOp:scale").Get() - scales_wp = wp.array(np.array(scales, dtype=np.float32), dtype=wp.float32, device=self._device) - return ProxyArray(scales_wp) - - def get_visibility(self, indices: wp.array | None = None) -> torch.Tensor: - """Get visibility for prims in the view. - - Args: - indices: Indices of prims to get visibility for. Defaults to None (all prims). + return ProxyArray(wp.array(np.array(scales, dtype=np.float32), dtype=wp.float32, device=self._device)) - Returns: - A tensor of shape ``(M,)`` containing the visibility of each prim (bool). - """ + def _get_world_scales_impl(self, indices: wp.array | None = None) -> ProxyArray: indices_list = self._resolve_indices(indices) + xf_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) - visibility = torch.zeros(len(indices_list), dtype=torch.bool, device=self._device) + scales = np.empty((len(indices_list), 3), dtype=np.float32) for idx, prim_idx in enumerate(indices_list): - imageable = UsdGeom.Imageable(self._prims[prim_idx]) - visibility[idx] = imageable.ComputeVisibility() != UsdGeom.Tokens.invisible - return visibility + prim = self._prims[prim_idx] + world_mtx = xf_cache.GetLocalToWorldTransform(prim) + scales[idx, 0] = Gf.Vec3d(world_mtx[0][0], world_mtx[0][1], world_mtx[0][2]).GetLength() + scales[idx, 1] = Gf.Vec3d(world_mtx[1][0], world_mtx[1][1], world_mtx[1][2]).GetLength() + scales[idx, 2] = Gf.Vec3d(world_mtx[2][0], world_mtx[2][1], world_mtx[2][2]).GetLength() + + return ProxyArray(wp.array(scales, dtype=wp.float32, device=self._device)) + + # ------------------------------------------------------------------ + # Deprecated get_scales / set_scales hooks + # ------------------------------------------------------------------ + + def _get_scales_impl(self, indices: wp.array | None = None) -> ProxyArray: + """USD legacy: deprecated get_scales returns local scales.""" + return self._get_local_scales_impl(indices) + + def _set_scales_impl(self, scales: wp.array, indices: wp.array | None = None) -> None: + """USD legacy: deprecated set_scales writes local scales via a one-shot writer scope.""" + with self.xform_local_space_writer() as writer: + writer.set_scales(scales, indices) # ------------------------------------------------------------------ # Helpers @@ -360,3 +394,44 @@ def _to_numpy(data: wp.array | torch.Tensor) -> np.ndarray: if isinstance(data, wp.array): return data.numpy() return data.cpu().numpy() + + +# ---------------------------------------------------------------------- +# Pass-through writer classes +# ---------------------------------------------------------------------- + + +class _UsdWorldSpaceWriter(FrameViewWorldSpaceWriter): + """USD world-space writer: pass-through to backend ``_apply_*`` hooks. + + USD has no separate world-matrix storage to keep in sync; ``__exit__`` + is a no-op beyond releasing the single-writer lock. + """ + + def set_poses(self, positions=None, orientations=None, indices=None) -> None: + self._view._apply_world_pose_write(positions, orientations, indices) # type: ignore[attr-defined] + + def set_scales(self, scales, indices=None) -> None: + self._view._apply_world_scale_write(scales, indices) # type: ignore[attr-defined] + + def get_poses(self, indices=None) -> tuple[ProxyArray, ProxyArray]: + return self._view._get_world_poses_impl(indices) # type: ignore[attr-defined] + + def get_scales(self, indices=None) -> ProxyArray: + return self._view._get_world_scales_impl(indices) # type: ignore[attr-defined] + + +class _UsdLocalSpaceWriter(FrameViewLocalSpaceWriter): + """USD local-space writer: pass-through to backend ``_apply_*`` hooks.""" + + def set_poses(self, positions=None, orientations=None, indices=None) -> None: + self._view._apply_local_pose_write(positions, orientations, indices) # type: ignore[attr-defined] + + def set_scales(self, scales, indices=None) -> None: + self._view._apply_local_scale_write(scales, indices) # type: ignore[attr-defined] + + def get_poses(self, indices=None) -> tuple[ProxyArray, ProxyArray]: + return self._view._get_local_poses_impl(indices) # type: ignore[attr-defined] + + def get_scales(self, indices=None) -> ProxyArray: + return self._view._get_local_scales_impl(indices) # type: ignore[attr-defined] diff --git a/source/isaaclab/isaaclab/sim/views/xform_space_writer.py b/source/isaaclab/isaaclab/sim/views/xform_space_writer.py new file mode 100644 index 000000000000..17c437294a0b --- /dev/null +++ b/source/isaaclab/isaaclab/sim/views/xform_space_writer.py @@ -0,0 +1,132 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Context-managed transform writers for :class:`~isaaclab.sim.views.BaseFrameView`. + +This module defines the recommended write API for FrameView poses and scales: + +.. code-block:: python + + with view.xform_world_space_writer() as writer: + writer.set_poses(positions=p, orientations=o) + writer.set_scales(scales=s) + # ... any number of writes ... + # On exit the writer derives the opposite-space matrices once, + # synchronizes once, and restores any saved Fabric tracking state. + +Only one writer may be active per view at a time. While a writer scope is +active on a view, view-level getters (``view.get_world_poses``, +``view.get_local_poses``, ``view.get_world_scales``, +``view.get_local_scales``) raise :class:`RuntimeError` -- use the writer's own +:meth:`~FrameViewSpaceWriterBase.get_poses` / :meth:`~FrameViewSpaceWriterBase.get_scales` +inside the scope, or exit the scope first. +""" + +from __future__ import annotations + +import abc +from typing import TYPE_CHECKING + +import warp as wp + +from isaaclab.utils.warp import ProxyArray + +if TYPE_CHECKING: + from .base_frame_view import BaseFrameView + + +class FrameViewSpaceWriterBase(abc.ABC): + """Abstract context-managed writer for a single transform space. + + Subclasses are returned by :meth:`BaseFrameView.xform_world_space_writer` / + :meth:`BaseFrameView.xform_local_space_writer`; they + are not constructed directly. The class is intentionally minimal -- the + pose/scale semantics depend on the writer's space (world or local), which + is conveyed by the concrete tag class :class:`FrameViewWorldSpaceWriter` or + :class:`FrameViewLocalSpaceWriter`. + """ + + def __init__(self, view: BaseFrameView): + self._view = view + + @abc.abstractmethod + def set_poses( + self, + positions: wp.array | None = None, + orientations: wp.array | None = None, + indices: wp.array | None = None, + ) -> None: + """Set positions and/or orientations in this writer's space. + + Args: + positions: Positions ``(M, 3)``. ``None`` leaves positions unchanged. + orientations: Quaternions ``(M, 4)`` in ``(x, y, z, w)``. + ``None`` leaves orientations unchanged. + indices: Subset of prims to update. ``None`` means all prims. + """ + ... + + @abc.abstractmethod + def set_scales(self, scales: wp.array, indices: wp.array | None = None) -> None: + """Set scales in this writer's space. + + Args: + scales: Scales ``(M, 3)`` as ``wp.array``. + indices: Subset of prims to update. ``None`` means all prims. + """ + ... + + @abc.abstractmethod + def get_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: + """Return ``(positions, orientations)`` in this writer's space. + + Reflects any in-scope writes that have already been queued on the + underlying device stream. + """ + ... + + @abc.abstractmethod + def get_scales(self, indices: wp.array | None = None) -> ProxyArray: + """Return scales in this writer's space.""" + ... + + def __enter__(self) -> FrameViewSpaceWriterBase: + if self._view._active_writer is not None: + raise RuntimeError( + f"{type(self._view).__name__} already has an active writer scope " + f"({type(self._view._active_writer).__name__}). Exit the existing scope before " + "opening a new one." + ) + self._view._active_writer = self + self._enter_impl() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + try: + self._exit_impl(exc_type, exc_val, exc_tb) + finally: + self._view._active_writer = None + + def _enter_impl(self) -> None: + """Backend hook called after the single-active-writer lock is claimed.""" + + def _exit_impl(self, exc_type, exc_val, exc_tb) -> None: + """Backend hook called before the single-active-writer lock is released.""" + + +class FrameViewWorldSpaceWriter(FrameViewSpaceWriterBase): + """Writer whose :meth:`set_poses` / :meth:`set_scales` write world-space values. + + On context exit the opposite-space (``local``) matrices are derived from + the just-written world matrices in a single Warp kernel launch. + """ + + +class FrameViewLocalSpaceWriter(FrameViewSpaceWriterBase): + """Writer whose :meth:`set_poses` / :meth:`set_scales` write local-space values. + + On context exit the opposite-space (``world``) matrices are derived from + the just-written local matrices in a single Warp kernel launch. + """ diff --git a/source/isaaclab/isaaclab/utils/warp/fabric.py b/source/isaaclab/isaaclab/utils/warp/fabric.py index a48f773f4991..e0519d98c338 100644 --- a/source/isaaclab/isaaclab/utils/warp/fabric.py +++ b/source/isaaclab/isaaclab/utils/warp/fabric.py @@ -18,12 +18,14 @@ if TYPE_CHECKING: FabricArrayUInt32 = Any FabricArrayMat44d = Any + IndexedFabricArrayMat44d = Any ArrayUInt32 = Any ArrayUInt32_1d = Any ArrayFloat32_2d = Any else: FabricArrayUInt32 = wp.fabricarray(dtype=wp.uint32) FabricArrayMat44d = wp.fabricarray(dtype=wp.mat44d) + IndexedFabricArrayMat44d = wp.indexedfabricarray(dtype=wp.mat44d) ArrayUInt32 = wp.array(ndim=1, dtype=wp.uint32) ArrayUInt32_1d = wp.array(dtype=wp.uint32) ArrayFloat32_2d = wp.array(ndim=2, dtype=wp.float32) @@ -163,6 +165,180 @@ def compose_fabric_transformation_matrix_from_warp_arrays( ) +@wp.kernel(enable_backward=False) +def decompose_indexed_fabric_transforms( + fabric_matrices: IndexedFabricArrayMat44d, + array_positions: ArrayFloat32_2d, + array_orientations: ArrayFloat32_2d, + array_scales: ArrayFloat32_2d, + indices: ArrayUInt32, +): + """Decompose indexed Fabric transformation matrices into position, orientation, and scale. + + Like :func:`decompose_fabric_transformation_matrix_to_warp_arrays` but operates on a + :class:`wp.indexedfabricarray` that already encodes the view-to-fabric mapping, removing + the need for a separate ``mapping`` array. + + Args: + fabric_matrices: Indexed fabric array containing 4x4 transformation matrices. + array_positions: Output array for positions [m], shape (N, 3). + array_orientations: Output array for quaternions in xyzw format, shape (N, 4). + array_scales: Output array for scales, shape (N, 3). + indices: View indices to process (subset selection). + """ + output_index = wp.tid() + view_index = indices[output_index] + + position, rotation, scale = _decompose_transformation_matrix(wp.mat44f(fabric_matrices[view_index])) + + if array_positions.shape[0] > 0: + array_positions[output_index, 0] = position[0] + array_positions[output_index, 1] = position[1] + array_positions[output_index, 2] = position[2] + if array_orientations.shape[0] > 0: + array_orientations[output_index, 0] = rotation[0] + array_orientations[output_index, 1] = rotation[1] + array_orientations[output_index, 2] = rotation[2] + array_orientations[output_index, 3] = rotation[3] + if array_scales.shape[0] > 0: + array_scales[output_index, 0] = scale[0] + array_scales[output_index, 1] = scale[1] + array_scales[output_index, 2] = scale[2] + + +@wp.kernel(enable_backward=False) +def compose_indexed_fabric_transforms( + fabric_matrices: IndexedFabricArrayMat44d, + array_positions: ArrayFloat32_2d, + array_orientations: ArrayFloat32_2d, + array_scales: ArrayFloat32_2d, + broadcast_positions: bool, + broadcast_orientations: bool, + broadcast_scales: bool, + indices: ArrayUInt32, +): + """Compose indexed Fabric transformation matrices from position, orientation, and scale. + + Like :func:`compose_fabric_transformation_matrix_from_warp_arrays` but operates on a + :class:`wp.indexedfabricarray` that already encodes the view-to-fabric mapping, removing + the need for a separate ``mapping`` array. + + Args: + fabric_matrices: Indexed fabric array containing 4x4 transformation matrices to update. + array_positions: Input array for positions [m], shape (N, 3). + array_orientations: Input array for quaternions in xyzw format, shape (N, 4). + array_scales: Input array for scales, shape (N, 3). + broadcast_positions: If True, use first position for all prims. + broadcast_orientations: If True, use first orientation for all prims. + broadcast_scales: If True, use first scale for all prims. + indices: View indices to process (subset selection). + """ + i = wp.tid() + view_index = indices[i] + position, rotation, scale = _decompose_transformation_matrix(wp.mat44f(fabric_matrices[view_index])) + + if array_positions.shape[0] > 0: + if broadcast_positions: + index = 0 + else: + index = i + position[0] = array_positions[index, 0] + position[1] = array_positions[index, 1] + position[2] = array_positions[index, 2] + if array_orientations.shape[0] > 0: + if broadcast_orientations: + index = 0 + else: + index = i + rotation[0] = array_orientations[index, 0] + rotation[1] = array_orientations[index, 1] + rotation[2] = array_orientations[index, 2] + rotation[3] = array_orientations[index, 3] + if array_scales.shape[0] > 0: + if broadcast_scales: + index = 0 + else: + index = i + scale[0] = array_scales[index, 0] + scale[1] = array_scales[index, 1] + scale[2] = array_scales[index, 2] + + fabric_matrices[view_index] = wp.mat44d( # type: ignore[arg-type] + wp.transpose(wp.transform_compose(position, rotation, scale)) # type: ignore[arg-type] + ) + + +@wp.kernel(enable_backward=False) +def update_indexed_local_matrix_from_world( + child_world_matrices: IndexedFabricArrayMat44d, + parent_world_matrices: IndexedFabricArrayMat44d, + child_local_matrices: IndexedFabricArrayMat44d, + indices: ArrayUInt32, +): + """Recompute child localMatrix from (parent worldMatrix, child worldMatrix). + + Computes ``child_local = inv(parent_world) * child_world`` per prim and writes the + result back to the child's :data:`omni:fabric:localMatrix` so that subsequent + ``get_local_poses`` calls see consistent values after a world-pose write. + + All three indexed arrays are expected to be indexed by the same per-view indices + (i.e. ``view_to_child_fabric``, ``view_to_parent_fabric``, ``view_to_child_fabric``) + so the kernel only needs the view-side indices. + + Storage convention: Fabric matrices are stored as the transpose of the standard + column-major math convention. Math is ``local = inv(parent) * world``; under + the transpose identity ``(A * B)^T = B^T * A^T`` (and ``inv(A^T) = inv(A)^T``) + that is equivalent to storage-side ``local^T = world^T * inv(parent^T)``, so we + can compute it directly on the stored matrices without explicit transposes. + + Args: + child_world_matrices: Indexed fabric array of child world matrices (read). + parent_world_matrices: Indexed fabric array of parent world matrices (read). + child_local_matrices: Indexed fabric array of child local matrices (written). + indices: View indices to process. + """ + i = wp.tid() + view_index = indices[i] + child_world = wp.mat44f(child_world_matrices[view_index]) + parent_world = wp.mat44f(parent_world_matrices[view_index]) + child_local_matrices[view_index] = wp.mat44d( # type: ignore[arg-type] + child_world * wp.inverse(parent_world) + ) + + +@wp.kernel(enable_backward=False) +def update_indexed_world_matrix_from_local( + child_local_matrices: IndexedFabricArrayMat44d, + parent_world_matrices: IndexedFabricArrayMat44d, + child_world_matrices: IndexedFabricArrayMat44d, + indices: ArrayUInt32, +): + """Recompute child worldMatrix from (parent worldMatrix, child localMatrix). + + Computes ``child_world = parent_world * child_local`` per prim and writes the + result back to the child's :data:`omni:fabric:worldMatrix`. Used after a + ``set_local_poses`` write so that subsequent ``get_world_poses`` calls see + consistent values. Mirror of :func:`update_indexed_local_matrix_from_world`. + + Args: + child_local_matrices: Indexed fabric array of child local matrices (read). + parent_world_matrices: Indexed fabric array of parent world matrices (read). + child_world_matrices: Indexed fabric array of child world matrices (written). + indices: View indices to process. + + Storage convention: same as :func:`update_indexed_local_matrix_from_world`. + Math is ``world = parent * local``; under the transpose identity that becomes + storage-side ``world^T = local^T * parent^T``, no explicit transposes needed. + """ + i = wp.tid() + view_index = indices[i] + child_local = wp.mat44f(child_local_matrices[view_index]) + parent_world = wp.mat44f(parent_world_matrices[view_index]) + child_world_matrices[view_index] = wp.mat44d( # type: ignore[arg-type] + child_local * parent_world + ) + + @wp.func def _decompose_transformation_matrix(m: Any): # -> tuple[wp.vec3f, wp.quatf, wp.vec3f] """Decompose a 4x4 transformation matrix into position, orientation, and scale. diff --git a/source/isaaclab/test/sim/frame_view_contract_utils.py b/source/isaaclab/test/sim/frame_view_contract_utils.py index 8cd73c02b6f6..e67330f36791 100644 --- a/source/isaaclab/test/sim/frame_view_contract_utils.py +++ b/source/isaaclab/test/sim/frame_view_contract_utils.py @@ -193,7 +193,8 @@ def test_set_world_roundtrip(device, view_factory): try: new_pos = _wp_vec3f([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], device=device) new_quat = _wp_vec4f([[0.0, 0.0, 0.7071068, 0.7071068], [0.0, 0.0, 0.0, 1.0]], device=device) - bundle.view.set_world_poses(new_pos, new_quat) + with bundle.view.xform_world_space_writer() as w: + w.set_poses(new_pos, new_quat) ret_pos, ret_quat = bundle.view.get_world_poses() torch.testing.assert_close(_t(ret_pos), _t(new_pos), atol=ATOL, rtol=0) @@ -209,7 +210,8 @@ def test_set_local_roundtrip(device, view_factory): try: new_pos = _wp_vec3f([[0.5, 0.3, 0.1], [0.2, 0.7, 0.4]], device=device) new_quat = _wp_vec4f([[0.0, 0.0, 0.0, 1.0]] * 2, device=device) - bundle.view.set_local_poses(new_pos, new_quat) + with bundle.view.xform_local_space_writer() as w: + w.set_poses(new_pos, new_quat) ret_pos, ret_quat = bundle.view.get_local_poses() torch.testing.assert_close(_t(ret_pos), _t(new_pos), atol=ATOL, rtol=0) @@ -224,10 +226,11 @@ def test_set_world_does_not_move_parent(device, view_factory): bundle = view_factory(num_envs=2, device=device) try: parent_before = bundle.get_parent_pos(2, device).clone() - bundle.view.set_world_poses( - _wp_vec3f([[99.0, 99.0, 99.0], [88.0, 88.0, 88.0]], device=device), - _wp_vec4f([[0.0, 0.0, 0.0, 1.0]] * 2, device=device), - ) + with bundle.view.xform_world_space_writer() as w: + w.set_poses( + _wp_vec3f([[99.0, 99.0, 99.0], [88.0, 88.0, 88.0]], device=device), + _wp_vec4f([[0.0, 0.0, 0.0, 1.0]] * 2, device=device), + ) parent_after = bundle.get_parent_pos(2, device) torch.testing.assert_close(parent_after, parent_before, atol=0, rtol=0) @@ -241,10 +244,11 @@ def test_set_local_does_not_move_parent(device, view_factory): bundle = view_factory(num_envs=2, device=device) try: parent_before = bundle.get_parent_pos(2, device).clone() - bundle.view.set_local_poses( - _wp_vec3f([[0.5, 0.5, 0.5], [1.0, 1.0, 1.0]], device=device), - _wp_vec4f([[0.0, 0.0, 0.0, 1.0]] * 2, device=device), - ) + with bundle.view.xform_local_space_writer() as w: + w.set_poses( + _wp_vec3f([[0.5, 0.5, 0.5], [1.0, 1.0, 1.0]], device=device), + _wp_vec4f([[0.0, 0.0, 0.0, 1.0]] * 2, device=device), + ) parent_after = bundle.get_parent_pos(2, device) torch.testing.assert_close(parent_after, parent_before, atol=0, rtol=0) @@ -264,10 +268,11 @@ def test_set_world_updates_local(device, view_factory): desired_offset = torch.tensor([[0.3, 0.7, 0.2], [0.8, 0.1, 0.6]], device=device) new_world = parent_pos + desired_offset - bundle.view.set_world_poses( - _wp_vec3f(new_world.tolist(), device=device), - _wp_vec4f([[0.0, 0.0, 0.0, 1.0]] * 2, device=device), - ) + with bundle.view.xform_world_space_writer() as w: + w.set_poses( + _wp_vec3f(new_world.tolist(), device=device), + _wp_vec4f([[0.0, 0.0, 0.0, 1.0]] * 2, device=device), + ) local_pos = _t(bundle.view.get_local_poses()[0]) torch.testing.assert_close(local_pos, desired_offset, atol=ATOL, rtol=0) @@ -285,10 +290,11 @@ def test_set_local_updates_world(device, view_factory): try: parent_pos = bundle.get_parent_pos(2, device) new_offset = torch.tensor([[0.4, 0.9, 0.15], [0.6, 0.2, 0.85]], device=device) - bundle.view.set_local_poses( - _wp_vec3f(new_offset.tolist(), device=device), - _wp_vec4f([[0.0, 0.0, 0.0, 1.0]] * 2, device=device), - ) + with bundle.view.xform_local_space_writer() as w: + w.set_poses( + _wp_vec3f(new_offset.tolist(), device=device), + _wp_vec4f([[0.0, 0.0, 0.0, 1.0]] * 2, device=device), + ) world_pos = _t(bundle.view.get_world_poses()[0]) torch.testing.assert_close(world_pos, parent_pos + new_offset, atol=ATOL, rtol=0) @@ -303,7 +309,8 @@ def test_set_world_partial_position_only(device, view_factory): try: _, orig_quat = bundle.view.get_world_poses() new_pos = _wp_vec3f([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], device=device) - bundle.view.set_world_poses(positions=new_pos) + with bundle.view.xform_world_space_writer() as w: + w.set_poses(positions=new_pos) ret_pos, ret_quat = bundle.view.get_world_poses() torch.testing.assert_close(_t(ret_pos), _t(new_pos), atol=ATOL, rtol=0) @@ -319,7 +326,8 @@ def test_set_world_partial_orientation_only(device, view_factory): try: orig_pos, _ = bundle.view.get_world_poses() new_quat = _wp_vec4f([[0.0, 0.0, 0.7071068, 0.7071068], [0.7071068, 0.0, 0.0, 0.7071068]], device=device) - bundle.view.set_world_poses(orientations=new_quat) + with bundle.view.xform_world_space_writer() as w: + w.set_poses(orientations=new_quat) ret_pos, ret_quat = bundle.view.get_world_poses() torch.testing.assert_close(_t(ret_pos), _t(orig_pos), atol=ATOL, rtol=0) @@ -335,7 +343,8 @@ def test_set_local_partial_position_only(device, view_factory): try: _, orig_quat = bundle.view.get_local_poses() new_pos = _wp_vec3f([[0.2, 0.3, 0.4], [0.5, 0.6, 0.7]], device=device) - bundle.view.set_local_poses(translations=new_pos) + with bundle.view.xform_local_space_writer() as w: + w.set_poses(positions=new_pos) ret_pos, ret_quat = bundle.view.get_local_poses() torch.testing.assert_close(_t(ret_pos), _t(new_pos), atol=ATOL, rtol=0) @@ -352,7 +361,8 @@ def test_set_world_indexed_only_affects_subset(device, view_factory): orig_pos = _t(bundle.view.get_world_poses()[0]).clone() indices = wp.array([1, 3], dtype=wp.int32, device=device) new_pos = _wp_vec3f([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], device=device) - bundle.view.set_world_poses(positions=new_pos, indices=indices) + with bundle.view.xform_world_space_writer() as w: + w.set_poses(positions=new_pos, indices=indices) updated = _t(bundle.view.get_world_poses()[0]) torch.testing.assert_close(updated[0], orig_pos[0], atol=0, rtol=0) @@ -401,6 +411,24 @@ def test_return_types_are_torcharray(device, view_factory): f"get_local_poses(indices)[1] must be ProxyArray, got {type(lquat_idx).__name__}" ) + world_scales_full = bundle.view.get_world_scales() + assert isinstance(world_scales_full, ProxyArray), ( + f"get_world_scales() must be ProxyArray, got {type(world_scales_full).__name__}" + ) + world_scales_idx = bundle.view.get_world_scales(indices) + assert isinstance(world_scales_idx, ProxyArray), ( + f"get_world_scales(indices) must be ProxyArray, got {type(world_scales_idx).__name__}" + ) + + local_scales_full = bundle.view.get_local_scales() + assert isinstance(local_scales_full, ProxyArray), ( + f"get_local_scales() must be ProxyArray, got {type(local_scales_full).__name__}" + ) + local_scales_idx = bundle.view.get_local_scales(indices) + assert isinstance(local_scales_idx, ProxyArray), ( + f"get_local_scales(indices) must be ProxyArray, got {type(local_scales_idx).__name__}" + ) + scales_full = bundle.view.get_scales() assert isinstance(scales_full, ProxyArray), f"get_scales() must be ProxyArray, got {type(scales_full).__name__}" scales_idx = bundle.view.get_scales(indices) @@ -409,3 +437,110 @@ def test_return_types_are_torcharray(device, view_factory): ) finally: bundle.teardown() + + +# ================================================================== +# Contract: Scales +# ================================================================== + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_local_scales_default_identity(device, view_factory): + """Local scales are (1, 1, 1) by default (no authored scale transforms).""" + bundle = view_factory(num_envs=2, device=device) + try: + scales = _t(bundle.view.get_local_scales()) + expected = torch.ones(2, 3, device=device) + torch.testing.assert_close(scales, expected, atol=ATOL, rtol=0) + finally: + bundle.teardown() + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_world_scales_default_identity(device, view_factory): + """World scales are (1, 1, 1) by default (no authored scale transforms).""" + bundle = view_factory(num_envs=2, device=device) + try: + scales = _t(bundle.view.get_world_scales()) + expected = torch.ones(2, 3, device=device) + torch.testing.assert_close(scales, expected, atol=ATOL, rtol=0) + finally: + bundle.teardown() + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_set_local_scales_roundtrip(device, view_factory): + """set_local_scales -> get_local_scales returns the same values.""" + bundle = view_factory(num_envs=2, device=device) + try: + new_scales = _wp_vec3f([[2.0, 3.0, 4.0], [0.5, 1.5, 2.5]], device=device) + with bundle.view.xform_local_space_writer() as w: + w.set_scales(new_scales) + + ret_scales = _t(bundle.view.get_local_scales()) + torch.testing.assert_close(ret_scales, _t(new_scales), atol=ATOL, rtol=0) + finally: + bundle.teardown() + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_set_world_scales_roundtrip(device, view_factory): + """set_world_scales -> get_world_scales returns the same values.""" + bundle = view_factory(num_envs=2, device=device) + try: + new_scales = _wp_vec3f([[2.0, 3.0, 4.0], [0.5, 1.5, 2.5]], device=device) + with bundle.view.xform_world_space_writer() as w: + w.set_scales(new_scales) + + ret_scales = _t(bundle.view.get_world_scales()) + torch.testing.assert_close(ret_scales, _t(new_scales), atol=ATOL, rtol=0) + finally: + bundle.teardown() + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_local_scales_do_not_affect_local_poses(device, view_factory): + """Changing scales does not change local pose translations/orientations.""" + bundle = view_factory(num_envs=2, device=device) + try: + local_pos_before = _t(bundle.view.get_local_poses()[0]).clone() + local_ori_before = _t(bundle.view.get_local_poses()[1]).clone() + + new_scales = _wp_vec3f([[3.0, 3.0, 3.0], [5.0, 5.0, 5.0]], device=device) + with bundle.view.xform_local_space_writer() as w: + w.set_scales(new_scales) + + local_pos_after = _t(bundle.view.get_local_poses()[0]) + local_ori_after = _t(bundle.view.get_local_poses()[1]) + + torch.testing.assert_close(local_pos_after, local_pos_before, atol=ATOL, rtol=0) + torch.testing.assert_close(local_ori_after, local_ori_before, atol=ATOL, rtol=0) + finally: + bundle.teardown() + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_scale_getters_return_proxyarray(device, view_factory): + """Public API contract -- scale getters return ProxyArray.""" + bundle = view_factory(num_envs=2, device=device) + try: + local_scales = bundle.view.get_local_scales() + assert isinstance(local_scales, ProxyArray), ( + f"get_local_scales() must return ProxyArray, got {type(local_scales).__name__}" + ) + world_scales = bundle.view.get_world_scales() + assert isinstance(world_scales, ProxyArray), ( + f"get_world_scales() must return ProxyArray, got {type(world_scales).__name__}" + ) + + indices = wp.array([0], dtype=wp.int32, device=bundle.view.device) + local_indexed = bundle.view.get_local_scales(indices) + assert isinstance(local_indexed, ProxyArray), ( + f"get_local_scales(indices) must return ProxyArray, got {type(local_indexed).__name__}" + ) + world_indexed = bundle.view.get_world_scales(indices) + assert isinstance(world_indexed, ProxyArray), ( + f"get_world_scales(indices) must return ProxyArray, got {type(world_indexed).__name__}" + ) + finally: + bundle.teardown() diff --git a/source/isaaclab/test/sim/test_views_xform_prim.py b/source/isaaclab/test/sim/test_views_xform_prim.py index 64cd86a7466f..2fa404f26a7f 100644 --- a/source/isaaclab/test/sim/test_views_xform_prim.py +++ b/source/isaaclab/test/sim/test_views_xform_prim.py @@ -226,8 +226,10 @@ def test_nested_hierarchy_world_poses(device): frames_view = FrameView("/World/Frame_.*", device=device) targets_view = FrameView("/World/Frame_.*/Target", device=device) - frames_view.set_local_poses(translations=torch.tensor(frame_positions, device=device)) - targets_view.set_local_poses(translations=torch.tensor(target_positions, device=device)) + with frames_view.xform_local_space_writer() as w: + w.set_poses(positions=torch.tensor(frame_positions, device=device)) + with targets_view.xform_local_space_writer() as w: + w.set_poses(positions=torch.tensor(target_positions, device=device)) world_pos = targets_view.get_world_poses()[0].torch expected = torch.tensor( @@ -237,6 +239,58 @@ def test_nested_hierarchy_world_poses(device): torch.testing.assert_close(world_pos, expected, atol=1e-5, rtol=0) +# ================================================================== +# USD-only: Cross-space scale conversion under a scaled parent +# ================================================================== +# +# These exercise the USD-specific world<->local scale math that the shared +# contract suite cannot cover: the contract fixtures only expose a unit-scale +# parent, and Newton has no independent local scale (local == world), so the +# parent-aware conversions below are not universal invariants. OvPhysxFrameView +# inherits this behavior by delegating to UsdFrameView. + + +def _make_scaled_parent_child_view(device, parent_scale, child_scale=None): + """Build a 1-prim view with a scaled parent (and optional authored child scale).""" + stage = sim_utils.get_current_stage() + sim_utils.create_prim("/World/Parent_0", "Xform", translation=PARENT_POS, scale=parent_scale, stage=stage) + child_kwargs = {} if child_scale is None else {"scale": child_scale} + sim_utils.create_prim("/World/Parent_0/Child", "Xform", translation=CHILD_OFFSET, stage=stage, **child_kwargs) + return FrameView("/World/Parent_.*/Child", device=device) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_set_local_scales_then_get_world_scales(device): + """Under a scaled parent, world scale == parent_scale * local_scale.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + view = _make_scaled_parent_child_view(device, parent_scale=(2.0, 1.0, 1.0)) + local_scales = wp.array([wp.vec3f(3.0, 1.0, 1.0)], dtype=wp.vec3f, device=device) + with view.xform_local_space_writer() as w: + w.set_scales(local_scales) + + world_scales = view.get_world_scales().torch + expected = torch.tensor([[6.0, 1.0, 1.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(world_scales, expected, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_set_world_scales_then_get_local_scales(device): + """Under a scaled parent, set_world_scales writes local = world / parent_scale.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + view = _make_scaled_parent_child_view(device, parent_scale=(2.0, 1.0, 1.0)) + world_scales = wp.array([wp.vec3f(6.0, 1.0, 1.0)], dtype=wp.vec3f, device=device) + with view.xform_world_space_writer() as w: + w.set_scales(world_scales) + + local_scales = view.get_local_scales().torch + expected = torch.tensor([[3.0, 1.0, 1.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(local_scales, expected, atol=1e-5, rtol=0) + + # ================================================================== # USD-only: Comparison with Isaac Sim # ================================================================== @@ -291,7 +345,8 @@ def test_with_franka_robots(device): new_pos = torch.tensor([[10.0, 10.0, 0.0], [-40.0, -40.0, 0.0]], device=device) new_quat = torch.tensor([[0.0, 0.0, 0.7071068, 0.7071068], [0.0, 0.0, -0.7071068, 0.7071068]], device=device) - view.set_world_poses(positions=new_pos, orientations=new_quat) + with view.xform_world_space_writer() as w: + w.set_poses(positions=new_pos, orientations=new_quat) ret_pos = view.get_world_poses()[0].torch torch.testing.assert_close(ret_pos, new_pos, atol=1e-5, rtol=0) diff --git a/source/isaaclab/test/terrains/check_terrain_importer.py b/source/isaaclab/test/terrains/check_terrain_importer.py index a8229023f90c..901cfc420be3 100644 --- a/source/isaaclab/test/terrains/check_terrain_importer.py +++ b/source/isaaclab/test/terrains/check_terrain_importer.py @@ -159,7 +159,8 @@ def main(): ball_initial_positions = terrain_importer.env_origins.clone() ball_initial_positions[:, 2] += 5.0 # set initial poses (writes to USD before simulation) - xform_view.set_world_poses(positions=ball_initial_positions) + with xform_view.xform_world_space_writer() as w: + w.set_poses(positions=ball_initial_positions) # Play simulator sim.reset() diff --git a/source/isaaclab/test/terrains/test_terrain_importer.py b/source/isaaclab/test/terrains/test_terrain_importer.py index 5234df4cae51..96375cab0758 100644 --- a/source/isaaclab/test/terrains/test_terrain_importer.py +++ b/source/isaaclab/test/terrains/test_terrain_importer.py @@ -316,4 +316,5 @@ def _populate_scene(sim: SimulationContext, num_balls: int = 2048, geom_sphere: ball_initial_positions[:, 2] += 5.0 # set initial poses # note: setting here writes to USD :) - ball_view.set_world_poses(positions=wp.from_torch(ball_initial_positions)) + with ball_view.xform_world_space_writer() as w: + w.set_poses(positions=wp.from_torch(ball_initial_positions)) diff --git a/source/isaaclab_mimic/changelog.d/xform-space-writer.skip b/source/isaaclab_mimic/changelog.d/xform-space-writer.skip new file mode 100644 index 000000000000..8556272d818a --- /dev/null +++ b/source/isaaclab_mimic/changelog.d/xform-space-writer.skip @@ -0,0 +1 @@ +no user-facing change: internal migration of one set_world_poses call site to the new FrameViewSpaceWriter context API diff --git a/source/isaaclab_mimic/isaaclab_mimic/locomanipulation_sdg/scene_utils.py b/source/isaaclab_mimic/isaaclab_mimic/locomanipulation_sdg/scene_utils.py index 4ba068fc8f56..00a29bc04afc 100644 --- a/source/isaaclab_mimic/isaaclab_mimic/locomanipulation_sdg/scene_utils.py +++ b/source/isaaclab_mimic/isaaclab_mimic/locomanipulation_sdg/scene_utils.py @@ -126,7 +126,8 @@ def set_pose(self, pose: torch.Tensor): xform_prim = self._get_xform_view() position = pose[..., :3] orientation = pose[..., 3:] - xform_prim.set_world_poses(wp.from_torch(position.contiguous()), wp.from_torch(orientation.contiguous()), None) + with xform_prim.xform_world_space_writer() as writer: + writer.set_poses(wp.from_torch(position.contiguous()), wp.from_torch(orientation.contiguous()), None) class RelativePose(HasPose): diff --git a/source/isaaclab_newton/changelog.d/fabric-local-poses.rst b/source/isaaclab_newton/changelog.d/fabric-local-poses.rst new file mode 100644 index 000000000000..4e18af38eb2b --- /dev/null +++ b/source/isaaclab_newton/changelog.d/fabric-local-poses.rst @@ -0,0 +1,19 @@ +Added +^^^^^ + +* Added :meth:`~isaaclab_newton.sim.views.NewtonSiteFrameView.get_local_scales`, + :meth:`~isaaclab_newton.sim.views.NewtonSiteFrameView.set_local_scales`, + :meth:`~isaaclab_newton.sim.views.NewtonSiteFrameView.get_world_scales`, and + :meth:`~isaaclab_newton.sim.views.NewtonSiteFrameView.set_world_scales` for + transform (xform) scales. These explicit APIs are intentionally separate from + Newton collision shape geometry sizes. + +Deprecated +^^^^^^^^^^ + +* Deprecated :meth:`~isaaclab_newton.sim.views.NewtonSiteFrameView.get_scales` + and :meth:`~isaaclab_newton.sim.views.NewtonSiteFrameView.set_scales` in favor + of the explicit xform-scale ``get_world_scales`` / ``set_world_scales`` (or + their local equivalents). The deprecated methods still work but emit a + ``DeprecationWarning`` and preserve Newton's legacy collision shape + geometry-scale behavior. diff --git a/source/isaaclab_newton/changelog.d/xform-space-writer.rst b/source/isaaclab_newton/changelog.d/xform-space-writer.rst new file mode 100644 index 000000000000..e04d5325e6aa --- /dev/null +++ b/source/isaaclab_newton/changelog.d/xform-space-writer.rst @@ -0,0 +1,13 @@ +Changed +^^^^^^^ + +* :class:`~isaaclab_newton.sim.views.NewtonSiteFrameView` now ships + pass-through ``FrameViewWorldSpaceWriter`` / ``FrameViewLocalSpaceWriter`` + implementations so writes follow the new + :meth:`~isaaclab.sim.views.BaseFrameView.xform_world_space_writer` / + :meth:`~isaaclab.sim.views.BaseFrameView.xform_local_space_writer` context API. + ``set_world_poses`` / ``set_local_poses`` shims still work (one-time + ``DeprecationWarning`` per class). The legacy ``set_scales`` / + ``get_scales`` paths continue to operate on Newton collision-shape + geometry sizes -- they are not routed through the writer because the + writer's ``set_scales`` writes the transform-scale state. diff --git a/source/isaaclab_newton/isaaclab_newton/sim/views/newton_site_frame_view.py b/source/isaaclab_newton/isaaclab_newton/sim/views/newton_site_frame_view.py index c7c31ae36928..d28d27f34730 100644 --- a/source/isaaclab_newton/isaaclab_newton/sim/views/newton_site_frame_view.py +++ b/source/isaaclab_newton/isaaclab_newton/sim/views/newton_site_frame_view.py @@ -14,9 +14,10 @@ from pxr import UsdPhysics import isaaclab.sim as sim_utils -from isaaclab.cloner.cloner_utils import get_suffix, iter_clone_plan_matches, split_clone_template +from isaaclab.cloner.cloner_utils import iter_clone_plan_matches from isaaclab.physics import PhysicsEvent from isaaclab.sim.views.base_frame_view import BaseFrameView +from isaaclab.sim.views.xform_space_writer import FrameViewLocalSpaceWriter, FrameViewWorldSpaceWriter from isaaclab.utils.string import resolve_matching_names from isaaclab.utils.warp import ProxyArray @@ -104,7 +105,7 @@ def _write_site_local_from_local_poses( @wp.kernel -def _gather_scales( +def _gather_shape_scales( shape_scale: wp.array(dtype=wp.vec3f), shape_body: wp.array(dtype=wp.int32), site_body: wp.array(dtype=wp.int32), @@ -112,7 +113,7 @@ def _gather_scales( num_shapes: wp.int32, out_scales: wp.array(dtype=wp.vec3f), ): - """Gather per-site scales from collision shapes on the same body.""" + """Gather legacy per-site geometry scales from collision shapes on the same body.""" i = wp.tid() si = indices[i] bid = site_body[si] @@ -126,7 +127,7 @@ def _gather_scales( @wp.kernel -def _scatter_scales( +def _scatter_shape_scales( site_body: wp.array(dtype=wp.int32), indices: wp.array(dtype=wp.int32), new_scales: wp.array(dtype=wp.vec3f), @@ -134,7 +135,7 @@ def _scatter_scales( num_shapes: wp.int32, shape_scale: wp.array(dtype=wp.vec3f), ): - """Scatter per-site scales to collision shapes on the same body.""" + """Scatter legacy per-site geometry scales to collision shapes on the same body.""" i = wp.tid() si = indices[i] bid = site_body[si] @@ -143,6 +144,28 @@ def _scatter_scales( shape_scale[s] = new_scales[i] +@wp.kernel +def _gather_xform_scales( + site_xform_scale: wp.array(dtype=wp.vec3f), + indices: wp.array(dtype=wp.int32), + out_scales: wp.array(dtype=wp.vec3f), +): + """Gather per-site xform scales.""" + i = wp.tid() + out_scales[i] = site_xform_scale[indices[i]] + + +@wp.kernel +def _scatter_xform_scales( + indices: wp.array(dtype=wp.int32), + new_scales: wp.array(dtype=wp.vec3f), + site_xform_scale: wp.array(dtype=wp.vec3f), +): + """Scatter per-site xform scales.""" + i = wp.tid() + site_xform_scale[indices[i]] = new_scales[i] + + class NewtonSiteFrameView(BaseFrameView): """Batched Newton site view for non-physics frames. @@ -178,43 +201,51 @@ def __init__( stage = sim_utils.get_current_stage() if stage is None else stage self._site_specs = self._resolve_site_specs(stage, validate_xform_ops) self._site_labels: list[str] = [] + self._site_label_scales: list[tuple[float, float, float]] = [] self._site_body: wp.array | None = None self._site_local: wp.array | None = None + self._site_xform_scale: wp.array | None = None self._site_indices: wp.array | None = None self._pos_buf: wp.array | None = None self._quat_buf: wp.array | None = None self._local_pos_buf: wp.array | None = None self._local_quat_buf: wp.array | None = None + self._scale_buf: wp.array | None = None self._pos_ta: ProxyArray | None = None self._quat_ta: ProxyArray | None = None self._local_pos_ta: ProxyArray | None = None self._local_quat_ta: ProxyArray | None = None + self._scale_ta: ProxyArray | None = None self._count = 0 model = NewtonManager.get_model() if model is not None: self._initialize_from_specs(model) else: - for body_patterns, xform, per_world, _env_ids in self._site_specs: + for body_patterns, xform, scale, per_world, _env_ids in self._site_specs: if body_patterns is None: self._site_labels.append(NewtonManager.cl_register_site(None, xform, per_world=per_world)) + self._site_label_scales.append(scale) else: for body_pattern in body_patterns: self._site_labels.append(NewtonManager.cl_register_site(body_pattern, xform)) + self._site_label_scales.append(scale) self._physics_ready_handle = NewtonManager.register_callback( self._on_physics_ready, PhysicsEvent.PHYSICS_READY, name=f"site_view_{self._prim_path}" ) def _resolve_site_specs( self, stage, validate_xform_ops: bool - ) -> list[tuple[tuple[str, ...] | None, wp.transform, bool, tuple[int, ...] | None]]: + ) -> list[tuple[tuple[str, ...] | None, wp.transform, tuple[float, float, float], bool, tuple[int, ...] | None]]: """Resolve source prims into Newton site registration specs.""" plan = sim_utils.SimulationContext.instance().get_clone_plan() model = NewtonManager.get_model() body_labels = list(model.body_label) if model is not None else () shape_labels = list(model.shape_label) if model is not None else () use_clone_body_pattern = model is None - specs: list[tuple[tuple[str, ...] | None, wp.transform, bool, tuple[int, ...] | None]] = [] + specs: list[ + tuple[tuple[str, ...] | None, wp.transform, tuple[float, float, float], bool, tuple[int, ...] | None] + ] = [] for path_expr in self._prim_paths: if resolve_matching_names(path_expr, body_labels, raise_when_no_match=False)[1]: @@ -268,8 +299,8 @@ def _resolve_source_prim( env_ids: tuple[int, ...] | None, use_clone_body_pattern: bool, stage, - ) -> tuple[tuple[str, ...] | None, wp.transform, bool, tuple[int, ...] | None]: - """Resolve one source prim into body patterns and a local frame.""" + ) -> tuple[tuple[str, ...] | None, wp.transform, tuple[float, float, float], bool, tuple[int, ...] | None]: + """Resolve one source prim into body patterns, local frame, and xform scale.""" prim_path = prim.GetPath().pathString if prim.HasAPI(UsdPhysics.RigidBodyAPI) or prim.HasAPI(UsdPhysics.ArticulationRootAPI): raise ValueError( @@ -281,6 +312,13 @@ def _resolve_source_prim( if not sim_utils.validate_standard_xform_ops(prim): raise ValueError(f"FrameView prim '{prim_path}' does not have standard xform ops.") + scale_attr = prim.GetAttribute("xformOp:scale") + scale = ( + tuple(float(v) for v in scale_attr.Get()) + if scale_attr and scale_attr.HasAuthoredValue() + else (1.0, 1.0, 1.0) + ) + body_prim = prim.GetParent() while body_prim and body_prim.IsValid(): if body_prim.HasAPI(UsdPhysics.RigidBodyAPI) or body_prim.HasAPI(UsdPhysics.ArticulationRootAPI): @@ -300,7 +338,7 @@ def _resolve_source_prim( raise RuntimeError( f"FrameView destination root '{destination_root}' does not end with '{suffix}'." ) - return (destination_root[: -len(suffix)],), wp.transform(pos, quat), False, env_ids + return (destination_root[: -len(suffix)],), wp.transform(pos, quat), scale, False, env_ids body_patterns = [] for env_id in env_ids: destination_root = destination_template.format(env_id) @@ -309,7 +347,7 @@ def _resolve_source_prim( f"FrameView destination root '{destination_root}' does not end with '{suffix}'." ) body_patterns.append(destination_root[: -len(suffix)]) - return tuple(body_patterns), wp.transform(pos, quat), False, env_ids + return tuple(body_patterns), wp.transform(pos, quat), scale, False, env_ids else: raise RuntimeError(f"FrameView source body '{body_path}' is not under '{source_root}'.") if use_clone_body_pattern: @@ -318,18 +356,12 @@ def _resolve_source_prim( body_patterns = tuple(destination_template.format(env_id) + suffix for env_id in env_ids) else: body_patterns = (body_path,) - return body_patterns, wp.transform(pos, quat), False, env_ids + return body_patterns, wp.transform(pos, quat), scale, False, env_ids body_prim = body_prim.GetParent() - ref_path = source_root - if source_root is not None and destination_template is not None: - template_prefix, _ = split_clone_template(destination_template) - source_suffix = get_suffix(source_root, template_prefix + "{}") - if source_suffix is not None: - ref_path = source_root[: -len(source_suffix)] if source_suffix else source_root - ref_prim = stage.GetPrimAtPath(ref_path) if ref_path is not None else None + ref_prim = stage.GetPrimAtPath(source_root) if source_root is not None else None pos, quat = sim_utils.resolve_prim_pose(prim, ref_prim if ref_prim and ref_prim.IsValid() else None) - return None, wp.transform(pos, quat), source_root is not None, env_ids + return None, wp.transform(pos, quat), scale, source_root is not None, env_ids def _on_physics_ready(self, _event) -> None: """Callback invoked when the Newton model becomes available.""" @@ -342,8 +374,9 @@ def _initialize_from_site_map(self, model) -> None: xform_t = wp.to_torch(model.shape_transform) site_bodies: list[int] = [] site_locals: list[list[float]] = [] + site_scales: list[tuple[float, float, float]] = [] - for site_label in self._site_labels: + for site_label, scale in zip(self._site_labels, self._site_label_scales, strict=True): global_idx, per_world = site_map[site_label] site_indices = ( [global_idx] if per_world is None else [site_idx for sites in per_world for site_idx in sites] @@ -351,16 +384,18 @@ def _initialize_from_site_map(self, model) -> None: for site_idx in site_indices: site_bodies.append(int(body_t[site_idx].item())) site_locals.append([float(v) for v in xform_t[site_idx].tolist()]) + site_scales.append(scale) - self._create_buffers(site_bodies, site_locals) + self._create_buffers(site_bodies, site_locals, site_scales) def _initialize_from_specs(self, model) -> None: """Initialize arrays directly from resolved specs and Newton body labels.""" body_labels = list(model.body_label) site_bodies: list[int] = [] site_locals: list[list[float]] = [] + site_scales: list[tuple[float, float, float]] = [] - for body_patterns, xform, per_world, env_ids in self._site_specs: + for body_patterns, xform, scale, per_world, env_ids in self._site_specs: if body_patterns is None: if per_world: if NewtonManager._world_xforms is None: @@ -370,9 +405,11 @@ def _initialize_from_specs(self, model) -> None: world_xform = NewtonManager._world_xforms[world_id] site_bodies.append(WORLD_BODY_INDEX) site_locals.append([float(v) for v in wp.transform_multiply(world_xform, xform)]) + site_scales.append(scale) else: site_bodies.append(WORLD_BODY_INDEX) site_locals.append([float(v) for v in xform]) + site_scales.append(scale) continue for body_pattern in body_patterns: @@ -385,24 +422,33 @@ def _initialize_from_specs(self, model) -> None: for body_idx in matched_indices: site_bodies.append(body_idx) site_locals.append([float(v) for v in xform]) + site_scales.append(scale) - self._create_buffers(site_bodies, site_locals) + self._create_buffers(site_bodies, site_locals, site_scales) - def _create_buffers(self, site_bodies: list[int], site_locals: list[list[float]]) -> None: + def _create_buffers( + self, + site_bodies: list[int], + site_locals: list[list[float]], + site_scales: list[tuple[float, float, float]], + ) -> None: """Allocate view buffers from body indices and local transforms.""" self._count = len(site_bodies) device = self._device self._site_body = wp.array(site_bodies, dtype=wp.int32, device=device) self._site_local = wp.array([wp.transform(*x) for x in site_locals], dtype=wp.transformf, device=device) + self._site_xform_scale = wp.array([wp.vec3f(*scale) for scale in site_scales], dtype=wp.vec3f, device=device) self._site_indices = wp.array(list(range(self._count)), dtype=wp.int32, device=device) self._pos_buf = wp.zeros(self._count, dtype=wp.vec3f, device=device) self._quat_buf = wp.zeros(self._count, dtype=wp.vec4f, device=device) self._local_pos_buf = wp.zeros(self._count, dtype=wp.vec3f, device=device) self._local_quat_buf = wp.zeros(self._count, dtype=wp.vec4f, device=device) + self._scale_buf = wp.zeros(self._count, dtype=wp.vec3f, device=device) self._pos_ta = ProxyArray(self._pos_buf) self._quat_ta = ProxyArray(self._quat_buf) self._local_pos_ta = ProxyArray(self._local_pos_buf) self._local_quat_ta = ProxyArray(self._local_quat_buf) + self._scale_ta = ProxyArray(self._site_xform_scale) @property def prims(self) -> list: @@ -422,7 +468,21 @@ def device(self) -> str: """Device where arrays are allocated.""" return self._device - def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: + # ------------------------------------------------------------------ + # Writer factory hooks (pass-through; Newton has no separate Fabric storage) + # ------------------------------------------------------------------ + + def _make_world_space_writer(self) -> FrameViewWorldSpaceWriter: + return _NewtonWorldSpaceWriter(self) + + def _make_local_space_writer(self) -> FrameViewLocalSpaceWriter: + return _NewtonLocalSpaceWriter(self) + + # ------------------------------------------------------------------ + # Backend hooks + # ------------------------------------------------------------------ + + def _get_world_poses_impl(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: """Get world-space positions and orientations.""" state = NewtonManager.get_state_0() site_indices = self._site_indices if indices is None else indices @@ -441,7 +501,7 @@ def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, return self._pos_ta, self._quat_ta return ProxyArray(pos_buf), ProxyArray(quat_buf) - def set_world_poses( + def _apply_world_pose_write( self, positions: wp.array | None = None, orientations: wp.array | None = None, @@ -453,7 +513,7 @@ def set_world_poses( state = NewtonManager.get_state_0() if positions is None or orientations is None: - cur_pos_ta, cur_quat_ta = self.get_world_poses(indices) + cur_pos_ta, cur_quat_ta = self._get_world_poses_impl(indices) if positions is None: positions = cur_pos_ta.warp if orientations is None: @@ -468,7 +528,7 @@ def set_world_poses( device=self._device, ) - def get_local_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: + def _get_local_poses_impl(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: """Get body-local positions and orientations.""" site_indices = self._site_indices if indices is None else indices n = self.count if indices is None else len(indices) @@ -486,7 +546,7 @@ def get_local_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, return self._local_pos_ta, self._local_quat_ta return ProxyArray(pos_buf), ProxyArray(quat_buf) - def set_local_poses( + def _apply_local_pose_write( self, translations: wp.array | None = None, orientations: wp.array | None = None, @@ -497,7 +557,7 @@ def set_local_poses( return if translations is None or orientations is None: - cur_pos_ta, cur_quat_ta = self.get_local_poses(indices) + cur_pos_ta, cur_quat_ta = self._get_local_poses_impl(indices) if translations is None: translations = cur_pos_ta.warp if orientations is None: @@ -512,15 +572,70 @@ def set_local_poses( device=self._device, ) - def get_scales(self, indices: wp.array | None = None) -> ProxyArray: - """Get per-site scales by reading from the first collision shape on the same body.""" + # ------------------------------------------------------------------ + # Scales + # ------------------------------------------------------------------ + + def _get_world_scales_impl(self, indices: wp.array | None = None) -> ProxyArray: + """Get per-site world xform scales. + + These are transform scales, matching the USD FrameView scale API. They + are intentionally separate from Newton collision shape geometry sizes. + """ + if indices is None: + return self._scale_ta + n = len(indices) + out = wp.zeros(n, dtype=wp.vec3f, device=self._device) + wp.launch( + _gather_xform_scales, + dim=n, + inputs=[self._site_xform_scale, indices], + outputs=[out], + device=self._device, + ) + return ProxyArray(out) + + def _get_local_scales_impl(self, indices: wp.array | None = None) -> ProxyArray: + """Get per-site local xform scales. + + These are transform scales, matching the USD FrameView scale API. They + are intentionally separate from Newton collision shape geometry sizes. + """ + return self._get_world_scales_impl(indices) + + def _apply_world_scale_write(self, scales: wp.array, indices: wp.array | None = None) -> None: + """Set per-site world xform scales. + + These update transform scale state only; use deprecated ``set_scales`` if + legacy Newton collision shape geometry-scale behavior is required. + """ + if indices is None: + indices = self._site_indices + n = self.count if indices is self._site_indices else len(indices) + wp.launch( + _scatter_xform_scales, + dim=n, + inputs=[indices, scales, self._site_xform_scale], + device=self._device, + ) + + def _apply_local_scale_write(self, scales: wp.array, indices: wp.array | None = None) -> None: + """Set per-site local xform scales. + + These update transform scale state only; use deprecated ``set_scales`` if + legacy Newton collision shape geometry-scale behavior is required. + """ + self._apply_world_scale_write(scales, indices) + + def _get_legacy_shape_scales(self, indices: wp.array | None = None) -> ProxyArray: + """Get Newton legacy geometry scales from collision shapes.""" model = NewtonManager.get_model() num_shapes = model.shape_count site_indices = self._site_indices if indices is None else indices n = self.count if indices is None else len(indices) out = wp.zeros(n, dtype=wp.vec3f, device=self._device) wp.launch( - _gather_scales, + _gather_shape_scales, dim=n, inputs=[model.shape_scale, model.shape_body, self._site_body, site_indices, num_shapes], outputs=[out], @@ -528,15 +643,66 @@ def get_scales(self, indices: wp.array | None = None) -> ProxyArray: ) return ProxyArray(out) - def set_scales(self, scales: wp.array, indices: wp.array | None = None) -> None: - """Set per-site scales by writing to all collision shapes on the same body.""" + def _set_legacy_shape_scales(self, scales: wp.array, indices: wp.array | None = None) -> None: + """Set Newton legacy geometry scales on collision shapes.""" model = NewtonManager.get_model() num_shapes = model.shape_count site_indices = self._site_indices if indices is None else indices n = self.count if indices is None else len(indices) wp.launch( - _scatter_scales, + _scatter_shape_scales, dim=n, inputs=[self._site_body, site_indices, scales, model.shape_body, num_shapes, model.shape_scale], device=self._device, ) + + def _get_scales_impl(self, indices: wp.array | None = None) -> ProxyArray: + """Newton legacy: get_scales returns collision shape geometry scales.""" + return self._get_legacy_shape_scales(indices) + + def _set_scales_impl(self, scales: wp.array, indices: wp.array | None = None) -> None: + """Newton legacy: deprecated set_scales writes collision shape geometry scales. + + Newton's legacy ``set_scales`` path is *not* routed through the + :class:`FrameViewSpaceWriterBase` API because it targets a different state + (collision-shape geometry sizes) than the transform-scale state that + the writer's :meth:`~FrameViewSpaceWriterBase.set_scales` operates on. + """ + self._set_legacy_shape_scales(scales, indices) + + +# ---------------------------------------------------------------------- +# Pass-through writer classes +# ---------------------------------------------------------------------- + + +class _NewtonWorldSpaceWriter(FrameViewWorldSpaceWriter): + """Newton world-space writer: pass-through to backend ``_apply_*`` hooks.""" + + def set_poses(self, positions=None, orientations=None, indices=None) -> None: + self._view._apply_world_pose_write(positions, orientations, indices) # type: ignore[attr-defined] + + def set_scales(self, scales, indices=None) -> None: + self._view._apply_world_scale_write(scales, indices) # type: ignore[attr-defined] + + def get_poses(self, indices=None) -> tuple[ProxyArray, ProxyArray]: + return self._view._get_world_poses_impl(indices) # type: ignore[attr-defined] + + def get_scales(self, indices=None) -> ProxyArray: + return self._view._get_world_scales_impl(indices) # type: ignore[attr-defined] + + +class _NewtonLocalSpaceWriter(FrameViewLocalSpaceWriter): + """Newton local-space writer: pass-through to backend ``_apply_*`` hooks.""" + + def set_poses(self, positions=None, orientations=None, indices=None) -> None: + self._view._apply_local_pose_write(positions, orientations, indices) # type: ignore[attr-defined] + + def set_scales(self, scales, indices=None) -> None: + self._view._apply_local_scale_write(scales, indices) # type: ignore[attr-defined] + + def get_poses(self, indices=None) -> tuple[ProxyArray, ProxyArray]: + return self._view._get_local_poses_impl(indices) # type: ignore[attr-defined] + + def get_scales(self, indices=None) -> ProxyArray: + return self._view._get_local_scales_impl(indices) # type: ignore[attr-defined] diff --git a/source/isaaclab_newton/test/sim/test_views_xform_prim_newton.py b/source/isaaclab_newton/test/sim/test_views_xform_prim_newton.py index d114a1da2a80..174f0684fb55 100644 --- a/source/isaaclab_newton/test/sim/test_views_xform_prim_newton.py +++ b/source/isaaclab_newton/test/sim/test_views_xform_prim_newton.py @@ -215,7 +215,8 @@ def test_world_attached_set_world_roundtrip(device): new_pos = _wp_vec3f([[10.0, 20.0, 30.0]], device=device) new_quat = _wp_vec4f([[0.0, 0.0, 0.0, 1.0]], device=device) - view.set_world_poses(new_pos, new_quat) + with view.xform_world_space_writer() as w: + w.set_poses(new_pos, new_quat) ret_pos, ret_quat = view.get_world_poses() torch.testing.assert_close(ret_pos.torch, wp.to_torch(new_pos), atol=1e-5, rtol=0) diff --git a/source/isaaclab_ovphysx/changelog.d/fabric-local-poses.rst b/source/isaaclab_ovphysx/changelog.d/fabric-local-poses.rst new file mode 100644 index 000000000000..eb8caa5d254f --- /dev/null +++ b/source/isaaclab_ovphysx/changelog.d/fabric-local-poses.rst @@ -0,0 +1,18 @@ +Added +^^^^^ + +* Added :meth:`~isaaclab_ovphysx.sim.views.OvPhysxFrameView.get_local_scales`, + :meth:`~isaaclab_ovphysx.sim.views.OvPhysxFrameView.set_local_scales`, + :meth:`~isaaclab_ovphysx.sim.views.OvPhysxFrameView.get_world_scales`, and + :meth:`~isaaclab_ovphysx.sim.views.OvPhysxFrameView.set_world_scales`, which + delegate to the internal :class:`~isaaclab.sim.views.UsdFrameView`. + +Deprecated +^^^^^^^^^^ + +* Deprecated :meth:`~isaaclab_ovphysx.sim.views.OvPhysxFrameView.get_scales` and + :meth:`~isaaclab_ovphysx.sim.views.OvPhysxFrameView.set_scales` in favor of the + explicit ``get_local_scales`` / ``set_local_scales`` (operates on + ``xformOp:scale``) or ``get_world_scales`` / ``set_world_scales``. The + deprecated methods still work but emit a ``DeprecationWarning`` and default to + local scales, preserving prior behavior. diff --git a/source/isaaclab_ovphysx/changelog.d/xform-space-writer.rst b/source/isaaclab_ovphysx/changelog.d/xform-space-writer.rst new file mode 100644 index 000000000000..3561232a311e --- /dev/null +++ b/source/isaaclab_ovphysx/changelog.d/xform-space-writer.rst @@ -0,0 +1,13 @@ +Changed +^^^^^^^ + +* :class:`~isaaclab_ovphysx.sim.views.OvPhysxFrameView` now ships + pass-through ``FrameViewWorldSpaceWriter`` / ``FrameViewLocalSpaceWriter`` + implementations so writes follow the new + :meth:`~isaaclab.sim.views.BaseFrameView.xform_world_space_writer` / + :meth:`~isaaclab.sim.views.BaseFrameView.xform_local_space_writer` context API. + ``set_world_poses`` / ``set_local_poses`` shims still work (one-time + ``DeprecationWarning`` per class). Scale writes inside the writer scope + delegate to the internal :class:`~isaaclab.sim.views.UsdFrameView` and + land in the USD stage (no propagation to OVPhysX-side collision-shape + scales). diff --git a/source/isaaclab_ovphysx/isaaclab_ovphysx/sim/views/ovphysx_frame_view.py b/source/isaaclab_ovphysx/isaaclab_ovphysx/sim/views/ovphysx_frame_view.py index 879ecddf385f..4eb60809ad6e 100644 --- a/source/isaaclab_ovphysx/isaaclab_ovphysx/sim/views/ovphysx_frame_view.py +++ b/source/isaaclab_ovphysx/isaaclab_ovphysx/sim/views/ovphysx_frame_view.py @@ -19,6 +19,7 @@ from isaaclab.physics import PhysicsEvent from isaaclab.sim.views.base_frame_view import BaseFrameView from isaaclab.sim.views.usd_frame_view import UsdFrameView +from isaaclab.sim.views.xform_space_writer import FrameViewLocalSpaceWriter, FrameViewWorldSpaceWriter from isaaclab.utils.warp import ProxyArray from isaaclab_ovphysx.physics import OvPhysxManager @@ -624,17 +625,22 @@ def _current_body_q(self) -> wp.array: # World poses # ------------------------------------------------------------------ - def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: - """Get world-space positions and orientations. + # ------------------------------------------------------------------ + # Writer factory hooks (pass-through; OvPhysX has no separate Fabric storage) + # ------------------------------------------------------------------ - Args: - indices: Subset of sites to query. ``None`` means all sites. + def _make_world_space_writer(self) -> FrameViewWorldSpaceWriter: + return _OvPhysxWorldSpaceWriter(self) - Returns: - A tuple ``(positions, orientations)`` of :class:`~isaaclab.utils.warp.ProxyArray` - wrappers. Use ``.warp`` for the underlying ``wp.array`` or ``.torch`` for a - cached zero-copy ``torch.Tensor`` view. - """ + def _make_local_space_writer(self) -> FrameViewLocalSpaceWriter: + return _OvPhysxLocalSpaceWriter(self) + + # ------------------------------------------------------------------ + # Backend hooks + # ------------------------------------------------------------------ + + def _get_world_poses_impl(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: + """Get world-space positions and orientations.""" self._require_initialized() body_q = self._current_body_q() @@ -660,31 +666,20 @@ def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ) return self._pos_ta, self._quat_ta - def set_world_poses( + def _apply_world_pose_write( self, positions: wp.array | None = None, orientations: wp.array | None = None, indices: wp.array | None = None, ) -> None: - """Set world-space positions and/or orientations. - - Updates ``site_local`` so that ``body_q[body] * site_local`` yields the - desired world pose. Does **not** modify ``body_q``. - - Args: - positions: Desired world positions ``(M, 3)`` [m]. ``None`` leaves - positions unchanged. - orientations: Desired world quaternions ``(M, 4)`` as - ``(qx, qy, qz, qw)``. ``None`` leaves orientations unchanged. - indices: Subset of sites to update. ``None`` means all sites. - """ + """Set world-space positions and/or orientations.""" if positions is None and orientations is None: return self._require_initialized() body_q = self._current_body_q() if positions is None or orientations is None: - cur_pos_ta, cur_quat_ta = self.get_world_poses(indices) + cur_pos_ta, cur_quat_ta = self._get_world_poses_impl(indices) if positions is None: positions = cur_pos_ta.warp if orientations is None: @@ -709,18 +704,8 @@ def set_world_poses( # Local poses (parent-relative) # ------------------------------------------------------------------ - def get_local_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: - """Get parent-relative positions and orientations. - - Computes ``inv(parent_world) * prim_world`` for each site. - - Args: - indices: Subset of sites to query. ``None`` means all sites. - - Returns: - A tuple ``(translations, orientations)`` of :class:`~isaaclab.utils.warp.ProxyArray` - wrappers. - """ + def _get_local_poses_impl(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: + """Get parent-relative positions and orientations.""" self._require_initialized() body_q = self._current_body_q() @@ -759,30 +744,20 @@ def get_local_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ) return self._local_pos_ta, self._local_quat_ta - def set_local_poses( + def _apply_local_pose_write( self, translations: wp.array | None = None, orientations: wp.array | None = None, indices: wp.array | None = None, ) -> None: - """Set parent-relative translations and/or orientations. - - Updates ``site_local`` only; does **not** modify ``body_q``. - - Args: - translations: Desired parent-relative translations ``(M, 3)`` [m]. - ``None`` leaves translations unchanged. - orientations: Desired parent-relative quaternions ``(M, 4)`` as - ``(qx, qy, qz, qw)``. ``None`` leaves orientations unchanged. - indices: Subset of sites to update. ``None`` means all sites. - """ + """Set parent-relative translations and/or orientations.""" if translations is None and orientations is None: return self._require_initialized() body_q = self._current_body_q() if translations is None or orientations is None: - cur_pos_ta, cur_quat_ta = self.get_local_poses(indices) + cur_pos_ta, cur_quat_ta = self._get_local_poses_impl(indices) if translations is None: translations = cur_pos_ta.warp if orientations is None: @@ -834,39 +809,43 @@ def _ensure_usd_view(self) -> UsdFrameView: ) return self._usd_view - def get_scales(self, indices: wp.array | None = None) -> ProxyArray: - """Get prim scales from the USD stage's ``xformOp:scale`` attribute. + def _get_local_scales_impl(self, indices: wp.array | None = None) -> ProxyArray: + """Get local-space scales (xformOp:scale) via the USD view. .. note:: This reads the *static* USD authored value, not a live physics-state value. OVPhysX does not maintain a per-shape ``shape_scale`` array equivalent to Newton's ``model.shape_scale``, so sim-driven scale - updates are not reflected here. For sites under ``clone_usd=False`` - envs without authored USD prims, the read returns the env_0 - template's scale via the lazy internal :class:`UsdFrameView`. - - Args: - indices: Subset of sites to query. ``None`` means all sites. - - Returns: - A :class:`~isaaclab.utils.warp.ProxyArray` of shape ``(M, 3)``. + updates are not reflected here. """ - return self._ensure_usd_view().get_scales(indices) + return self._ensure_usd_view()._get_local_scales_impl(indices) - def set_scales(self, scales: wp.array, indices: wp.array | None = None) -> None: - """Set prim scales by writing the USD ``xformOp:scale`` attribute. + def _get_world_scales_impl(self, indices: wp.array | None = None) -> ProxyArray: + """Get world-space (composed) scales via the USD view.""" + return self._ensure_usd_view()._get_world_scales_impl(indices) + + def _apply_local_scale_write(self, scales: wp.array, indices: wp.array | None = None) -> None: + """Set local-space scales (xformOp:scale) via the USD view. .. note:: The write lands in the USD stage but does *not* propagate to any OVPhysX-side collision-shape scale. PhysX is unaffected; this is a - stage-only annotation. Use :class:`~isaaclab_ovphysx.assets.RigidObject` - APIs if you need to change physics-effective shape sizes. - - Args: - scales: Scales ``(M, 3)`` as ``wp.array``. - indices: Subset of sites to update. ``None`` means all sites. + stage-only annotation. """ - self._ensure_usd_view().set_scales(scales, indices) + self._ensure_usd_view()._apply_local_scale_write(scales, indices) + + def _apply_world_scale_write(self, scales: wp.array, indices: wp.array | None = None) -> None: + """Set world-space scales via the USD view.""" + self._ensure_usd_view()._apply_world_scale_write(scales, indices) + + def _get_scales_impl(self, indices=None): + """OvPhysX legacy: deprecated get_scales returns local scales.""" + return self._get_local_scales_impl(indices) + + def _set_scales_impl(self, scales, indices=None): + """OvPhysX legacy: deprecated set_scales writes local scales via a one-shot writer scope.""" + with self.xform_local_space_writer() as writer: + writer.set_scales(scales, indices) def get_visibility(self, indices: wp.array | None = None): """Get visibility for prims in the view (USD-backed). @@ -888,3 +867,40 @@ def _gf_matrix_to_xform7(mat: Gf.Matrix4d) -> list[float]: q = mat.ExtractRotationQuat() imag = q.GetImaginary() return [float(t[0]), float(t[1]), float(t[2]), float(imag[0]), float(imag[1]), float(imag[2]), float(q.GetReal())] + + +# ---------------------------------------------------------------------- +# Pass-through writer classes +# ---------------------------------------------------------------------- + + +class _OvPhysxWorldSpaceWriter(FrameViewWorldSpaceWriter): + """OvPhysX world-space writer: pass-through to backend ``_apply_*`` hooks.""" + + def set_poses(self, positions=None, orientations=None, indices=None) -> None: + self._view._apply_world_pose_write(positions, orientations, indices) # type: ignore[attr-defined] + + def set_scales(self, scales, indices=None) -> None: + self._view._apply_world_scale_write(scales, indices) # type: ignore[attr-defined] + + def get_poses(self, indices=None) -> tuple[ProxyArray, ProxyArray]: + return self._view._get_world_poses_impl(indices) # type: ignore[attr-defined] + + def get_scales(self, indices=None) -> ProxyArray: + return self._view._get_world_scales_impl(indices) # type: ignore[attr-defined] + + +class _OvPhysxLocalSpaceWriter(FrameViewLocalSpaceWriter): + """OvPhysX local-space writer: pass-through to backend ``_apply_*`` hooks.""" + + def set_poses(self, positions=None, orientations=None, indices=None) -> None: + self._view._apply_local_pose_write(positions, orientations, indices) # type: ignore[attr-defined] + + def set_scales(self, scales, indices=None) -> None: + self._view._apply_local_scale_write(scales, indices) # type: ignore[attr-defined] + + def get_poses(self, indices=None) -> tuple[ProxyArray, ProxyArray]: + return self._view._get_local_poses_impl(indices) # type: ignore[attr-defined] + + def get_scales(self, indices=None) -> ProxyArray: + return self._view._get_local_scales_impl(indices) # type: ignore[attr-defined] diff --git a/source/isaaclab_physx/changelog.d/fabric-local-poses.rst b/source/isaaclab_physx/changelog.d/fabric-local-poses.rst new file mode 100644 index 000000000000..546d395cb78e --- /dev/null +++ b/source/isaaclab_physx/changelog.d/fabric-local-poses.rst @@ -0,0 +1,46 @@ +Added +^^^^^ + +* Added :func:`~isaaclab.utils.warp.fabric.decompose_indexed_fabric_transforms` + and :func:`~isaaclab.utils.warp.fabric.compose_indexed_fabric_transforms` + Warp kernels. They mirror the existing + ``decompose_fabric_transformation_matrix_to_warp_arrays`` / + ``compose_fabric_transformation_matrix_from_warp_arrays`` kernels but + operate on :class:`wp.indexedfabricarray`, so the view-to-fabric mapping + is baked into the array and the kernel just dereferences + ``ifa[view_index]`` instead of taking a separate ``mapping`` argument. + +* Added :func:`~isaaclab.utils.warp.fabric.update_indexed_local_matrix_from_world` + and :func:`~isaaclab.utils.warp.fabric.update_indexed_world_matrix_from_local` + Warp kernels that propagate ``local = world * inv(parent)`` and + ``world = local * parent`` directly on Fabric storage matrices. + +* Added Fabric-accelerated ``get_local_poses`` / ``set_local_poses`` to + :class:`~isaaclab_physx.sim.views.FabricFrameView`. + + Local-pose operations now use ``wp.indexedfabricarray`` to read/write + ``omni:fabric:localMatrix`` directly on the GPU, propagating between + parent world matrices and child local/world matrices via Warp kernels + without round-tripping through USD. + +* Added lazy per-view dirty tracking: ``set_local_poses`` marks the world + matrix dirty and vice-versa, triggering automatic re-propagation only on + the next read (no eager kernel launches on the write path). + +* Added interleave detection: interleaving ``set_world_poses`` and + ``set_local_poses`` on the same view within a frame flushes the stale + direction automatically and emits a one-time performance warning. + +* Added topology-change recovery via automatic ``PrepareForReuse`` detection + and per-selection index rebuild. + +Deprecated +^^^^^^^^^^ + +* Deprecated ``get_scales`` / ``set_scales`` on all ``BaseFrameView`` subclasses. + Use the new explicit ``get_local_scales`` / ``set_local_scales`` (operates on + ``xformOp:scale`` / ``localMatrix``) or ``get_world_scales`` / + ``set_world_scales`` (operates on composed world-space scale) instead. + The deprecated methods still work but emit a ``DeprecationWarning``; + ``UsdFrameView`` defaults to local, ``FabricFrameView`` defaults to world + (preserving prior behavior). diff --git a/source/isaaclab_physx/changelog.d/xform-space-writer.rst b/source/isaaclab_physx/changelog.d/xform-space-writer.rst new file mode 100644 index 000000000000..2ff88d32d5b4 --- /dev/null +++ b/source/isaaclab_physx/changelog.d/xform-space-writer.rst @@ -0,0 +1,28 @@ +Changed +^^^^^^^ + +* :class:`~isaaclab_physx.sim.views.FabricFrameView` now writes Fabric + ``omni:fabric:worldMatrix`` and ``omni:fabric:localMatrix`` through the + new context-managed + :class:`~isaaclab.sim.views.FrameViewSpaceWriterBase` scope. Each scope: + + - eagerly writes both the primary matrix (world or local, per the + chosen space) and derives the opposite-space matrix in a single Warp + kernel on ``__exit__``; + - calls ``wp.synchronize()`` once on ``__exit__``; + - pauses :meth:`IFabricHierarchy.track_local_xform_changes` and + :meth:`track_world_xform_changes` while the scope is active and + restores their prior state on exit, so Kit's per-tick + ``updateWorldXforms()`` does not redundantly recompute matrices the + user just wrote. The renderer's independent ``omni:fabric:worldMatrix`` + listener is unaffected and observes the writes. + + The lazy-dirty-flag mechanism (the ``_DirtyFlag`` enum, ``_dirty`` field, + ``_sync_*_if_dirty`` helpers, and the one-time + ``interleaved set_world_poses / set_local_poses`` warning) has been + removed -- the eager dual-write inside the scope makes all of that + unnecessary. + + The three-selection RO/RW layout (``_trans_sel_ro``, + ``_world_sel_rw``, ``_local_sel_rw``) is kept as a defensive layer and + for clarity of authoring intent. diff --git a/source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py b/source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py index 3d3c0a3b0d9e..8d4d64720349 100644 --- a/source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py +++ b/source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py @@ -12,12 +12,12 @@ import torch import warp as wp -from pxr import Usd +from pxr import Gf, Usd, UsdGeom -import isaaclab.sim as sim_utils from isaaclab.app.settings_manager import SettingsManager from isaaclab.sim.views.base_frame_view import BaseFrameView from isaaclab.sim.views.usd_frame_view import UsdFrameView +from isaaclab.sim.views.xform_space_writer import FrameViewLocalSpaceWriter, FrameViewWorldSpaceWriter from isaaclab.utils.warp import ProxyArray from isaaclab.utils.warp import fabric as fabric_utils @@ -44,25 +44,77 @@ class FabricFrameView(BaseFrameView): """FrameView with Fabric GPU acceleration for the PhysX backend. Uses composition: holds a :class:`UsdFrameView` internally for USD - fallback and non-accelerated operations (local poses, visibility, scales - when Fabric is disabled). - - When Fabric is enabled, world-pose and scale operations use Warp kernels - operating on ``omni:fabric:worldMatrix``. Fabric acceleration runs on - the same CUDA device the view was constructed with — ``cuda:0``, - ``cuda:1``, or any other available CUDA index — so this view is safe - to use from distributed-training workers pinned to non-primary GPUs. - All other operations delegate to the internal USD view. + fallback and non-accelerated operations (visibility, and all pose/scale + operations when Fabric is disabled). - After every Fabric write (``set_world_poses``, ``set_scales``), - :meth:`PrepareForReuse` is called on the ``PrimSelection`` to notify - the FSD renderer that Fabric data has changed and to detect topology - changes that require rebuilding internal mappings. Read operations - do not call PrepareForReuse to avoid unnecessary renderer invalidation. + When Fabric is enabled, world-pose, local-pose, and scale operations run + on the GPU via Warp kernels that read and write + ``omni:fabric:worldMatrix`` and ``omni:fabric:localMatrix`` directly. + All other operations delegate to the internal USD view. - Pose getters return :class:`~isaaclab.utils.warp.ProxyArray`. Setters accept ``wp.array``. + All writes go through the writer-scope API + (:meth:`xform_world_space_writer` / :meth:`xform_local_space_writer`, + recommended) or the + deprecated :meth:`set_world_poses` / :meth:`set_local_poses` / etc. shims + inherited from :class:`BaseFrameView`. + + Behavior (Fabric path): + + * **Leaf-prim assumption.** This view manages a flat set of sibling prims + (e.g. all cameras under ``/World/Env_*/Camera``). It does NOT propagate + transforms to child prims. If a managed prim has children whose world + matrices depend on the parent, those children must be updated via a + separate view, a physics step, or ``IFabricHierarchy.update_world_xforms``. + * **No write-back to USD.** Fabric writes update only + ``omni:fabric:worldMatrix`` / ``omni:fabric:localMatrix``; the prim's + USD ``xformOp:*`` attributes are unchanged. Downstream consumers that + read the prim's USD attributes after a Fabric write will see stale + values until the next USD-side sync. + * **Eager dual-write inside a writer scope (no dirty tracking).** + When a writer scope is open, all writes go to the primary attribute + (``worldMatrix`` for the world writer, ``localMatrix`` for the local + writer). On scope exit, a single Warp kernel derives the opposite + attribute and a single ``wp.synchronize()`` runs. After the scope + exits, both Fabric matrices are self-consistent; getters read directly + from Fabric storage without any further synchronization. + * **Hierarchy listeners are paused while a writer scope is active.** + The writer's ``__enter__`` calls + :meth:`IFabricHierarchy.track_local_xform_changes(False)` / + :meth:`track_world_xform_changes(False)` (saving the prior state) so + that Kit's per-tick ``updateWorldXforms()`` does not redundantly + recompute matrices we just wrote. ``__exit__`` restores the prior + tracking state (so we do not re-enable listeners the caller had + previously paused). The renderer's own independent worldMatrix + listener is unaffected and still observes our writes. + * **Three selections with asymmetric RO/RW access.** Despite the + pause/restore above, we keep three selections as a defensive layer: + + .. code-block:: text + + _trans_sel_ro : worldMatrix=RO, localMatrix=RO (reads) + _world_sel_rw : worldMatrix=RW, localMatrix=RO (world writer) + _local_sel_rw : worldMatrix=RO, localMatrix=RW (local writer) + + A combined ``ReadWrite(world, local)`` selection is unsafe even with + tracking pause -- if a refactor accidentally re-enables tracking, + Fabric would see both attributes as user-authored and fall back to the + hierarchy's canonical direction (local -> world), clobbering our world + write. The separate RO/RW layout makes the intended authoring + direction explicit. + * **Topology-adaptive.** Fabric topology changes are detected on each + access via per-selection ``PrepareForReuse()`` polls; the affected + indexed arrays rebuild automatically and no manual refresh is required. + + Pose getters return :class:`~isaaclab.utils.warp.ProxyArray`; the + deprecated :meth:`set_world_poses` / :meth:`set_local_poses` shims accept + :class:`wp.array`. Inside a writer scope, the writer's + :meth:`~FrameViewSpaceWriterBase.set_poses` / :meth:`~FrameViewSpaceWriterBase.set_scales` + accept :class:`wp.array`. """ + _WORLD_MATRIX_NAME = "omni:fabric:worldMatrix" + _LOCAL_MATRIX_NAME = "omni:fabric:localMatrix" + def __init__( self, prim_path: str, @@ -76,7 +128,7 @@ def __init__( Args: prim_path: USD prim-path pattern to match. device: Device for Warp arrays. Either ``"cpu"`` or any CUDA - device string (``"cuda:0"``, ``"cuda:1"``, …); Fabric + device string (``"cuda:0"``, ``"cuda:1"``, ...); Fabric acceleration is supported on every CUDA index. validate_xform_ops: Whether to validate prim xform-ops. stage: USD stage; defaults to the current sim context's stage. @@ -90,18 +142,35 @@ def __init__( settings = SettingsManager.instance() self._use_fabric = bool(settings.get("/physics/fabricEnabled", False)) - # TODO(pv): Misleading abstraction — FabricFrameView can fall back to USD internally; + + # TODO(pv): Misleading abstraction -- FabricFrameView can fall back to USD internally; # the concrete class should be determined by the factory instead. (PR #5673 pv/fabric-view-no-fallback) - # TODO(pv): Fuse set_world_poses/set_scales into single kernel launch (PR #5674 pv/fabric-fused-compose) self._fabric_initialized = False - self._fabric_usd_sync_done = False - self._fabric_selection = None - self._fabric_to_view: wp.array | None = None - self._view_to_fabric: wp.array | None = None - self._default_view_indices: wp.array | None = None + self._stage = None self._fabric_hierarchy = None - self._view_index_attr = f"isaaclab:view_index:{abs(hash(self))}" + + # Three persistent Fabric selections with asymmetric access flags. + self._trans_sel_ro = None + self._world_sel_rw = None + self._local_sel_rw = None + + # Index arrays (view-side indices and per-selection view->fabric mappings). + self._view_indices: wp.array | None = None + self._trans_ro_fabric_indices: wp.array | None = None + self._world_rw_fabric_indices: wp.array | None = None + self._local_rw_fabric_indices: wp.array | None = None + self._parent_fabric_indices: wp.array | None = None + + # Indexed fabric arrays per (selection, attribute) pair. + self._world_ifa_ro = None + self._local_ifa_ro = None + self._world_ifa_rw = None + self._local_ifa_rw = None + self._parent_world_ifa_ro = None + + # Sentinel passed to compose/decompose kernels for unused slots. + self._fabric_empty_2d_array_sentinel: wp.array | None = None # ------------------------------------------------------------------ # Delegated properties @@ -135,59 +204,29 @@ def set_visibility(self, visibility, indices=None): self._usd_view.set_visibility(visibility, indices) # ------------------------------------------------------------------ - # World poses — Fabric-accelerated or USD fallback + # Writer factory hooks # ------------------------------------------------------------------ - def set_world_poses(self, positions=None, orientations=None, indices=None): + def _make_world_space_writer(self) -> FrameViewWorldSpaceWriter: if not self._use_fabric: - self._usd_view.set_world_poses(positions, orientations, indices) - return - - if not self._fabric_initialized: - self._initialize_fabric() - - self._prepare_for_reuse() - - indices_wp = self._resolve_indices_wp(indices) - count = indices_wp.shape[0] - - dummy = wp.zeros((0, 3), dtype=wp.float32, device=self._device) - positions_wp = _to_float32_2d(positions) if positions is not None else dummy - orientations_wp = ( - _to_float32_2d(orientations) - if orientations is not None - else wp.zeros((0, 4), dtype=wp.float32, device=self._device) - ) + return _FabricFallbackWorldWriter(self) + return _FabricWorldSpaceWriter(self) - wp.launch( - kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, - dim=count, - inputs=[ - self._fabric_world_matrices, - positions_wp, - orientations_wp, - dummy, - False, - False, - False, - indices_wp, - self._view_to_fabric, - ], - device=self._fabric_device, - ) - wp.synchronize() + def _make_local_space_writer(self) -> FrameViewLocalSpaceWriter: + if not self._use_fabric: + return _FabricFallbackLocalWriter(self) + return _FabricLocalSpaceWriter(self) - self._fabric_hierarchy.update_world_xforms() - self._fabric_usd_sync_done = True + # ------------------------------------------------------------------ + # Getter hooks -- read directly from Fabric (no lazy sync) + # ------------------------------------------------------------------ - def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: + def _get_world_poses_impl(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: if not self._use_fabric: - return self._usd_view.get_world_poses(indices) + return self._usd_view._get_world_poses_impl(indices) if not self._fabric_initialized: self._initialize_fabric() - if not self._fabric_usd_sync_done: - self._sync_fabric_from_usd_once() indices_wp = self._resolve_indices_wp(indices) count = indices_wp.shape[0] @@ -201,17 +240,16 @@ def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, orientations_wp = wp.zeros((count, 4), dtype=wp.float32, device=self._device) wp.launch( - kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, + kernel=fabric_utils.decompose_indexed_fabric_transforms, dim=count, inputs=[ - self._fabric_world_matrices, + self._get_world_ro_array(), positions_wp, orientations_wp, - self._fabric_dummy_buffer, + self._fabric_empty_2d_array_sentinel, indices_wp, - self._view_to_fabric, ], - device=self._fabric_device, + device=self._device, ) if use_cached: @@ -219,67 +257,62 @@ def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, return self._fabric_positions_ta, self._fabric_orientations_ta return ProxyArray(positions_wp), ProxyArray(orientations_wp) - # ------------------------------------------------------------------ - # Local poses — USD fallback (Fabric only accelerates world poses) - # ------------------------------------------------------------------ - - def set_local_poses(self, translations=None, orientations=None, indices=None): - self._usd_view.set_local_poses(translations, orientations, indices) - - def get_local_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: - return self._usd_view.get_local_poses(indices) - - # ------------------------------------------------------------------ - # Scales — Fabric-accelerated or USD fallback - # ------------------------------------------------------------------ - - def set_scales(self, scales, indices=None): + def _get_local_poses_impl(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]: if not self._use_fabric: - self._usd_view.set_scales(scales, indices) - return + return self._usd_view._get_local_poses_impl(indices) if not self._fabric_initialized: self._initialize_fabric() - self._prepare_for_reuse() - indices_wp = self._resolve_indices_wp(indices) count = indices_wp.shape[0] - dummy3 = wp.zeros((0, 3), dtype=wp.float32, device=self._device) - dummy4 = wp.zeros((0, 4), dtype=wp.float32, device=self._device) - scales_wp = _to_float32_2d(scales) + use_cached = indices is None or indices == slice(None) + if use_cached: + translations_wp = self._fabric_local_translations_buf + orientations_wp = self._fabric_local_orientations_buf + else: + translations_wp = wp.zeros((count, 3), dtype=wp.float32, device=self._device) + orientations_wp = wp.zeros((count, 4), dtype=wp.float32, device=self._device) wp.launch( - kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays, + kernel=fabric_utils.decompose_indexed_fabric_transforms, dim=count, inputs=[ - self._fabric_world_matrices, - dummy3, - dummy4, - scales_wp, - False, - False, - False, + self._get_local_ro_array(), + translations_wp, + orientations_wp, + self._fabric_empty_2d_array_sentinel, indices_wp, - self._view_to_fabric, ], - device=self._fabric_device, + device=self._device, ) - wp.synchronize() - self._fabric_hierarchy.update_world_xforms() - self._fabric_usd_sync_done = True + if use_cached: + wp.synchronize() + return self._fabric_local_translations_ta, self._fabric_local_orientations_ta + return ProxyArray(translations_wp), ProxyArray(orientations_wp) - def get_scales(self, indices: wp.array | None = None) -> ProxyArray: + def _get_world_scales_impl(self, indices=None) -> ProxyArray: if not self._use_fabric: - return self._usd_view.get_scales(indices) + return self._usd_view._get_world_scales_impl(indices) if not self._fabric_initialized: self._initialize_fabric() - if not self._fabric_usd_sync_done: - self._sync_fabric_from_usd_once() + return self._decompose_scales(self._get_world_ro_array(), indices) + + def _get_local_scales_impl(self, indices=None) -> ProxyArray: + if not self._use_fabric: + return self._usd_view._get_local_scales_impl(indices) + + if not self._fabric_initialized: + self._initialize_fabric() + + return self._decompose_scales(self._get_local_ro_array(), indices) + + def _decompose_scales(self, ro_array, indices) -> ProxyArray: + """Shared scale-decompose path for world / local getters.""" indices_wp = self._resolve_indices_wp(indices) count = indices_wp.shape[0] @@ -290,168 +323,593 @@ def get_scales(self, indices: wp.array | None = None) -> ProxyArray: scales_wp = wp.zeros((count, 3), dtype=wp.float32, device=self._device) wp.launch( - kernel=fabric_utils.decompose_fabric_transformation_matrix_to_warp_arrays, + kernel=fabric_utils.decompose_indexed_fabric_transforms, dim=count, inputs=[ - self._fabric_world_matrices, - self._fabric_dummy_buffer, - self._fabric_dummy_buffer, + ro_array, + self._fabric_empty_2d_array_sentinel, + self._fabric_empty_2d_array_sentinel, scales_wp, indices_wp, - self._view_to_fabric, ], - device=self._fabric_device, + device=self._device, ) if use_cached: wp.synchronize() + return self._fabric_scales_ta return ProxyArray(scales_wp) # ------------------------------------------------------------------ - # Internal — PrepareForReuse (renderer notification + topology tracking) + # Deprecated get_scales / set_scales hooks # ------------------------------------------------------------------ - def _prepare_for_reuse(self) -> None: - """Call PrepareForReuse on the PrimSelection to notify the renderer. + def _get_scales_impl(self, indices=None) -> ProxyArray: + """Fabric: deprecated get_scales returns world-space scales (legacy behavior).""" + return self._get_world_scales_impl(indices) - PrepareForReuse serves two purposes: + def _set_scales_impl(self, scales, indices=None) -> None: + """Fabric: deprecated set_scales writes world-space scales via a one-shot writer scope.""" + with self.xform_world_space_writer() as writer: + writer.set_scales(scales, indices) - 1. **Renderer notification**: Tells FSD/Storm that Fabric data has - been (or will be) modified, so the next rendered frame reflects - the updated transforms. - 2. **Topology change detection**: Returns True when Fabric's - internal memory layout changed (e.g., prims added/removed). - In that case, view-to-fabric index mappings and fabricarrays - must be rebuilt. - """ - if self._fabric_selection is None: - return + # ------------------------------------------------------------------ + # Internal -- helpers shared by writers + initialization + # ------------------------------------------------------------------ - topology_changed = self._fabric_selection.PrepareForReuse() - if topology_changed: - logger.debug("Fabric topology changed — rebuilding view-to-fabric index mapping.") - self._rebuild_fabric_arrays() + def _to_float32_2d_or_empty(self, data): + return self._fabric_empty_2d_array_sentinel if data is None else _to_float32_2d(data) - def _rebuild_fabric_arrays(self) -> None: - """Rebuild fabricarray and view↔fabric mappings after a topology change. + def _recompute_local_from_world_all(self) -> None: + """Derive ``localMatrix = inv(parent) * worldMatrix`` for every prim in the view. - Note: Only index mappings and fabricarrays are rebuilt. Position/orientation/scale - buffers are *not* resized because ``self.count`` is derived from the USD prim-path - pattern (via ``_usd_view.count``) and does not change when Fabric rearranges its - internal memory layout. The assertion below guards this invariant. + Called from :class:`_FabricWorldSpaceWriter` ``__exit__`` to keep the + (world, local) pair self-consistent after a world-space write. + Storage convention: see + :func:`isaaclab.utils.warp.fabric.update_indexed_local_matrix_from_world`. """ - assert self.count == self._default_view_indices.shape[0], ( - f"Prim count changed ({self.count} vs {self._default_view_indices.shape[0]}). " - "Fabric topology change added/removed tracked prims — full re-initialization required." + if self._trans_sel_ro.PrepareForReuse() or self._parent_world_ifa_ro is None: + self._rebuild_trans_ro_arrays() + wp.launch( + kernel=fabric_utils.update_indexed_local_matrix_from_world, + dim=self.count, + inputs=[ + self._world_ifa_ro, + self._parent_world_ifa_ro, + self._get_local_rw_array(), + self._view_indices, + ], + device=self._device, ) - self._view_to_fabric = wp.zeros((self.count,), dtype=wp.uint32, device=self._fabric_device) - self._fabric_to_view = wp.fabricarray(self._fabric_selection, self._view_index_attr) + def _recompute_world_from_local_all(self) -> None: + """Derive ``worldMatrix = parent * localMatrix`` for every prim in the view. + + Called from :class:`_FabricLocalSpaceWriter` ``__exit__`` and from + :meth:`_sync_fabric_from_usd_initial` after seeding local matrices. + Storage convention: see + :func:`isaaclab.utils.warp.fabric.update_indexed_world_matrix_from_local`. + """ + if self._trans_sel_ro.PrepareForReuse() or self._parent_world_ifa_ro is None: + self._rebuild_trans_ro_arrays() wp.launch( - kernel=fabric_utils.set_view_to_fabric_array, - dim=self._fabric_to_view.shape[0], - inputs=[self._fabric_to_view, self._view_to_fabric], - device=self._fabric_device, + kernel=fabric_utils.update_indexed_world_matrix_from_local, + dim=self.count, + inputs=[ + self._local_ifa_ro, + self._parent_world_ifa_ro, + self._get_world_rw_array(), + self._view_indices, + ], + device=self._device, ) - wp.synchronize() - self._fabric_world_matrices = wp.fabricarray(self._fabric_selection, "omni:fabric:worldMatrix") + # ------------------------------------------------------------------ + # Internal -- selection accessors with on-demand index rebuild + # ------------------------------------------------------------------ + + def _get_world_ro_array(self): + if self._trans_sel_ro.PrepareForReuse(): + self._rebuild_trans_ro_arrays() + return self._world_ifa_ro + + def _get_local_ro_array(self): + if self._trans_sel_ro.PrepareForReuse(): + self._rebuild_trans_ro_arrays() + return self._local_ifa_ro + + def _get_world_rw_array(self): + if self._world_sel_rw.PrepareForReuse(): + self._world_rw_fabric_indices = self._compute_fabric_indices(self._world_sel_rw) + self._world_ifa_rw = self._build_indexed_array( + self._world_sel_rw, self._WORLD_MATRIX_NAME, self._world_rw_fabric_indices + ) + return self._world_ifa_rw + + def _get_local_rw_array(self): + if self._local_sel_rw.PrepareForReuse(): + self._local_rw_fabric_indices = self._compute_fabric_indices(self._local_sel_rw) + self._local_ifa_rw = self._build_indexed_array( + self._local_sel_rw, self._LOCAL_MATRIX_NAME, self._local_rw_fabric_indices + ) + return self._local_ifa_rw + + def _get_parent_world_ro_array(self): + # Built and refreshed alongside the trans_ro selection (parents share that selection). + if self._parent_world_ifa_ro is None or self._trans_sel_ro.PrepareForReuse(): + self._rebuild_trans_ro_arrays() + return self._parent_world_ifa_ro + + def _rebuild_trans_ro_arrays(self) -> None: + """Rebuild the trans_ro indices and the three indexed arrays that depend on them. + + ``_world_ifa_ro``, ``_local_ifa_ro`` and ``_parent_world_ifa_ro`` are all + keyed off the ``trans_sel_ro`` path ordering, so they are refreshed together. + """ + self._trans_ro_fabric_indices = self._compute_fabric_indices(self._trans_sel_ro) + self._world_ifa_ro = self._build_indexed_array( + self._trans_sel_ro, self._WORLD_MATRIX_NAME, self._trans_ro_fabric_indices + ) + self._local_ifa_ro = self._build_indexed_array( + self._trans_sel_ro, self._LOCAL_MATRIX_NAME, self._trans_ro_fabric_indices + ) + self._parent_world_ifa_ro = self._build_parent_indexed_array(self._trans_sel_ro) # ------------------------------------------------------------------ - # Internal — Fabric initialization + # Internal -- index computation + # ------------------------------------------------------------------ + + def _compute_fabric_indices(self, selection) -> wp.array: + fabric_paths = selection.GetPaths() + path_to_fabric_idx: dict[str, int] = {str(p): i for i, p in enumerate(fabric_paths)} + indices: list[int] = [] + for prim_path in self.prim_paths: + fabric_idx = path_to_fabric_idx.get(prim_path) + if fabric_idx is None: + raise RuntimeError( + f"Prim '{prim_path}' not found in Fabric selection. Ensure the hierarchy has been populated." + ) + indices.append(fabric_idx) + return wp.array(indices, dtype=wp.int32, device=self._device) + + def _compute_parent_fabric_indices(self, selection) -> wp.array: + """For each child in this view, look up the parent prim's fabric index.""" + fabric_paths = selection.GetPaths() + path_to_fabric_idx: dict[str, int] = {str(p): i for i, p in enumerate(fabric_paths)} + indices: list[int] = [] + for prim_path in self.prim_paths: + parent_path = prim_path.rsplit("/", 1)[0] + if parent_path == "": + raise RuntimeError( + f"Child prim '{prim_path}' is at stage root and has no parent prim. " + "FabricFrameView requires every prim to have a non-pseudoroot parent " + "with Fabric world+local matrices." + ) + fabric_idx = path_to_fabric_idx.get(parent_path) + if fabric_idx is None: + raise RuntimeError( + f"Parent prim '{parent_path}' (for child '{prim_path}') not found in Fabric selection. " + "Ensure parents have Fabric world+local matrices populated." + ) + indices.append(fabric_idx) + return wp.array(indices, dtype=wp.int32, device=self._device) + + def _build_indexed_array(self, selection, attribute_name: str, fabric_indices: wp.array) -> wp.indexedfabricarray: + fa = wp.fabricarray(selection, attribute_name) + return wp.indexedfabricarray(fa=fa, indices=fabric_indices) + + def _build_parent_indexed_array(self, selection) -> wp.indexedfabricarray: + self._parent_fabric_indices = self._compute_parent_fabric_indices(selection) + fa = wp.fabricarray(selection, self._WORLD_MATRIX_NAME) + return wp.indexedfabricarray(fa=fa, indices=self._parent_fabric_indices) + + def _resolve_indices_wp(self, indices: wp.array | None) -> wp.array: + """Resolve view indices as a Warp uint32 array.""" + if indices is None or indices == slice(None): + if self._view_indices is None: + raise RuntimeError("Fabric view indices are not initialized.") + return self._view_indices + if indices.dtype != wp.uint32: + return wp.array(indices.numpy().astype("uint32"), dtype=wp.uint32, device=self._device) + return indices + + # ------------------------------------------------------------------ + # Internal -- Fabric initialization # ------------------------------------------------------------------ def _initialize_fabric(self) -> None: - """Initialize Fabric batch infrastructure for GPU-accelerated pose queries.""" + """One-time Fabric setup: hierarchy handle, attribute population, selections, indexed arrays.""" import usdrt # noqa: PLC0415 from usdrt import Rt # noqa: PLC0415 - stage_id = sim_utils.get_current_stage_id() - fabric_stage = usdrt.Usd.Stage.Attach(stage_id) + from isaaclab.sim.utils import get_current_stage_id # noqa: PLC0415 - for i in range(self.count): - rt_prim = fabric_stage.GetPrimAtPath(self.prim_paths[i]) - rt_xformable = Rt.Xformable(rt_prim) + # Attach usdrt stage and create hierarchy handle. + stage_id = get_current_stage_id() + self._stage = usdrt.Usd.Stage.Attach(stage_id) + fabric_id = self._stage.GetFabricId() + self._fabric_id = fabric_id.id + self._fabric_hierarchy = usdrt.hierarchy.IFabricHierarchy().get_fabric_hierarchy( + fabric_id, self._stage.GetStageIdAsStageId() + ) - has_attr = ( - rt_xformable.HasFabricHierarchyWorldMatrixAttr() - if hasattr(rt_xformable, "HasFabricHierarchyWorldMatrixAttr") - else False - ) - if not has_attr: + # Ensure each child prim AND its parent have BOTH Fabric world and local matrix + # attributes. ``Create*Attr`` calls are idempotent. + seen_paths: set[str] = set() + for child_path in self.prim_paths: + for path in (child_path, child_path.rsplit("/", 1)[0]): + if path in seen_paths: + continue + seen_paths.add(path) + rt_prim = self._stage.GetPrimAtPath(path) + if not rt_prim.IsValid(): + continue + rt_xformable = Rt.Xformable(rt_prim) rt_xformable.CreateFabricHierarchyWorldMatrixAttr() + rt_xformable.CreateFabricHierarchyLocalMatrixAttr() + rt_xformable.SetLocalXformFromUsd() + rt_xformable.SetWorldXformFromUsd() + + # Three persistent selections with asymmetric access flags. + matrix = usdrt.Sdf.ValueTypeNames.Matrix4d + ro = usdrt.Usd.Access.Read + rw = usdrt.Usd.Access.ReadWrite + wm_ro = (matrix, self._WORLD_MATRIX_NAME, ro) + lm_ro = (matrix, self._LOCAL_MATRIX_NAME, ro) + wm_rw = (matrix, self._WORLD_MATRIX_NAME, rw) + lm_rw = (matrix, self._LOCAL_MATRIX_NAME, rw) + self._trans_sel_ro = self._stage.SelectPrims(require_attrs=[wm_ro, lm_ro], device=self._device, want_paths=True) + self._world_sel_rw = self._stage.SelectPrims(require_attrs=[wm_rw, lm_ro], device=self._device, want_paths=True) + self._local_sel_rw = self._stage.SelectPrims(require_attrs=[wm_ro, lm_rw], device=self._device, want_paths=True) + + # Build the view-side indices array and per-selection view->fabric mappings. + self._view_indices = wp.array(list(range(self.count)), dtype=wp.uint32, device=self._device) + self._trans_ro_fabric_indices = self._compute_fabric_indices(self._trans_sel_ro) + self._world_rw_fabric_indices = self._compute_fabric_indices(self._world_sel_rw) + self._local_rw_fabric_indices = self._compute_fabric_indices(self._local_sel_rw) + + # Indexed fabric arrays per (selection x attribute). + self._world_ifa_ro = self._build_indexed_array( + self._trans_sel_ro, self._WORLD_MATRIX_NAME, self._trans_ro_fabric_indices + ) + self._local_ifa_ro = self._build_indexed_array( + self._trans_sel_ro, self._LOCAL_MATRIX_NAME, self._trans_ro_fabric_indices + ) + self._world_ifa_rw = self._build_indexed_array( + self._world_sel_rw, self._WORLD_MATRIX_NAME, self._world_rw_fabric_indices + ) + self._local_ifa_rw = self._build_indexed_array( + self._local_sel_rw, self._LOCAL_MATRIX_NAME, self._local_rw_fabric_indices + ) + self._parent_world_ifa_ro = self._build_parent_indexed_array(self._trans_sel_ro) - rt_xformable.SetWorldXformFromUsd() + # Pre-allocated reusable output buffers (world + local + scales). + self._fabric_positions_buf = wp.zeros((self.count, 3), dtype=wp.float32, device=self._device) + self._fabric_orientations_buf = wp.zeros((self.count, 4), dtype=wp.float32, device=self._device) + self._fabric_scales_buf = wp.zeros((self.count, 3), dtype=wp.float32, device=self._device) + self._fabric_local_translations_buf = wp.zeros((self.count, 3), dtype=wp.float32, device=self._device) + self._fabric_local_orientations_buf = wp.zeros((self.count, 4), dtype=wp.float32, device=self._device) + self._fabric_empty_2d_array_sentinel = wp.zeros((0, 0), dtype=wp.float32, device=self._device) - rt_prim.CreateAttribute(self._view_index_attr, usdrt.Sdf.ValueTypeNames.UInt, custom=True) - rt_prim.GetAttribute(self._view_index_attr).Set(i) + self._fabric_positions_ta = ProxyArray(self._fabric_positions_buf) + self._fabric_orientations_ta = ProxyArray(self._fabric_orientations_buf) + self._fabric_scales_ta = ProxyArray(self._fabric_scales_buf) + self._fabric_local_translations_ta = ProxyArray(self._fabric_local_translations_buf) + self._fabric_local_orientations_ta = ProxyArray(self._fabric_local_orientations_buf) - self._fabric_hierarchy = usdrt.hierarchy.IFabricHierarchy().get_fabric_hierarchy( - fabric_stage.GetFabricId(), fabric_stage.GetStageIdAsStageId() - ) - self._fabric_hierarchy.update_world_xforms() + self._fabric_initialized = True + + # Seed Fabric matrices from USD authoritatively. + self._sync_fabric_from_usd_initial() + + def _sync_fabric_from_usd_initial(self) -> None: + """Populate Fabric world+local matrices for children and parents from USD. - self._default_view_indices = wp.zeros((self.count,), dtype=wp.uint32, device=self._device) + Performed once during ``_initialize_fabric``. Without this step Fabric's + matrices are identity for stages that haven't been rendered yet, and our + getters (which read from Fabric) would return wrong values. + """ + # --- Children: compose child localMatrix from USD-authored local transforms. + scales_wp = _to_float32_2d(self._usd_view.get_local_scales().warp) + local_pos_ta, local_ori_ta = self._usd_view.get_local_poses() wp.launch( - kernel=fabric_utils.arange_k, dim=self.count, inputs=[self._default_view_indices], device=self._device + kernel=fabric_utils.compose_indexed_fabric_transforms, + dim=self.count, + inputs=[ + self._local_ifa_rw, + _to_float32_2d(local_pos_ta.warp), + _to_float32_2d(local_ori_ta.warp), + _to_float32_2d(scales_wp), + False, + False, + False, + self._view_indices, + ], + device=self._device, ) + + # --- Parents (one entry per unique parent path) --- + unique_parent_paths = list(dict.fromkeys(p.rsplit("/", 1)[0] for p in self.prim_paths)) + if unique_parent_paths: + from isaaclab.sim.utils import get_current_stage # noqa: PLC0415 + + usd_stage = get_current_stage() + xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default()) + world_pos_rows: list[list[float]] = [] + world_ori_rows: list[list[float]] = [] + world_scale_rows: list[list[float]] = [] + decomposer = Gf.Transform() + warned_shear = False + for path in unique_parent_paths: + prim = usd_stage.GetPrimAtPath(path) + tf = xform_cache.GetLocalToWorldTransform(prim) + decomposer.SetMatrix(tf) + s = decomposer.GetScale() + if not warned_shear: + row0 = Gf.Vec3d(tf[0][0], tf[0][1], tf[0][2]).GetNormalized() + row1 = Gf.Vec3d(tf[1][0], tf[1][1], tf[1][2]).GetNormalized() + row2 = Gf.Vec3d(tf[2][0], tf[2][1], tf[2][2]).GetNormalized() + if ( + abs(Gf.Dot(row0, row1)) > 1e-3 + or abs(Gf.Dot(row0, row2)) > 1e-3 + or abs(Gf.Dot(row1, row2)) > 1e-3 + ): + warned_shear = True + logger.warning( + "FabricFrameView: parent prim '%s' has a sheared/skewed world " + "transform. TRS decomposition (used by scale getters and world<->local " + "propagation) does not support shear -- extracted scales and rotations " + "will be approximate. Avoid shear in parent transforms for correct results.", + path, + ) + tf.Orthonormalize() + t = tf.ExtractTranslation() + q = tf.ExtractRotationQuat() + img, real = q.GetImaginary(), q.GetReal() + world_pos_rows.append([float(t[0]), float(t[1]), float(t[2])]) + world_ori_rows.append([float(img[0]), float(img[1]), float(img[2]), float(real)]) + world_scale_rows.append([float(s[0]), float(s[1]), float(s[2])]) + parent_view_indices = wp.array(list(range(len(unique_parent_paths))), dtype=wp.uint32, device=self._device) + parent_pos_wp = wp.array(world_pos_rows, dtype=wp.float32, device=self._device) + parent_ori_wp = wp.array(world_ori_rows, dtype=wp.float32, device=self._device) + parent_scale_wp = wp.array(world_scale_rows, dtype=wp.float32, device=self._device) + parent_world_rw = wp.indexedfabricarray( + fa=wp.fabricarray(self._world_sel_rw, self._WORLD_MATRIX_NAME), + indices=self._compute_fabric_indices_for(self._world_sel_rw, unique_parent_paths), + ) + wp.launch( + kernel=fabric_utils.compose_indexed_fabric_transforms, + dim=len(unique_parent_paths), + inputs=[ + parent_world_rw, + parent_pos_wp, + parent_ori_wp, + parent_scale_wp, + False, + False, + False, + parent_view_indices, + ], + device=self._device, + ) wp.synchronize() - self._fabric_selection = fabric_stage.SelectPrims( - require_attrs=[ - (usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read), - (usdrt.Sdf.ValueTypeNames.Matrix4d, "omni:fabric:worldMatrix", usdrt.Usd.Access.ReadWrite), + # After seeding local matrices from USD, recompute world matrices so + # the view starts with consistent state. + self._recompute_world_from_local_all() + wp.synchronize() + + def _compute_fabric_indices_for(self, selection, paths: list[str]) -> wp.array: + """Path-dict lookup helper used to build one-shot indexed arrays for a custom path set.""" + fabric_paths = selection.GetPaths() + path_to_idx = {str(p): i for i, p in enumerate(fabric_paths)} + indices: list[int] = [] + for path in paths: + idx = path_to_idx.get(path) + if idx is None: + raise RuntimeError(f"Path '{path}' not found in Fabric selection.") + indices.append(idx) + return wp.array(indices, dtype=wp.int32, device=self._device) + + +# ---------------------------------------------------------------------- +# Concrete writer classes for FabricFrameView +# ---------------------------------------------------------------------- + + +class _FabricWriterMixin: + """Common ``__enter__`` / ``__exit__`` for the Fabric world / local writers. + + Pauses ``track_local_xform_changes`` / ``track_world_xform_changes`` on + the Fabric hierarchy while the scope is active so Kit does not redundantly + recompute the matrices we just wrote, then restores the prior state on + exit. + """ + + def _enter_impl(self) -> None: + view: FabricFrameView = self._view # type: ignore[assignment] + if not view._fabric_initialized: + view._initialize_fabric() + self._wrote_anything = False + h = view._fabric_hierarchy + self._was_tracking_local = h.tracking_local_xform_changes + self._was_tracking_world = h.tracking_world_xform_changes + if self._was_tracking_local: + h.track_local_xform_changes(False) + if self._was_tracking_world: + h.track_world_xform_changes(False) + + def _exit_impl(self, exc_type, exc_val, exc_tb) -> None: + view: FabricFrameView = self._view # type: ignore[assignment] + try: + if self._wrote_anything and exc_type is None: + self._derive_opposite() + wp.synchronize() + finally: + h = view._fabric_hierarchy + if self._was_tracking_world: + h.track_world_xform_changes(True) + if self._was_tracking_local: + h.track_local_xform_changes(True) + + def _derive_opposite(self) -> None: + raise NotImplementedError + + +class _FabricWorldSpaceWriter(_FabricWriterMixin, FrameViewWorldSpaceWriter): + """World-space writer for :class:`FabricFrameView`. + + Writes flow through ``_world_sel_rw``; on exit ``localMatrix`` is derived + from the just-written ``worldMatrix`` via + :func:`update_indexed_local_matrix_from_world`. + """ + + def _derive_opposite(self) -> None: + self._view._recompute_local_from_world_all() # type: ignore[attr-defined] + + def set_poses(self, positions=None, orientations=None, indices=None) -> None: + view: FabricFrameView = self._view # type: ignore[assignment] + indices_wp = view._resolve_indices_wp(indices) + positions_wp = view._to_float32_2d_or_empty(positions) + orientations_wp = view._to_float32_2d_or_empty(orientations) + wp.launch( + kernel=fabric_utils.compose_indexed_fabric_transforms, + dim=indices_wp.shape[0], + inputs=[ + view._get_world_rw_array(), + positions_wp, + orientations_wp, + view._fabric_empty_2d_array_sentinel, + False, + False, + False, + indices_wp, ], - device=self._device, + device=view._device, ) + self._wrote_anything = True - self._view_to_fabric = wp.zeros((self.count,), dtype=wp.uint32, device=self._device) - self._fabric_to_view = wp.fabricarray(self._fabric_selection, self._view_index_attr) + def set_scales(self, scales, indices=None) -> None: + view: FabricFrameView = self._view # type: ignore[assignment] + indices_wp = view._resolve_indices_wp(indices) + scales_wp = view._to_float32_2d_or_empty(scales) + wp.launch( + kernel=fabric_utils.compose_indexed_fabric_transforms, + dim=indices_wp.shape[0], + inputs=[ + view._get_world_rw_array(), + view._fabric_empty_2d_array_sentinel, + view._fabric_empty_2d_array_sentinel, + scales_wp, + False, + False, + False, + indices_wp, + ], + device=view._device, + ) + self._wrote_anything = True + + def get_poses(self, indices=None) -> tuple[ProxyArray, ProxyArray]: + return self._view._get_world_poses_impl(indices) # type: ignore[attr-defined] + + def get_scales(self, indices=None) -> ProxyArray: + return self._view._get_world_scales_impl(indices) # type: ignore[attr-defined] + + +class _FabricLocalSpaceWriter(_FabricWriterMixin, FrameViewLocalSpaceWriter): + """Local-space writer for :class:`FabricFrameView`. + + Writes flow through ``_local_sel_rw``; on exit ``worldMatrix`` is derived + from the just-written ``localMatrix`` via + :func:`update_indexed_world_matrix_from_local`. + """ + def _derive_opposite(self) -> None: + self._view._recompute_world_from_local_all() # type: ignore[attr-defined] + + def set_poses(self, positions=None, orientations=None, indices=None) -> None: + view: FabricFrameView = self._view # type: ignore[assignment] + indices_wp = view._resolve_indices_wp(indices) + translations_wp = view._to_float32_2d_or_empty(positions) + orientations_wp = view._to_float32_2d_or_empty(orientations) wp.launch( - kernel=fabric_utils.set_view_to_fabric_array, - dim=self._fabric_to_view.shape[0], - inputs=[self._fabric_to_view, self._view_to_fabric], - device=self._device, + kernel=fabric_utils.compose_indexed_fabric_transforms, + dim=indices_wp.shape[0], + inputs=[ + view._get_local_rw_array(), + translations_wp, + orientations_wp, + view._fabric_empty_2d_array_sentinel, + False, + False, + False, + indices_wp, + ], + device=view._device, ) - wp.synchronize() + self._wrote_anything = True - self._fabric_positions_buf = wp.zeros((self.count, 3), dtype=wp.float32, device=self._device) - self._fabric_orientations_buf = wp.zeros((self.count, 4), dtype=wp.float32, device=self._device) - self._fabric_positions_ta = ProxyArray(self._fabric_positions_buf) - self._fabric_orientations_ta = ProxyArray(self._fabric_orientations_buf) - self._fabric_scales_buf = wp.zeros((self.count, 3), dtype=wp.float32, device=self._device) - self._fabric_dummy_buffer = wp.zeros((0, 3), dtype=wp.float32, device=self._device) - self._fabric_world_matrices = wp.fabricarray(self._fabric_selection, "omni:fabric:worldMatrix") - self._fabric_stage = fabric_stage - self._fabric_device = self._device + def set_scales(self, scales, indices=None) -> None: + view: FabricFrameView = self._view # type: ignore[assignment] + indices_wp = view._resolve_indices_wp(indices) + scales_wp = view._to_float32_2d_or_empty(scales) + wp.launch( + kernel=fabric_utils.compose_indexed_fabric_transforms, + dim=indices_wp.shape[0], + inputs=[ + view._get_local_rw_array(), + view._fabric_empty_2d_array_sentinel, + view._fabric_empty_2d_array_sentinel, + scales_wp, + False, + False, + False, + indices_wp, + ], + device=view._device, + ) + self._wrote_anything = True - self._fabric_initialized = True - self._fabric_usd_sync_done = False + def get_poses(self, indices=None) -> tuple[ProxyArray, ProxyArray]: + return self._view._get_local_poses_impl(indices) # type: ignore[attr-defined] - def _sync_fabric_from_usd_once(self) -> None: - """Sync Fabric world matrices from USD once, on the first read. + def get_scales(self, indices=None) -> ProxyArray: + return self._view._get_local_scales_impl(indices) # type: ignore[attr-defined] - ``set_world_poses`` and ``set_scales`` each set ``_fabric_usd_sync_done`` - themselves, so no explicit flag assignment is needed here. - """ - if not self._fabric_initialized: - self._initialize_fabric() - positions_usd_ta, orientations_usd_ta = self._usd_view.get_world_poses() - positions_usd = positions_usd_ta.warp - orientations_usd = orientations_usd_ta.warp - scales_usd = self._usd_view.get_scales().warp +class _FabricFallbackWorldWriter(FrameViewWorldSpaceWriter): + """Fallback world-space writer used when Fabric is disabled. - self.set_world_poses(positions_usd, orientations_usd) - self.set_scales(scales_usd) + Delegates set/get calls to the internal :class:`UsdFrameView`'s backend + hooks directly. No batching, no listener pausing -- there's no Fabric to + confuse. + """ - def _resolve_indices_wp(self, indices: wp.array | None) -> wp.array: - """Resolve view indices as a Warp uint32 array.""" - if indices is None or indices == slice(None): - if self._default_view_indices is None: - raise RuntimeError("Fabric indices are not initialized.") - return self._default_view_indices - if indices.dtype != wp.uint32: - return wp.array(indices.numpy().astype("uint32"), dtype=wp.uint32, device=self._device) - return indices + def set_poses(self, positions=None, orientations=None, indices=None) -> None: + self._view._usd_view._apply_world_pose_write(positions, orientations, indices) # type: ignore[attr-defined] + + def set_scales(self, scales, indices=None) -> None: + self._view._usd_view._apply_world_scale_write(scales, indices) # type: ignore[attr-defined] + + def get_poses(self, indices=None) -> tuple[ProxyArray, ProxyArray]: + return self._view._usd_view._get_world_poses_impl(indices) # type: ignore[attr-defined] + + def get_scales(self, indices=None) -> ProxyArray: + return self._view._usd_view._get_world_scales_impl(indices) # type: ignore[attr-defined] + + +class _FabricFallbackLocalWriter(FrameViewLocalSpaceWriter): + """Fallback local-space writer used when Fabric is disabled.""" + + def set_poses(self, positions=None, orientations=None, indices=None) -> None: + self._view._usd_view._apply_local_pose_write(positions, orientations, indices) # type: ignore[attr-defined] + + def set_scales(self, scales, indices=None) -> None: + self._view._usd_view._apply_local_scale_write(scales, indices) # type: ignore[attr-defined] + + def get_poses(self, indices=None) -> tuple[ProxyArray, ProxyArray]: + return self._view._usd_view._get_local_poses_impl(indices) # type: ignore[attr-defined] + + def get_scales(self, indices=None) -> ProxyArray: + return self._view._usd_view._get_local_scales_impl(indices) # type: ignore[attr-defined] diff --git a/source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py b/source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py index 3cfe70095fd3..4b030abc20c6 100644 --- a/source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py +++ b/source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py @@ -24,7 +24,7 @@ import torch # noqa: E402 import warp as wp # noqa: E402 from frame_view_contract_utils import * # noqa: F401, F403, E402 -from frame_view_contract_utils import CHILD_OFFSET, ViewBundle, test_set_world_updates_local # noqa: E402 +from frame_view_contract_utils import CHILD_OFFSET, ViewBundle # noqa: E402 from isaaclab_physx.sim.views import FabricFrameView as FrameView # noqa: E402 from pxr import Gf, UsdGeom # noqa: E402 @@ -57,7 +57,7 @@ def _skip_if_unavailable(device: str): # a misconfigured multi-GPU runner is already caught there. Failing here would # only break the standard single-GPU CI runners that legitimately can't run # ``cuda:1+`` tests. - pytest.skip(f"{device} not available (device_count={n}) — multi-GPU test skipped") + pytest.skip(f"{device} not available (device_count={n}) -- multi-GPU test skipped") # ------------------------------------------------------------------ @@ -118,28 +118,11 @@ def factory(num_envs: int, device: str) -> ViewBundle: # ------------------------------------------------------------------ -# Override shared contract test with expected failure for Fabric. -# FabricFrameView.set_world_poses writes to Fabric worldMatrix only; the local -# pose (read via USD) does not reflect the change because there is no -# Fabric → USD writeback for local poses. This is tracked as Issue #5 -# (localMatrix: set_local_poses falls back to USD). +# Override: ensure the shared contract test runs without xfail now that +# get_local_poses computes local from Fabric world matrices. # ------------------------------------------------------------------ - - -@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) -@pytest.mark.xfail( - reason=( - "Issue #5: FabricFrameView.set_world_poses writes to Fabric worldMatrix only. " - "get_local_poses reads from stale USD because there is no Fabric→USD " - "writeback for local poses." - ), - strict=True, -) -def test_set_world_updates_local(device, view_factory): # noqa: F811 - """Override the shared test to mark it as expected failure.""" - from frame_view_contract_utils import test_set_world_updates_local as _impl # noqa: PLC0415 - - _impl(device, view_factory) +# (No override needed -- the shared test_set_world_updates_local from +# frame_view_contract_utils is imported via wildcard and will run as-is.) # ------------------------------------------------------------------ @@ -174,10 +157,11 @@ def test_fabric_set_world_does_not_write_back_to_usd(device, view_factory): usd_t_before = usd_tf_before.ExtractTranslation() orig_usd_pos = torch.tensor([float(usd_t_before[0]), float(usd_t_before[1]), float(usd_t_before[2])]) - # Write to Fabric — move to (99, 99, 99) + # Write to Fabric -- move to (99, 99, 99) new_pos = wp.zeros((1, 3), dtype=wp.float32, device=device) wp.launch(kernel=_fill_position, dim=1, inputs=[new_pos, 99.0, 99.0, 99.0], device=device) - view.set_world_poses(positions=new_pos) + with view.xform_world_space_writer() as w: + w.set_poses(positions=new_pos) # Verify Fabric has the new position fab_pos, _ = view.get_world_poses() @@ -187,7 +171,7 @@ def test_fabric_set_world_does_not_write_back_to_usd(device, view_factory): ) # Verify USD still has the ORIGINAL position (no writeback). Equality, not - # approximate — USD should literally not have moved, so any drift would + # approximate -- USD should literally not have moved, so any drift would # indicate a residual writeback path. xform_cache_after = UsdGeom.XformCache() usd_tf_after = xform_cache_after.GetLocalToWorldTransform(prim) @@ -200,55 +184,405 @@ def test_fabric_set_world_does_not_write_back_to_usd(device, view_factory): @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) -def test_fabric_rebuild_after_topology_change(device, view_factory, monkeypatch): - """Forcing the topology-changed branch on a write triggers - :meth:`_rebuild_fabric_arrays` and leaves the view in a state where - subsequent writes/reads still produce correct data. - - Real ``PrimSelection.PrepareForReuse`` reports topology change only when - Fabric reallocates internally, which is hard to provoke from a unit test. - Instead we monkeypatch ``_prepare_for_reuse`` on the instance to always - take the rebuild branch and verify the view remains usable. +def test_fabric_rebuild_after_topology_change(device, view_factory): + """A simulated topology change rebuilds the indexed fabric arrays and leaves + the view in a state where subsequent writes/reads still produce correct data. + + Real ``PrimSelection.PrepareForReuse`` reports topology change only when Fabric + reallocates internally, which is hard to provoke from a unit test. Instead we + invoke :meth:`FabricFrameView._compute_fabric_indices` and rebuild the indexed + arrays manually, mimicking what ``_get_*_array`` would do on a real topology + event, then verify a roundtrip still works. """ bundle = view_factory(2, device) view = bundle.view - # First write — initializes Fabric and binds _fabric_selection. + # First write -- initializes Fabric. initial = wp.zeros((2, 3), dtype=wp.float32, device=device) wp.launch(kernel=_fill_position, dim=2, inputs=[initial, 1.0, 2.0, 3.0], device=device) - view.set_world_poses(positions=initial) - - rebuild_calls = [] - real_rebuild = view._rebuild_fabric_arrays - - def spy_rebuild(): - rebuild_calls.append(True) - real_rebuild() - - def force_topology_changed(): - if view._fabric_selection is not None: - view._fabric_selection.PrepareForReuse() - spy_rebuild() - - monkeypatch.setattr(view, "_prepare_for_reuse", force_topology_changed) + with view.xform_world_space_writer() as w: + w.set_poses(positions=initial) + + # Simulate topology change: recompute per-selection fabric indices and rebuild + # every indexed array, mirroring the lazy paths in the ``_get_*_array`` accessors. + view._rebuild_trans_ro_arrays() + view._world_rw_fabric_indices = view._compute_fabric_indices(view._world_sel_rw) + view._world_ifa_rw = view._build_indexed_array( + view._world_sel_rw, view._WORLD_MATRIX_NAME, view._world_rw_fabric_indices + ) + view._local_rw_fabric_indices = view._compute_fabric_indices(view._local_sel_rw) + view._local_ifa_rw = view._build_indexed_array( + view._local_sel_rw, view._LOCAL_MATRIX_NAME, view._local_rw_fabric_indices + ) - # Trigger another write — goes through the forced topology-change branch. + # Trigger another write through the rebuilt arrays. new = wp.zeros((2, 3), dtype=wp.float32, device=device) wp.launch(kernel=_fill_position, dim=2, inputs=[new, 4.0, 5.0, 6.0], device=device) - view.set_world_poses(positions=new) - - assert rebuild_calls, "Forced topology-change branch did not invoke _rebuild_fabric_arrays" + with view.xform_world_space_writer() as w: + w.set_poses(positions=new) - # Read back — proves the rebuilt _view_to_fabric and _fabric_world_matrices - # are still consistent. ret_pos, _ = view.get_world_poses() pos_torch = torch.as_tensor(ret_pos, device=device) expected = torch.tensor([[4.0, 5.0, 6.0], [4.0, 5.0, 6.0]], device=device) - assert torch.allclose(pos_torch, expected, atol=1e-7), f"Read after rebuild failed on {device}: {pos_torch}" + # 1e-5 ≈ 20 ULP at magnitudes ~4-6; absorbs float32 SRT compose/decompose drift. + assert torch.allclose(pos_torch, expected, atol=1e-5), f"Read after rebuild failed on {device}: {pos_torch}" + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_prepare_for_reuse_detects_topology_change(device, view_factory): + """Each persistent ``PrimSelection`` exposes ``PrepareForReuse`` and returns a + bool. When the underlying Fabric topology is unchanged it returns False. + """ + bundle = view_factory(1, device) + view = bundle.view + view.get_world_poses() # trigger Fabric init + + assert view._trans_sel_ro is not None, "trans_sel_ro selection not initialized" + for selection in (view._trans_sel_ro, view._world_sel_rw, view._local_sel_rw): + result = selection.PrepareForReuse() + assert isinstance(result, bool), f"PrepareForReuse should return bool, got {type(result)}" + assert not result, "PrepareForReuse should return False when no topology change" + + +def _read_fabric_world_matrix_translation(view, prim_index=0): + """Read cached Fabric worldMatrix directly, without FrameView getter sync.""" + rt_prim = view._stage.GetPrimAtPath(view.prim_paths[prim_index]) + world_attr = rt_prim.GetAttribute(view._WORLD_MATRIX_NAME) + matrix = world_attr.Get() + translation = matrix.ExtractTranslation() + return torch.tensor( + [[float(translation[0]), float(translation[1]), float(translation[2])]], + dtype=torch.float32, + device=view._device, + ) + + +def _read_fabric_world_matrix_scale(view, prim_index=0): + """Read cached Fabric worldMatrix scale directly, without FrameView getter sync.""" + import usdrt # noqa: PLC0415 + + rt_prim = view._stage.GetPrimAtPath(view.prim_paths[prim_index]) + world_attr = rt_prim.GetAttribute(view._WORLD_MATRIX_NAME) + matrix = world_attr.Get() + scale = usdrt.Gf.Transform(matrix).GetScale() + return torch.tensor( + [[float(scale[0]), float(scale[1]), float(scale[2])]], + dtype=torch.float32, + device=view._device, + ) + + +def _read_fabric_local_matrix_translation(view, prim_index=0): + """Read cached Fabric localMatrix directly, without FrameView getter sync.""" + rt_prim = view._stage.GetPrimAtPath(view.prim_paths[prim_index]) + local_attr = rt_prim.GetAttribute(view._LOCAL_MATRIX_NAME) + matrix = local_attr.Get() + translation = matrix.ExtractTranslation() + return torch.tensor( + [[float(translation[0]), float(translation[1]), float(translation[2])]], + dtype=torch.float32, + device=view._device, + ) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_set_local_via_fabric_path(device, view_factory): + """Exercise the Fabric-native set_local_poses path. + + Ensures set_local_poses computes child_world = parent_world * local + entirely within Fabric (not falling back to USD) by first triggering + the Fabric sync via get_world_poses. + """ + bundle = view_factory(num_envs=1, device=device) + view = bundle.view + + # Trigger lazy `_initialize_fabric()` so subsequent calls take the Fabric path. + view.get_world_poses() + + # Now write via the writer scope (Fabric path). + new_local_pos = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_local_pos, 1.0, 2.0, 3.0], device=device) + ori = torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=torch.float32, device=device) + new_local_ori = wp.from_torch(ori) + + with view.xform_local_space_writer() as w: + w.set_poses(positions=new_local_pos, orientations=new_local_ori) + + # Verify: world = parent(0,0,1) + local(1,2,3) = (1,2,4) + world_pos, _ = view.get_world_poses() + expected = torch.tensor([[1.0, 2.0, 4.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(torch.as_tensor(world_pos, device=device), expected, atol=1e-4, rtol=0) + + # Verify get_local_poses returns the local offset + local_pos, _ = view.get_local_poses() + expected_local = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(torch.as_tensor(local_pos, device=device), expected_local, atol=1e-4, rtol=0) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_get_scales_fabric_path(device, view_factory): + """Exercise the Fabric-native get_world_scales path.""" + bundle = view_factory(num_envs=1, device=device) + view = bundle.view + + # Trigger lazy `_initialize_fabric()` so the get_world_scales call below uses Fabric. + view.get_world_poses() + + scales = view.get_world_scales() + scales_t = scales.torch + # Default scale should be (1, 1, 1) + expected = torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(scales_t, expected, atol=1e-4, rtol=0) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_local_scales_roundtrip(device, view_factory): + """set_local_scales -> get_local_scales roundtrip via localMatrix.""" + bundle = view_factory(num_envs=2, device=device) + view = bundle.view + + # Force Fabric init + view.get_world_poses() + + new_scales = wp.zeros((2, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=2, inputs=[new_scales, 2.0, 3.0, 4.0], device=device) + with view.xform_local_space_writer() as w: + w.set_scales(new_scales) + + ret_scales = view.get_local_scales() + scales_torch = ret_scales.torch + expected = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 4.0]], device=device) + torch.testing.assert_close(scales_torch, expected, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_world_scales_roundtrip(device, view_factory): + """set_world_scales -> get_world_scales roundtrip via worldMatrix.""" + bundle = view_factory(num_envs=2, device=device) + view = bundle.view + + # Force Fabric init + view.get_world_poses() + + new_scales = wp.zeros((2, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=2, inputs=[new_scales, 5.0, 6.0, 7.0], device=device) + with view.xform_world_space_writer() as w: + w.set_scales(new_scales) + + ret_scales = view.get_world_scales() + scales_torch = ret_scales.torch + expected = torch.tensor([[5.0, 6.0, 7.0], [5.0, 6.0, 7.0]], device=device) + torch.testing.assert_close(scales_torch, expected, atol=1e-5, rtol=0) + + +# ------------------------------------------------------------------ +# Transpose-convention verification: world ↔ local kernels rely on the +# identity ``(A·B)ᵀ = Bᵀ·Aᵀ`` to drop explicit transposes when operating +# on Fabric's column-transposed matrix storage. The translation-only +# parents used by the standard fixture cannot distinguish the right +# convention from the wrong one -- the rotation block is identity and +# equals its own transpose. These tests use a parent rotated 90° around +# Z so that an incorrect storage convention would produce a clearly +# wrong child pose. +# ------------------------------------------------------------------ + + +# Parent at (0, 0, 1) rotated +90° around Z (so the parent X axis points +# along world +Y). Quaternion components in (x, y, z, w) order. +_ROTATED_PARENT_POS = (0.0, 0.0, 1.0) +_ROTATED_PARENT_QUAT_XYZW = (0.0, 0.0, 0.70710678, 0.70710678) + + +def _build_rotated_parent_view(device: str) -> "FrameView": + """Build a 1-env FabricFrameView whose parent is rotated 90° around Z.""" + stage = sim_utils.get_current_stage() + sim_utils.create_prim( + "/World/Parent_0", + "Xform", + translation=_ROTATED_PARENT_POS, + orientation=_ROTATED_PARENT_QUAT_XYZW, + stage=stage, + ) + sim_utils.create_prim("/World/Parent_0/Child", "Camera", translation=(0.0, 0.0, 0.0), stage=stage) + sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.01, device=device, use_fabric=True)) + view = FrameView("/World/Parent_.*/Child", device=device) + view.get_world_poses() # force Fabric init and USD→Fabric seed + return view + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_set_local_then_get_world_with_rotated_parent(device): + """Verify ``update_indexed_world_matrix_from_local`` under non-identity parent rotation. + + With parent rotated +90° around Z, a child local translation of (1, 0, 0) + must produce world translation (0, 1, 1) -- parent_pos + R · local. If the + transpose convention in the kernel were wrong, the rotation would flip + direction and the world position would land at (0, -1, 1) instead. + """ + _skip_if_unavailable(device) + view = _build_rotated_parent_view(device) + + new_local = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_local, 1.0, 0.0, 0.0], device=device) + identity_quat = wp.from_torch(torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=torch.float32, device=device)) + with view.xform_local_space_writer() as w: + w.set_poses(positions=new_local, orientations=identity_quat) + + world_pos, _ = view.get_world_poses() + expected = torch.tensor([[0.0, 1.0, 1.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(torch.as_tensor(world_pos, device=device), expected, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_set_world_then_get_local_with_rotated_parent(device): + """Verify ``update_indexed_local_matrix_from_world`` under non-identity parent rotation. + + With parent rotated +90° around Z and at (0, 0, 1), writing child world + translation (5, 0, 2) must yield child local translation Rᵀ · (5, 0, 1) = + (0, -5, 1). A wrong transpose convention would invert the rotation in the + wrong direction and produce (0, 5, 1) instead. + """ + _skip_if_unavailable(device) + view = _build_rotated_parent_view(device) + + new_world = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_world, 5.0, 0.0, 2.0], device=device) + with view.xform_world_space_writer() as w: + w.set_poses(positions=new_world) + + local_pos, _ = view.get_local_poses() + expected = torch.tensor([[0.0, -5.0, 1.0]], dtype=torch.float32, device=device) + torch.testing.assert_close(torch.as_tensor(local_pos, device=device), expected, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_initial_seed_with_scaled_parent(device): + """Verify the initial USD→Fabric seed handles non-unit scales correctly. + + Sets up a parent with world scale (2, 1, 1) and a child with local scale + (3, 1, 1) at local translation (1, 0, 0). Expected world-space values for + the child: + + * world scale = parent_scale * child_local_scale = (6, 1, 1) + * world position = parent_pos + parent_scale * child_local_pos + = (0, 0, 1) + (2 * 1, 0, 0) = (2, 0, 1) + + If the parent's worldMatrix is seeded with a hardcoded unit scale, + ``get_scales`` returns (3, 1, 1) instead of (6, 1, 1) and ``get_world_poses`` + returns (1, 0, 1) instead of (2, 0, 1). If the child's localMatrix is + seeded without scale, after ``_sync_world_from_local_if_dirty`` the world + scale collapses to (2, 1, 1). This test catches both regressions. + """ + _skip_if_unavailable(device) + stage = sim_utils.get_current_stage() + sim_utils.create_prim("/World/Parent_0", "Xform", translation=(0.0, 0.0, 1.0), scale=(2.0, 1.0, 1.0), stage=stage) + sim_utils.create_prim( + "/World/Parent_0/Child", + "Camera", + translation=(1.0, 0.0, 0.0), + scale=(3.0, 1.0, 1.0), + stage=stage, + ) + sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.01, device=device, use_fabric=True)) + view = FrameView("/World/Parent_.*/Child", device=device) + + world_pos, _ = view.get_world_poses() + torch.testing.assert_close( + torch.as_tensor(world_pos, device=device), + torch.tensor([[2.0, 0.0, 1.0]], dtype=torch.float32, device=device), + atol=1e-5, + rtol=0, + ) + + scales = view.get_world_scales().torch + torch.testing.assert_close( + scales, + torch.tensor([[6.0, 1.0, 1.0]], dtype=torch.float32, device=device), + atol=1e-5, + rtol=0, + ) + + +# ------------------------------------------------------------------ +# Multi-view per stage: per-view single-writer isolation +# ------------------------------------------------------------------ + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_multi_view_writer_isolation(device): + """Two ``FabricFrameView`` instances on the same stage have independent writer scopes. + + Each view's ``_active_writer`` is per-instance, so view B's read or writer + scope must not interfere with view A's pending writes. Verifies that + writes through one view do not corrupt the other view's poses, and that + each view can hold its own writer scope concurrently with reads on the + other view. + """ + _skip_if_unavailable(device) + stage = sim_utils.get_current_stage() + + sim_utils.create_prim("/World/EnvA_0", "Xform", translation=(0.0, 0.0, 1.0), stage=stage) + sim_utils.create_prim("/World/EnvA_0/ChildA", "Camera", translation=(0.1, 0.0, 0.0), stage=stage) + sim_utils.create_prim("/World/EnvB_0", "Xform", translation=(0.0, 0.0, 2.0), stage=stage) + sim_utils.create_prim("/World/EnvB_0/ChildB", "Camera", translation=(0.2, 0.0, 0.0), stage=stage) + + sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.01, device=device, use_fabric=True)) + view_a = FrameView("/World/EnvA_.*/ChildA", device=device) + view_b = FrameView("/World/EnvB_.*/ChildB", device=device) + + expected_a0 = torch.tensor([[0.1, 0.0, 1.0]], dtype=torch.float32, device=device) + expected_b0 = torch.tensor([[0.2, 0.0, 2.0]], dtype=torch.float32, device=device) + torch.testing.assert_close( + torch.as_tensor(view_a.get_world_poses()[0], device=device), expected_a0, atol=1e-5, rtol=0 + ) + torch.testing.assert_close( + torch.as_tensor(view_b.get_world_poses()[0], device=device), expected_b0, atol=1e-5, rtol=0 + ) + + # Write a new local pose on view A only via a writer scope. + new_local_a = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_local_a, 1.0, 0.0, 0.0], device=device) + identity_quat = wp.from_torch(torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=torch.float32, device=device)) + with view_a.xform_local_space_writer() as w: + w.set_poses(positions=new_local_a, orientations=identity_quat) + + # View B remains undisturbed. + torch.testing.assert_close( + torch.as_tensor(view_b.get_world_poses()[0], device=device), expected_b0, atol=1e-5, rtol=0 + ) + + # View A's world reflects the new local. + expected_a1 = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32, device=device) + torch.testing.assert_close( + torch.as_tensor(view_a.get_world_poses()[0], device=device), expected_a1, atol=1e-5, rtol=0 + ) + + # Write a new local pose on view B; view A unaffected. + new_local_b = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_local_b, 3.0, 0.0, 0.0], device=device) + with view_b.xform_local_space_writer() as w: + w.set_poses(positions=new_local_b, orientations=identity_quat) + torch.testing.assert_close( + torch.as_tensor(view_a.get_world_poses()[0], device=device), expected_a1, atol=1e-5, rtol=0 + ) + expected_b1 = torch.tensor([[3.0, 0.0, 2.0]], dtype=torch.float32, device=device) + torch.testing.assert_close( + torch.as_tensor(view_b.get_world_poses()[0], device=device), expected_b1, atol=1e-5, rtol=0 + ) + + # Single-active-writer is per-view: opening a writer on A leaves B free. + with view_a.xform_world_space_writer(): + assert view_a._active_writer is not None + assert view_b._active_writer is None + # B can still open its own writer concurrently. + with view_b.xform_world_space_writer(): + assert view_b._active_writer is not None + assert view_a._active_writer is None + assert view_b._active_writer is None # ------------------------------------------------------------------ -# Multi-GPU tests (cuda:1) — skipped automatically on single-GPU workstations +# Multi-GPU tests (cuda:1) -- skipped automatically on single-GPU workstations # ------------------------------------------------------------------ @@ -268,7 +602,8 @@ def test_fabric_cuda1_world_pose_roundtrip(device, view_factory): new_pos = wp.zeros((2, 3), dtype=wp.float32, device=device) wp.launch(kernel=_fill_position, dim=2, inputs=[new_pos, 10.0, 20.0, 30.0], device=device) - view.set_world_poses(positions=new_pos) + with view.xform_world_space_writer() as w: + w.set_poses(positions=new_pos) ret_pos, _ = view.get_world_poses() pos_torch = torch.as_tensor(ret_pos, device=device) @@ -298,9 +633,10 @@ def test_fabric_cuda1_no_usd_writeback(device, view_factory): new_pos = wp.zeros((1, 3), dtype=wp.float32, device=device) wp.launch(kernel=_fill_position, dim=1, inputs=[new_pos, 99.0, 99.0, 99.0], device=device) - view.set_world_poses(positions=new_pos) + with view.xform_world_space_writer() as w: + w.set_poses(positions=new_pos) - # USD must not have moved at all — equality, not approximate. + # USD must not have moved at all -- equality, not approximate. t_after = UsdGeom.XformCache().GetLocalToWorldTransform(prim).ExtractTranslation() usd_pos_after = torch.tensor([float(t_after[0]), float(t_after[1]), float(t_after[2])]) assert torch.allclose(usd_pos_after, orig_usd_pos, atol=0.0), ( @@ -314,9 +650,9 @@ def test_fabric_cuda1_no_usd_writeback(device, view_factory): ) @pytest.mark.parametrize("device", ["cuda:1"]) def test_fabric_cuda1_scales_roundtrip(device, view_factory): - """set_scales -> get_scales roundtrip works on cuda:1. + """set_world_scales -> get_world_scales roundtrip works on cuda:1. - Both write paths (``set_world_poses`` and ``set_scales``) call + Both write paths (``set_world_poses`` and ``set_world_scales``) call ``_prepare_for_reuse`` and launch on ``self._device``; this test covers the scales path on the non-primary CUDA device. """ @@ -325,9 +661,322 @@ def test_fabric_cuda1_scales_roundtrip(device, view_factory): new_scales = wp.zeros((2, 3), dtype=wp.float32, device=device) wp.launch(kernel=_fill_position, dim=2, inputs=[new_scales, 2.0, 3.0, 4.0], device=device) - view.set_scales(new_scales) + with view.xform_world_space_writer() as w: + w.set_scales(new_scales) - ret_scales = view.get_scales() - scales_torch = torch.as_tensor(ret_scales, device=device) + ret_scales = view.get_world_scales() + scales_torch = ret_scales.torch expected = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 4.0]], device=device) assert torch.allclose(scales_torch, expected, atol=1e-7), f"Scales roundtrip failed on {device}: {scales_torch}" + + +# ------------------------------------------------------------------ +# Sequential writer scopes (interleaved world / local writes via two scopes) +# ------------------------------------------------------------------ + + +def _build_two_child_view(device: str) -> "FrameView": + """Build a 2-env FabricFrameView with rotated parent for interleave tests. + + Parent at (0, 0, 1) rotated 90° around Z. Two child prims at identity local. + """ + _skip_if_unavailable(device) + stage = sim_utils.get_current_stage() + for i in range(2): + sim_utils.create_prim( + f"/World/Parent_{i}", + "Xform", + translation=_ROTATED_PARENT_POS, + orientation=_ROTATED_PARENT_QUAT_XYZW, + stage=stage, + ) + sim_utils.create_prim(f"/World/Parent_{i}/Child", "Camera", translation=(0.0, 0.0, 0.0), stage=stage) + sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.01, device=device, use_fabric=True)) + view = FrameView("/World/Parent_.*/Child", device=device) + view.get_world_poses() # force init + return view + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_sequential_world_then_local_scopes_partial_indices(device): + """A world writer scope (idx 0), then a local writer scope (idx 1). Both correct.""" + view = _build_two_child_view(device) + + new_world_pos = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_world_pos, 5.0, 0.0, 2.0], device=device) + idx0 = wp.from_torch(torch.tensor([0], dtype=torch.int32, device=device)) + with view.xform_world_space_writer() as w: + w.set_poses(positions=new_world_pos, indices=idx0) + + new_local_pos = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_local_pos, 1.0, 0.0, 0.0], device=device) + identity_quat = wp.from_torch(torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=torch.float32, device=device)) + idx1 = wp.from_torch(torch.tensor([1], dtype=torch.int32, device=device)) + with view.xform_local_space_writer() as w: + w.set_poses(positions=new_local_pos, orientations=identity_quat, indices=idx1) + + # Verify index 0's world pose is still (5, 0, 2) -- index 1's local-scope write + # derives world from the just-written local; index 0 was outside the derived + # set on entry but the world-scope already wrote it and the second scope + # re-derives world from local for all prims (including idx 0). After the + # second scope, idx 0's local was derived from its world (= (0, -5, 1)), + # so re-deriving world = parent * local lands back on (5, 0, 2). + world_pos, _ = view.get_world_poses(indices=idx0) + torch.testing.assert_close( + torch.as_tensor(world_pos, device=device), + torch.tensor([[5.0, 0.0, 2.0]], dtype=torch.float32, device=device), + atol=1e-5, + rtol=0, + ) + + # Index 0's local (derived from its world): + # local = Rᵀ · (child_world_pos - parent_pos) = Rz(-90)·(5, 0, 1) = (0, -5, 1) + local_pos_0, _ = view.get_local_poses(indices=idx0) + torch.testing.assert_close( + torch.as_tensor(local_pos_0, device=device), + torch.tensor([[0.0, -5.0, 1.0]], dtype=torch.float32, device=device), + atol=1e-5, + rtol=0, + ) + + # Index 1's world (derived from local): + # world = parent_world * local = Rz(90)·(1, 0, 0) + parent_pos = (0, 1, 0) + (0, 0, 1) = (0, 1, 1) + world_pos_1, _ = view.get_world_poses(indices=idx1) + torch.testing.assert_close( + torch.as_tensor(world_pos_1, device=device), + torch.tensor([[0.0, 1.0, 1.0]], dtype=torch.float32, device=device), + atol=1e-5, + rtol=0, + ) + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_sequential_local_then_world_scopes_partial_indices(device): + """A local writer scope (idx 0), then a world writer scope (idx 1). Both correct.""" + view = _build_two_child_view(device) + + new_local_pos = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_local_pos, 2.0, 3.0, 0.0], device=device) + identity_quat = wp.from_torch(torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=torch.float32, device=device)) + idx0 = wp.from_torch(torch.tensor([0], dtype=torch.int32, device=device)) + with view.xform_local_space_writer() as w: + w.set_poses(positions=new_local_pos, orientations=identity_quat, indices=idx0) + + new_world_pos = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_world_pos, 10.0, 20.0, 30.0], device=device) + idx1 = wp.from_torch(torch.tensor([1], dtype=torch.int32, device=device)) + with view.xform_world_space_writer() as w: + w.set_poses(positions=new_world_pos, indices=idx1) + + # Index 0's world (derived from local): + # world = Rz(90)·(2, 3, 0) + (0, 0, 1) = (-3, 2, 0) + (0, 0, 1) = (-3, 2, 1) + world_pos_0, _ = view.get_world_poses(indices=idx0) + torch.testing.assert_close( + torch.as_tensor(world_pos_0, device=device), + torch.tensor([[-3.0, 2.0, 1.0]], dtype=torch.float32, device=device), + atol=1e-5, + rtol=0, + ) + + # Index 1's world is still (10, 20, 30). + world_pos_1, _ = view.get_world_poses(indices=idx1) + torch.testing.assert_close( + torch.as_tensor(world_pos_1, device=device), + torch.tensor([[10.0, 20.0, 30.0]], dtype=torch.float32, device=device), + atol=1e-5, + rtol=0, + ) + + # Index 1's local (derived from world): + # local = Rᵀ·(world - parent) = Rz(-90)·(10, 20, 29) = (20, -10, 29) + local_pos_1, _ = view.get_local_poses(indices=idx1) + torch.testing.assert_close( + torch.as_tensor(local_pos_1, device=device), + torch.tensor([[20.0, -10.0, 29.0]], dtype=torch.float32, device=device), + atol=1e-5, + rtol=0, + ) + + +# ------------------------------------------------------------------ +# FrameViewSpaceWriterBase contract tests +# ------------------------------------------------------------------ + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_world_writer_writes_world_and_derives_local(device, view_factory): + """A world writer's set_poses + set_scales updates cached Fabric world AND derives local on exit.""" + bundle = view_factory(num_envs=1, device=device) + view = bundle.view + view.get_world_poses() # trigger Fabric init + + new_pos = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_pos, 1.0, 2.0, 4.0], device=device) + new_scales = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_scales, 2.0, 3.0, 4.0], device=device) + + with view.xform_world_space_writer() as w: + w.set_poses(positions=new_pos) + w.set_scales(new_scales) + + # Cached world reflects the writes (parent is at (0, 0, 1) so child world pos + # is whatever we wrote). + expected_world = torch.tensor([[1.0, 2.0, 4.0]], dtype=torch.float32, device=device) + cached_world = _read_fabric_world_matrix_translation(view) + torch.testing.assert_close(cached_world, expected_world, atol=1e-5, rtol=0) + + expected_scale = torch.tensor([[2.0, 3.0, 4.0]], dtype=torch.float32, device=device) + cached_scale = _read_fabric_world_matrix_scale(view) + torch.testing.assert_close(cached_scale, expected_scale, atol=1e-5, rtol=0) + + # Local derived: with parent at (0, 0, 1) (identity rotation, unit scale), + # local translation = world - parent = (1, 2, 3). + expected_local = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32, device=device) + cached_local = _read_fabric_local_matrix_translation(view) + torch.testing.assert_close(cached_local, expected_local, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_local_writer_writes_local_and_derives_world(device, view_factory): + """A local writer's set_poses updates cached Fabric local AND derives world on exit.""" + bundle = view_factory(num_envs=1, device=device) + view = bundle.view + view.get_world_poses() + + new_pos = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_pos, 1.0, 2.0, 3.0], device=device) + identity_quat = wp.from_torch(torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=torch.float32, device=device)) + + with view.xform_local_space_writer() as w: + w.set_poses(positions=new_pos, orientations=identity_quat) + + expected_local = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32, device=device) + cached_local = _read_fabric_local_matrix_translation(view) + torch.testing.assert_close(cached_local, expected_local, atol=1e-5, rtol=0) + + # World derived: with parent at (0, 0, 1), world = parent + local = (1, 2, 4). + expected_world = torch.tensor([[1.0, 2.0, 4.0]], dtype=torch.float32, device=device) + cached_world = _read_fabric_world_matrix_translation(view) + torch.testing.assert_close(cached_world, expected_world, atol=1e-5, rtol=0) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_writer_single_derivation_per_scope(device, view_factory, monkeypatch): + """Multiple set_* calls inside one scope produce exactly one derive-kernel launch.""" + bundle = view_factory(num_envs=1, device=device) + view = bundle.view + view.get_world_poses() + + calls = 0 + original = view._recompute_local_from_world_all + + def counted(): + nonlocal calls + calls += 1 + original() + + monkeypatch.setattr(view, "_recompute_local_from_world_all", counted) + + new_pos = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_pos, 1.0, 2.0, 4.0], device=device) + new_scales = wp.zeros((1, 3), dtype=wp.float32, device=device) + wp.launch(kernel=_fill_position, dim=1, inputs=[new_scales, 2.0, 3.0, 4.0], device=device) + + with view.xform_world_space_writer() as w: + w.set_poses(positions=new_pos) + w.set_scales(new_scales) + assert calls == 0 # no derive yet + + assert calls == 1 # exactly one derive on exit + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_writer_single_active_invariant(device, view_factory): + """Only one writer scope may be active per view at a time.""" + bundle = view_factory(num_envs=1, device=device) + view = bundle.view + view.get_world_poses() + + with view.xform_world_space_writer(): + with pytest.raises(RuntimeError, match="already has an active writer"): + view.xform_world_space_writer().__enter__() + with pytest.raises(RuntimeError, match="already has an active writer"): + view.xform_local_space_writer().__enter__() + # After the outer scope exits, the lock is released and a new scope succeeds. + with view.xform_local_space_writer(): + pass + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_writer_restores_hierarchy_change_tracking(device, view_factory): + """``__exit__`` restores the prior ``track_*_xform_changes`` state (don't re-enable paused listeners).""" + bundle = view_factory(num_envs=1, device=device) + view = bundle.view + view.get_world_poses() + h = view._fabric_hierarchy + + # Case 1: pre-paused local stays paused after exit. + h.track_local_xform_changes(False) + assert not h.tracking_local_xform_changes + with view.xform_world_space_writer(): + pass + assert not h.tracking_local_xform_changes, "writer must not re-enable a pre-paused local listener" + + # Case 2: pre-enabled local stays enabled after exit. + h.track_local_xform_changes(True) + with view.xform_world_space_writer(): + pass + assert h.tracking_local_xform_changes, "writer must restore the pre-enabled local listener" + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_writer_empty_scope_does_no_derivation(device, view_factory, monkeypatch): + """Entering and exiting a writer scope without any ``set_*`` call must not launch the derive kernel.""" + bundle = view_factory(num_envs=1, device=device) + view = bundle.view + view.get_world_poses() + + calls = 0 + original = view._recompute_local_from_world_all + + def counted(): + nonlocal calls + calls += 1 + original() + + monkeypatch.setattr(view, "_recompute_local_from_world_all", counted) + + with view.xform_world_space_writer(): + pass + + assert calls == 0 + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_view_getter_inside_scope_raises(device, view_factory): + """View-level getters raise ``RuntimeError`` while a writer scope is active.""" + bundle = view_factory(num_envs=1, device=device) + view = bundle.view + view.get_world_poses() + + with view.xform_world_space_writer(): + with pytest.raises(RuntimeError, match="while a writer scope is active"): + view.get_world_poses() + with pytest.raises(RuntimeError, match="while a writer scope is active"): + view.get_local_poses() + with pytest.raises(RuntimeError, match="while a writer scope is active"): + view.get_world_scales() + with pytest.raises(RuntimeError, match="while a writer scope is active"): + view.get_local_scales() + # After the scope exits, view getters work again. + view.get_world_poses() + + +def test_set_world_scales_method_no_longer_exists(): + """``set_world_scales`` / ``set_local_scales`` were deleted from this PR's surface.""" + # The deprecated set_world_poses / set_local_poses shims remain (with warnings), + # but the never-shipped set_world_scales / set_local_scales were removed. + from isaaclab.sim.views import BaseFrameView # noqa: PLC0415 + + assert not hasattr(BaseFrameView, "set_world_scales") + assert not hasattr(BaseFrameView, "set_local_scales")