Skip to content

fix: correct LoRA initialization and forward pass under tensor parallelism#150

Open
chen2021673 wants to merge 4 commits into
masterfrom
lora_ddp_loss
Open

fix: correct LoRA initialization and forward pass under tensor parallelism#150
chen2021673 wants to merge 4 commits into
masterfrom
lora_ddp_loss

Conversation

@chen2021673
Copy link
Copy Markdown
Contributor

@chen2021673 chen2021673 commented Apr 30, 2026

Summary

Fix LoRA loss divergence under tensor/data parallel training by aligning LoRA initialization, forward collective ordering, and test-time weight loading.

Changes

  • Fuse base linear and LoRA linear contributions locally before running TP collectives:
    • LoRAColumnParallelLinear: compute base shard + LoRA shard first, then run one gather when needed.
    • LoRARowParallelLinear: compute base shard + LoRA shard first, then run one reduce/reduce-scatter when needed.
  • Make TP LoRA initialization deterministic:
    • broadcast replicated ColumnParallel lora_A from TP rank 0.
    • initialize full RowParallel lora_A on TP rank 0 and scatter shards by TP rank.
  • Add in-place ProcessGroup::BroadCast and ScatterFromRank helpers for TP-aware initialization.
  • Extend LoRA weight loading to slice full saved tensors into TP-local shards when target shapes differ.
  • Update model/profile test config to load fixed LoRA weights for selected GPT2/LLaMA3 LoRA cases.

Motivation

Previously, LoRA and base linear paths could perform separate TP collectives. That changed the floating-point reduction/gather ordering and made LoRA runs diverge across TP/DDP configurations. Replicated and sharded LoRA initialization also needed to be rank-consistent so every TP rank starts from the same logical LoRA weights.

…ence

Inline base and LoRA matmuls, add locally, then issue a single
AllGather/AllReduce instead of two separate collective ops. The prior
two-collective approach caused floating-point divergence in DDP loss.

Also fix LoadLoRAWeights to slice sharded tensors by tp_rank when the
checkpoint shape differs from the partitioned model shape.
Replace BroadCast's allocate-then-return signature with an in-place form
(void return) that takes pre-grouped tensors per local device. Lets root
ranks broadcast directly out of the source tensor with no self-copy and
no extra allocation. Add ScatterFromRank as the multi-process counterpart
to Scatter for the same reason. Use both in LoRA*ParallelLinear so TP
rank-0 init no longer pays a tp_size-fold communication or scratch cost.
Remove the device-grouped tensor layout requirement from BroadCast and
ScatterFromRank — derive each tensor's group rank from its own device
instead. LoadLoRAWeights now infers TP rank from the destination
tensor's device with strict shape/divisibility checks, rather than
relying on global tp_rank.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant