Add multigpu training entrypoint for easier use#5661
Conversation
There was a problem hiding this comment.
Code Review Summary
This PR adds a convenient train_multigpu entrypoint for launching distributed RL training with torch.distributed.run. The implementation is clean and well-structured. Below are observations and suggestions:
✅ What Looks Good
- Clean architecture: Clear separation between launcher arguments and training arguments forwarded to the underlying script
- Comprehensive torchrun support: Exposes all major torchrun parameters (nnodes, rdzv_*, master_addr/port, etc.) for multi-node scenarios
- Dry-run mode: Useful for debugging and CI verification
- Good test coverage: Tests cover single-node, multi-node, and CLI integration paths
- Documentation: Clear examples in
uv_run.rstshowing basic and advanced usage
🔍 Observations and Suggestions
1. Script Path Resolution (Minor)
File: scripts/reinforcement_learning/train_multigpu.py (lines 16-17)
SCRIPT_DIR = Path(__file__).resolve().parent
TRAIN_SCRIPT = SCRIPT_DIR / "train.py"This relies on the script being run from the source tree layout. When installed as a package entry point, __file__ resolves correctly, but consider adding a fallback or validation:
if not TRAIN_SCRIPT.exists():
raise FileNotFoundError(f"Training script not found: {TRAIN_SCRIPT}")2. Distributed Flag Position (Minor)
File: scripts/reinforcement_learning/train_multigpu.py (lines 84-87)
def _with_distributed_arg(train_args: list[str]) -> list[str]:
if "--distributed" in train_args:
return train_args
return ["--distributed", *train_args]Prepending --distributed to the front of train_args should work with argparse, but some scripts may have positional arguments. Consider appending instead if any future training libraries have positional args, though this is fine for current usage.
3. Subprocess Signal Handling (Suggestion)
File: scripts/reinforcement_learning/train_multigpu.py (lines 137-138)
return subprocess.run(command, check=False).returncodeFor distributed training, Ctrl+C propagation is important. Consider forwarding signals to the subprocess:
import signal
def main(argv: list[str] | None = None) -> int:
# ... existing code ...
proc = subprocess.Popen(command)
signal.signal(signal.SIGTERM, lambda *_: proc.terminate())
signal.signal(signal.SIGINT, lambda *_: proc.terminate())
return proc.wait()This ensures clean shutdown when the user cancels training.
4. Test Coverage Enhancement (Optional)
File: source/isaaclab/test/cli/test_train_multigpu.py
Consider adding a test for invalid library selection:
def test_invalid_library_rejected():
"""Invalid --rl_library should be rejected by argparse."""
with pytest.raises(SystemExit):
TRAIN_MULTIGPU._parse_args(["--rl_library", "invalid_lib"])Summary
This is a well-implemented feature that significantly improves the multi-GPU training UX. The suggestions above are minor improvements rather than blockers. The PR is ready for merge after addressing any items the author considers worthwhile.
Verdict: Approve ✅
Update (c0bf9e6): Reviewed incremental changes since ea57cad. New commits add unrelated documentation updates (ecosystem docs, installation docs), CI test infrastructure (install workflow tests, conda Docker layer), and changelog entries for other features. No changes to the multigpu training entrypoint or its tests. Original approval stands.
Update (db1d900):
- ✅ Signal handling fixed — Added
_run_torchrun_command()with proper SIGTERM/SIGINT forwarding (addresses suggestion #3) - ✅ Distributed flag position fixed — Now appends
--distributedinstead of prepending (addresses observation #2) ⚠️ Test file removed —test_train_multigpu.pywas deleted (133 lines). Was this intentional? Tests were previously noted as a positive. Consider restoring or relocating them.
Overall improvements are good, but test removal should be clarified.
Update (d28b82e): Reviewed incremental changes. No modifications to the multi-GPU training entrypoint in this push. New commits contain: wheel workflow CI improvements, pyproject.toml metadata cleanups, camera sensor refactors, version bumps, and test harness updates for other features. Previous question about test file removal remains open. Original approval stands.
Update (f89b4d3): Reviewed incremental changes since d28b82e.
New in this push:
- ✅ skrl JAX launcher support — Added
_build_skrl_jax_command()and_is_skrl_jax_launcher()to supportskrl.utils.distributed.jaxas an alternative to torchrun for JAX-based training - ✅ Refactored dispatcher — New
_build_distributed_command()cleanly routes to the appropriate launcher based on library/framework - ✅ Validation — Added
_validate_launcher_args()to reject torchrun-only args when using skrl JAX - ✅ New CLI options — Added
--coordinator_addressand--ml_frameworkforwarding for JAX multi-node setup - ✅ Changelog updated — Now mentions skrl JAX support
The implementation is well-structured and maintains backward compatibility with existing torchrun usage. Original approval stands.
Update (3ffbbd2): Reviewed incremental changes since f89b4d3. No modifications to the multi-GPU training entrypoint in this push. New commits contain: OVRTX renderer cleanup removing legacy 0.2.x code paths (imports, kernels, and test coverage for deprecated functionality). This is unrelated maintenance work merged into the branch. Original approval stands.
Update (299b79e): Reviewed incremental changes since 3ffbbd2.
Documentation consolidation:
- ✅ New comprehensive docs — Added
docs/source/features/multi_gpu.rstsection with detailedtrain_multigpucommand documentation including tab-set examples for bothisaaclab.shanduv runapproaches - ✅ Coverage of all modes — Single-node, custom GPU count, skrl JAX, and multi-node configurations all documented
- ✅ De-duplicated uv_run.rst — Removed redundant multi-GPU section from installation docs (now consolidated in features/multi_gpu.rst)
- ✅ Help text cleanup — Removed
--headlessfrom example strings in script epilog (aligns with code patterns where headless is not always required)
Good documentation housekeeping. The consolidated location in multi_gpu.rst is the appropriate home for this content. Original approval stands.
Update (797c4d6): Reviewed incremental changes since 299b79e.
Changes in this push:
1. Multi-GPU validation (train_multigpu.py) ✅
- Added
_get_visible_cuda_device_count()to detect available CUDA devices viaCUDA_VISIBLE_DEVICESortorch.cuda.device_count() - Added
_validate_num_gpus_against_visible_devices()that errors early if--num_gpusexceeds visible devices - Good UX improvement — users get clear error messages instead of cryptic torchrun failures
2. Export script cleanup (export.py) ✅
- Fixed parser description: "Train" → "Export" (was copy-paste artifact)
- Renamed
train_task_name→checkpoint_task_namefor semantic clarity - Added null check:
if not resume_path: return Falsewith helpful message - These are polish improvements that enhance code quality
3. LEAPP annotation (anymal_c_env.py) ✅
- Added
height_datato LEAPP input annotations when available - Removed outdated comment about height data not being used
4. Docs (exporting_direct_workflow_policies_with_leapp.rst) ✅
- Updated emphasize-lines range to match code changes (92-105 → 92-107)
All changes are solid improvements. Original approval stands.
Signed-off-by: Mustafa H <34825877+StafaH@users.noreply.github.com>
| --run_name gpu4_vis presets=newton | ||
|
|
||
| Use ``--rl_library`` for other distributed-capable libraries: ``rsl_rl``, ``rl_games``, or ``skrl``. | ||
| For multi-node jobs, pass torchrun settings such as ``--nnodes``, ``--node_rank``, |
There was a problem hiding this comment.
What about if I want to run multi-gpu with JAX: https://isaac-sim.github.io/IsaacLab/main/source/features/multi_gpu.html#jax-implementation ?
Also, in the JAX multi-GPU setup, parameters such as rdzv_backend, rdzv_endpoint and rdzv_id do not exist
There was a problem hiding this comment.
Hey @Toni-SM, good question! There's a couple choices we can try. We can detect in the arguments if the user has skrl + jax and modify the command for them. We could also create a special simple argument for jax (--jax) that can be used in combination with --rl_library skrl. For args validation, we can add that as well, that's what makes this new entry point script really strong, we can do very quick early parsing and make sure the correct args are there and error out very early if they arent.
Which option would you prefer?
|
Update (550f9f4): Reviewed latest merge from develop. No modifications to the multi-GPU training entrypoint code in this sync. All previous observations remain valid:
CI checks are passing. The feature implementation is clean and ready for merge once the test removal is addressed or confirmed intentional. Verdict: Approve ✅ |
Description
Adds a
train_multigpuentry point for launching distributed RL training withtorch.distributed.run.The wrapper defaults to
rsl_rl, automatically forwards--distributed, and keeps normal trainingarguments available without requiring users to type the full torch distributed command.
Example commands (using the new uv workflow, which requires no installation):
No new dependencies are required.
Fixes #
Type of change
Screenshots
Not applicable.
Checklist
pre-commitchecks with./isaaclab.sh --formatCONTRIBUTORS.mdor my name already exists there