Skip to content

feat: Support ZeRO-2 based on DistributedOptimizer#110

Open
Chamberlain0w0 wants to merge 5 commits into
masterfrom
feat/zero2
Open

feat: Support ZeRO-2 based on DistributedOptimizer#110
Chamberlain0w0 wants to merge 5 commits into
masterfrom
feat/zero2

Conversation

@Chamberlain0w0
Copy link
Copy Markdown
Contributor

@Chamberlain0w0 Chamberlain0w0 commented Feb 28, 2026

基于框架目前 DistOpt 的建设,实现的 ZeRO-2 的梯度分片显存优化策略。

  1. 用户接口修改:添加 zero_stage 的 gflag,在启用 --use_distributed_optimizer 的同时可以指定 zero 级别(目前 zero3 为占位符);zero_stage 的信息也作为成员变量存在 DDPConfig 类里。

  2. 实现上的修改
    a. 核心逻辑:ZeRO-2 的核心是对模型参数所对应的梯度信息也按 dp 来分片存储,每个 rank 拿到自己负责的那部分;考虑到原先的 DistOpt 实现依赖于一个大的一维连续 ParamAndGradBuffer,所以为了实现 ZeRO-2,也就需要在初始化时不构造全量的 grad_buffer,仅构造每个 shard 大小的 grad_buffer。

    b. grad 于 ParamAndGradBucketGroup 创建的时候构造(见 ParamAndGradBucketGroup 构造函数):每个 group 单独构造各自 rank 上面的 shard grad buffer,以 grad_shard_buffer_list_ 的成员变量存储(按 buckets 存成一个 list,但是实际上默认情况就是一个 group 一个 bucket,所以这里就是一个 size()==1 的 list)。

    c. Autograd 反向流程中,按需临时分配内存创建 full grad,用完后释放。考虑到之前修改了 tensor->grad 的 lazy init,以及每轮可能存在的 ZeroGrad(set_to_none=true),此 full grad 创建时机位于 AccumulateGrad::Backward。

    d. 补充上述 c 的细节:为了不在 AccumulateGrad::Backward 插入过多 zero2 相关的 if else 判断污染代码逻辑,直接定义一个 pre-accumulate-grad 的 bypass function,用于劫持原先的 AccumulateGrad::Backward,从而实现流程的重写,替换为:视情况创建 full grad;把其中与 tensor->grad 有关的操作,都改为 full grad 的操作;正常完成梯度更新。

  3. 局限性分析:目前的 ZeRO-2 实现上,相当于把 full grad 从原先“显存上的永久存储”的角色转换为了一个“autograd 反向流程中随用随分配并及时释放的激活内存”的角色,从而使得显存下降。但是最坏情况下,仍然可能存在一种情况导致显存优化效果不太明显:计算/通信太慢,full grad 的 cudaFree 操作排在流后面太久还没释放,full grad 作为激活值会较长时间占据内存。目前的做法是适当减小 bucket size,这样 full grad 的 reduce scatter 操作会变得更细粒度一些,更能贴近随用随释放的目标。

@Chamberlain0w0 Chamberlain0w0 changed the title [WIP] feat: Support ZeRO-2 based on DistributedOptimizer feat: Support ZeRO-2 based on DistributedOptimizer Apr 27, 2026
StartGradSync();
}

