Fix/wan2.1 flash attention#153
Conversation
There was a problem hiding this comment.
Pull request overview
Updates AMD inference Dockerfiles to adjust FlashAttention build/install behavior (notably for Wan2.1) and expands the supported ROCm arch list for Mochi.
Changes:
- Replaces the pinned/parameterized FlashAttention wheel build in the Wan2.1 Dockerfile with a direct
setup.py installfrom an unpinned ROCm/flash-attention clone. - Adds
gfx950to thePYTORCH_ROCM_ARCHlist in the Mochi inference Dockerfile.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| docker/pyt_wan2.1_inference.ubuntu.amd.Dockerfile | Changes FlashAttention installation steps for Wan2.1 image builds. |
| docker/pyt_mochi_inference.ubuntu.amd.Dockerfile | Updates the ROCm architecture list used when building FlashAttention. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| #ARG BUILD_FA="1" | ||
| #ARG FA_BRANCH="v3.0.0.r1-cktile" | ||
| #ARG FA_REPO="https://github.com/ROCm/flash-attention.git" | ||
| #RUN if [ "$BUILD_FA" = "1" ]; then \ | ||
| # cd ${WORKSPACE_DIR} \ | ||
| # && pip uninstall -y flash-attention \ | ||
| # && rm -rf flash-attention \ | ||
| # && git clone ${FA_REPO} \ | ||
| # && cd flash-attention \ | ||
| # && git checkout ${FA_BRANCH} \ | ||
| # && git submodule update --init \ | ||
| # && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \ | ||
| # && pip install dist/*.whl \ | ||
| # && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \ | ||
| # fi | ||
| # install flash attention | ||
| ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" | ||
|
|
||
| RUN git clone https://github.com/ROCm/flash-attention.git &&\ | ||
| cd flash-attention &&\ | ||
| python setup.py install |
| #ARG BUILD_FA="1" | ||
| #ARG FA_BRANCH="v3.0.0.r1-cktile" | ||
| #ARG FA_REPO="https://github.com/ROCm/flash-attention.git" | ||
| #RUN if [ "$BUILD_FA" = "1" ]; then \ | ||
| # cd ${WORKSPACE_DIR} \ | ||
| # && pip uninstall -y flash-attention \ | ||
| # && rm -rf flash-attention \ | ||
| # && git clone ${FA_REPO} \ | ||
| # && cd flash-attention \ | ||
| # && git checkout ${FA_BRANCH} \ | ||
| # && git submodule update --init \ | ||
| # && GPU_ARCHS=${HIP_ARCHITECTURES} python3 setup.py bdist_wheel --dist-dir=dist \ | ||
| # && pip install dist/*.whl \ | ||
| # && python -c "import flash_attn; print(f'Flash Attention version == {flash_attn.__version__}')"; \ | ||
| # fi | ||
| # install flash attention | ||
| ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" | ||
|
|
||
| RUN git clone https://github.com/ROCm/flash-attention.git &&\ | ||
| cd flash-attention &&\ | ||
| python setup.py install |
| ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" | ||
| ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201 | ||
| ARG PYTORCH_ROCM_ARCH=gfx950;gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201 |
| # install flash attention | ||
| ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" | ||
|
|
||
| RUN git clone https://github.com/ROCm/flash-attention.git &&\ |
There was a problem hiding this comment.
Please use FA_BRANCH & FA_REPO arguments. These are meant to build using build arguments with whatever branch is needed. Already existing branch is the latest tag from Flash-attention, any specific reason to remove it?
There was a problem hiding this comment.
@lcskrishna as per the steps mentioned on SWDEV-564747, thus why the args. has removed, and also please refer the steps mentioned in this repo - https://github.com/Dao-AILab/flash-attention.
Motivation
Updated wan2.1 dockerfile with FA steps taken from ROCM FA repo.
Technical Details
Test Plan
Test Result
Submission Checklist