Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions example/ck_tile/50_sparse_attn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# CMakeLists.txt for sparse attention (Jenga and VSA)
#Copyright(c) Advanced Micro Devices, Inc., or its affiliates.
#SPDX - License - Identifier : MIT
#CMakeLists.txt for sparse attention(Jenga and VSA)

# Use SUPPORTED_GPU_TARGETS directly
#Use SUPPORTED_GPU_TARGETS directly
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS})

Expand All @@ -16,7 +16,7 @@ endif()

message(STATUS "Building Sparse Attention (Jenga & VSA) for targets: ${INST_TARGETS}")

# Code generation scripts
#Code generation scripts
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
${CMAKE_CURRENT_LIST_DIR}/generate.py
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
Expand Down Expand Up @@ -153,4 +153,47 @@ target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE
-Wno-float-equal
)

# ============================================================================
# Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen)
# ============================================================================
set(SPARGE_BLOCKMAP_INSTANCES "tile_sparge_blockmap_instances")

add_library(${SPARGE_BLOCKMAP_INSTANCES} OBJECT EXCLUDE_FROM_ALL
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp
)
target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
)
set_source_files_properties(
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp
PROPERTIES LANGUAGE HIP
)
set_property(TARGET ${SPARGE_BLOCKMAP_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})

target_compile_options(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
-DCK_TILE_FMHA_FWD_FAST_EXP2
-Wno-undefined-func-template
-Wno-float-equal
)

# ----------------------------------------------------------------------------
# Build unified Sparge test: combines blockmap, Jenga, and VSA attention
# for end-to-end evaluation and timing in a single executable.
# ----------------------------------------------------------------------------
set(EXAMPLE_SPARGE "tile_example_sparge")
message(DEBUG "adding example ${EXAMPLE_SPARGE}")
add_executable(${EXAMPLE_SPARGE} EXCLUDE_FROM_ALL test_sparge.cpp)
target_link_libraries(${EXAMPLE_SPARGE}
${SPARSE_ATTN_JENGA_INSTANCES}
${SPARSE_ATTN_VSA_INSTANCES}
${SPARGE_BLOCKMAP_INSTANCES}
)
target_include_directories(${EXAMPLE_SPARGE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_compile_options(${EXAMPLE_SPARGE} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
)

set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
179 changes: 156 additions & 23 deletions example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,17 @@ def update_file(file_path, content):
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}

template<>
void fmha_jenga_fwd_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_jenga_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
"""

FMHA_FWD_API_FILENAME = "fmha_jenga_fwd_api.cpp"
Expand Down Expand Up @@ -219,6 +230,45 @@ def update_file(file_path, content):
}}
"""

FMHA_FWD_ONESHOT_API_FILENAME = "fmha_jenga_fwd_oneshot_api.cpp"
FMHA_FWD_ONESHOT_API = """
#include "fmha_fwd_trek.hpp"
#include <iostream>

void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{

const bool has_load_tr = ck_tile::is_load_tr_supported();

{F_dispatch}
std::cerr << "fmha_jenga_fwd_oneshot: no matching dispatch (dtype=" << t.data_type
<< " hdim_q=" << t.hdim_q << " hdim_v=" << t.hdim_v
<< " seqlen_q=" << a.seqlen_q << " seqlen_k=" << a.seqlen_k
<< " mask=" << static_cast<int>(t.mask_type) << ")" << std::endl;
}}
"""

FMHA_FWD_ONESHOT_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
{F_dtype_case}
}}
"""

FMHA_FWD_ONESHOT_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""

FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
fmha_jenga_fwd_oneshot_<trait_>(s, a);
return;
}}
"""


@dataclass
class CppConstraint:
Expand Down Expand Up @@ -274,10 +324,7 @@ def scheck(self) -> str:

@property
def seqtune(self) -> str:
if self.bm0 == 128:
return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true
else:
return f"a.seqlen_q <= {self.bm0}"
return "true"

@property
def skcheck(self) -> str:
Expand Down Expand Up @@ -447,6 +494,67 @@ def api(self) -> str:
per_tr_load += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)

@property
def oneshot_api(self) -> str:
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}

per_tr_load = str()
for tr_load in ["t", "f"]:
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case = str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits = [
t
for t in self.pool[dtype][(hdim, hdim_v)]
if tr_load == t.tr_load
]
inners = str()
for k, trait in enumerate(traits):
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_ONESHOT_API_INNER_DISPATCH.format(
F_if=if_k,
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_trload=BOOL_MAP[trait.tr_load],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_ONESHOT_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_ONESHOT_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
)
per_tr_load += FMHA_FWD_ONESHOT_API_PER_TRLOAD.format(
F_if="if",
F_trload_cond=tr_load_cond_map[tr_load],
F_dtype_case=per_dtypes,
)
if not per_tr_load:
per_tr_load += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_ONESHOT_API.format(F_dispatch=per_tr_load)


@dataclass
class FmhaFwdTileSize:
Expand Down Expand Up @@ -582,38 +690,39 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128): [
FmhaFwdTileSize( # fmt: skip
16,
FmhaFwdTileSize( # fmt: skip -- 128x128 tile (original, for old sparse attn test)
128,
128,
32,
64,
128,
32,
128,
4,
1,
1,
4,
1,
1,
1,
1,
16,
16,
32,
16,
32,
16,
32,
32,
16,
-1,
CppConstraint("t.bm0 == 0 || t.bm0 == 128"),
),
FmhaFwdTileSize( # fmt: skip
32,
32,
FmhaFwdTileSize( # fmt: skip -- 64x128 tile (for sparge blockmap kM0=64)
64,
128,
32,
128,
32,
128,
2,
1,
1,
1,
1,
2,
1,
1,
32,
Expand All @@ -623,18 +732,40 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
32,
16,
-1,
CppConstraint("t.bm0 == 64"),
),
FmhaFwdTileSize( # fmt: skip
128,
16,
32,
64,
128,
32,
128,
1,
1,
1,
1,
1,
1,
16,
16,
32,
16,
16,
32,
-1,
),
FmhaFwdTileSize( # fmt: skip
32,
32,
128,
128,
32,
128,
4,
1,
1,
4,
1,
1,
1,
1,
32,
Expand All @@ -647,10 +778,10 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
),
FmhaFwdTileSize( # fmt: skip
128,
128,
64,
32,
128,
32,
16,
128,
4,
1,
Expand Down Expand Up @@ -780,7 +911,7 @@ def get_fwd_blobs(
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
if tile.F_bm0 != 128 or tile.F_bn0 != 128:
if tile.F_bm0 not in (64, 128) or tile.F_bn0 != 128:
continue
if pipeline.tag != "qr_async":
continue
Expand Down Expand Up @@ -846,6 +977,7 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:

def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
update_file(autogen_dir / FMHA_FWD_ONESHOT_API_FILENAME, api_pool.oneshot_api)


def write_blobs(
Expand All @@ -865,3 +997,4 @@ def list_blobs(
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_ONESHOT_API_FILENAME).as_posix() + "\n")
Loading