if (params_with_grad_.empty()) {
Copy link
Copy Markdown
Contributor

@chen2021673 chen2021673 May 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没有强 wait ReduceScatter 完成?optimizer 可能读到还没同步完成的 grad_shard_buffer

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是下面第三条发现梯度累积下面的问题后,修改的时候忘记改这部分了,我统一下逻辑

}
std::weak_ptr<ParamAndGradBucketGroup> weak_group = it->second;
param->SetGradAccumulateBypass(
[weak_group, param](const std::shared_ptr<Tensor> &grad_output, bool overwrite, float learning_rate) {
Copy link
Copy Markdown
Contributor

@chen2021673 chen2021673 May 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里param是一个shared_ptr,param Tensor 持有 grad_accumulate_bypass,这个 lambda 函数里面又捕获 param,会不会形成引用环导致该tensor不析构?这里weak_group 应该就是解决这一问题的,param会不会有同样问题

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

太细了,改了

return;
}
// TODO(zbl): check this if sync is only done in last mircobatch
// if (!inserted) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块不需要的话可以删掉,留TODO就行

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块其实是之前实现的一个遗漏的小问题:如果存在梯度累积,其实本质上只需要最后一个 microbatch 才进行梯度通信同步就可以了,但我们的 ddp 实现里似乎每一个 microbatch 都正常通信同步,这块不影响数值正确性,但确实有很多不必要的通信操作。当时想了下怎么做这块,但是可能会涉及到 DDP 与 PP 的再耦合、以及可能需要加一个全局 config 来控制这个,所以考虑后续单独再提 PR 修改了,改的就是这块的逻辑,所以这里先留了 TODO,到时候取消注释就可以。


if (params_with_grad_.size() == params_.size()) {
// All param grads are ready in this group, trigger grad sync
StartGradSync();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里注释是说ready in this group,但我看实现里遍历了所有bucket group,这个是符合预期的吗

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个表示这个 group 里面所有 bucket(通常状况就是一个)的 param 集合都 ready 了


// Only register grads as ready when processing the last microbatch
// TODO(zbl): Only register grads as ready and trigger grad sync when processing the last microbatch
// For now, is_last_microbatch_ is always true
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么is_last_microbatch_ always true 呀

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

跟上面说的问题一致:

如果存在梯度累积,其实本质上只需要最后一个 microbatch 才进行梯度通信同步就可以了,但我们的 ddp 实现里似乎每一个 microbatch 都正常通信同步,这块不影响数值正确性,但确实有很多不必要的通信操作。

对标 Megatron 的接口,ParamAndGradBucketGroup 提供一个 is_last_microbatch 的 flag,来判断要不要触发通信,标准做法是梯度累积下只有 last_microbatch=True 才进行通信。目前的实现里,由于每个 microbatch 都同步,所以每个 microbatch 都是 last_microbatch=True,同时也没有加任何对 last_microbatch 做 set 的操作,相当于是个占位符

}

void ParamAndGradBucket::ScaleGradients(float scaling_factor) {
if (!grad_data_ || scaling_factor == 1.f) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad_data_不一定有了,这里的判断还合理吗?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实是,我加下对应逻辑吧

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不太好搞,zero2 下面这个关系不太绑定,暂时考虑在调用的地方做分叉,zero1 还调这个 function,zero2 的话手动对 full_grad_buffer 做 scale

Comment thread example/llama3/main.cc
// optimization
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那应该加个检查,如果没开 use_distributed_optimizer但设置了zero_stage直接报错,而不是静默

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实是,我加上

temp_full_grad_initialized_[bucket_idx] = false;
}

if (!temp_full_grad_initialized_[bucket_idx]) {
Copy link
Copy Markdown
Contributor

@chen2021673 chen2021673 May 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有点怪,temp_full_grad_buffer_ 在第一次使用时必然需要清零,那在 AllocateFlatBuffer 后立即 Fill(0.0f)就可以,为什么还需要 temp_full_grad_initialized_ 这个flag。另外在并行状态下这个写入可靠吗

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实,可能改岔忘删了。我感觉这块不涉及并行上的问题,本质就是 autograd 流程上会调这个(具体在 AccumulateGrad::Backward 里),目前 autograd 本身流程还是串行的(只是后面触发的通信可能存在重叠)

void ResetAccumulator();

// ZeRO-2: Use this function to take over AccumulateGrad::Backward
using GradAccumulateBypass
Copy link
Copy Markdown
Contributor

@chen2021673 chen2021673 May 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里我认为不要让Tensor 持有 GradAccumulateBypass,用 post_accumulate_grad_hook_ :把 hook 接口扩展一个 TryBypassAccumulate(grad, overwrite, lr),默认返回 false。AccumulateGrad::Backward() 先问 hook 是否接管,ZeRO2 hook 返回 true。这样不污染 Tensor 基类太多,只复用已有 hook 生命周期。

class PostAccumulateGradHook {
public:
    virtual void operator()(const std::shared_ptr<Tensor> &grad) = 0;

    virtual bool TryBypassAccumulate(const std::shared_ptr<Tensor> &grad_output,
                                     bool overwrite,
                                     float learning_rate) {
        return false;
    }

    virtual ~PostAccumulateGradHook() = default;
};

然后 AccumulateGrad ::Backward() 里变成:

const bool overwrite = tensor_->ConsumeGradOverwriteFlag();

auto hook = tensor_->post_accumulate_grad_hook();
if (hook && hook->TryBypassAccumulate(grad_output, overwrite, learning_rate_)) {
    tensor_->ResetAccumulator();
    return {};
}

再给ZeRO2 定义一个 hook

class Zero2AccumulateGradHook final : public autograd::PostAccumulateGradHook {
public:
    Zero2AccumulateGradHook(std::weak_ptr<ParamAndGradBucketGroup> group,
                            std::shared_ptr<Tensor> param)
        : group_(std::move(group)), param_(std::move(param)) {}

    bool TryBypassAccumulate(const std::shared_ptr<Tensor> &grad_output,
                             bool overwrite,
                             float learning_rate) override {
        if (auto group = group_.lock()) {
            group->AccumulateParamGrad(param_, grad_output, overwrite, learning_rate);
            if (group->config().overlap_grad_reduce) {
                group->RegisterGradReady(param_);
            }
            return true;
        }
        return false;
    }

    void operator()(const std::shared_ptr<Tensor> &) override {}

private:
    std::weak_ptr<ParamAndGradBucketGroup> group_;
    std::shared_ptr<Tensor> param_;
};

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

很妙,这样其实既能不改 tensor,也省得到处解释这个 bypass 与正常 hook 的区别(其实他俩本质就应该是一个东西,一起创建,一个地方调用,就是一个后面带个 break,一个不带)。这样统一收口了以后读起来也很容易理解。我照着改一下

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants