Fsdp2 stormscope [WIP]#1671
Conversation
| forward_prefetch=True, # Optimization for faster training | ||
| backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Backward prefetching for overlap | ||
| ) | ||
| # FSDP2 rejects non-contiguous parameters (PyTorch <= 2.10): |
There was a problem hiding this comment.
NOTE: this block exists solely for backward compatibility with PyTorch <= 2.10. Do we care about backward compatibility?
Greptile SummaryThis PR migrates StormScope/StormCast from FSDP1 (
Important Files Changed
|
| if isinstance(inner, FSDPModule): | ||
| bases = type(inner).__bases__ | ||
| if len(bases) >= 2 and bases[0] is FSDPModule: | ||
| return bases[1].__name__ | ||
| return type(inner).__name__ |
There was a problem hiding this comment.
Fragile MRO assumption for FSDP2 class name
_unwrapped_class_name returns bases[1].__name__ only when bases[0] is FSDPModule. If a future PyTorch version changes the order of bases or introduces an intermediate mixin in the dynamically-generated class (e.g. (FSDPModule, SomeMixin, OriginalCls)), the condition bases[0] is FSDPModule still holds but bases[1].__name__ would return SomeMixin instead of the real user class, silently generating the wrong checkpoint filename. Using type(inner).__mro__ to find the first non-FSDPModule/torch.nn.Module base would be more resilient.
| with torch.no_grad(): | ||
| for p in model.parameters(): | ||
| if p.is_contiguous(): | ||
| continue | ||
| p.data = p.data.contiguous() |
There was a problem hiding this comment.
When
use_shard_tensor=True, distribute_module has already been called and model.parameters() yields DTensor-backed nn.Parameters. Assigning p.data = p.data.contiguous() on a DTensor parameter is not documented PyTorch API; Tensor.set_() (which backs the .data setter) with a DTensor argument may silently strip the DTensor's mesh/placements metadata, breaking the subsequent fully_shard call. In practice distribute_tensor normalises contiguity internally so the guard p.is_contiguous() is usually True for DTensor params and the assignment is skipped — but making the skip explicit prevents a silent breakage if that behaviour changes.
| with torch.no_grad(): | |
| for p in model.parameters(): | |
| if p.is_contiguous(): | |
| continue | |
| p.data = p.data.contiguous() | |
| with torch.no_grad(): | |
| for p in model.parameters(): | |
| if isinstance(p.data, DTensor): | |
| continue # distribute_module already normalises DTensor local shards | |
| if p.is_contiguous(): | |
| continue | |
| p.data = p.data.contiguous() |
PhysicsNeMo Pull Request
Migrates StormScope off FSDP1 onto FSDP2 (
fully_shard/FSDPModule)....Description
FSDP1's flat-param machinery doesn't compose with
ShardTensor/ DTensor.This is the immediate motivator: the refactored
ShardTensorin #1556 breaks FSDP1's backward pass in StormScope, so until StormScope/StormCast move to FSDP2, domain parallelism implementations were not working with FSDP (or using DDP entirely). The current implementation is DDP only.Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.