From eed42a9dfa2d81358344689f04489517ee8c0510 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Thu, 19 Mar 2026 23:28:36 -0400 Subject: [PATCH 1/6] Add host-side Sparge block-map pipeline for sparse attention examples - Add sparge_tool.hpp: host-side Sparge block-map builder (mean-sim scoring, CDF/topk selection) and VSA delta-LUT converter. - Add test_sparge_jenga_sparse_attn.cpp and test_sparge_vsa_sparse_attn.cpp as end-to-end demos. - Update CMakeLists.txt to register both new executables. Note: block size is currently fixed at 128; flexible block size support is not yet addressed. --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 22 + .../ck_tile/50_sparse_attn/sparge_tool.hpp | 408 +++++++++++++++++ .../test_sparge_jenga_sparse_attn.cpp | 422 +++++++++++++++++ .../test_sparge_vsa_sparse_attn.cpp | 429 ++++++++++++++++++ 4 files changed, 1281 insertions(+) create mode 100644 example/ck_tile/50_sparse_attn/sparge_tool.hpp create mode 100644 example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp create mode 100644 example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index 65bb2077642..c916f642ebb 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -88,6 +88,17 @@ target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) +# Sparge + Jenga Example executable +set(EXAMPLE_SPARGE_JENGA_SPARSE_ATTN "tile_example_sparge_jenga_sparse_attn") +message(DEBUG "adding example ${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN}") +add_executable(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_jenga_sparse_attn.cpp) +target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES}) +target_include_directories(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_compile_options(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal +) + # ============================================================================ # VSA Sparse Attention # ============================================================================ @@ -153,4 +164,15 @@ target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) +# Sparge + VSA Example executable +set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn") +message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}") +add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp) +target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES}) +target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal +) + set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/50_sparse_attn/sparge_tool.hpp b/example/ck_tile/50_sparse_attn/sparge_tool.hpp new file mode 100644 index 00000000000..49c69cc6f74 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_tool.hpp @@ -0,0 +1,408 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace sparge { + +struct SpargeParams +{ + int BLKQ = 128; + int BLKK = 128; + + // Similarity gate threshold (TODO: per-head support). + float simthreshd1 = 0.6f; + + // Exactly one of the following should be used: + // - Use CDF threshold if topk < 0 + // - Both should be in [0, 1] <-- NEED TO CHECK THIS + float cdfthreshd = 0.98f; + float topk = -1.0f; + + // If true, treat Q/K as BHSD; otherwise BSHD (same convention as CK examples). + bool i_perm = true; +}; + +// Output format CK VSA expects. +struct VSALut +{ + ck_tile::HostTensor lut; // [B, Hq, Q_blk, K_blk] delta-encoded + ck_tile::HostTensor valid_block_num; // [B, Hq, Q_blk] +}; + +namespace detail { + +template +inline float to_f32(const T& x) +{ + return ck_tile::type_convert(x); +} + +// Read element from HostTensor with either BHSD or BSHD layout. +// Q: [B, Hq, Sq, D] if i_perm else [B, Sq, Hq, D] +// K: [B, Hk, Sk, D] if i_perm else [B, Sk, Hk, D] +template +inline float load(const ck_tile::HostTensor& X, bool i_perm, int b, int h, int s, int d) +{ + return i_perm ? to_f32(X(b, h, s, d)) : to_f32(X(b, s, h, d)); +} + +// Compute pooled mean vector of one block: mean over tokens in [s0, s1). +template +std::vector +pooled_mean_block(const ck_tile::HostTensor& X, bool i_perm, int b, int h, int s0, int s1, int d) +{ + std::vector mean(d, 0.0f); + const int bs = std::max(0, s1 - s0); + if(bs == 0) + return mean; + + for(int s = s0; s < s1; ++s) + { + for(int d_ = 0; d_ < d; ++d_) + { + mean[d_] += load(X, i_perm, b, h, s, d_); + } + } + const float inv = 1.0f / static_cast(bs); + for(int d_ = 0; d_ < d; ++d_) + mean[d_] *= inv; + return mean; +} + +// Compute "sim" flag of one block following SpargeAttn's intent: +// mean_sim = sum(Gram(x_hat)) / (BS_*BS_), where x_hat are token vectors normalized along D. +// +// Important: sum(Gram) = ||sum_i x_hat_i||^2, so we can compute it in O(BS_*D) exactly +// instead of O(BS_^2 * D). +template +bool sim_block_flag(const ck_tile::HostTensor& X, + bool i_perm, + int b, + int h, + int s0, + int s1, + int d, + float simthreshd1) +{ + const int bs = std::max(0, s1 - s0); + if(bs == 0) + return false; + + std::vector sum_hat(d, 0.0f); + + for(int s = s0; s < s1; ++s) + { + // Compute L2 norm over D. + float norm2 = 0.0f; + for(int d_ = 0; d_ < d; ++d_) + { + const float v = load(X, i_perm, b, h, s, d_); + norm2 += v * v; + } + float inv_norm = 1.0f; + // spargeAttn use eps to prevent division by zero + if(norm2 > 0.0f) + inv_norm = 1.0f / std::sqrt(norm2); + + // Accumulate normalized vector. + for(int d_ = 0; d_ < d; ++d_) + { + sum_hat[d_] += load(X, i_perm, b, h, s, d_) * inv_norm; + } + } + + float sum_gram = 0.0f; + for(int d_ = 0; d_ < d; ++d_) + sum_gram += sum_hat[d_] * sum_hat[d_]; + + const float denom = static_cast(bs) * static_cast(bs); + const float mean_sim = sum_gram / denom; + + return mean_sim > simthreshd1; +} + +inline int select_count_from_cdf(const std::vector& sorted_probs, float cdfthreshd) +{ + // Choose the smallest n such that cdf[n-1] >= cdfthreshd. + // Ensure at least 1. + if(sorted_probs.empty()) + return 0; + if(cdfthreshd <= 0.0f) + return 1; + + float c = 0.0f; + for(int i = 0; i < static_cast(sorted_probs.size()); ++i) + { + c += sorted_probs[i]; + if(c >= cdfthreshd) + return i + 1; + } + return static_cast(sorted_probs.size()); +} + +inline int select_count_from_topk(int K_blk, float topk) +{ + if(K_blk <= 0) + return 0; + int n = static_cast(std::floor(topk * static_cast(K_blk))); + n = std::max(1, n); + return n; +} + +} // namespace detail + +// Build one-hot block_map[b,hq,qb,kb] in {0,1}. +// - No causal mask +// - No attention sink +// - Logic matches SpargeAttn's structure: +// - score softmax is only over sim_kblocks; ~sim_kblocks are forced ON later +// - if a Q-block is not "similar", force the whole row ON +template +ck_tile::HostTensor build_block_map_meansim(const ck_tile::HostTensor& Q, + const ck_tile::HostTensor& K, + const SpargeParams& p) +{ + const auto qlens = Q.get_lengths(); + const auto klens = K.get_lengths(); + + const int B = static_cast(qlens[0]); + const int Hq = p.i_perm ? static_cast(qlens[1]) : static_cast(qlens[2]); + const int Sq = p.i_perm ? static_cast(qlens[2]) : static_cast(qlens[1]); + const int D = static_cast(qlens[3]); + + [[maybe_unused]] const int Bk = static_cast(klens[0]); + const int Hk = p.i_perm ? static_cast(klens[1]) : static_cast(klens[2]); + const int Sk = p.i_perm ? static_cast(klens[2]) : static_cast(klens[1]); + [[maybe_unused]] const int Dk = static_cast(klens[3]); + + assert(B == Bk && D == Dk && Hq % Hk == 0); + assert(p.BLKQ > 0 && p.BLKK > 0); + + const int nhead_ratio_qk = Hq / Hk; + const int Q_blk = ck_tile::integer_divide_ceil(Sq, p.BLKQ); + const int K_blk = ck_tile::integer_divide_ceil(Sk, p.BLKK); + + ck_tile::HostTensor block_map({B, Hq, Q_blk, K_blk}); + + // pooled_q: [B,Hq,Q_blk,D], pooled_k: [B,Hk,K_blk,D] + // sim_q: [B,Hq,Q_blk], sim_k: [B,Hk,K_blk] + std::vector pooled_q(static_cast(B) * Hq * Q_blk * D, 0.0f); + std::vector pooled_k(static_cast(B) * Hk * K_blk * D, 0.0f); + std::vector sim_q(static_cast(B) * Hq * Q_blk, 0); + std::vector sim_k(static_cast(B) * Hk * K_blk, 0); + + auto idx_pq = [&](int b, int hq, int qb, int d) { + return (((b * Hq + hq) * Q_blk + qb) * D + d); + }; + auto idx_pk = [&](int b, int hk, int kb, int d) { + return (((b * Hk + hk) * K_blk + kb) * D + d); + }; + auto idx_sq = [&](int b, int hq, int qb) { return ((b * Hq + hq) * Q_blk + qb); }; + auto idx_sk = [&](int b, int hk, int kb) { return ((b * Hk + hk) * K_blk + kb); }; + + for(int b = 0; b < B; ++b) + { + for(int hq = 0; hq < Hq; ++hq) + { + // Q blocks + for(int qb = 0; qb < Q_blk; ++qb) + { + const int s0 = qb * p.BLKQ; + const int s1 = std::min(Sq, (qb + 1) * p.BLKQ); + + // pooled mean + auto mean = detail::pooled_mean_block(Q, p.i_perm, b, hq, s0, s1, D); + for(int d = 0; d < D; ++d) + pooled_q[idx_pq(b, hq, qb, d)] = mean[d]; + + // sim flag + sim_q[idx_sq(b, hq, qb)] = + detail::sim_block_flag(Q, p.i_perm, b, hq, s0, s1, D, p.simthreshd1) ? 1 : 0; + } + } + + for(int hk = 0; hk < Hk; ++hk) + { + // K blocks + for(int kb = 0; kb < K_blk; ++kb) + { + const int s0 = kb * p.BLKK; + const int s1 = std::min(Sk, (kb + 1) * p.BLKK); + + auto mean = detail::pooled_mean_block(K, p.i_perm, b, hk, s0, s1, D); + for(int d = 0; d < D; ++d) + pooled_k[idx_pk(b, hk, kb, d)] = mean[d]; + + sim_k[idx_sk(b, hk, kb)] = + detail::sim_block_flag(K, p.i_perm, b, hk, s0, s1, D, p.simthreshd1) ? 1 : 0; + } + } + } + + const float scale = 1.0f / std::sqrt(static_cast(D)); + + // Main loop + for(int b = 0; b < B; ++b) + { + for(int hq = 0; hq < Hq; ++hq) + { + const int hk = hq / nhead_ratio_qk; + + for(int qb = 0; qb < Q_blk; ++qb) + { + const bool q_is_sim = (sim_q[idx_sq(b, hq, qb)] != 0); + + // If Q-block is not "similar", force dense row. + if(!q_is_sim) + { + for(int kb = 0; kb < K_blk; ++kb) + block_map(b, hq, qb, kb) = 1; + continue; + } + + // Compute scores over K blocks (only sim_kblocks participate in softmax; others set + // to -inf). + std::vector score(K_blk, -std::numeric_limits::infinity()); + for(int kb = 0; kb < K_blk; ++kb) + { + const bool k_is_sim = (sim_k[idx_sk(b, hk, kb)] != 0); + if(!k_is_sim) + { + block_map(b, hq, qb, kb) = 1; + continue; + } + + float dot = 0.0f; + for(int d = 0; d < D; ++d) + { + dot += pooled_q[idx_pq(b, hq, qb, d)] * pooled_k[idx_pk(b, hk, kb, d)]; + } + score[kb] = dot * scale; + } + + // Softmax over K_blk (numerically stable). If all -inf, probs become all zeros. + float maxv = -std::numeric_limits::infinity(); + for(int kb = 0; kb < K_blk; ++kb) + maxv = std::max(maxv, score[kb]); + + std::vector prob(K_blk, 0.0f); + if(std::isfinite(maxv)) + { + float sumexp = 0.0f; + for(int kb = 0; kb < K_blk; ++kb) + { + if(!std::isfinite(score[kb])) + continue; + const float e = std::exp(score[kb] - maxv); + prob[kb] = e; + sumexp += e; + } + if(sumexp > 0.0f) + { + const float inv = 1.0f / sumexp; + for(int kb = 0; kb < K_blk; ++kb) + prob[kb] *= inv; + } + else + { + // All exponentials underflowed: keep zeros. + std::fill(prob.begin(), prob.end(), 0.0f); + } + } + + // Sort indices by prob descending. + std::vector order(K_blk); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), [&](int a, int c) { + if(prob[a] != prob[c]) + return prob[a] > prob[c]; + return a < c; // tie-breaker for determinism + }); + + // Determine how many to select. + int num_to_select = 0; + if(p.topk > 0.0f) + { + num_to_select = detail::select_count_from_topk(K_blk, p.topk); + } + else + { + // Use CDF threshold selection (smallest n s.t. cumulative prob >= cdfthreshd). + std::vector sorted_probs(K_blk); + for(int i = 0; i < K_blk; ++i) + sorted_probs[i] = prob[order[i]]; + num_to_select = detail::select_count_from_cdf(sorted_probs, p.cdfthreshd); + num_to_select = std::max(1, num_to_select); + } + + // Select top-kb blocks by order[0..num_to_select-1]. + for(int i = 0; i < num_to_select; ++i) + { + const int kb = order[i]; + block_map(b, hq, qb, kb) = 1; + } + } + } + } + + return block_map; +} + +// Convert one-hot block_map -> delta-encoded LUT + valid_block_num (CK VSA format). +template +VSALut block_map_to_vsa_lut_delta(const ck_tile::HostTensor& block_map) +{ + const auto lens = block_map.get_lengths(); + const int B = static_cast(lens[0]); + const int H = static_cast(lens[1]); + const int Q = static_cast(lens[2]); + const int K = static_cast(lens[3]); + + VSALut out{ + ck_tile::HostTensor({B, H, Q, K}), + ck_tile::HostTensor({B, H, Q}), + }; + + for(int b = 0; b < B; ++b) + { + for(int h = 0; h < H; ++h) + { + for(int q = 0; q < Q; ++q) + { + int32_t valid = 0; + int32_t prev = 0; + + for(int k = 0; k < K; ++k) + { + const bool on = static_cast(block_map(b, h, q, k)) != 0; + if(on) + { + out.lut(b, h, q, valid) = static_cast(k - prev); + prev = static_cast(k); + ++valid; + } + } + + out.valid_block_num(b, h, q) = valid; + + // Optional: zero-fill the unused tail for determinism. + for(int i = valid; i < K; ++i) + out.lut(b, h, q, i) = 0; + } + } + } + + return out; +} + +} // namespace sparge diff --git a/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp new file mode 100644 index 00000000000..0bd664adf68 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp @@ -0,0 +1,422 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// Demo: Sparge block-map -> Jenga sparse attention + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/reference/reference_blocked_attention.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" + +#include "jenga_sparse_attention.h" +#include "sparge_tool.hpp" + +// ============================================================================ +// Helper Functions +// ============================================================================ + +template +ck_tile::HostTensor make_qkv_tensor(ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t seqlen, + ck_tile::index_t hdim, + bool i_perm) +{ + if(i_perm) + { + return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); + } + return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); +} + +template +ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) +{ + auto lens = tensor.get_lengths(); + ck_tile::index_t batch = lens[0]; + ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; + ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; + ck_tile::index_t hdim = lens[3]; + + ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t s = 0; s < seqlen; ++s) + { + for(ck_tile::index_t d = 0; d < hdim; ++d) + { + out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); + } + } + } + } + return out; +} + +template +auto get_error_tolerance() +{ + double rtol = 1e-2; + double atol = 4e-2; + if constexpr(std::is_same_v) + { + atol = 2e-1; + rtol = 2e-1; + } + return ck_tile::make_tuple(rtol, atol); +} + +template +float to_float_for_compare(T value) +{ + return static_cast(value); +} + +template <> +float to_float_for_compare(ck_tile::bf16_t value) +{ +#if CK_TILE_USE_CUSTOM_DATA_TYPE + return static_cast(value); +#else + return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); +#endif +} + +// ============================================================================ +// Command line argument parser +// ============================================================================ + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") + .insert("b", "1", "batch size") + .insert("h", "4", "num of head for q") + .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") + .insert("s", "4096", "seqlen_q") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("prec", "fp16", "data type: fp16/bf16") + .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") + .insert("operm", "1", "permute output") + .insert("seed", "42", "random seed") + .insert("warmup", "5", "warmup iterations") + .insert("repeat", "20", "benchmark iterations") + .insert("kname", "0", "print kernel name") + // Sparge-specific + .insert("blkq", "128", "Sparge BLKQ") + .insert("blkk", "128", "Sparge BLKK") + .insert("simthreshd1", "0.6", "Sparge sim threshold") + .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") + .insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// ============================================================================ +// Main Test Function +// ============================================================================ + +template +bool run_test(const ck_tile::ArgParser& arg_parser) +{ + int do_validation = arg_parser.get_int("v"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int kname = arg_parser.get_int("kname"); + + // Sparge params + ck_tile::index_t blkq = arg_parser.get_int("blkq"); + ck_tile::index_t blkk = arg_parser.get_int("blkk"); + float simthreshd1 = arg_parser.get_float("simthreshd1"); + float cdfthreshd = arg_parser.get_float("cdfthreshd"); + float topk = arg_parser.get_float("topk"); + + if(nhead_k < 0) + nhead_k = nhead; + if(seqlen_k < 0) + seqlen_k = seqlen_q; + if(hdim_v < 0) + hdim_v = hdim_q; + + if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128) + { + std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; + std::cout << "Jenga/VSA kernel instances are generated for BLKQ=BLKK=128, " + "hdim_q=128, hdim_v=128 only." + << std::endl; + std::cout << "TEST SKIPPED" << std::endl; + return true; + } + + ck_tile::index_t BLKQ = blkq; + ck_tile::index_t BLKK = blkk; + + ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; + ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; + + std::cout << "============================================================" << std::endl; + std::cout << "[Sparge -> Jenga Sparse Attention Demo]" << std::endl; + std::cout << "============================================================" << std::endl; + std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k + << std::endl; + std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl; + std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl; + std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl; + std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks + << std::endl; + std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd + << ", topk=" << topk << ")" << std::endl; + std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; + + // Create host tensors + ck_tile::HostTensor q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); + ck_tile::HostTensor k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); + ck_tile::HostTensor v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); + ck_tile::HostTensor output_host = + o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) + : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); + + std::cout << "\nInitializing tensors..." << std::endl; + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); + + // Build block map using Sparge tool + std::cout << "Building Sparge block map..." << std::endl; + sparge::SpargeParams p; + p.BLKQ = static_cast(BLKQ); + p.BLKK = static_cast(BLKK); + p.simthreshd1 = simthreshd1; + p.cdfthreshd = cdfthreshd; + p.topk = topk; + p.i_perm = i_perm; + + ck_tile::HostTensor block_relation_onehot = + sparge::build_block_map_meansim(q_host, k_host, p); + + // Print actual sparsity + std::size_t total_blocks = 0; + std::size_t active_blocks = 0; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + total_blocks++; + if(block_relation_onehot(b, h, qb, kb) != 0) + active_blocks++; + } + } + } + } + float actual_sparsity = + 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" + << total_blocks << " blocks active)" << std::endl; + + std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl; + + try + { + if(kname) + { + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 1); + } + + for(int i = 0; i < warmup; ++i) + { + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 0); + } + + [[maybe_unused]] auto sync_status1 = hipDeviceSynchronize(); + auto start = std::chrono::high_resolution_clock::now(); + + for(int i = 0; i < repeat; ++i) + { + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 0); + } + + [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); + auto end = std::chrono::high_resolution_clock::now(); + double avg_time_ms = + std::chrono::duration(end - start).count() / repeat; + + std::cout << "\n>>>> Jenga sparse attention average time: " << avg_time_ms << " ms <<<<" + << std::endl; + } + catch(const std::exception& e) + { + std::cerr << "Error during kernel execution: " << e.what() << std::endl; + return false; + } + + bool pass = true; + if(do_validation) + { + std::cout << "\n--- Performing CPU validation ---" << std::endl; + float scale = 1.0f / std::sqrt(static_cast(hdim_q)); + + std::cout << "Computing reference output..." << std::endl; + auto q_ref = to_bhsd(q_host, i_perm); + auto k_ref = to_bhsd(k_host, i_perm); + auto v_ref = to_bhsd(v_host, i_perm); + + ck_tile::reference_blocked_attention( + q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale); + + auto [rtol, atol] = get_error_tolerance(); + + float max_diff = 0.0f; + float max_rel_diff = 0.0f; + std::size_t num_errors = 0; + + auto output_host_bhsd = to_bhsd(output_host, o_perm); + for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i) + { + float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); + float ref_val = to_float_for_compare(output_ref.mData[i]); + float diff = std::abs(gpu_val - ref_val); + float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; + + max_diff = std::max(max_diff, diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + + if(diff > atol && rel_diff > rtol) + { + num_errors++; + if(num_errors <= 5) + { + std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val + << ", Ref=" << ref_val << ", Diff=" << diff << std::endl; + } + } + } + + std::cout << "\nValidation results:" << std::endl; + std::cout << " Max absolute difference: " << max_diff << std::endl; + std::cout << " Max relative difference: " << max_rel_diff << std::endl; + std::cout << " Number of mismatches: " << num_errors << " / " + << output_host_bhsd.mData.size() << std::endl; + + if(num_errors == 0) + { + std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl; + } + else + { + std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl; + pass = false; + } + } + + std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl; + return pass; +} + +// ============================================================================ +// Main +// ============================================================================ + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + std::cerr << "Failed to parse arguments" << std::endl; + return -1; + } + + std::string prec = arg_parser.get_str("prec"); + + bool test_result = false; + if(prec == "fp16") + { + test_result = run_test(arg_parser); + } + else if(prec == "bf16") + { + test_result = run_test(arg_parser); + } + else + { + std::cerr << "Unsupported precision: " << prec << std::endl; + return -1; + } + + return test_result ? 0 : -1; +} diff --git a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp new file mode 100644 index 00000000000..dd1d3e60bee --- /dev/null +++ b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp @@ -0,0 +1,429 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/reference/reference_blocked_attention.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" + +#include "jenga_sparse_attention.h" +#include "sparge_tool.hpp" + +// ============================================================================ +// Helper Functions +// ============================================================================ + +template +ck_tile::HostTensor make_qkv_tensor(ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t seqlen, + ck_tile::index_t hdim, + bool i_perm) +{ + if(i_perm) + { + return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); + } + return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); +} + +template +ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) +{ + auto lens = tensor.get_lengths(); + ck_tile::index_t batch = lens[0]; + ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; + ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; + ck_tile::index_t hdim = lens[3]; + + ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t s = 0; s < seqlen; ++s) + { + for(ck_tile::index_t d = 0; d < hdim; ++d) + { + out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); + } + } + } + } + return out; +} + +template +auto get_error_tolerance() +{ + double rtol = 1e-2; + double atol = 4e-2; + if constexpr(std::is_same_v) + { + atol = 2e-1; + rtol = 2e-1; + } + return ck_tile::make_tuple(rtol, atol); +} + +template +float to_float_for_compare(T value) +{ + return static_cast(value); +} + +template <> +float to_float_for_compare(ck_tile::bf16_t value) +{ +#if CK_TILE_USE_CUSTOM_DATA_TYPE + return static_cast(value); +#else + return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); +#endif +} + +// ============================================================================ +// Command line argument parser +// ============================================================================ + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") + .insert("b", "1", "batch size") + .insert("h", "4", "num of head for q") + .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") + .insert("s", "4096", "seqlen_q") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("prec", "fp16", "data type: fp16/bf16") + .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") + .insert("operm", "1", "permute output") + .insert("seed", "42", "random seed") + .insert("warmup", "5", "warmup iterations") + .insert("repeat", "20", "benchmark iterations") + .insert("kname", "0", "print kernel name") + // Sparge-specific + .insert("blkq", "128", "Sparge BLKQ") + .insert("blkk", "128", "Sparge BLKK") + .insert("simthreshd1", "0.6", "Sparge sim threshold") + .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") + .insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// ============================================================================ +// Main Test Function +// ============================================================================ + +template +bool run_test(const ck_tile::ArgParser& arg_parser) +{ + int do_validation = arg_parser.get_int("v"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int kname = arg_parser.get_int("kname"); + + // Sparge params + ck_tile::index_t blkq = arg_parser.get_int("blkq"); + ck_tile::index_t blkk = arg_parser.get_int("blkk"); + float simthreshd1 = arg_parser.get_float("simthreshd1"); + float cdfthreshd = arg_parser.get_float("cdfthreshd"); + float topk = arg_parser.get_float("topk"); + + if(nhead_k < 0) + nhead_k = nhead; + if(seqlen_k < 0) + seqlen_k = seqlen_q; + if(hdim_v < 0) + hdim_v = hdim_q; + + if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128) + { + std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; + std::cout << "VSA kernel instances are generated for BLKQ=BLKK=128, " + "hdim_q=128, hdim_v=128 only." + << std::endl; + std::cout << "TEST SKIPPED" << std::endl; + return true; + } + + ck_tile::index_t BLKQ = blkq; + ck_tile::index_t BLKK = blkk; + + ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; + ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; + + std::cout << "============================================================" << std::endl; + std::cout << "[Sparge -> VSA Sparse Attention Demo]" << std::endl; + std::cout << "============================================================" << std::endl; + std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k + << std::endl; + std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl; + std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl; + std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl; + std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks + << std::endl; + std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd + << ", topk=" << topk << ")" << std::endl; + std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; + + // Create host tensors + ck_tile::HostTensor q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); + ck_tile::HostTensor k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); + ck_tile::HostTensor v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); + ck_tile::HostTensor output_host = + o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) + : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); + + std::cout << "\nInitializing tensors..." << std::endl; + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); + + // Build block map using Sparge tool + std::cout << "Building Sparge block map..." << std::endl; + sparge::SpargeParams p; + p.BLKQ = static_cast(BLKQ); + p.BLKK = static_cast(BLKK); + p.simthreshd1 = simthreshd1; + p.cdfthreshd = cdfthreshd; + p.topk = topk; + p.i_perm = i_perm; + + ck_tile::HostTensor block_relation_onehot = + sparge::build_block_map_meansim(q_host, k_host, p); + + // Convert to VSA LUT (delta-encoded) + valid_block_num + std::cout << "Converting block map to VSA LUT (delta)..." << std::endl; + auto vsa_lut = sparge::block_map_to_vsa_lut_delta(block_relation_onehot); + + // Print actual sparsity (based on one-hot) + std::size_t total_blocks = 0; + std::size_t active_blocks = 0; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + total_blocks++; + if(block_relation_onehot(b, h, qb, kb) != 0) + active_blocks++; + } + } + } + } + float actual_sparsity = + 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" + << total_blocks << " blocks active)" << std::endl; + + std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; + + try + { + if(kname) + { + vsa_sparse_attention(q_host, + k_host, + v_host, + vsa_lut.lut, + vsa_lut.valid_block_num, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 1); + } + + for(int i = 0; i < warmup; ++i) + { + vsa_sparse_attention(q_host, + k_host, + v_host, + vsa_lut.lut, + vsa_lut.valid_block_num, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 0); + } + + [[maybe_unused]] auto sync_status1 = hipDeviceSynchronize(); + auto start = std::chrono::high_resolution_clock::now(); + + for(int i = 0; i < repeat; ++i) + { + vsa_sparse_attention(q_host, + k_host, + v_host, + vsa_lut.lut, + vsa_lut.valid_block_num, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 0); + } + + [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); + auto end = std::chrono::high_resolution_clock::now(); + double avg_time_ms = + std::chrono::duration(end - start).count() / repeat; + + std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<" + << std::endl; + } + catch(const std::exception& e) + { + std::cerr << "Error during kernel execution: " << e.what() << std::endl; + return false; + } + + bool pass = true; + if(do_validation) + { + std::cout << "\n--- Performing CPU validation ---" << std::endl; + float scale = 1.0f / std::sqrt(static_cast(hdim_q)); + + std::cout << "Computing reference output..." << std::endl; + auto q_ref = to_bhsd(q_host, i_perm); + auto k_ref = to_bhsd(k_host, i_perm); + auto v_ref = to_bhsd(v_host, i_perm); + + ck_tile::reference_blocked_attention( + q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale); + + auto [rtol, atol] = get_error_tolerance(); + + float max_diff = 0.0f; + float max_rel_diff = 0.0f; + std::size_t num_errors = 0; + + auto output_host_bhsd = to_bhsd(output_host, o_perm); + for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i) + { + float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); + float ref_val = to_float_for_compare(output_ref.mData[i]); + float diff = std::abs(gpu_val - ref_val); + float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; + + max_diff = std::max(max_diff, diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + + if(diff > atol && rel_diff > rtol) + { + num_errors++; + if(num_errors <= 5) + { + std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val + << ", Ref=" << ref_val << ", Diff=" << diff << std::endl; + } + } + } + + std::cout << "\nValidation results:" << std::endl; + std::cout << " Max absolute difference: " << max_diff << std::endl; + std::cout << " Max relative difference: " << max_rel_diff << std::endl; + std::cout << " Number of mismatches: " << num_errors << " / " + << output_host_bhsd.mData.size() << std::endl; + + if(num_errors == 0) + { + std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl; + } + else + { + std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl; + pass = false; + } + } + + std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl; + return pass; +} + +// ============================================================================ +// Main +// ============================================================================ + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + std::cerr << "Failed to parse arguments" << std::endl; + return -1; + } + + std::string prec = arg_parser.get_str("prec"); + + bool test_result = false; + if(prec == "fp16") + { + test_result = run_test(arg_parser); + } + else if(prec == "bf16") + { + test_result = run_test(arg_parser); + } + else + { + std::cerr << "Unsupported precision: " << prec << std::endl; + return -1; + } + + return test_result ? 0 : -1; +} From 9317fc4a8508cb53ec9bd829781d4781d84ce428 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Tue, 24 Mar 2026 05:57:54 -0400 Subject: [PATCH 2/6] Support 64x128 tile size in sparge fwd for Jenga and VSA paths --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 116 ++- .../codegen/ops/sparge_fwd_jenga.py | 799 ++++++++++++++++++ .../codegen/ops/sparge_fwd_vsa.py | 799 ++++++++++++++++++ .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 6 + .../50_sparse_attn/jenga_sparge_attention.cpp | 189 +++++ .../50_sparse_attn/jenga_sparge_attention.h | 27 + .../test_sparge_jenga_sparse_attn.cpp | 14 +- .../test_sparge_vsa_sparse_attn.cpp | 14 +- .../50_sparse_attn/vsa_sparge_attention.cpp | 195 +++++ .../50_sparse_attn/vsa_sparge_attention.h | 28 + ...block_fmha_pipeline_qr_ks_vs_async_vsa.hpp | 2 +- 11 files changed, 2167 insertions(+), 22 deletions(-) create mode 100644 example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py create mode 100644 example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py create mode 100644 example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp create mode 100644 example/ck_tile/50_sparse_attn/jenga_sparge_attention.h create mode 100644 example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp create mode 100644 example/ck_tile/50_sparse_attn/vsa_sparge_attention.h diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index c916f642ebb..0ac86f6affa 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -1,8 +1,8 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT -# CMakeLists.txt for sparse attention (Jenga and VSA) +#Copyright(c) Advanced Micro Devices, Inc., or its affiliates. +#SPDX - License - Identifier : MIT +#CMakeLists.txt for sparse attention(Jenga and VSA) -# Use SUPPORTED_GPU_TARGETS directly +#Use SUPPORTED_GPU_TARGETS directly set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS}) @@ -16,7 +16,7 @@ endif() message(STATUS "Building Sparse Attention (Jenga & VSA) for targets: ${INST_TARGETS}") -# Code generation scripts +#Code generation scripts file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS ${CMAKE_CURRENT_LIST_DIR}/generate.py ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py @@ -88,11 +88,62 @@ target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) +# ============================================================================ +# Sparge Jenga (64x128 tile) +# ============================================================================ +set(SPARGE_JENGA_CODE_GEN_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api sparge_fwd_jenga + --receipt 600 +) + +execute_process( + COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to generate Sparge Jenga kernel list") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt SPARGE_JENGA_GEN_BLOBS) + +add_custom_command( + OUTPUT ${SPARGE_JENGA_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} + COMMENT "Generate CK Tile Sparge Jenga kernels" +) + +message(STATUS "Sparge Jenga kernel files to be generated: ${SPARGE_JENGA_GEN_BLOBS}") + +set(SPARGE_JENGA_INSTANCES "tile_sparge_jenga_instances") + +add_library(${SPARGE_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL + ${SPARGE_JENGA_GEN_BLOBS} + ${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp +) +target_include_directories(${SPARGE_JENGA_INSTANCES} PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn +) +set_source_files_properties(${SPARGE_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp PROPERTIES LANGUAGE HIP) +set_property(TARGET ${SPARGE_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +target_compile_options(${SPARGE_JENGA_INSTANCES} PRIVATE + -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + -DCK_TILE_FMHA_FWD_FAST_EXP2 + -Wno-undefined-func-template + -Wno-float-equal +) + # Sparge + Jenga Example executable set(EXAMPLE_SPARGE_JENGA_SPARSE_ATTN "tile_example_sparge_jenga_sparse_attn") message(DEBUG "adding example ${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN}") add_executable(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_jenga_sparse_attn.cpp) -target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES}) +target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARGE_JENGA_INSTANCES}) target_include_directories(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_compile_options(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE -Wno-undefined-func-template @@ -164,11 +215,62 @@ target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) +# ============================================================================ +# Sparge VSA (64x128 tile) +# ============================================================================ +set(SPARGE_VSA_CODE_GEN_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api sparge_fwd_vsa + --receipt 600 +) + +execute_process( + COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to generate Sparge VSA kernel list") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt SPARGE_VSA_GEN_BLOBS) + +add_custom_command( + OUTPUT ${SPARGE_VSA_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} + COMMENT "Generate CK Tile Sparge VSA kernels" +) + +message(STATUS "Sparge VSA kernel files to be generated: ${SPARGE_VSA_GEN_BLOBS}") + +set(SPARGE_VSA_INSTANCES "tile_sparge_vsa_instances") + +add_library(${SPARGE_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL + ${SPARGE_VSA_GEN_BLOBS} + ${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp +) +target_include_directories(${SPARGE_VSA_INSTANCES} PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn +) +set_source_files_properties(${SPARGE_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp PROPERTIES LANGUAGE HIP) +set_property(TARGET ${SPARGE_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE + -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + -DCK_TILE_FMHA_FWD_FAST_EXP2 + -Wno-undefined-func-template + -Wno-float-equal +) + # Sparge + VSA Example executable set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn") message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}") add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp) -target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES}) +target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARGE_VSA_INSTANCES}) target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE -Wno-undefined-func-template diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py new file mode 100644 index 00000000000..872da2326ea --- /dev/null +++ b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py @@ -0,0 +1,799 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass, field +import fnmatch +import itertools +import os +import os.path as path +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cpp_symbol_map import ( + BOOL_MAP, + FWD_DTYPE_MAP, + LAYOUT_MAP, + MODE_MAP, + PIPELINE_ENUM_MAP, + PIPELINE_MAP, + get_mask_check_map, + get_mask_map, +) + +GEN_DIR = "" + + +def update_file(file_path, content): + """Update the file at file_path with the given content if it differs from the existing content. + + It avoids unnecessary touching of the file which triggers rebuilds + """ + + existing_content = "" + if path.exists(file_path): + with open(file_path, "r") as file: + existing_content = file.read() + if existing_content == content: + return + with open(file_path, "w") as file: + file.write(content) + + +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16} + +K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd_trek.hpp" +#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" +#include "kernel/fmha_fwd_jenga_kernel.hpp" + +""" + +# NOTE: Jenga sparse attention kernel has the following restrictions enforced by static_assert: +# - Group mode: NOT supported (batch mode only) +# - Bias: NOT supported (NO_BIAS only) +# - LSE output: NOT supported (false only) +# - Dropout: NOT supported (false only) +# - Logits soft-cap: NOT supported (false only) +# - FP8 static quantization: NOT supported (NO_SCALE only) +# The template below hardcodes these unsupported features accordingly. + +FMHA_FWD_KERNEL_BODY = """ +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum, +// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + false, // has_logits_soft_cap - NOT supported + ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported + false, // store_lse - NOT supported + false, // has_dropout - NOT supported + false, // has_randval - NOT supported + ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported + {F_occupancy}, + false>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported) + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaSparseFwdTypeConfig::QDataType, + typename FmhaSparseFwdTypeConfig::KDataType, + typename FmhaSparseFwdTypeConfig::VDataType, + typename FmhaSparseFwdTypeConfig::SaccDataType, + typename FmhaSparseFwdTypeConfig::SMPLComputeDataType, + typename FmhaSparseFwdTypeConfig::BiasDataType, + typename FmhaSparseFwdTypeConfig::RandValOutputDataType, + typename FmhaSparseFwdTypeConfig::LSEDataType, + typename FmhaSparseFwdTypeConfig::PDataType, + typename FmhaSparseFwdTypeConfig::OaccDataType, + typename FmhaSparseFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + {F_trload}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdJengaKernel; + +using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + +#include + +template<> +float fmha_jenga_fwd_(const ck_tile::stream_config& s, fmha_jenga_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << "{F_kernel_name}" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME = "sparge_jenga_fwd_api.cpp" +FMHA_FWD_API = """ +#include + +#include + +namespace {{ +bool get_num_cus(unsigned& num_cus) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cus = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + +float sparge_jenga_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + return fmha_jenga_fwd_(s, a); + }} +""" + + +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return "true" + else: + return f"{self.bool_expr}" + + def __and__(self, other): + return CppConstraint(f"({str(self)}) && ({str(other)})") + + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag: str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim: str + dtype: str # data type + mode: str # value from MODE_MAP + bm0: int # tile size along q seqlen (block size) + bn0: int # tile size along qk seqlen + bk0: int # tile size along qk gemm unroll + bn1: int # tile size along v head_dim + bk1: int # tile size along kv gemm unroll + bk0max: int + vlayout: str + logits: str + mask: str + spad: str + skpad: str + dpad: str + dvpad: str + tr_load: str + constraint: CppConstraint + + @property + def name(self) -> str: + return ( + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" + ) + + @property + def scheck(self) -> str: + if self.mode == "group": + return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true + if self.spad == "t": + return "true" # always support + return "true" + + @property + def seqtune(self) -> str: + return "true" + + @property + def skcheck(self) -> str: + if self.mode == "group": + return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true + if self.skpad == "t": + return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" + return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" + + @property + def dcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + assert False + + @property + def dvcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + assert False + + +@dataclass +class FmhaFwdPipeline: + tag: str + + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_logits: str # t/f + F_mask: str # value from MASK_MAP + F_trload: str # true/false + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + def pad_name() -> str: + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n + return n + + pn = pad_name() + n = f"{self.tag}_v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" + else: + n += "_npad" + + if self.F_logits == "t": + n += "_logits" + else: + n += "_nlogits" + + n += "_nbias" + + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" + else: + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" + + n += "_nskip" + + n += "_nsquant" + + if self.F_trload == "t": + n += "_trload" + else: + n += "_ntrload" + + return n + + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait: FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() + + self.pool[trait.dtype][hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + # F_logits removed - hardcoded to false (NOT supported) + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + # empty string we add some ignore to suppress warning in api + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + + +@dataclass +class FmhaFwdTileSize: + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_bk0: int # tile size along qk gemm unroll + F_bn1: int # tile size along v head_dim + F_bk1: int # tile size along kv gemm unroll + F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0: int # number of warps for gemm0 along q seqlen + F_rn0: int # number of warps for gemm0 along k seqlen + F_rk0: int # number of warps for gemm0 along head dim q (not used) + F_rm1: int # number of warps for gemm1 along q seqlen + F_rn1: int # number of warps for gemm1 along head dim v + F_rk1: int # number of warps for gemm1 along k seqlen (not used) + F_wm0: int # gemm0 warp size along m + F_wn0: int # gemm0 warp size along n + F_wk0: int # gemm0 warp size along k + F_wm1: int # gemm1 warp size along m + F_wn1: int # gemm1 warp size along n + F_wk1: int # gemm1 warp size along k + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + return ( + f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" + + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" + + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" + + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + ) + + +@dataclass +class FmhaFwdKernel: + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: FmhaFwdTileSize + F_pipeline: FmhaFwdPipeline + mask_impl: str + + @property + def template(self) -> str: + # kernel_body removed - unused + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bn1=self.F_tile.F_bn1, + F_bk1=self.F_tile.F_bk1, + F_bk0max=self.F_tile.F_bk0max, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + # F_logits removed - hardcoded to false in template (NOT supported) + F_occupancy=self.F_tile.F_occupancy, + F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode=MODE_MAP[self.F_mode], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_kernel_name=self.name, + ) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return ( + f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + tr_load=self.F_pipeline.F_trload, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + ) + + +class KernelComponentFactory: + # TODO: design a more practical way to do it + # this is current supported tile size per hdim + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype == "fp16" or dtype == "bf16": + return { + # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128, 128): [ + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + ), + ], + # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + else: + return None + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + @staticmethod + def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # NOTE: logits soft-cap is NOT supported by Jenga sparse attention (enforced by static_assert) + pipelines = [] + if dtype in ["fp16", "bf16"]: + for logits, mask in itertools.product( + ["f"], # logits soft-cap NOT supported, always false + get_mask_map(mask_impl).keys(), + ): + if hdim == 256 and hdim_v == 256: + # jenga fmha only supports dim <= 192 for now. + continue + pipelines.append( + FmhaFwdPipeline( # fmt: skip + "qr_async", + "row", + "t", + "f", + "t", + "t", + logits, + mask, + "f", + ) + ) + pipelines.append( + FmhaFwdPipeline( # fmt: skip + "qr_async", + "row", + "t", + "t", + "t", + "t", + logits, + mask, + "f", + ) + ) + else: + assert False + return pipelines + + +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == "fp16" or dtype == "bf16": + if (128, 128) in result.keys(): + result[(128, 128)].insert( + 0, + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + CppConstraint( + "get_num_blocks(128) < num_cus * min_cu_util_rate" + ), + ), + ) + return result + + +def get_fwd_blobs( + kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + factory = ( + CustomFactory + if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" + else KernelComponentFactory + ) + + # Only generate fp16/bf16 kernels for now. + # NOTE: Jenga sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert) + for dtype in ["fp16", "bf16"]: + d = factory.get_hdim_tile_size_dict(dtype) + if d is None: + continue + for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): + for tile, pipeline in itertools.product( + tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) + ): + if pipeline.tag != "qr_async": + continue + k = FmhaFwdKernel( + F_idx=2, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= mode == "batch" + cond &= pipeline.F_logits == "f" + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + update_file(autogen_dir / kernel.filename, kernel.template) + + +def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + + +def write_blobs( + output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + + +def list_blobs( + file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + with file_path.open("a") as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py new file mode 100644 index 00000000000..c9a389df3fa --- /dev/null +++ b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py @@ -0,0 +1,799 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass, field +import fnmatch +import itertools +import os +import os.path as path +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cpp_symbol_map import ( + BOOL_MAP, + FWD_DTYPE_MAP, + LAYOUT_MAP, + MODE_MAP, + PIPELINE_ENUM_MAP, + PIPELINE_MAP, + get_mask_check_map, + get_mask_map, +) + +GEN_DIR = "" + + +def update_file(file_path, content): + """Update the file at file_path with the given content if it differs from the existing content. + + It avoids unnecessary touching of the file which triggers rebuilds + """ + + existing_content = "" + if path.exists(file_path): + with open(file_path, "r") as file: + existing_content = file.read() + if existing_content == content: + return + with open(file_path, "w") as file: + file.write(content) + + +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16} + +K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd_trek.hpp" +#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" +#include "kernel/fmha_fwd_vsa_kernel.hpp" + +""" + +# NOTE: VSA sparse attention kernel has the following restrictions enforced by static_assert: +# - Group mode: NOT supported (batch mode only) +# - Bias: NOT supported (NO_BIAS only) +# - LSE output: NOT supported (false only) +# - Dropout: NOT supported (false only) +# - Logits soft-cap: NOT supported (false only) +# - FP8 static quantization: NOT supported (NO_SCALE only) +# The template below hardcodes these unsupported features accordingly. + +FMHA_FWD_KERNEL_BODY = """ +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum, +// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + false, // has_logits_soft_cap - NOT supported + ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported + false, // store_lse - NOT supported + false, // has_dropout - NOT supported + false, // has_randval - NOT supported + ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported + {F_occupancy}, + false>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported) + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaSparseFwdTypeConfig::QDataType, + typename FmhaSparseFwdTypeConfig::KDataType, + typename FmhaSparseFwdTypeConfig::VDataType, + typename FmhaSparseFwdTypeConfig::SaccDataType, + typename FmhaSparseFwdTypeConfig::SMPLComputeDataType, + typename FmhaSparseFwdTypeConfig::BiasDataType, + typename FmhaSparseFwdTypeConfig::RandValOutputDataType, + typename FmhaSparseFwdTypeConfig::LSEDataType, + typename FmhaSparseFwdTypeConfig::PDataType, + typename FmhaSparseFwdTypeConfig::OaccDataType, + typename FmhaSparseFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + {F_trload}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdVSAKernel; + +using trait_{F_idx} = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + +#include + +template<> +float fmha_vsa_fwd_(const ck_tile::stream_config& s, fmha_vsa_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << "{F_kernel_name}" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME = "sparge_vsa_fwd_api.cpp" +FMHA_FWD_API = """ +#include + +#include + +namespace {{ +bool get_num_cus(unsigned& num_cus) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cus = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + +float sparge_vsa_fwd(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + return fmha_vsa_fwd_(s, a); + }} +""" + + +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return "true" + else: + return f"{self.bool_expr}" + + def __and__(self, other): + return CppConstraint(f"({str(self)}) && ({str(other)})") + + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag: str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim: str + dtype: str # data type + mode: str # value from MODE_MAP + bm0: int # tile size along q seqlen (block size) + bn0: int # tile size along qk seqlen + bk0: int # tile size along qk gemm unroll + bn1: int # tile size along v head_dim + bk1: int # tile size along kv gemm unroll + bk0max: int + vlayout: str + logits: str + mask: str + spad: str + skpad: str + dpad: str + dvpad: str + tr_load: str + constraint: CppConstraint + + @property + def name(self) -> str: + return ( + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" + ) + + @property + def scheck(self) -> str: + if self.mode == "group": + return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true + if self.spad == "t": + return "true" # always support + return "true" + + @property + def seqtune(self) -> str: + return "true" + + @property + def skcheck(self) -> str: + if self.mode == "group": + return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true + if self.skpad == "t": + return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" + return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" + + @property + def dcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + assert False + + @property + def dvcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + assert False + + +@dataclass +class FmhaFwdPipeline: + tag: str + + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_logits: str # t/f + F_mask: str # value from MASK_MAP + F_trload: str # true/false + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + def pad_name() -> str: + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n + return n + + pn = pad_name() + n = f"{self.tag}_v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" + else: + n += "_npad" + + if self.F_logits == "t": + n += "_logits" + else: + n += "_nlogits" + + n += "_nbias" + + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" + else: + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" + + n += "_nskip" + + n += "_nsquant" + + if self.F_trload == "t": + n += "_trload" + else: + n += "_ntrload" + + return n + + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait: FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() + + self.pool[trait.dtype][hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + # F_logits removed - hardcoded to false (NOT supported) + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + # empty string we add some ignore to suppress warning in api + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + + +@dataclass +class FmhaFwdTileSize: + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_bk0: int # tile size along qk gemm unroll + F_bn1: int # tile size along v head_dim + F_bk1: int # tile size along kv gemm unroll + F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0: int # number of warps for gemm0 along q seqlen + F_rn0: int # number of warps for gemm0 along k seqlen + F_rk0: int # number of warps for gemm0 along head dim q (not used) + F_rm1: int # number of warps for gemm1 along q seqlen + F_rn1: int # number of warps for gemm1 along head dim v + F_rk1: int # number of warps for gemm1 along k seqlen (not used) + F_wm0: int # gemm0 warp size along m + F_wn0: int # gemm0 warp size along n + F_wk0: int # gemm0 warp size along k + F_wm1: int # gemm1 warp size along m + F_wn1: int # gemm1 warp size along n + F_wk1: int # gemm1 warp size along k + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + return ( + f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" + + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" + + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" + + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + ) + + +@dataclass +class FmhaFwdKernel: + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: FmhaFwdTileSize + F_pipeline: FmhaFwdPipeline + mask_impl: str + + @property + def template(self) -> str: + # kernel_body removed - unused + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bn1=self.F_tile.F_bn1, + F_bk1=self.F_tile.F_bk1, + F_bk0max=self.F_tile.F_bk0max, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + # F_logits removed - hardcoded to false in template (NOT supported) + F_occupancy=self.F_tile.F_occupancy, + F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode=MODE_MAP[self.F_mode], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_kernel_name=self.name, + ) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return ( + f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + tr_load=self.F_pipeline.F_trload, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + ) + + +class KernelComponentFactory: + # TODO: design a more practical way to do it + # this is current supported tile size per hdim + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype == "fp16" or dtype == "bf16": + return { + # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128, 128): [ + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + ), + ], + # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + else: + return None + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + @staticmethod + def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # NOTE: logits soft-cap is NOT supported by VSA sparse attention (enforced by static_assert) + pipelines = [] + if dtype in ["fp16", "bf16"]: + for logits, mask in itertools.product( + ["f"], # logits soft-cap NOT supported, always false + get_mask_map(mask_impl).keys(), + ): + if hdim == 256 and hdim_v == 256: + # vsa fmha only supports dim <= 192 for now. + continue + pipelines.append( + FmhaFwdPipeline( + "qr_async_vsa", + "row", + "t", + "f", + "t", + "t", + logits, + mask, + "f", + ) + ) + pipelines.append( + FmhaFwdPipeline( + "qr_async_vsa", + "row", + "t", + "t", + "t", + "t", + logits, + mask, + "f", + ) + ) + else: + assert False + return pipelines + + +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == "fp16" or dtype == "bf16": + if (128, 128) in result.keys(): + result[(128, 128)].insert( + 0, + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + CppConstraint( + "get_num_blocks(128) < num_cus * min_cu_util_rate" + ), + ), + ) + return result + + +def get_fwd_blobs( + kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + factory = ( + CustomFactory + if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" + else KernelComponentFactory + ) + + # Only generate fp16/bf16 kernels for now. + # NOTE: VSA sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert) + for dtype in ["fp16", "bf16"]: + d = factory.get_hdim_tile_size_dict(dtype) + if d is None: + continue + for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): + for tile, pipeline in itertools.product( + tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) + ): + if pipeline.tag != "qr_async_vsa": + continue + k = FmhaFwdKernel( + F_idx=1, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= mode == "batch" + cond &= pipeline.F_logits == "f" + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + update_file(autogen_dir / kernel.filename, kernel.template) + + +def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + + +def write_blobs( + output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + + +def list_blobs( + file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + with file_path.open("a") as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 7349c3576e8..25e3513d2fa 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -277,6 +277,9 @@ struct fmha_jenga_fwd_traits float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); +// sparge jenga +float sparge_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); + template float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args); @@ -322,6 +325,9 @@ using fmha_vsa_fwd_traits = fmha_jenga_fwd_traits; float fmha_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&); +// sparge vsa +float sparge_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&); + template float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp new file mode 100644 index 00000000000..88f3e08204e --- /dev/null +++ b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp @@ -0,0 +1,189 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "jenga_sparge_attention.h" +#include "fmha_fwd_trek.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/host/device_memory.hpp" +#include + +template +ck_tile::HostTensor +jenga_sparge_attention(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& Tblock_relation_onehot, + ck_tile::HostTensor& Y, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k, + int log_level) +{ + static_assert(std::is_same_v || + std::is_same_v, + "Jenga sparse attention supports fp16/bf16 only."); + std::string data_type = "fp16"; + if constexpr(std::is_same_v) + { + data_type = "bf16"; + } + + if(max_seqlen_q == 0) + max_seqlen_q = seqlen_q; + if(max_seqlen_k == 0) + max_seqlen_k = seqlen_k; + bool is_v_rowmajor = true; + float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + std::string msk_str = "0"; + mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); + + const ck_tile::index_t shape_seqlen_q = seqlen_q; + const ck_tile::index_t shape_seqlen_k = seqlen_k; + + ck_tile::stream_config stream_config{nullptr, + false, // time_kernel + log_level, + 0, + 1, + false}; + + ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_relation_buf(Tblock_relation_onehot.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes()); + + q_buf.ToDevice(TQ.data()); + k_buf.ToDevice(TK.data()); + v_buf.ToDevice(TV.data()); + block_relation_buf.ToDevice(Tblock_relation_onehot.data()); + + const auto init_args = [&](auto& args) { + assert(nhead % nhead_k == 0); + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + }(); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? shape_seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; + }(); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; + const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); + args.block_relation_onehot_ptr = block_relation_buf.GetDeviceBuffer(); + + args.batch = batch; + args.seqlen_q = shape_seqlen_q; + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + + args.o_ptr = o_buf.GetDeviceBuffer(); + + args.seqlen_k = shape_seqlen_k; + args.max_seqlen_q = max_seqlen_q; + + args.scale_s = scale_s; + + args.stride_o = stride_o; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_o = batch_stride_o; + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + }; + + const auto init_traits = [&](auto& traits) { + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type; + traits.is_v_rowmajor = is_v_rowmajor; + traits.mask_type = mask.type; + }; + + fmha_jenga_fwd_traits fmha_traits; + init_traits(fmha_traits); + + fmha_jenga_fwd_args args; + init_args(args); + + sparge_jenga_fwd(fmha_traits, args, stream_config); + + o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes()); + + return Y; +} + +template ck_tile::HostTensor +jenga_sparge_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); + +template ck_tile::HostTensor +jenga_sparge_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparge_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.h new file mode 100644 index 00000000000..6259fcc73cf --- /dev/null +++ b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.h @@ -0,0 +1,27 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +ck_tile::HostTensor +jenga_sparge_attention(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& Tblock_relation_onehot, + ck_tile::HostTensor& Y, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k, + int log_level = 0); diff --git a/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp index 0bd664adf68..590e51db144 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp @@ -16,7 +16,7 @@ #include "ck_tile/host/reference/reference_blocked_attention.hpp" #include "ck_tile/core/utility/bit_cast.hpp" -#include "jenga_sparse_attention.h" +#include "jenga_sparge_attention.h" #include "sparge_tool.hpp" // ============================================================================ @@ -115,7 +115,7 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "20", "benchmark iterations") .insert("kname", "0", "print kernel name") // Sparge-specific - .insert("blkq", "128", "Sparge BLKQ") + .insert("blkq", "64", "Sparge BLKQ") .insert("blkk", "128", "Sparge BLKK") .insert("simthreshd1", "0.6", "Sparge sim threshold") .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") @@ -161,10 +161,10 @@ bool run_test(const ck_tile::ArgParser& arg_parser) if(hdim_v < 0) hdim_v = hdim_q; - if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128) + if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128) { std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; - std::cout << "Jenga/VSA kernel instances are generated for BLKQ=BLKK=128, " + std::cout << "Sparge Jenga kernel instances are generated for BLKQ=64, BLKK=128, " "hdim_q=128, hdim_v=128 only." << std::endl; std::cout << "TEST SKIPPED" << std::endl; @@ -247,7 +247,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) { if(kname) { - jenga_sparse_attention(q_host, + jenga_sparge_attention(q_host, k_host, v_host, block_relation_onehot, @@ -268,7 +268,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < warmup; ++i) { - jenga_sparse_attention(q_host, + jenga_sparge_attention(q_host, k_host, v_host, block_relation_onehot, @@ -292,7 +292,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < repeat; ++i) { - jenga_sparse_attention(q_host, + jenga_sparge_attention(q_host, k_host, v_host, block_relation_onehot, diff --git a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp index dd1d3e60bee..c0feb23e581 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp @@ -16,7 +16,7 @@ #include "ck_tile/host/reference/reference_blocked_attention.hpp" #include "ck_tile/core/utility/bit_cast.hpp" -#include "jenga_sparse_attention.h" +#include "vsa_sparge_attention.h" #include "sparge_tool.hpp" // ============================================================================ @@ -115,7 +115,7 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "20", "benchmark iterations") .insert("kname", "0", "print kernel name") // Sparge-specific - .insert("blkq", "128", "Sparge BLKQ") + .insert("blkq", "64", "Sparge BLKQ") .insert("blkk", "128", "Sparge BLKK") .insert("simthreshd1", "0.6", "Sparge sim threshold") .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") @@ -161,10 +161,10 @@ bool run_test(const ck_tile::ArgParser& arg_parser) if(hdim_v < 0) hdim_v = hdim_q; - if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128) + if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128) { std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; - std::cout << "VSA kernel instances are generated for BLKQ=BLKK=128, " + std::cout << "Sparge VSA kernel instances are generated for BLKQ=64, BLKK=128, " "hdim_q=128, hdim_v=128 only." << std::endl; std::cout << "TEST SKIPPED" << std::endl; @@ -251,7 +251,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) { if(kname) { - vsa_sparse_attention(q_host, + vsa_sparge_attention(q_host, k_host, v_host, vsa_lut.lut, @@ -273,7 +273,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < warmup; ++i) { - vsa_sparse_attention(q_host, + vsa_sparge_attention(q_host, k_host, v_host, vsa_lut.lut, @@ -298,7 +298,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < repeat; ++i) { - vsa_sparse_attention(q_host, + vsa_sparge_attention(q_host, k_host, v_host, vsa_lut.lut, diff --git a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp new file mode 100644 index 00000000000..5f9c2676ddb --- /dev/null +++ b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp @@ -0,0 +1,195 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "vsa_sparge_attention.h" +#include "fmha_fwd_trek.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/host/device_memory.hpp" +#include + +template +ck_tile::HostTensor +vsa_sparge_attention(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& TKV_block_idx, + const ck_tile::HostTensor& TKV_blocks, + ck_tile::HostTensor& Y, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k, + int log_level) +{ + static_assert(std::is_same_v || + std::is_same_v, + "VSA sparse attention supports fp16/bf16 only."); + std::string data_type = "fp16"; + if constexpr(std::is_same_v) + { + data_type = "bf16"; + } + + if(max_seqlen_q == 0) + max_seqlen_q = seqlen_q; + if(max_seqlen_k == 0) + max_seqlen_k = seqlen_k; + bool is_v_rowmajor = true; + float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + std::string msk_str = "0"; + mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); + + const ck_tile::index_t shape_seqlen_q = seqlen_q; + const ck_tile::index_t shape_seqlen_k = seqlen_k; + + ck_tile::stream_config stream_config{nullptr, + false, // time_kernel + log_level, + 0, + 1, + false}; + + ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lut_buf(TKV_block_idx.get_element_space_size_in_bytes()); + ck_tile::DeviceMem valid_block_num_buf(TKV_blocks.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes()); + + q_buf.ToDevice(TQ.data()); + k_buf.ToDevice(TK.data()); + v_buf.ToDevice(TV.data()); + lut_buf.ToDevice(TKV_block_idx.data()); + valid_block_num_buf.ToDevice(TKV_blocks.data()); + + const auto init_args = [&](auto& args) { + assert(nhead % nhead_k == 0); + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + }(); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? shape_seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; + }(); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; + const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); + args.lut_ptr = lut_buf.GetDeviceBuffer(); + args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer(); + + args.batch = batch; + args.seqlen_q = shape_seqlen_q; + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + + args.o_ptr = o_buf.GetDeviceBuffer(); + + args.seqlen_k = shape_seqlen_k; + args.max_seqlen_q = max_seqlen_q; + + args.scale_s = scale_s; + + args.stride_o = stride_o; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_o = batch_stride_o; + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + }; + + const auto init_traits = [&](auto& traits) { + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type; + traits.is_v_rowmajor = is_v_rowmajor; + traits.mask_type = mask.type; + }; + + fmha_vsa_fwd_traits fmha_traits; + init_traits(fmha_traits); + + fmha_vsa_fwd_args args; + init_args(args); + + sparge_vsa_fwd(fmha_traits, args, stream_config); + + o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes()); + + return Y; +} + +template ck_tile::HostTensor +vsa_sparge_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); + +template ck_tile::HostTensor +vsa_sparge_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); diff --git a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h new file mode 100644 index 00000000000..d51a7e8c00b --- /dev/null +++ b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h @@ -0,0 +1,28 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +ck_tile::HostTensor +vsa_sparge_attention(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& TKV_block_idx, + const ck_tile::HostTensor& TKV_blocks, + ck_tile::HostTensor& Y, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k, + int log_level = 0); diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp index 2b097ae5827..578ad7e6039 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp @@ -200,7 +200,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - int seqlen_k_start = kv_block_idx_ptr[0] * kM0; + int seqlen_k_start = kv_block_idx_ptr[0] * kN0; auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_origin(), From d1d457b82a63fdf6e68461194fa2c1098ace5f93 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Mon, 13 Apr 2026 03:34:08 -0400 Subject: [PATCH 3/6] Add sparge gpu pipeline in tile_example_sparge_vsa_sparse_attn --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 34 +- .../50_sparse_attn/sparge_blockmap.cpp | 156 ++++++ .../ck_tile/50_sparse_attn/sparge_blockmap.h | 26 + .../50_sparse_attn/sparge_blockmap_inst.cpp | 88 +++ .../50_sparse_attn/sparge_blockmap_trek.hpp | 93 ++++ .../test_sparge_vsa_sparse_attn.cpp | 234 ++++++-- .../kernel/sparge_blockmap_kernel.hpp | 195 +++++++ .../pipeline/sparge_blockmap_pipeline.hpp | 521 ++++++++++++++++++ 8 files changed, 1296 insertions(+), 51 deletions(-) create mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap.cpp create mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap.h create mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp create mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp create mode 100644 include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp create mode 100644 include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index 0ac86f6affa..169ed87ac3b 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -266,11 +266,41 @@ target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE -Wno-float-equal ) -# Sparge + VSA Example executable +# ============================================================================ +# Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen) +# ============================================================================ +set(SPARGE_BLOCKMAP_INSTANCES "tile_sparge_blockmap_instances") + +add_library(${SPARGE_BLOCKMAP_INSTANCES} OBJECT EXCLUDE_FROM_ALL + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp +) +target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn +) +set_source_files_properties( + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp + PROPERTIES LANGUAGE HIP +) +set_property(TARGET ${SPARGE_BLOCKMAP_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +target_compile_options(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE + -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + -DCK_TILE_FMHA_FWD_FAST_EXP2 + -Wno-undefined-func-template + -Wno-float-equal +) + +# Sparge + VSA Example executable (now links blockmap kernel too) set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn") message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}") add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp) -target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARGE_VSA_INSTANCES}) +target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} + ${SPARGE_VSA_INSTANCES} + ${SPARGE_BLOCKMAP_INSTANCES} +) target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE -Wno-undefined-func-template diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp new file mode 100644 index 00000000000..b9ac56c533c --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp @@ -0,0 +1,156 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "sparge_blockmap.h" +#include "sparge_blockmap_trek.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/host/device_memory.hpp" +#include +#include + +template +sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + ck_tile::HostTensor& block_map_out, + int batch, + int nhead_q, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + bool i_perm, + float simthreshd1, + float cdfthreshd, + float topk, + int blkq, + int blkk, + int log_level) +{ + static_assert(std::is_same_v || + std::is_same_v, + "sparge_blockmap_gpu supports fp16/bf16 only."); + + std::string data_type = "fp16"; + if constexpr(std::is_same_v) + { + data_type = "bf16"; + } + + const ck_tile::index_t num_q_blocks = ck_tile::integer_divide_ceil(seqlen_q, blkq); + const ck_tile::index_t num_k_blocks = ck_tile::integer_divide_ceil(seqlen_k, blkk); + + const float scale = 1.0f / std::sqrt(static_cast(hdim_q)); + + // Allocate device memory + ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); + + const std::size_t bmap_bytes = + static_cast(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(uint8_t); + const std::size_t lut_bytes = + static_cast(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(int32_t); + const std::size_t valid_bytes = + static_cast(batch) * nhead_q * num_q_blocks * sizeof(int32_t); + + ck_tile::DeviceMem bmap_buf(bmap_bytes); + ck_tile::DeviceMem lut_buf(lut_bytes); + ck_tile::DeviceMem valid_buf(valid_bytes); + + q_buf.ToDevice(TQ.data()); + k_buf.ToDevice(TK.data()); + bmap_buf.SetZero(); + lut_buf.SetZero(); + valid_buf.SetZero(); + + // Compute strides (assumes BHSD if i_perm, BSHD otherwise) + const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead_q * hdim_q; + const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q; + const ck_tile::index_t nhead_stride_q = + i_perm ? static_cast(seqlen_q) * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_k = + i_perm ? static_cast(seqlen_k) * hdim_q : hdim_q; + const ck_tile::index_t batch_stride_q = + static_cast(nhead_q) * seqlen_q * hdim_q; + const ck_tile::index_t batch_stride_k = + static_cast(nhead_k) * seqlen_k * hdim_q; + + ck_tile::stream_config stream_config{nullptr, false, log_level, 0, 1, false}; + + sparge_blockmap_args args; + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.batch = batch; + args.seqlen_q = seqlen_q; + args.seqlen_k = seqlen_k; + args.hdim_q = hdim_q; + args.nhead_q = nhead_q; + args.nhead_k = nhead_k; + args.stride_q = stride_q; + args.stride_k = stride_k; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.simthreshd1 = simthreshd1; + args.cdfthreshd = cdfthreshd; + args.topk = topk; + args.scale = scale; + args.block_map_ptr = bmap_buf.GetDeviceBuffer(); + args.lut_ptr = lut_buf.GetDeviceBuffer(); + args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); + + sparge_blockmap_traits traits; + traits.data_type = data_type; + traits.hdim_q = hdim_q; + + sparge_blockmap_fwd(traits, args, stream_config); + + // Copy results back to host + bmap_buf.FromDevice(block_map_out.data(), bmap_bytes); + + sparge::VSALut vsa_lut{ + ck_tile::HostTensor({batch, nhead_q, num_q_blocks, num_k_blocks}), + ck_tile::HostTensor({batch, nhead_q, num_q_blocks}), + }; + lut_buf.FromDevice(vsa_lut.lut.data(), lut_bytes); + valid_buf.FromDevice(vsa_lut.valid_block_num.data(), valid_bytes); + + return vsa_lut; +} + +// Explicit template instantiations +template sparge::VSALut +sparge_blockmap_gpu(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + bool, + float, + float, + float, + int, + int, + int); + +template sparge::VSALut +sparge_blockmap_gpu(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + bool, + float, + float, + float, + int, + int, + int); diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap.h b/example/ck_tile/50_sparse_attn/sparge_blockmap.h new file mode 100644 index 00000000000..3057257ca14 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap.h @@ -0,0 +1,26 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include "sparge_tool.hpp" + +template +sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + ck_tile::HostTensor& block_map_out, + int batch, + int nhead_q, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + bool i_perm, + float simthreshd1, + float cdfthreshd, + float topk, + int blkq, + int blkk, + int log_level = 0); diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp new file mode 100644 index 00000000000..fbd18b9ff24 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// Hand-written template instantiation for SpargeBlockMapKernel (fp16, D=128). + +#include "sparge_blockmap_trek.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" + +#include + +// ============================================================================ +// Type configuration for block map kernel (reuses FmhaSparseFwdTypeConfig) +// ============================================================================ + +// fp16: D=128, kM0=64, kN0=128 +using bmap_fp16_block_tile = ck_tile::sequence<64, 128, 128, 128, 128, 128>; +// kM0 kN0 kK0 kN1 kK1 kQKHeaddim(D) + +using bmap_fp16_shape = + ck_tile::TileFmhaShape, // Gemm0BlockWarps + ck_tile::sequence<16, 16, 16>, // Gemm0WarpTile (unused by blockmap, but + // needed by shape) + ck_tile::sequence<4, 1, 1>, // Gemm1BlockWarps + ck_tile::sequence<16, 16, 16>, // Gemm1WarpTile + true>; // VLayout row-major + +using bmap_fp16_trait = ck_tile::TileFmhaTraits; // kIsVRowMajorSkip + +using bmap_fp16_variant = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; +using bmap_fp16_mask = ck_tile::GenericAttentionMask; + +using bmap_fp16_problem = ck_tile::BlockFmhaPipelineProblem; + +using bmap_fp16_pipeline = ck_tile::SpargeBlockMapPipeline; +using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel; + +// ============================================================================ +// Dispatch +// ============================================================================ + +float sparge_blockmap_fwd(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& s) +{ + if(traits.data_type == "fp16" && traits.hdim_q == 128) + { + using k_ = bmap_fp16_kernel; + if(s.log_level_ > 0) + std::cout << ", sparge_blockmap_fp16_d128" << std::flush; + auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); + } + + if(s.log_level_ > 0) + std::cerr << "sparge_blockmap_fwd: unsupported config (data_type=" << traits.data_type + << ", hdim_q=" << traits.hdim_q << ")" << std::endl; + return -1.f; +} diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp new file mode 100644 index 00000000000..1e7e33248a2 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp @@ -0,0 +1,93 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" +#include "ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp" +#include "ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp" + +#include "fmha_fwd_trek.hpp" + +#include +#include + +// ============================================================================ +// Args and traits for sparge block map GPU kernel +// ============================================================================ +struct sparge_blockmap_args +{ + const void* q_ptr; + const void* k_ptr; + + ck_tile::index_t batch; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + + float simthreshd1; + float cdfthreshd; + float topk; + float scale; + + void* block_map_ptr; + void* lut_ptr; + void* valid_block_num_ptr; +}; + +struct sparge_blockmap_traits +{ + std::string data_type; + int hdim_q; +}; + +// ============================================================================ +// Create kernel args and grid dimensions +// ============================================================================ +template +auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = BlockMapKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.stride_q, + args.stride_k, + args.nhead_stride_q, + args.nhead_stride_k, + args.batch_stride_q, + args.batch_stride_k, + args.simthreshd1, + args.cdfthreshd, + args.topk, + args.scale, + args.block_map_ptr, + args.lut_ptr, + args.valid_block_num_ptr); + + dim3 grids = BlockMapKernel::GridSize(args.batch, args.nhead_q, args.seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +// ============================================================================ +// Hand-written template instantiation dispatch +// ============================================================================ +float sparge_blockmap_fwd(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& stream_config); diff --git a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp index c0feb23e581..638a867b0f3 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp @@ -17,6 +17,7 @@ #include "ck_tile/core/utility/bit_cast.hpp" #include "vsa_sparge_attention.h" +#include "sparge_blockmap.h" #include "sparge_tool.hpp" // ============================================================================ @@ -198,53 +199,37 @@ bool run_test(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor output_host = o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); - ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); std::cout << "\nInitializing tensors..." << std::endl; ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); - // Build block map using Sparge tool - std::cout << "Building Sparge block map..." << std::endl; - sparge::SpargeParams p; - p.BLKQ = static_cast(BLKQ); - p.BLKK = static_cast(BLKK); - p.simthreshd1 = simthreshd1; - p.cdfthreshd = cdfthreshd; - p.topk = topk; - p.i_perm = i_perm; - - ck_tile::HostTensor block_relation_onehot = - sparge::build_block_map_meansim(q_host, k_host, p); - - // Convert to VSA LUT (delta-encoded) + valid_block_num - std::cout << "Converting block map to VSA LUT (delta)..." << std::endl; - auto vsa_lut = sparge::block_map_to_vsa_lut_delta(block_relation_onehot); - - // Print actual sparsity (based on one-hot) - std::size_t total_blocks = 0; - std::size_t active_blocks = 0; - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) - { - for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) - { - total_blocks++; - if(block_relation_onehot(b, h, qb, kb) != 0) - active_blocks++; - } - } - } - } - float actual_sparsity = - 1.0f - static_cast(active_blocks) / static_cast(total_blocks); - std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" - << total_blocks << " blocks active)" << std::endl; - + // ================================================================== + // GPU: Build block map + VSA LUT in one kernel (always run) + // ================================================================== + std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl; + ck_tile::HostTensor block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks}); + auto vsa_lut_gpu = sparge_blockmap_gpu(q_host, + k_host, + block_map_gpu, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + i_perm, + simthreshd1, + cdfthreshd, + topk, + static_cast(BLKQ), + static_cast(BLKK), + 0); + + // ================================================================== + // VSA sparse attention kernel (always run) + // ================================================================== std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; try @@ -254,8 +239,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) vsa_sparge_attention(q_host, k_host, v_host, - vsa_lut.lut, - vsa_lut.valid_block_num, + vsa_lut_gpu.lut, + vsa_lut_gpu.valid_block_num, output_host, batch, nhead, @@ -276,8 +261,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) vsa_sparge_attention(q_host, k_host, v_host, - vsa_lut.lut, - vsa_lut.valid_block_num, + vsa_lut_gpu.lut, + vsa_lut_gpu.valid_block_num, output_host, batch, nhead, @@ -301,8 +286,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) vsa_sparge_attention(q_host, k_host, v_host, - vsa_lut.lut, - vsa_lut.valid_block_num, + vsa_lut_gpu.lut, + vsa_lut_gpu.valid_block_num, output_host, batch, nhead, @@ -332,17 +317,168 @@ bool run_test(const ck_tile::ArgParser& arg_parser) return false; } + // ================================================================== + // Sparsity statistics (always run, pure CPU read of HostTensor) + // ================================================================== + std::size_t total_blocks = 0; + std::size_t active_blocks = 0; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + total_blocks++; + if(block_map_gpu(b, h, qb, kb) != 0) + active_blocks++; + } + } + } + } + float actual_sparsity = + 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << "\n Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" + << total_blocks << " blocks active)" << std::endl; + + // ================================================================== + // Validation (only when -v=1) + // ================================================================== bool pass = true; if(do_validation) { std::cout << "\n--- Performing CPU validation ---" << std::endl; + + // CPU golden: block map + VSA LUT + std::cout << "Building Sparge block map (CPU golden)..." << std::endl; + sparge::SpargeParams p; + p.BLKQ = static_cast(BLKQ); + p.BLKK = static_cast(BLKK); + p.simthreshd1 = simthreshd1; + p.cdfthreshd = cdfthreshd; + p.topk = topk; + p.i_perm = i_perm; + + ck_tile::HostTensor block_relation_onehot = + sparge::build_block_map_meansim(q_host, k_host, p); + + std::cout << "Converting block map to VSA LUT (delta, CPU)..." << std::endl; + auto vsa_lut_cpu = sparge::block_map_to_vsa_lut_delta(block_relation_onehot); + + // Validate block map + std::cout << "\n--- Validating GPU block map vs CPU golden ---" << std::endl; + { + std::size_t bmap_mismatches = 0; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + if(block_map_gpu(b, h, qb, kb) != + block_relation_onehot(b, h, qb, kb)) + { + bmap_mismatches++; + if(bmap_mismatches <= 10) + { + std::cout + << " block_map mismatch at [" << b << "," << h << "," + << qb << "," << kb + << "]: GPU=" + << static_cast(block_map_gpu(b, h, qb, kb)) + << " CPU=" + << static_cast( + block_relation_onehot(b, h, qb, kb)) + << std::endl; + } + } + } + } + } + } + std::cout << " Block map mismatches: " << bmap_mismatches << " / " + << (batch * nhead * num_q_blocks * num_k_blocks) << std::endl; + if(bmap_mismatches > 0) + { + std::cout << ">>> GPU BLOCK MAP VALIDATION FAILED <<<" << std::endl; + pass = false; + } + else + { + std::cout << ">>> GPU BLOCK MAP VALIDATION PASSED <<<" << std::endl; + } + } + + // Validate VSA LUT + std::cout << "\n--- Validating GPU VSA LUT vs CPU golden ---" << std::endl; + { + std::size_t lut_mismatches = 0; + std::size_t valid_mismatches = 0; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + if(vsa_lut_gpu.valid_block_num(b, h, qb) != + vsa_lut_cpu.valid_block_num(b, h, qb)) + { + valid_mismatches++; + if(valid_mismatches <= 5) + { + std::cout + << " valid_block_num mismatch at [" << b << "," << h + << "," << qb + << "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb) + << " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb) + << std::endl; + } + } + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + if(vsa_lut_gpu.lut(b, h, qb, kb) != + vsa_lut_cpu.lut(b, h, qb, kb)) + { + lut_mismatches++; + if(lut_mismatches <= 10) + { + std::cout + << " LUT mismatch at [" << b << "," << h << "," << qb + << "," << kb + << "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb) + << " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb) + << std::endl; + } + } + } + } + } + } + std::cout << " LUT mismatches: " << lut_mismatches << std::endl; + std::cout << " valid_block_num mismatches: " << valid_mismatches << std::endl; + if(lut_mismatches == 0 && valid_mismatches == 0) + { + std::cout << ">>> GPU VSA LUT VALIDATION PASSED <<<" << std::endl; + } + else + { + std::cout << ">>> GPU VSA LUT VALIDATION FAILED <<<" << std::endl; + pass = false; + } + } + + // Validate attention output float scale = 1.0f / std::sqrt(static_cast(hdim_q)); - std::cout << "Computing reference output..." << std::endl; + std::cout << "\nComputing reference attention output..." << std::endl; auto q_ref = to_bhsd(q_host, i_perm); auto k_ref = to_bhsd(k_host, i_perm); auto v_ref = to_bhsd(v_host, i_perm); + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); ck_tile::reference_blocked_attention( q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale); @@ -374,7 +510,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) } } - std::cout << "\nValidation results:" << std::endl; + std::cout << "\nAttention validation results:" << std::endl; std::cout << " Max absolute difference: " << max_diff << std::endl; std::cout << " Max relative difference: " << max_rel_diff << std::endl; std::cout << " Number of mismatches: " << num_errors << " / " diff --git a/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp new file mode 100644 index 00000000000..ca177abf23a --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp @@ -0,0 +1,195 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include + +namespace ck_tile { + +template +struct SpargeBlockMapKernel +{ + using Pipeline = remove_cvref_t; + + static constexpr index_t kBlockSize = Pipeline::kBlockSize; + static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu; + + using QDataType = typename Pipeline::QDataType; + using KDataType = typename Pipeline::KDataType; + + static constexpr index_t kM0 = Pipeline::kM0; + static constexpr index_t kN0 = Pipeline::kN0; + static constexpr index_t D = Pipeline::D; + + static constexpr index_t kAlignment = 16 / sizeof(QDataType); + + struct Kargs + { + const void* q_ptr; + const void* k_ptr; + + index_t seqlen_q; + index_t seqlen_k; + index_t hdim_q; + + index_t nhead_q; + index_t nhead_ratio_qk; + + index_t stride_q; + index_t stride_k; + index_t nhead_stride_q; + index_t nhead_stride_k; + index_t batch_stride_q; + index_t batch_stride_k; + + float simthreshd1; + float cdfthreshd; + float topk; + float scale; + + void* block_map_ptr; + void* lut_ptr; + void* valid_block_num_ptr; + + index_t N_k; + }; + + CK_TILE_HOST static constexpr auto MakeKargs(const void* q_ptr, + const void* k_ptr, + index_t seqlen_q, + index_t seqlen_k, + index_t hdim_q, + index_t nhead_q, + index_t nhead_ratio_qk, + index_t stride_q, + index_t stride_k, + index_t nhead_stride_q, + index_t nhead_stride_k, + index_t batch_stride_q, + index_t batch_stride_k, + float simthreshd1, + float cdfthreshd, + float topk, + float scale, + void* block_map_ptr, + void* lut_ptr, + void* valid_block_num_ptr) + { + const index_t N_k = integer_divide_ceil(seqlen_k, kN0); + return Kargs{q_ptr, + k_ptr, + seqlen_q, + seqlen_k, + hdim_q, + nhead_q, + nhead_ratio_qk, + stride_q, + stride_k, + nhead_stride_q, + nhead_stride_k, + batch_stride_q, + batch_stride_k, + simthreshd1, + cdfthreshd, + topk, + scale, + block_map_ptr, + lut_ptr, + valid_block_num_ptr, + N_k}; + } + + CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_q, index_t seqlen_q) + { + const index_t Q_blk = integer_divide_ceil(seqlen_q, kM0); + return dim3(Q_blk, nhead_q, batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + const index_t qb = static_cast(blockIdx.x); + const index_t hq = static_cast(blockIdx.y); + const index_t b = static_cast(blockIdx.z); + + const index_t hk = hq / kargs.nhead_ratio_qk; + + // Q pointer for this (batch, head, q_block) + const auto* q_base = reinterpret_cast(kargs.q_ptr) + + b * kargs.batch_stride_q + hq * kargs.nhead_stride_q + + qb * kM0 * kargs.stride_q; + + // K pointer for this (batch, head_k) + const auto* k_base = reinterpret_cast(kargs.k_ptr) + + b * kargs.batch_stride_k + hk * kargs.nhead_stride_k; + + // Q DRAM view with OOB padding + const auto q_dram_naive = make_naive_tensor_view( + q_base, + make_tuple(kargs.seqlen_q - qb * kM0, D), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + const auto q_dram = pad_tensor_view( + q_dram_naive, make_tuple(number{}, number{}), sequence{}); + + auto q_window = make_tile_window(q_dram, + make_tuple(number{}, number{}), + {0, 0}, + Pipeline::MakeQBlockDistribution()); + + // K DRAM view with OOB padding + const auto k_dram_naive = + make_naive_tensor_view(k_base, + make_tuple(kargs.seqlen_k, D), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + const auto k_dram = pad_tensor_view( + k_dram_naive, make_tuple(number{}, number{}), sequence{}); + + auto k_window = make_tile_window(k_dram, + make_tuple(number{}, number{}), + {0, 0}, + Pipeline::MakeKBlockDistribution()); + + // Output pointers for this (batch, head, q_block) + const index_t N_k = kargs.N_k; + const index_t bmap_offset = + (b * kargs.nhead_q + hq) * integer_divide_ceil(kargs.seqlen_q, kM0) * N_k + qb * N_k; + auto* bmap_ptr = reinterpret_cast(kargs.block_map_ptr) + bmap_offset; + + int32_t* lut_out = nullptr; + int32_t* valid_out = nullptr; + if(kargs.lut_ptr != nullptr) + { + lut_out = reinterpret_cast(kargs.lut_ptr) + bmap_offset; + const index_t valid_offset = + (b * kargs.nhead_q + hq) * integer_divide_ceil(kargs.seqlen_q, kM0) + qb; + valid_out = reinterpret_cast(kargs.valid_block_num_ptr) + valid_offset; + } + + // Shared memory + __shared__ char smem[Pipeline::GetSmemSize()]; + + Pipeline{}(q_window, + k_window, + kargs.seqlen_q, + kargs.seqlen_k, + qb, + N_k, + kargs.nhead_ratio_qk, + kargs.simthreshd1, + kargs.cdfthreshd, + kargs.topk, + kargs.scale, + bmap_ptr, + lut_out, + valid_out, + static_cast(smem)); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp new file mode 100644 index 00000000000..222e73c60e2 --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp @@ -0,0 +1,521 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce.hpp" + +namespace ck_tile { + +template +struct SpargeBlockMapPipeline +{ + using Problem = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t D = BlockFmhaShape::kQKHeaddim; + static constexpr index_t NumWarps = BlockFmhaShape::NumWarps; + static constexpr index_t WarpSize = get_warp_size(); + + static constexpr index_t KPerThread = 16 / sizeof(QDataType); + static constexpr index_t KThreads = D / KPerThread; + static constexpr index_t SeqThreadPerWarp = WarpSize / KThreads; + static constexpr index_t MPerThread = kM0 / (SeqThreadPerWarp * NumWarps); + static constexpr index_t NPerThread = kN0 / (SeqThreadPerWarp * NumWarps); + + static constexpr index_t kBlockPerCu = 1; + static constexpr index_t kMaxKBlocks = 1024; + + // LDS layout (non-overlapping, all used simultaneously in Phase 2): + // [0 .. kReduceBytes) cross-warp reduction scratch + // [kScoreOffset ..) scores[N_k] + // [kBmapOffset ..) block_map[N_k] + // [kSmallOffset ..) Phase 3 argmax scratch (2*NumWarps floats) + static constexpr index_t kReduceBytes = NumWarps * D * sizeof(float); + static constexpr index_t kScoreOffset = kReduceBytes; + static constexpr index_t kBmapOffset = kScoreOffset + kMaxKBlocks * sizeof(float); + static constexpr index_t kSmallOffset = kBmapOffset + kMaxKBlocks * sizeof(uint8_t); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return kSmallOffset + 2 * NumWarps * sizeof(float); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeQBlockDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeKBlockDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + // Extract tile data into a local float array via static_for (compile-time indices). + template + CK_TILE_DEVICE static void tile_to_float(const Tile& tile, float (&out)[BufSize]) + { + static_assert(Tile::get_thread_buffer_size() == BufSize); + const auto& buf = tile.get_thread_buffer(); + static_for<0, BufSize, 1>{}([&](auto i) { out[i.value] = type_convert(buf[i]); }); + } + + // Column-wise (dim=0) sum: accumulate SeqPerThread rows into KPerThread partial sums, + // then xor-shuffle across m_idx within warp. + template + CK_TILE_DEVICE static void column_reduce_thread_and_warp(const float* __restrict__ data, + float (&col_acc)[KPerThread]) + { + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] = 0.f; + + for(index_t m = 0; m < SeqPerThread; ++m) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += data[m * KPerThread + k]; + + for(index_t stride = KThreads; stride < WarpSize; stride *= 2) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += warp_shuffle(col_acc[k], __lane_id() ^ stride); + } + + // Cross-warp LDS reduction for column sums. + CK_TILE_DEVICE static void column_reduce_cross_warp(float (&col_acc)[KPerThread], + float* __restrict__ smem_reduce) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + const index_t k_idx = lane_id % KThreads; + const index_t m_idx = lane_id / KThreads; + + if(m_idx == 0) + for(index_t k = 0; k < KPerThread; ++k) + smem_reduce[warp_id * D + k_idx * KPerThread + k] = col_acc[k]; + __syncthreads(); + + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] = 0.f; + for(index_t w = 0; w < NumWarps; ++w) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += smem_reduce[w * D + k_idx * KPerThread + k]; + __syncthreads(); + } + + // Compute ||v||^2 per row: sum along KPerThread then xor-shuffle across k_idx. + template + CK_TILE_DEVICE static void row_reduce_sq_norm(const float* __restrict__ data, + float (&row_norms)[SeqPerThread], + index_t actual_seq) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t m_idx = (tid % WarpSize) / KThreads; + + for(index_t m = 0; m < SeqPerThread; ++m) + { + float sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + { + float v = data[m * KPerThread + k]; + sq += v * v; + } + for(index_t stride = 1; stride < KThreads; stride *= 2) + sq += warp_shuffle(sq, __lane_id() ^ stride); + + index_t gsq = m * (SeqThreadPerWarp * NumWarps) + warp_id * SeqThreadPerWarp + m_idx; + row_norms[m] = (gsq < actual_seq) ? sq : 0.f; + } + } + + // Column reduce of normalised rows: sum_hat[d] = sum_i data[i,d] / ||data[i,:]||. + template + CK_TILE_DEVICE static void column_reduce_normalised(const float* __restrict__ data, + const float* __restrict__ row_norms, + float (&col_acc)[KPerThread], + index_t actual_seq) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t m_idx = (tid % WarpSize) / KThreads; + + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] = 0.f; + + for(index_t m = 0; m < SeqPerThread; ++m) + { + float inv_norm = (row_norms[m] > 0.f) ? (1.0f / __builtin_sqrtf(row_norms[m])) : 0.f; + index_t gsq = m * (SeqThreadPerWarp * NumWarps) + warp_id * SeqThreadPerWarp + m_idx; + if(gsq < actual_seq) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += data[m * KPerThread + k] * inv_norm; + } + + for(index_t stride = KThreads; stride < WarpSize; stride *= 2) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += warp_shuffle(col_acc[k], __lane_id() ^ stride); + } + + // Scalar reduce across k_idx lanes (within warp). + CK_TILE_DEVICE static float reduce_across_k(float v) + { + for(index_t stride = 1; stride < KThreads; stride *= 2) + v += warp_shuffle(v, __lane_id() ^ stride); + return v; + } + + // Full-block scalar reduce (warp xor + cross-warp LDS). + CK_TILE_DEVICE static float block_reduce_sum(float v, float* smem_small) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + + for(index_t stride = 1; stride < WarpSize; stride *= 2) + v += warp_shuffle(v, __lane_id() ^ stride); + if(lane_id == 0) + smem_small[warp_id] = v; + __syncthreads(); + if(tid == 0) + { + float s = 0.f; + for(index_t w = 0; w < NumWarps; ++w) + s += smem_small[w]; + smem_small[0] = s; + } + __syncthreads(); + return smem_small[0]; + } + + CK_TILE_DEVICE static float block_reduce_max(float v, float* smem_small) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + + for(index_t stride = 1; stride < WarpSize; stride *= 2) + v = max(v, warp_shuffle(v, __lane_id() ^ stride)); + if(lane_id == 0) + smem_small[warp_id] = v; + __syncthreads(); + if(tid == 0) + { + float s = smem_small[0]; + for(index_t w = 1; w < NumWarps; ++w) + s = max(s, smem_small[w]); + smem_small[0] = s; + } + __syncthreads(); + return smem_small[0]; + } + + // ====================================================================== + template + CK_TILE_DEVICE void operator()(const QWindowType& q_window_in, + const KWindowType& k_window_in, + index_t seqlen_q, + index_t seqlen_k, + index_t qb, + index_t N_k, + index_t /*nhead_ratio_qk*/, + float simthreshd1, + float cdfthreshd, + float topk, + float scale, + uint8_t* block_map_ptr, + int32_t* lut_ptr, + int32_t* valid_block_num_ptr, + void* smem_ptr) const + { + const index_t tid = static_cast(threadIdx.x); + + auto* smem_float = reinterpret_cast(smem_ptr); + auto* smem_scores = + reinterpret_cast(reinterpret_cast(smem_ptr) + kScoreOffset); + auto* smem_bmap = + reinterpret_cast(reinterpret_cast(smem_ptr) + kBmapOffset); + auto* smem_small = + reinterpret_cast(reinterpret_cast(smem_ptr) + kSmallOffset); + + const index_t bs_q = min(static_cast(kM0), seqlen_q - qb * kM0); + const float inv_bs_q = (bs_q > 0) ? (1.0f / static_cast(bs_q)) : 0.f; + + // ================================================================== + // Phase 1: Q Block Statistics + // ================================================================== + auto q_tile = load_tile(q_window_in); + + float q_data[MPerThread * KPerThread]; + tile_to_float(q_tile, q_data); + + // 1a. L2 norm per token + float psq[MPerThread]; + row_reduce_sq_norm(q_data, psq, bs_q); + + // 1b. Column sum -> mean + float pooled_q_mean[KPerThread]; + column_reduce_thread_and_warp(q_data, pooled_q_mean); + column_reduce_cross_warp(pooled_q_mean, smem_float); + for(index_t k = 0; k < KPerThread; ++k) + pooled_q_mean[k] *= inv_bs_q; + + // 1c. Normalised sum_hat + float sum_hat[KPerThread]; + column_reduce_normalised(q_data, psq, sum_hat, bs_q); + column_reduce_cross_warp(sum_hat, smem_float); + + // 1d. sim_q = ||sum_hat||^2 / bs_q^2 + float sh_sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + sh_sq += sum_hat[k] * sum_hat[k]; + sh_sq = reduce_across_k(sh_sq); + const float denom_q = static_cast(bs_q) * static_cast(bs_q); + const bool sim_q = (denom_q > 0.f) && ((sh_sq / denom_q) > simthreshd1); + + // Not similar → force all K blocks ON, early exit + if(!sim_q) + { + for(index_t i = tid; i < N_k; i += kBlockSize) + block_map_ptr[i] = 1; + + if(lut_ptr != nullptr && tid == 0) + { + int32_t valid = 0, prev = 0; + for(index_t kb = 0; kb < N_k; ++kb) + { + lut_ptr[valid] = static_cast(kb) - prev; + prev = static_cast(kb); + ++valid; + } + for(index_t i = valid; i < N_k; ++i) + lut_ptr[i] = 0; + *valid_block_num_ptr = valid; + } + return; + } + + // ================================================================== + // Phase 2: K Block Loop + // ================================================================== + for(index_t i = tid; i < N_k; i += kBlockSize) + smem_bmap[i] = 0; + __syncthreads(); + + auto k_window = k_window_in; + + for(index_t kb = 0; kb < N_k; ++kb) + { + const index_t bs_k = min(static_cast(kN0), seqlen_k - kb * kN0); + const float inv_bs_k = (bs_k > 0) ? (1.0f / static_cast(bs_k)) : 0.f; + + auto k_tile = load_tile(k_window); + + float k_data[NPerThread * KPerThread]; + tile_to_float(k_tile, k_data); + + // K mean + float pooled_k_mean[KPerThread]; + column_reduce_thread_and_warp(k_data, pooled_k_mean); + column_reduce_cross_warp(pooled_k_mean, smem_float); + for(index_t k = 0; k < KPerThread; ++k) + pooled_k_mean[k] *= inv_bs_k; + + // dot(pooled_q_mean, pooled_k_mean) + float dot = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + dot += pooled_q_mean[k] * pooled_k_mean[k]; + dot = reduce_across_k(dot); + + // K L2 norms + normalised sum_hat + float k_psq[NPerThread]; + row_reduce_sq_norm(k_data, k_psq, bs_k); + + float k_sum_hat[KPerThread]; + column_reduce_normalised(k_data, k_psq, k_sum_hat, bs_k); + column_reduce_cross_warp(k_sum_hat, smem_float); + + // sim_k + float ksh_sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + ksh_sq += k_sum_hat[k] * k_sum_hat[k]; + ksh_sq = reduce_across_k(ksh_sq); + const float denom_k = static_cast(bs_k) * static_cast(bs_k); + const bool sim_k = (denom_k > 0.f) && ((ksh_sq / denom_k) > simthreshd1); + + if(tid == 0) + { + if(!sim_k) + { + smem_bmap[kb] = 1; + smem_scores[kb] = -numeric::infinity(); + } + else + { + smem_scores[kb] = dot * scale; + } + } + __syncthreads(); + + move_tile_window(k_window, {kN0, 0}); + } + + // ================================================================== + // Phase 3: Softmax + Selection + // ================================================================== + + // max + float lmax = -numeric::infinity(); + for(index_t i = tid; i < N_k; i += kBlockSize) + lmax = max(lmax, smem_scores[i]); + const float max_score = block_reduce_max(lmax, smem_small); + + // exp + sum + float lsum = 0.f; + for(index_t i = tid; i < N_k; i += kBlockSize) + { + float e = (smem_scores[i] > -numeric::infinity()) + ? __builtin_expf(smem_scores[i] - max_score) + : 0.f; + smem_scores[i] = e; + lsum += e; + } + const float sum_exp = block_reduce_sum(lsum, smem_small); + + // normalise + const float inv_sum = (sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f; + for(index_t i = tid; i < N_k; i += kBlockSize) + smem_scores[i] *= inv_sum; + __syncthreads(); + + // Selection: iterative argmax + index_t num_to_select = + (topk > 0.f) + ? max(static_cast(1), static_cast(topk * static_cast(N_k))) + : N_k; + + float cumulative_prob = 0.f; + for(index_t round = 0; round < num_to_select; ++round) + { + // thread-local argmax + float best_val = -1.f; + index_t best_idx = 0; + for(index_t i = tid; i < N_k; i += kBlockSize) + { + if(smem_scores[i] > best_val || (smem_scores[i] == best_val && i < best_idx)) + { + best_val = smem_scores[i]; + best_idx = i; + } + } + + // warp argmax + for(index_t stride = 1; stride < WarpSize; stride *= 2) + { + float rv = warp_shuffle(best_val, __lane_id() ^ stride); + index_t ri = warp_shuffle(best_idx, __lane_id() ^ stride); + if(rv > best_val || (rv == best_val && ri < best_idx)) + { + best_val = rv; + best_idx = ri; + } + } + + // cross-warp argmax via LDS + const index_t lane_id = tid % WarpSize; + const index_t warp_id = tid / WarpSize; + if(lane_id == 0) + { + smem_small[warp_id] = best_val; + smem_small[NumWarps + warp_id] = bit_cast(static_cast(best_idx)); + } + __syncthreads(); + + if(tid == 0) + { + float bv = smem_small[0]; + index_t bi = bit_cast(smem_small[NumWarps]); + for(index_t w = 1; w < NumWarps; ++w) + { + float wv = smem_small[w]; + index_t wi = bit_cast(smem_small[NumWarps + w]); + if(wv > bv || (wv == bv && wi < bi)) + { + bv = wv; + bi = wi; + } + } + smem_small[0] = bv; + smem_small[1] = bit_cast(static_cast(bi)); + } + __syncthreads(); + + float g_val = smem_small[0]; + index_t g_idx = bit_cast(smem_small[1]); + + if(g_val <= 0.f) + break; + + if(tid == 0) + { + smem_bmap[g_idx] = 1; + smem_scores[g_idx] = -1.f; + } + __syncthreads(); + + if(topk > 0.f) + { + if(round + 1 >= num_to_select) + break; + } + else + { + cumulative_prob += g_val; + if(cumulative_prob >= cdfthreshd) + break; + } + } + + // ================================================================== + // Write outputs to global memory + // ================================================================== + for(index_t i = tid; i < N_k; i += kBlockSize) + block_map_ptr[i] = smem_bmap[i]; + + if(lut_ptr != nullptr && tid == 0) + { + int32_t valid = 0, prev = 0; + for(index_t kb = 0; kb < N_k; ++kb) + { + if(smem_bmap[kb] != 0) + { + lut_ptr[valid] = static_cast(kb) - prev; + prev = static_cast(kb); + ++valid; + } + } + for(index_t i = valid; i < N_k; ++i) + lut_ptr[i] = 0; + *valid_block_num_ptr = valid; + } + } +}; + +} // namespace ck_tile From c7e6e4f616b483f2a9aafd3e8d00238f02de77e5 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Tue, 14 Apr 2026 10:11:00 -0400 Subject: [PATCH 4/6] fix extra host side operations. --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 4 - .../50_sparse_attn/sparge_blockmap.cpp | 156 --------- .../ck_tile/50_sparse_attn/sparge_blockmap.h | 26 -- .../test_sparge_vsa_sparse_attn.cpp | 296 ++++++++++-------- .../50_sparse_attn/vsa_sparge_attention.cpp | 195 ------------ .../50_sparse_attn/vsa_sparge_attention.h | 28 -- 6 files changed, 164 insertions(+), 541 deletions(-) delete mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap.cpp delete mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap.h delete mode 100644 example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp delete mode 100644 example/ck_tile/50_sparse_attn/vsa_sparge_attention.h diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index 169ed87ac3b..f234f631b6b 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -249,14 +249,12 @@ set(SPARGE_VSA_INSTANCES "tile_sparge_vsa_instances") add_library(${SPARGE_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL ${SPARGE_VSA_GEN_BLOBS} - ${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp ) target_include_directories(${SPARGE_VSA_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR} ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn ) set_source_files_properties(${SPARGE_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) -set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp PROPERTIES LANGUAGE HIP) set_property(TARGET ${SPARGE_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE @@ -273,7 +271,6 @@ set(SPARGE_BLOCKMAP_INSTANCES "tile_sparge_blockmap_instances") add_library(${SPARGE_BLOCKMAP_INSTANCES} OBJECT EXCLUDE_FROM_ALL ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp - ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp ) target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR} @@ -281,7 +278,6 @@ target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE ) set_source_files_properties( ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp - ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp PROPERTIES LANGUAGE HIP ) set_property(TARGET ${SPARGE_BLOCKMAP_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp deleted file mode 100644 index b9ac56c533c..00000000000 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#include "sparge_blockmap.h" -#include "sparge_blockmap_trek.hpp" -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" -#include "ck_tile/host/device_memory.hpp" -#include -#include - -template -sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - ck_tile::HostTensor& block_map_out, - int batch, - int nhead_q, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - bool i_perm, - float simthreshd1, - float cdfthreshd, - float topk, - int blkq, - int blkk, - int log_level) -{ - static_assert(std::is_same_v || - std::is_same_v, - "sparge_blockmap_gpu supports fp16/bf16 only."); - - std::string data_type = "fp16"; - if constexpr(std::is_same_v) - { - data_type = "bf16"; - } - - const ck_tile::index_t num_q_blocks = ck_tile::integer_divide_ceil(seqlen_q, blkq); - const ck_tile::index_t num_k_blocks = ck_tile::integer_divide_ceil(seqlen_k, blkk); - - const float scale = 1.0f / std::sqrt(static_cast(hdim_q)); - - // Allocate device memory - ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); - - const std::size_t bmap_bytes = - static_cast(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(uint8_t); - const std::size_t lut_bytes = - static_cast(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(int32_t); - const std::size_t valid_bytes = - static_cast(batch) * nhead_q * num_q_blocks * sizeof(int32_t); - - ck_tile::DeviceMem bmap_buf(bmap_bytes); - ck_tile::DeviceMem lut_buf(lut_bytes); - ck_tile::DeviceMem valid_buf(valid_bytes); - - q_buf.ToDevice(TQ.data()); - k_buf.ToDevice(TK.data()); - bmap_buf.SetZero(); - lut_buf.SetZero(); - valid_buf.SetZero(); - - // Compute strides (assumes BHSD if i_perm, BSHD otherwise) - const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead_q * hdim_q; - const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q; - const ck_tile::index_t nhead_stride_q = - i_perm ? static_cast(seqlen_q) * hdim_q : hdim_q; - const ck_tile::index_t nhead_stride_k = - i_perm ? static_cast(seqlen_k) * hdim_q : hdim_q; - const ck_tile::index_t batch_stride_q = - static_cast(nhead_q) * seqlen_q * hdim_q; - const ck_tile::index_t batch_stride_k = - static_cast(nhead_k) * seqlen_k * hdim_q; - - ck_tile::stream_config stream_config{nullptr, false, log_level, 0, 1, false}; - - sparge_blockmap_args args; - args.q_ptr = q_buf.GetDeviceBuffer(); - args.k_ptr = k_buf.GetDeviceBuffer(); - args.batch = batch; - args.seqlen_q = seqlen_q; - args.seqlen_k = seqlen_k; - args.hdim_q = hdim_q; - args.nhead_q = nhead_q; - args.nhead_k = nhead_k; - args.stride_q = stride_q; - args.stride_k = stride_k; - args.nhead_stride_q = nhead_stride_q; - args.nhead_stride_k = nhead_stride_k; - args.batch_stride_q = batch_stride_q; - args.batch_stride_k = batch_stride_k; - args.simthreshd1 = simthreshd1; - args.cdfthreshd = cdfthreshd; - args.topk = topk; - args.scale = scale; - args.block_map_ptr = bmap_buf.GetDeviceBuffer(); - args.lut_ptr = lut_buf.GetDeviceBuffer(); - args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); - - sparge_blockmap_traits traits; - traits.data_type = data_type; - traits.hdim_q = hdim_q; - - sparge_blockmap_fwd(traits, args, stream_config); - - // Copy results back to host - bmap_buf.FromDevice(block_map_out.data(), bmap_bytes); - - sparge::VSALut vsa_lut{ - ck_tile::HostTensor({batch, nhead_q, num_q_blocks, num_k_blocks}), - ck_tile::HostTensor({batch, nhead_q, num_q_blocks}), - }; - lut_buf.FromDevice(vsa_lut.lut.data(), lut_bytes); - valid_buf.FromDevice(vsa_lut.valid_block_num.data(), valid_bytes); - - return vsa_lut; -} - -// Explicit template instantiations -template sparge::VSALut -sparge_blockmap_gpu(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - bool, - float, - float, - float, - int, - int, - int); - -template sparge::VSALut -sparge_blockmap_gpu(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - bool, - float, - float, - float, - int, - int, - int); diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap.h b/example/ck_tile/50_sparse_attn/sparge_blockmap.h deleted file mode 100644 index 3057257ca14..00000000000 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#pragma once - -#include -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" -#include "sparge_tool.hpp" - -template -sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - ck_tile::HostTensor& block_map_out, - int batch, - int nhead_q, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - bool i_perm, - float simthreshd1, - float cdfthreshd, - float topk, - int blkq, - int blkk, - int log_level = 0); diff --git a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp index 638a867b0f3..572b708f9ef 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp @@ -1,23 +1,17 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention +// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention (all-in-device) #include -#include #include -#include #include -#include -#include -#include - #include "ck_tile/host.hpp" #include "ck_tile/core.hpp" #include "ck_tile/host/reference/reference_blocked_attention.hpp" #include "ck_tile/core/utility/bit_cast.hpp" -#include "vsa_sparge_attention.h" -#include "sparge_blockmap.h" +#include "sparge_blockmap_trek.hpp" +#include "fmha_fwd_trek.hpp" #include "sparge_tool.hpp" // ============================================================================ @@ -192,7 +186,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) << ", topk=" << topk << ")" << std::endl; std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; - // Create host tensors + // Create host tensors and fill with random data ck_tile::HostTensor q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); ck_tile::HostTensor k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); ck_tile::HostTensor v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); @@ -206,119 +200,157 @@ bool run_test(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); // ================================================================== - // GPU: Build block map + VSA LUT in one kernel (always run) + // Allocate device memory once, HtoD once // ================================================================== - std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl; - ck_tile::HostTensor block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks}); - auto vsa_lut_gpu = sparge_blockmap_gpu(q_host, - k_host, - block_map_gpu, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - i_perm, - simthreshd1, - cdfthreshd, - topk, - static_cast(BLKQ), - static_cast(BLKK), - 0); + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(output_host.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + v_buf.ToDevice(v_host.data()); + + const std::size_t bmap_bytes = + static_cast(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(uint8_t); + const std::size_t lut_bytes = + static_cast(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(int32_t); + const std::size_t valid_bytes = + static_cast(batch) * nhead * num_q_blocks * sizeof(int32_t); + + ck_tile::DeviceMem bmap_buf(bmap_bytes); + ck_tile::DeviceMem lut_buf(lut_bytes); + ck_tile::DeviceMem valid_buf(valid_bytes); + bmap_buf.SetZero(); + lut_buf.SetZero(); + valid_buf.SetZero(); // ================================================================== - // VSA sparse attention kernel (always run) + // Common stride calculations // ================================================================== - std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; + assert(nhead % nhead_k == 0); + const float scale_s = 1.0f / std::sqrt(static_cast(hdim_q)); + + const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead * hdim_q; + const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q; + const ck_tile::index_t stride_v = i_perm ? hdim_v : nhead_k * hdim_v; + const ck_tile::index_t stride_o = o_perm ? hdim_v : nhead * hdim_v; + const ck_tile::index_t nhead_stride_q = i_perm ? seqlen_q * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_k = i_perm ? seqlen_k * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_v = i_perm ? seqlen_k * hdim_v : hdim_v; + const ck_tile::index_t nhead_stride_o = o_perm ? seqlen_q * hdim_v : hdim_v; + const ck_tile::index_t batch_stride_q = nhead * seqlen_q * hdim_q; + const ck_tile::index_t batch_stride_k = nhead_k * seqlen_k * hdim_q; + const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * seqlen_k; + const ck_tile::index_t batch_stride_o = nhead * seqlen_q * hdim_v; + + std::string data_type = "fp16"; + if constexpr(std::is_same_v) + data_type = "bf16"; - try - { - if(kname) - { - vsa_sparge_attention(q_host, - k_host, - v_host, - vsa_lut_gpu.lut, - vsa_lut_gpu.valid_block_num, - output_host, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - i_perm, - o_perm, - seqlen_q, - seqlen_k, - 1); - } + std::string msk_str = "0"; + mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); - for(int i = 0; i < warmup; ++i) - { - vsa_sparge_attention(q_host, - k_host, - v_host, - vsa_lut_gpu.lut, - vsa_lut_gpu.valid_block_num, - output_host, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - i_perm, - o_perm, - seqlen_q, - seqlen_k, - 0); - } + // ================================================================== + // GPU: Build block map + VSA LUT (always run, device-only) + // ================================================================== + std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl; + { + sparge_blockmap_args args; + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.batch = batch; + args.seqlen_q = seqlen_q; + args.seqlen_k = seqlen_k; + args.hdim_q = hdim_q; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + args.stride_q = stride_q; + args.stride_k = stride_k; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.simthreshd1 = simthreshd1; + args.cdfthreshd = cdfthreshd; + args.topk = topk; + args.scale = scale_s; + args.block_map_ptr = bmap_buf.GetDeviceBuffer(); + args.lut_ptr = lut_buf.GetDeviceBuffer(); + args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); + + sparge_blockmap_traits traits; + traits.data_type = data_type; + traits.hdim_q = hdim_q; + + sparge_blockmap_fwd(traits, args, ck_tile::stream_config{}); + } - [[maybe_unused]] auto sync_status1 = hipDeviceSynchronize(); - auto start = std::chrono::high_resolution_clock::now(); + // ================================================================== + // VSA sparse attention kernel (always run, LUT stays on device) + // ================================================================== + std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; - for(int i = 0; i < repeat; ++i) - { - vsa_sparge_attention(q_host, - k_host, - v_host, - vsa_lut_gpu.lut, - vsa_lut_gpu.valid_block_num, - output_host, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - i_perm, - o_perm, - seqlen_q, - seqlen_k, - 0); - } + fmha_vsa_fwd_args fmha_args; + fmha_args.q_ptr = q_buf.GetDeviceBuffer(); + fmha_args.k_ptr = k_buf.GetDeviceBuffer(); + fmha_args.v_ptr = v_buf.GetDeviceBuffer(); + fmha_args.lut_ptr = lut_buf.GetDeviceBuffer(); + fmha_args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); + fmha_args.o_ptr = o_buf.GetDeviceBuffer(); + fmha_args.batch = batch; + fmha_args.seqlen_q = seqlen_q; + fmha_args.seqlen_k = seqlen_k; + fmha_args.max_seqlen_q = seqlen_q; + fmha_args.hdim_q = hdim_q; + fmha_args.hdim_v = hdim_v; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead_k; + fmha_args.scale_s = scale_s; + fmha_args.stride_q = stride_q; + fmha_args.stride_k = stride_k; + fmha_args.stride_v = stride_v; + fmha_args.stride_o = stride_o; + fmha_args.nhead_stride_q = nhead_stride_q; + fmha_args.nhead_stride_k = nhead_stride_k; + fmha_args.nhead_stride_v = nhead_stride_v; + fmha_args.nhead_stride_o = nhead_stride_o; + fmha_args.batch_stride_q = batch_stride_q; + fmha_args.batch_stride_k = batch_stride_k; + fmha_args.batch_stride_v = batch_stride_v; + fmha_args.batch_stride_o = batch_stride_o; + fmha_args.window_size_left = mask.left; + fmha_args.window_size_right = mask.right; + fmha_args.mask_type = static_cast(mask.type); + + fmha_vsa_fwd_traits fmha_traits; + fmha_traits.hdim_q = hdim_q; + fmha_traits.hdim_v = hdim_v; + fmha_traits.data_type = data_type; + fmha_traits.is_v_rowmajor = true; + fmha_traits.mask_type = mask.type; + + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ kname ? 1 : 0, + warmup, + repeat, + false}; + + float avg_time_ms = sparge_vsa_fwd(fmha_traits, fmha_args, stream_config); + + std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<" + << std::endl; - [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); - auto end = std::chrono::high_resolution_clock::now(); - double avg_time_ms = - std::chrono::duration(end - start).count() / repeat; + // DtoH: attention output (always needed) + o_buf.FromDevice(output_host.data(), output_host.get_element_space_size_in_bytes()); - std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<" - << std::endl; - } - catch(const std::exception& e) - { - std::cerr << "Error during kernel execution: " << e.what() << std::endl; - return false; - } + // DtoH: block_map (needed for sparsity stats and validation) + ck_tile::HostTensor block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks}); + bmap_buf.FromDevice(block_map_gpu.data(), bmap_bytes); // ================================================================== - // Sparsity statistics (always run, pure CPU read of HostTensor) + // Sparsity statistics (pure CPU, reads block_map HostTensor) // ================================================================== std::size_t total_blocks = 0; std::size_t active_blocks = 0; @@ -366,6 +398,14 @@ bool run_test(const ck_tile::ArgParser& arg_parser) std::cout << "Converting block map to VSA LUT (delta, CPU)..." << std::endl; auto vsa_lut_cpu = sparge::block_map_to_vsa_lut_delta(block_relation_onehot); + // DtoH: LUT + valid_block_num (only for validation) + sparge::VSALut vsa_lut_gpu{ + ck_tile::HostTensor({batch, nhead, num_q_blocks, num_k_blocks}), + ck_tile::HostTensor({batch, nhead, num_q_blocks}), + }; + lut_buf.FromDevice(vsa_lut_gpu.lut.data(), lut_bytes); + valid_buf.FromDevice(vsa_lut_gpu.valid_block_num.data(), valid_bytes); + // Validate block map std::cout << "\n--- Validating GPU block map vs CPU golden ---" << std::endl; { @@ -378,20 +418,16 @@ bool run_test(const ck_tile::ArgParser& arg_parser) { for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) { - if(block_map_gpu(b, h, qb, kb) != - block_relation_onehot(b, h, qb, kb)) + if(block_map_gpu(b, h, qb, kb) != block_relation_onehot(b, h, qb, kb)) { bmap_mismatches++; if(bmap_mismatches <= 10) { std::cout - << " block_map mismatch at [" << b << "," << h << "," - << qb << "," << kb - << "]: GPU=" - << static_cast(block_map_gpu(b, h, qb, kb)) - << " CPU=" - << static_cast( - block_relation_onehot(b, h, qb, kb)) + << " block_map mismatch at [" << b << "," << h << "," << qb + << "," << kb << "]: GPU=" + << static_cast(block_map_gpu(b, h, qb, kb)) << " CPU=" + << static_cast(block_relation_onehot(b, h, qb, kb)) << std::endl; } } @@ -429,28 +465,24 @@ bool run_test(const ck_tile::ArgParser& arg_parser) valid_mismatches++; if(valid_mismatches <= 5) { - std::cout - << " valid_block_num mismatch at [" << b << "," << h - << "," << qb - << "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb) - << " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb) - << std::endl; + std::cout << " valid_block_num mismatch at [" << b << "," << h + << "," << qb + << "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb) + << " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb) + << std::endl; } } for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) { - if(vsa_lut_gpu.lut(b, h, qb, kb) != - vsa_lut_cpu.lut(b, h, qb, kb)) + if(vsa_lut_gpu.lut(b, h, qb, kb) != vsa_lut_cpu.lut(b, h, qb, kb)) { lut_mismatches++; if(lut_mismatches <= 10) { std::cout << " LUT mismatch at [" << b << "," << h << "," << qb - << "," << kb - << "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb) - << " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb) - << std::endl; + << "," << kb << "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb) + << " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb) << std::endl; } } } diff --git a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp deleted file mode 100644 index 5f9c2676ddb..00000000000 --- a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#include "vsa_sparge_attention.h" -#include "fmha_fwd_trek.hpp" -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" -#include "ck_tile/host/device_memory.hpp" -#include - -template -ck_tile::HostTensor -vsa_sparge_attention(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - const ck_tile::HostTensor& TV, - const ck_tile::HostTensor& TKV_block_idx, - const ck_tile::HostTensor& TKV_blocks, - ck_tile::HostTensor& Y, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - bool i_perm, - bool o_perm, - int max_seqlen_q, - int max_seqlen_k, - int log_level) -{ - static_assert(std::is_same_v || - std::is_same_v, - "VSA sparse attention supports fp16/bf16 only."); - std::string data_type = "fp16"; - if constexpr(std::is_same_v) - { - data_type = "bf16"; - } - - if(max_seqlen_q == 0) - max_seqlen_q = seqlen_q; - if(max_seqlen_k == 0) - max_seqlen_k = seqlen_k; - bool is_v_rowmajor = true; - float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); - std::string msk_str = "0"; - mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); - - const ck_tile::index_t shape_seqlen_q = seqlen_q; - const ck_tile::index_t shape_seqlen_k = seqlen_k; - - ck_tile::stream_config stream_config{nullptr, - false, // time_kernel - log_level, - 0, - 1, - false}; - - ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); - ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes()); - ck_tile::DeviceMem lut_buf(TKV_block_idx.get_element_space_size_in_bytes()); - ck_tile::DeviceMem valid_block_num_buf(TKV_blocks.get_element_space_size_in_bytes()); - ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes()); - - q_buf.ToDevice(TQ.data()); - k_buf.ToDevice(TK.data()); - v_buf.ToDevice(TV.data()); - lut_buf.ToDevice(TKV_block_idx.data()); - valid_block_num_buf.ToDevice(TKV_blocks.data()); - - const auto init_args = [&](auto& args) { - assert(nhead % nhead_k == 0); - const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); - const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_v = [&]() { - if(is_v_rowmajor) - return i_perm ? hdim_v : nhead_k * hdim_v; - else - return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); - }(); - const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); - const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; - const ck_tile::index_t nhead_stride_v = [&]() { - if(is_v_rowmajor) - return i_perm ? shape_seqlen_k * hdim_v : hdim_v; - else - return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; - }(); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; - const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); - - args.q_ptr = q_buf.GetDeviceBuffer(); - args.k_ptr = k_buf.GetDeviceBuffer(); - args.v_ptr = v_buf.GetDeviceBuffer(); - args.lut_ptr = lut_buf.GetDeviceBuffer(); - args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer(); - - args.batch = batch; - args.seqlen_q = shape_seqlen_q; - args.hdim_q = hdim_q; - args.hdim_v = hdim_v; - args.nhead_q = nhead; - args.nhead_k = nhead_k; - - args.stride_q = stride_q; - args.stride_k = stride_k; - args.stride_v = stride_v; - args.nhead_stride_q = nhead_stride_q; - args.nhead_stride_k = nhead_stride_k; - args.nhead_stride_v = nhead_stride_v; - args.batch_stride_q = batch_stride_q; - args.batch_stride_k = batch_stride_k; - args.batch_stride_v = batch_stride_v; - - args.o_ptr = o_buf.GetDeviceBuffer(); - - args.seqlen_k = shape_seqlen_k; - args.max_seqlen_q = max_seqlen_q; - - args.scale_s = scale_s; - - args.stride_o = stride_o; - args.nhead_stride_o = nhead_stride_o; - args.batch_stride_o = batch_stride_o; - - args.window_size_left = mask.left; - args.window_size_right = mask.right; - args.mask_type = static_cast(mask.type); - }; - - const auto init_traits = [&](auto& traits) { - traits.hdim_q = hdim_q; - traits.hdim_v = hdim_v; - traits.data_type = data_type; - traits.is_v_rowmajor = is_v_rowmajor; - traits.mask_type = mask.type; - }; - - fmha_vsa_fwd_traits fmha_traits; - init_traits(fmha_traits); - - fmha_vsa_fwd_args args; - init_args(args); - - sparge_vsa_fwd(fmha_traits, args, stream_config); - - o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes()); - - return Y; -} - -template ck_tile::HostTensor -vsa_sparge_attention(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - int, - bool, - bool, - int, - int, - int); - -template ck_tile::HostTensor -vsa_sparge_attention(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - int, - bool, - bool, - int, - int, - int); diff --git a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h deleted file mode 100644 index d51a7e8c00b..00000000000 --- a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#pragma once -#include -#include -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" - -template -ck_tile::HostTensor -vsa_sparge_attention(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - const ck_tile::HostTensor& TV, - const ck_tile::HostTensor& TKV_block_idx, - const ck_tile::HostTensor& TKV_blocks, - ck_tile::HostTensor& Y, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - bool i_perm, - bool o_perm, - int max_seqlen_q, - int max_seqlen_k, - int log_level = 0); From ab44b835667e29cf5ba844d2a7bdd52fc9f4cc17 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Wed, 22 Apr 2026 13:13:37 -0400 Subject: [PATCH 5/6] refactor to combine two kernel --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 131 +-- .../codegen/ops/fmha_fwd_jenga.py | 141 +++- .../codegen/ops/fmha_fwd_vsa.py | 141 +++- .../codegen/ops/sparge_fwd_jenga.py | 799 ------------------ .../codegen/ops/sparge_fwd_vsa.py | 799 ------------------ .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 16 +- .../50_sparse_attn/jenga_sparge_attention.cpp | 189 ----- .../50_sparse_attn/jenga_sparge_attention.h | 27 - .../50_sparse_attn/sparge_blockmap_inst.cpp | 139 +++ .../50_sparse_attn/sparge_blockmap_trek.hpp | 13 + .../ck_tile/50_sparse_attn/test_sparge.cpp | 432 ++++++++++ .../test_sparge_jenga_sparse_attn.cpp | 422 --------- .../test_sparge_vsa_sparse_attn.cpp | 597 ------------- ...ock_fmha_pipeline_qr_ks_vs_async_jenga.hpp | 40 +- 14 files changed, 896 insertions(+), 2990 deletions(-) delete mode 100644 example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py delete mode 100644 example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py delete mode 100644 example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp delete mode 100644 example/ck_tile/50_sparse_attn/jenga_sparge_attention.h create mode 100644 example/ck_tile/50_sparse_attn/test_sparge.cpp delete mode 100644 example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp delete mode 100644 example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index f234f631b6b..b20a661805f 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -88,68 +88,6 @@ target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) -# ============================================================================ -# Sparge Jenga (64x128 tile) -# ============================================================================ -set(SPARGE_JENGA_CODE_GEN_ARGS - ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api sparge_fwd_jenga - --receipt 600 -) - -execute_process( - COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS} - --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt - RESULT_VARIABLE ret -) -if(ret AND NOT ret EQUAL 0) - message(FATAL_ERROR "Failed to generate Sparge Jenga kernel list") -endif() - -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt SPARGE_JENGA_GEN_BLOBS) - -add_custom_command( - OUTPUT ${SPARGE_JENGA_GEN_BLOBS} - COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS} - --output_dir ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${CODE_GEN_SCRIPTS} - COMMENT "Generate CK Tile Sparge Jenga kernels" -) - -message(STATUS "Sparge Jenga kernel files to be generated: ${SPARGE_JENGA_GEN_BLOBS}") - -set(SPARGE_JENGA_INSTANCES "tile_sparge_jenga_instances") - -add_library(${SPARGE_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL - ${SPARGE_JENGA_GEN_BLOBS} - ${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp -) -target_include_directories(${SPARGE_JENGA_INSTANCES} PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn -) -set_source_files_properties(${SPARGE_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) -set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp PROPERTIES LANGUAGE HIP) -set_property(TARGET ${SPARGE_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) - -target_compile_options(${SPARGE_JENGA_INSTANCES} PRIVATE - -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN - -DCK_TILE_FMHA_FWD_FAST_EXP2 - -Wno-undefined-func-template - -Wno-float-equal -) - -# Sparge + Jenga Example executable -set(EXAMPLE_SPARGE_JENGA_SPARSE_ATTN "tile_example_sparge_jenga_sparse_attn") -message(DEBUG "adding example ${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN}") -add_executable(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_jenga_sparse_attn.cpp) -target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARGE_JENGA_INSTANCES}) -target_include_directories(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_compile_options(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE - -Wno-undefined-func-template - -Wno-float-equal -) - # ============================================================================ # VSA Sparse Attention # ============================================================================ @@ -215,55 +153,6 @@ target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) -# ============================================================================ -# Sparge VSA (64x128 tile) -# ============================================================================ -set(SPARGE_VSA_CODE_GEN_ARGS - ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api sparge_fwd_vsa - --receipt 600 -) - -execute_process( - COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS} - --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt - RESULT_VARIABLE ret -) -if(ret AND NOT ret EQUAL 0) - message(FATAL_ERROR "Failed to generate Sparge VSA kernel list") -endif() - -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt SPARGE_VSA_GEN_BLOBS) - -add_custom_command( - OUTPUT ${SPARGE_VSA_GEN_BLOBS} - COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS} - --output_dir ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${CODE_GEN_SCRIPTS} - COMMENT "Generate CK Tile Sparge VSA kernels" -) - -message(STATUS "Sparge VSA kernel files to be generated: ${SPARGE_VSA_GEN_BLOBS}") - -set(SPARGE_VSA_INSTANCES "tile_sparge_vsa_instances") - -add_library(${SPARGE_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL - ${SPARGE_VSA_GEN_BLOBS} -) -target_include_directories(${SPARGE_VSA_INSTANCES} PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn -) -set_source_files_properties(${SPARGE_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) -set_property(TARGET ${SPARGE_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) - -target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE - -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN - -DCK_TILE_FMHA_FWD_FAST_EXP2 - -Wno-undefined-func-template - -Wno-float-equal -) - # ============================================================================ # Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen) # ============================================================================ @@ -289,16 +178,20 @@ target_compile_options(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE -Wno-float-equal ) -# Sparge + VSA Example executable (now links blockmap kernel too) -set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn") -message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}") -add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp) -target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} - ${SPARGE_VSA_INSTANCES} +# ---------------------------------------------------------------------------- +# Build unified Sparge test: combines blockmap, Jenga, and VSA attention +# for end-to-end evaluation and timing in a single executable. +# ---------------------------------------------------------------------------- +set(EXAMPLE_SPARGE "tile_example_sparge") +message(DEBUG "adding example ${EXAMPLE_SPARGE}") +add_executable(${EXAMPLE_SPARGE} EXCLUDE_FROM_ALL test_sparge.cpp) +target_link_libraries(${EXAMPLE_SPARGE} + ${SPARSE_ATTN_JENGA_INSTANCES} + ${SPARSE_ATTN_VSA_INSTANCES} ${SPARGE_BLOCKMAP_INSTANCES} ) -target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE +target_include_directories(${EXAMPLE_SPARGE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_compile_options(${EXAMPLE_SPARGE} PRIVATE -Wno-undefined-func-template -Wno-float-equal ) diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index a3d32652a98..1f0a78048d9 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -141,6 +141,17 @@ def update_file(file_path, content): constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} + +template<> +void fmha_jenga_fwd_oneshot_(const ck_tile::stream_config& s, fmha_jenga_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} """ FMHA_FWD_API_FILENAME = "fmha_jenga_fwd_api.cpp" @@ -219,6 +230,45 @@ def update_file(file_path, content): }} """ +FMHA_FWD_ONESHOT_API_FILENAME = "fmha_jenga_fwd_oneshot_api.cpp" +FMHA_FWD_ONESHOT_API = """ +#include "fmha_fwd_trek.hpp" +#include + +void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{ + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + std::cerr << "fmha_jenga_fwd_oneshot: no matching dispatch (dtype=" << t.data_type + << " hdim_q=" << t.hdim_q << " hdim_v=" << t.hdim_v + << " seqlen_q=" << a.seqlen_q << " seqlen_k=" << a.seqlen_k + << " mask=" << static_cast(t.mask_type) << ")" << std::endl; +}} +""" + +FMHA_FWD_ONESHOT_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_ONESHOT_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + fmha_jenga_fwd_oneshot_(s, a); + return; + }} +""" + @dataclass class CppConstraint: @@ -274,10 +324,7 @@ def scheck(self) -> str: @property def seqtune(self) -> str: - if self.bm0 == 128: - return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true - else: - return f"a.seqlen_q <= {self.bm0}" + return "true" @property def skcheck(self) -> str: @@ -447,6 +494,67 @@ def api(self) -> str: per_tr_load += " (void)t ; (void)s ; (void)a;" return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + @property + def oneshot_api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_ONESHOT_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_ONESHOT_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_ONESHOT_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_ONESHOT_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_ONESHOT_API.format(F_dispatch=per_tr_load) + @dataclass class FmhaFwdTileSize: @@ -582,6 +690,27 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128): [ + FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128 + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + ), FmhaFwdTileSize( # fmt: skip 16, 32, @@ -780,7 +909,7 @@ def get_fwd_blobs( for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if tile.F_bm0 != 128 or tile.F_bn0 != 128: + if tile.F_bm0 != 64 or tile.F_bn0 != 128: continue if pipeline.tag != "qr_async": continue @@ -846,6 +975,7 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + update_file(autogen_dir / FMHA_FWD_ONESHOT_API_FILENAME, api_pool.oneshot_api) def write_blobs( @@ -865,3 +995,4 @@ def list_blobs( for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_ONESHOT_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index 038738de246..217cfcfe2a4 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -141,6 +141,17 @@ def update_file(file_path, content): constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} + +template<> +void fmha_vsa_fwd_oneshot_(const ck_tile::stream_config& s, fmha_vsa_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} """ FMHA_FWD_API_FILENAME = "fmha_vsa_fwd_api.cpp" @@ -219,6 +230,45 @@ def update_file(file_path, content): }} """ +FMHA_FWD_ONESHOT_API_FILENAME = "fmha_vsa_fwd_oneshot_api.cpp" +FMHA_FWD_ONESHOT_API = """ +#include "fmha_fwd_trek.hpp" +#include + +void fmha_vsa_fwd_oneshot(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{ + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + std::cerr << "fmha_vsa_fwd_oneshot: no matching dispatch (dtype=" << t.data_type + << " hdim_q=" << t.hdim_q << " hdim_v=" << t.hdim_v + << " seqlen_q=" << a.seqlen_q << " seqlen_k=" << a.seqlen_k + << " mask=" << static_cast(t.mask_type) << ")" << std::endl; +}} +""" + +FMHA_FWD_ONESHOT_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_ONESHOT_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + fmha_vsa_fwd_oneshot_(s, a); + return; + }} +""" + @dataclass class CppConstraint: @@ -274,10 +324,7 @@ def scheck(self) -> str: @property def seqtune(self) -> str: - if self.bm0 == 128: - return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true - else: - return f"a.seqlen_q <= {self.bm0}" + return "true" @property def skcheck(self) -> str: @@ -447,6 +494,67 @@ def api(self) -> str: per_tr_load += " (void)t ; (void)s ; (void)a;" return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + @property + def oneshot_api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_ONESHOT_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_ONESHOT_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_ONESHOT_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_ONESHOT_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_ONESHOT_API.format(F_dispatch=per_tr_load) + @dataclass class FmhaFwdTileSize: @@ -582,6 +690,27 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128): [ + FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128 + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + ), FmhaFwdTileSize( # fmt: skip 16, 32, @@ -780,7 +909,7 @@ def get_fwd_blobs( for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if tile.F_bm0 != 128 or tile.F_bn0 != 128: + if tile.F_bm0 != 64 or tile.F_bn0 != 128: continue if pipeline.tag != "qr_async_vsa": continue @@ -846,6 +975,7 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + update_file(autogen_dir / FMHA_FWD_ONESHOT_API_FILENAME, api_pool.oneshot_api) def write_blobs( @@ -865,3 +995,4 @@ def list_blobs( for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_ONESHOT_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py deleted file mode 100644 index 872da2326ea..00000000000 --- a/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py +++ /dev/null @@ -1,799 +0,0 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT -# generate kernel instances to speed up compilation - -import copy -from dataclasses import dataclass, field -import fnmatch -import itertools -import os -import os.path as path -from pathlib import Path -from typing import List, Optional, Tuple - -from codegen.cpp_symbol_map import ( - BOOL_MAP, - FWD_DTYPE_MAP, - LAYOUT_MAP, - MODE_MAP, - PIPELINE_ENUM_MAP, - PIPELINE_MAP, - get_mask_check_map, - get_mask_map, -) - -GEN_DIR = "" - - -def update_file(file_path, content): - """Update the file at file_path with the given content if it differs from the existing content. - - It avoids unnecessary touching of the file which triggers rebuilds - """ - - existing_content = "" - if path.exists(file_path): - with open(file_path, "r") as file: - existing_content = file.read() - if existing_content == content: - return - with open(file_path, "w") as file: - file.write(content) - - -DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16} - -K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} - -FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n -// auto generated by generate.py -#include "ck_tile/ops/fmha/block/variants.hpp" -#include "fmha_fwd_trek.hpp" -#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" -#include "kernel/fmha_fwd_jenga_kernel.hpp" - -""" - -# NOTE: Jenga sparse attention kernel has the following restrictions enforced by static_assert: -# - Group mode: NOT supported (batch mode only) -# - Bias: NOT supported (NO_BIAS only) -# - LSE output: NOT supported (false only) -# - Dropout: NOT supported (false only) -# - Logits soft-cap: NOT supported (false only) -# - FP8 static quantization: NOT supported (NO_SCALE only) -# The template below hardcodes these unsupported features accordingly. - -FMHA_FWD_KERNEL_BODY = """ -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; - -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, - ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, - ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, - {F_vlayout}>; - -// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum, -// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - false, // has_logits_soft_cap - NOT supported - ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported - false, // store_lse - NOT supported - false, // has_dropout - NOT supported - false, // has_randval - NOT supported - ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported - {F_occupancy}, - false>; - -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported) - -using fmha_mask_{F_idx} = {F_mask}; - -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaSparseFwdTypeConfig::QDataType, - typename FmhaSparseFwdTypeConfig::KDataType, - typename FmhaSparseFwdTypeConfig::VDataType, - typename FmhaSparseFwdTypeConfig::SaccDataType, - typename FmhaSparseFwdTypeConfig::SMPLComputeDataType, - typename FmhaSparseFwdTypeConfig::BiasDataType, - typename FmhaSparseFwdTypeConfig::RandValOutputDataType, - typename FmhaSparseFwdTypeConfig::LSEDataType, - typename FmhaSparseFwdTypeConfig::PDataType, - typename FmhaSparseFwdTypeConfig::OaccDataType, - typename FmhaSparseFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, - {F_mode}, - fmha_variant_{F_idx}, - fmha_mask_{F_idx}, - {F_trload}, - fmha_trait_{F_idx}>; - -using fmha_pipeline_{F_idx} = {F_pipeline}< - fmha_pipeline_problem_{F_idx}>; - -using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; - -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdJengaKernel; - -using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; - -#include - -template<> -float fmha_jenga_fwd_(const ck_tile::stream_config& s, fmha_jenga_fwd_args a) -{{ - using k_ = fmha_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << "{F_kernel_name}" << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} -""" - -FMHA_FWD_API_FILENAME = "sparge_jenga_fwd_api.cpp" -FMHA_FWD_API = """ -#include - -#include - -namespace {{ -bool get_num_cus(unsigned& num_cus) {{ - int device; - auto status = hipGetDevice(&device); - if(status != hipSuccess) {{ - fprintf(stderr, "failed to get device"); - return false; - }} - - hipDeviceProp_t props{{}}; - status = hipGetDeviceProperties(&props, device); - if(status != hipSuccess) {{ - fprintf(stderr, "failed to get device properties"); - return false; - }} - - num_cus = props.multiProcessorCount; - return true; -}} - -unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ - const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; - const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 - - return batch * nheads * num_m_blocks * num_n_blocks; -}} -}} // namespace - -float sparge_jenga_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{ - float r = -1; - - [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate - - unsigned num_cus; - if (!get_num_cus(num_cus)) {{ - return r; - }} - - [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ - return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); - }}; - - const bool has_load_tr = ck_tile::is_load_tr_supported(); - -{F_dispatch} - return r; -}} -""" - -FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ -{F_dtype_case} - }} -""" - -FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} - }} -""" -FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ -{F_inner_dispatch} - }} -""" - -FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && - ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; - return fmha_jenga_fwd_(s, a); - }} -""" - - -@dataclass -class CppConstraint: - bool_expr: str = None - - def __str__(self): - if self.bool_expr is None: - return "true" - else: - return f"{self.bool_expr}" - - def __and__(self, other): - return CppConstraint(f"({str(self)}) && ({str(other)})") - - -@dataclass -class FmhaFwdApiTrait: - pipeline_tag: str - # sync with fmha_fwd_traits<>, to generate fallback calls - hdim: str - dtype: str # data type - mode: str # value from MODE_MAP - bm0: int # tile size along q seqlen (block size) - bn0: int # tile size along qk seqlen - bk0: int # tile size along qk gemm unroll - bn1: int # tile size along v head_dim - bk1: int # tile size along kv gemm unroll - bk0max: int - vlayout: str - logits: str - mask: str - spad: str - skpad: str - dpad: str - dvpad: str - tr_load: str - constraint: CppConstraint - - @property - def name(self) -> str: - return ( - f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" - ) - - @property - def scheck(self) -> str: - if self.mode == "group": - return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true - if self.spad == "t": - return "true" # always support - return "true" - - @property - def seqtune(self) -> str: - return "true" - - @property - def skcheck(self) -> str: - if self.mode == "group": - return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true - if self.skpad == "t": - return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" - return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" - - @property - def dcheck(self) -> str: - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == "t": - return f"a.hdim_q % {vec} == 0" - assert False - - @property - def dvcheck(self) -> str: - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == "t": - return f"a.hdim_v % {vec} == 0" - assert False - - -@dataclass -class FmhaFwdPipeline: - tag: str - - F_vlayout: str # row/col - F_spad: str # true/false - F_skpad: str # - F_dpad: str # - F_dvpad: str # - F_logits: str # t/f - F_mask: str # value from MASK_MAP - F_trload: str # true/false - F_constraint: CppConstraint = field(default_factory=CppConstraint) - - @property - def name(self) -> str: - def pad_name() -> str: - n = "" - if self.F_spad == "t": - n += "s" - if self.F_skpad == "t": - n += "sk" - if self.F_dpad == "t": - n += "d" - if self.F_dvpad == "t": - n += "dv" - if n != "": - n = "p" + n - return n - - pn = pad_name() - n = f"{self.tag}_v{self.F_vlayout[0]}" - if pn != "": - n += f"_{pn}" - else: - n += "_npad" - - if self.F_logits == "t": - n += "_logits" - else: - n += "_nlogits" - - n += "_nbias" - - if self.F_mask[0:2] == "s_": - if self.F_mask == "s_mask": - n += "_mask" - else: - n += "_nmask" - else: - if self.F_mask != "no": - n += f"_m{self.F_mask[0]}" - else: - n += "_nmask" - - n += "_nskip" - - n += "_nsquant" - - if self.F_trload == "t": - n += "_trload" - else: - n += "_ntrload" - - return n - - -class FmhaFwdApiPool: - def __init__(self, mask_impl): - self.pool = dict() - self.mask_impl = mask_impl - - def register_traits(self, trait: FmhaFwdApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - hdim = trait.hdim, trait.bn1 - if hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][hdim] = list() - - self.pool[trait.dtype][hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - tr_load_cond_map = {"t": "has_load_tr", "f": "true"} - - per_tr_load = str() - for tr_load in ["t", "f"]: - per_dtypes = str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case = str() - for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): - traits = [ - t - for t in self.pool[dtype][(hdim, hdim_v)] - if tr_load == t.tr_load - ] - inners = str() - for k, trait in enumerate(traits): - if_k = "if" if k == 0 else "else if" - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( - F_if=if_k, - F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], - # F_logits removed - hardcoded to false (NOT supported) - F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], - F_trload=BOOL_MAP[trait.tr_load], - F_scheck=trait.scheck, - F_seqtune=trait.seqtune, - F_skcheck=trait.skcheck, - F_dcheck=trait.dcheck, - F_dvcheck=trait.dvcheck, - F_constraint=trait.constraint, - F_spad=BOOL_MAP[trait.spad], - F_skpad=BOOL_MAP[trait.skpad], - F_dpad=BOOL_MAP[trait.dpad], - F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, - F_bn0=trait.bn0, - F_bk0=trait.bk0, - F_bn1=trait.bn1, - F_bk1=trait.bk1, - F_bk0max=trait.bk0max, - F_hdim=hdim, - F_dtype=FWD_DTYPE_MAP[dtype], - ) - if_j = "if" if j == 0 else "else if" - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( - F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners - ) - if_i = "if" if i == 0 else "else if" - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( - F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case - ) - per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( - F_if="if", - F_trload_cond=tr_load_cond_map[tr_load], - F_dtype_case=per_dtypes, - ) - if not per_tr_load: - # empty string we add some ignore to suppress warning in api - per_tr_load += " (void)t ; (void)s ; (void)a;" - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) - - -@dataclass -class FmhaFwdTileSize: - F_bm0: int # tile size along q seqlen (block size) - F_bn0: int # tile size along k seqlen - F_bk0: int # tile size along qk gemm unroll - F_bn1: int # tile size along v head_dim - F_bk1: int # tile size along kv gemm unroll - F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm0: int # number of warps for gemm0 along q seqlen - F_rn0: int # number of warps for gemm0 along k seqlen - F_rk0: int # number of warps for gemm0 along head dim q (not used) - F_rm1: int # number of warps for gemm1 along q seqlen - F_rn1: int # number of warps for gemm1 along head dim v - F_rk1: int # number of warps for gemm1 along k seqlen (not used) - F_wm0: int # gemm0 warp size along m - F_wn0: int # gemm0 warp size along n - F_wk0: int # gemm0 warp size along k - F_wm1: int # gemm1 warp size along m - F_wn1: int # gemm1 warp size along n - F_wk1: int # gemm1 warp size along k - F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - F_constraint: CppConstraint = field(default_factory=CppConstraint) - - @property - def name(self) -> str: - return ( - f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" - + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" - + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" - + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") - ) - - -@dataclass -class FmhaFwdKernel: - F_idx: int # this is not a tunable, but a counter to differentiate symbol - F_hdim: int # hdim - F_dtype: str # data type - F_mode: str # value from MODE_MAP - F_tile: FmhaFwdTileSize - F_pipeline: FmhaFwdPipeline - mask_impl: str - - @property - def template(self) -> str: - # kernel_body removed - unused - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( - F_idx=self.F_idx, - F_hdim=self.F_hdim, - F_dtype=FWD_DTYPE_MAP[self.F_dtype], - F_bm0=self.F_tile.F_bm0, - F_bn0=self.F_tile.F_bn0, - F_bk0=self.F_tile.F_bk0, - F_bn1=self.F_tile.F_bn1, - F_bk1=self.F_tile.F_bk1, - F_bk0max=self.F_tile.F_bk0max, - F_rm0=self.F_tile.F_rm0, - F_rn0=self.F_tile.F_rn0, - F_rk0=self.F_tile.F_rk0, - F_rm1=self.F_tile.F_rm1, - F_rn1=self.F_tile.F_rn1, - F_rk1=self.F_tile.F_rk1, - F_wm0=self.F_tile.F_wm0, - F_wn0=self.F_tile.F_wn0, - F_wk0=self.F_tile.F_wk0, - F_wm1=self.F_tile.F_wm1, - F_wn1=self.F_tile.F_wn1, - F_wk1=self.F_tile.F_wk1, - F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad=BOOL_MAP[self.F_pipeline.F_spad], - F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], - # F_logits removed - hardcoded to false in template (NOT supported) - F_occupancy=self.F_tile.F_occupancy, - F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode=MODE_MAP[self.F_mode], - F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], - F_trload=BOOL_MAP[self.F_pipeline.F_trload], - F_kernel_name=self.name, - ) - - @property - def name(self) -> str: - # TODO: we don't encode idx here - return ( - f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" - + self.F_tile.name - + "_" - + self.F_pipeline.name - ) - - @property - def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaFwdApiTrait: - return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - logits=self.F_pipeline.F_logits, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad, - tr_load=self.F_pipeline.F_trload, - constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, - ) - - -class KernelComponentFactory: - # TODO: design a more practical way to do it - # this is current supported tile size per hdim - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype == "fp16" or dtype == "bf16": - return { - # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (128, 128): [ - FmhaFwdTileSize( - 64, - 128, - 64, - 128, - 64, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - ), - ], - # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - } - else: - return None - - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # NOTE: logits soft-cap is NOT supported by Jenga sparse attention (enforced by static_assert) - pipelines = [] - if dtype in ["fp16", "bf16"]: - for logits, mask in itertools.product( - ["f"], # logits soft-cap NOT supported, always false - get_mask_map(mask_impl).keys(), - ): - if hdim == 256 and hdim_v == 256: - # jenga fmha only supports dim <= 192 for now. - continue - pipelines.append( - FmhaFwdPipeline( # fmt: skip - "qr_async", - "row", - "t", - "f", - "t", - "t", - logits, - mask, - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( # fmt: skip - "qr_async", - "row", - "t", - "t", - "t", - "t", - logits, - mask, - "f", - ) - ) - else: - assert False - return pipelines - - -class CustomFactory(KernelComponentFactory): - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) - if dtype == "fp16" or dtype == "bf16": - if (128, 128) in result.keys(): - result[(128, 128)].insert( - 0, - FmhaFwdTileSize( - 64, - 128, - 64, - 128, - 64, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - CppConstraint( - "get_num_blocks(128) < num_cus * min_cu_util_rate" - ), - ), - ) - return result - - -def get_fwd_blobs( - kernel_filter: Optional[str], receipt, optdim_list, mask_impl -) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: - gen = list() - api_pool = FmhaFwdApiPool(mask_impl) - - factory = ( - CustomFactory - if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" - else KernelComponentFactory - ) - - # Only generate fp16/bf16 kernels for now. - # NOTE: Jenga sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert) - for dtype in ["fp16", "bf16"]: - d = factory.get_hdim_tile_size_dict(dtype) - if d is None: - continue - for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): - for tile, pipeline in itertools.product( - tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) - ): - if pipeline.tag != "qr_async": - continue - k = FmhaFwdKernel( - F_idx=2, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl, - ) - if kernel_filter != "": - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= mode == "batch" - cond &= pipeline.F_logits == "f" - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ["fp16", "bf16"] - cond &= mode == "batch" - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # Aiter(mha_varlen_fwd) integration - elif receipt == 200: - cond = dtype in ["fp16", "bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # aiter::mha_fwd C++ api integration - elif receipt == 600: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - - api_pool.register_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - - -def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - update_file(autogen_dir / kernel.filename, kernel.template) - - -def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: - update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) - - -def write_blobs( - output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl -) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) - for kernel in kernels: - write_single_fwd_kernel(kernel, output_dir) - write_fwd_api(api_pool, output_dir) - - -def list_blobs( - file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl -) -> None: - with file_path.open("a") as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) - for kernel in kernels: - f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") - f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py deleted file mode 100644 index c9a389df3fa..00000000000 --- a/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py +++ /dev/null @@ -1,799 +0,0 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT -# generate kernel instances to speed up compilation - -import copy -from dataclasses import dataclass, field -import fnmatch -import itertools -import os -import os.path as path -from pathlib import Path -from typing import List, Optional, Tuple - -from codegen.cpp_symbol_map import ( - BOOL_MAP, - FWD_DTYPE_MAP, - LAYOUT_MAP, - MODE_MAP, - PIPELINE_ENUM_MAP, - PIPELINE_MAP, - get_mask_check_map, - get_mask_map, -) - -GEN_DIR = "" - - -def update_file(file_path, content): - """Update the file at file_path with the given content if it differs from the existing content. - - It avoids unnecessary touching of the file which triggers rebuilds - """ - - existing_content = "" - if path.exists(file_path): - with open(file_path, "r") as file: - existing_content = file.read() - if existing_content == content: - return - with open(file_path, "w") as file: - file.write(content) - - -DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16} - -K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} - -FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n -// auto generated by generate.py -#include "ck_tile/ops/fmha/block/variants.hpp" -#include "fmha_fwd_trek.hpp" -#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" -#include "kernel/fmha_fwd_vsa_kernel.hpp" - -""" - -# NOTE: VSA sparse attention kernel has the following restrictions enforced by static_assert: -# - Group mode: NOT supported (batch mode only) -# - Bias: NOT supported (NO_BIAS only) -# - LSE output: NOT supported (false only) -# - Dropout: NOT supported (false only) -# - Logits soft-cap: NOT supported (false only) -# - FP8 static quantization: NOT supported (NO_SCALE only) -# The template below hardcodes these unsupported features accordingly. - -FMHA_FWD_KERNEL_BODY = """ -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; - -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, - ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, - ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, - {F_vlayout}>; - -// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum, -// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - false, // has_logits_soft_cap - NOT supported - ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported - false, // store_lse - NOT supported - false, // has_dropout - NOT supported - false, // has_randval - NOT supported - ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported - {F_occupancy}, - false>; - -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported) - -using fmha_mask_{F_idx} = {F_mask}; - -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaSparseFwdTypeConfig::QDataType, - typename FmhaSparseFwdTypeConfig::KDataType, - typename FmhaSparseFwdTypeConfig::VDataType, - typename FmhaSparseFwdTypeConfig::SaccDataType, - typename FmhaSparseFwdTypeConfig::SMPLComputeDataType, - typename FmhaSparseFwdTypeConfig::BiasDataType, - typename FmhaSparseFwdTypeConfig::RandValOutputDataType, - typename FmhaSparseFwdTypeConfig::LSEDataType, - typename FmhaSparseFwdTypeConfig::PDataType, - typename FmhaSparseFwdTypeConfig::OaccDataType, - typename FmhaSparseFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, - {F_mode}, - fmha_variant_{F_idx}, - fmha_mask_{F_idx}, - {F_trload}, - fmha_trait_{F_idx}>; - -using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA< - fmha_pipeline_problem_{F_idx}>; - -using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; - -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdVSAKernel; - -using trait_{F_idx} = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; - -#include - -template<> -float fmha_vsa_fwd_(const ck_tile::stream_config& s, fmha_vsa_fwd_args a) -{{ - using k_ = fmha_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << "{F_kernel_name}" << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} -""" - -FMHA_FWD_API_FILENAME = "sparge_vsa_fwd_api.cpp" -FMHA_FWD_API = """ -#include - -#include - -namespace {{ -bool get_num_cus(unsigned& num_cus) {{ - int device; - auto status = hipGetDevice(&device); - if(status != hipSuccess) {{ - fprintf(stderr, "failed to get device"); - return false; - }} - - hipDeviceProp_t props{{}}; - status = hipGetDeviceProperties(&props, device); - if(status != hipSuccess) {{ - fprintf(stderr, "failed to get device properties"); - return false; - }} - - num_cus = props.multiProcessorCount; - return true; -}} - -unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ - const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; - const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 - - return batch * nheads * num_m_blocks * num_n_blocks; -}} -}} // namespace - -float sparge_vsa_fwd(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{ - float r = -1; - - [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate - - unsigned num_cus; - if (!get_num_cus(num_cus)) {{ - return r; - }} - - [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ - return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); - }}; - - const bool has_load_tr = ck_tile::is_load_tr_supported(); - -{F_dispatch} - return r; -}} -""" - -FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ -{F_dtype_case} - }} -""" - -FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} - }} -""" -FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ -{F_inner_dispatch} - }} -""" - -FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && - ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; - return fmha_vsa_fwd_(s, a); - }} -""" - - -@dataclass -class CppConstraint: - bool_expr: str = None - - def __str__(self): - if self.bool_expr is None: - return "true" - else: - return f"{self.bool_expr}" - - def __and__(self, other): - return CppConstraint(f"({str(self)}) && ({str(other)})") - - -@dataclass -class FmhaFwdApiTrait: - pipeline_tag: str - # sync with fmha_fwd_traits<>, to generate fallback calls - hdim: str - dtype: str # data type - mode: str # value from MODE_MAP - bm0: int # tile size along q seqlen (block size) - bn0: int # tile size along qk seqlen - bk0: int # tile size along qk gemm unroll - bn1: int # tile size along v head_dim - bk1: int # tile size along kv gemm unroll - bk0max: int - vlayout: str - logits: str - mask: str - spad: str - skpad: str - dpad: str - dvpad: str - tr_load: str - constraint: CppConstraint - - @property - def name(self) -> str: - return ( - f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" - ) - - @property - def scheck(self) -> str: - if self.mode == "group": - return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true - if self.spad == "t": - return "true" # always support - return "true" - - @property - def seqtune(self) -> str: - return "true" - - @property - def skcheck(self) -> str: - if self.mode == "group": - return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true - if self.skpad == "t": - return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" - return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" - - @property - def dcheck(self) -> str: - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == "t": - return f"a.hdim_q % {vec} == 0" - assert False - - @property - def dvcheck(self) -> str: - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == "t": - return f"a.hdim_v % {vec} == 0" - assert False - - -@dataclass -class FmhaFwdPipeline: - tag: str - - F_vlayout: str # row/col - F_spad: str # true/false - F_skpad: str # - F_dpad: str # - F_dvpad: str # - F_logits: str # t/f - F_mask: str # value from MASK_MAP - F_trload: str # true/false - F_constraint: CppConstraint = field(default_factory=CppConstraint) - - @property - def name(self) -> str: - def pad_name() -> str: - n = "" - if self.F_spad == "t": - n += "s" - if self.F_skpad == "t": - n += "sk" - if self.F_dpad == "t": - n += "d" - if self.F_dvpad == "t": - n += "dv" - if n != "": - n = "p" + n - return n - - pn = pad_name() - n = f"{self.tag}_v{self.F_vlayout[0]}" - if pn != "": - n += f"_{pn}" - else: - n += "_npad" - - if self.F_logits == "t": - n += "_logits" - else: - n += "_nlogits" - - n += "_nbias" - - if self.F_mask[0:2] == "s_": - if self.F_mask == "s_mask": - n += "_mask" - else: - n += "_nmask" - else: - if self.F_mask != "no": - n += f"_m{self.F_mask[0]}" - else: - n += "_nmask" - - n += "_nskip" - - n += "_nsquant" - - if self.F_trload == "t": - n += "_trload" - else: - n += "_ntrload" - - return n - - -class FmhaFwdApiPool: - def __init__(self, mask_impl): - self.pool = dict() - self.mask_impl = mask_impl - - def register_traits(self, trait: FmhaFwdApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - hdim = trait.hdim, trait.bn1 - if hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][hdim] = list() - - self.pool[trait.dtype][hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - tr_load_cond_map = {"t": "has_load_tr", "f": "true"} - - per_tr_load = str() - for tr_load in ["t", "f"]: - per_dtypes = str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case = str() - for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): - traits = [ - t - for t in self.pool[dtype][(hdim, hdim_v)] - if tr_load == t.tr_load - ] - inners = str() - for k, trait in enumerate(traits): - if_k = "if" if k == 0 else "else if" - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( - F_if=if_k, - F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], - # F_logits removed - hardcoded to false (NOT supported) - F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], - F_trload=BOOL_MAP[trait.tr_load], - F_scheck=trait.scheck, - F_seqtune=trait.seqtune, - F_skcheck=trait.skcheck, - F_dcheck=trait.dcheck, - F_dvcheck=trait.dvcheck, - F_constraint=trait.constraint, - F_spad=BOOL_MAP[trait.spad], - F_skpad=BOOL_MAP[trait.skpad], - F_dpad=BOOL_MAP[trait.dpad], - F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, - F_bn0=trait.bn0, - F_bk0=trait.bk0, - F_bn1=trait.bn1, - F_bk1=trait.bk1, - F_bk0max=trait.bk0max, - F_hdim=hdim, - F_dtype=FWD_DTYPE_MAP[dtype], - ) - if_j = "if" if j == 0 else "else if" - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( - F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners - ) - if_i = "if" if i == 0 else "else if" - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( - F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case - ) - per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( - F_if="if", - F_trload_cond=tr_load_cond_map[tr_load], - F_dtype_case=per_dtypes, - ) - if not per_tr_load: - # empty string we add some ignore to suppress warning in api - per_tr_load += " (void)t ; (void)s ; (void)a;" - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) - - -@dataclass -class FmhaFwdTileSize: - F_bm0: int # tile size along q seqlen (block size) - F_bn0: int # tile size along k seqlen - F_bk0: int # tile size along qk gemm unroll - F_bn1: int # tile size along v head_dim - F_bk1: int # tile size along kv gemm unroll - F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm0: int # number of warps for gemm0 along q seqlen - F_rn0: int # number of warps for gemm0 along k seqlen - F_rk0: int # number of warps for gemm0 along head dim q (not used) - F_rm1: int # number of warps for gemm1 along q seqlen - F_rn1: int # number of warps for gemm1 along head dim v - F_rk1: int # number of warps for gemm1 along k seqlen (not used) - F_wm0: int # gemm0 warp size along m - F_wn0: int # gemm0 warp size along n - F_wk0: int # gemm0 warp size along k - F_wm1: int # gemm1 warp size along m - F_wn1: int # gemm1 warp size along n - F_wk1: int # gemm1 warp size along k - F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - F_constraint: CppConstraint = field(default_factory=CppConstraint) - - @property - def name(self) -> str: - return ( - f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" - + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" - + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" - + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") - ) - - -@dataclass -class FmhaFwdKernel: - F_idx: int # this is not a tunable, but a counter to differentiate symbol - F_hdim: int # hdim - F_dtype: str # data type - F_mode: str # value from MODE_MAP - F_tile: FmhaFwdTileSize - F_pipeline: FmhaFwdPipeline - mask_impl: str - - @property - def template(self) -> str: - # kernel_body removed - unused - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( - F_idx=self.F_idx, - F_hdim=self.F_hdim, - F_dtype=FWD_DTYPE_MAP[self.F_dtype], - F_bm0=self.F_tile.F_bm0, - F_bn0=self.F_tile.F_bn0, - F_bk0=self.F_tile.F_bk0, - F_bn1=self.F_tile.F_bn1, - F_bk1=self.F_tile.F_bk1, - F_bk0max=self.F_tile.F_bk0max, - F_rm0=self.F_tile.F_rm0, - F_rn0=self.F_tile.F_rn0, - F_rk0=self.F_tile.F_rk0, - F_rm1=self.F_tile.F_rm1, - F_rn1=self.F_tile.F_rn1, - F_rk1=self.F_tile.F_rk1, - F_wm0=self.F_tile.F_wm0, - F_wn0=self.F_tile.F_wn0, - F_wk0=self.F_tile.F_wk0, - F_wm1=self.F_tile.F_wm1, - F_wn1=self.F_tile.F_wn1, - F_wk1=self.F_tile.F_wk1, - F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad=BOOL_MAP[self.F_pipeline.F_spad], - F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], - # F_logits removed - hardcoded to false in template (NOT supported) - F_occupancy=self.F_tile.F_occupancy, - F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode=MODE_MAP[self.F_mode], - F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], - F_trload=BOOL_MAP[self.F_pipeline.F_trload], - F_kernel_name=self.name, - ) - - @property - def name(self) -> str: - # TODO: we don't encode idx here - return ( - f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" - + self.F_tile.name - + "_" - + self.F_pipeline.name - ) - - @property - def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaFwdApiTrait: - return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - logits=self.F_pipeline.F_logits, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad, - tr_load=self.F_pipeline.F_trload, - constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, - ) - - -class KernelComponentFactory: - # TODO: design a more practical way to do it - # this is current supported tile size per hdim - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype == "fp16" or dtype == "bf16": - return { - # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (128, 128): [ - FmhaFwdTileSize( - 64, - 128, - 64, - 128, - 64, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - ), - ], - # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - } - else: - return None - - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # NOTE: logits soft-cap is NOT supported by VSA sparse attention (enforced by static_assert) - pipelines = [] - if dtype in ["fp16", "bf16"]: - for logits, mask in itertools.product( - ["f"], # logits soft-cap NOT supported, always false - get_mask_map(mask_impl).keys(), - ): - if hdim == 256 and hdim_v == 256: - # vsa fmha only supports dim <= 192 for now. - continue - pipelines.append( - FmhaFwdPipeline( - "qr_async_vsa", - "row", - "t", - "f", - "t", - "t", - logits, - mask, - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( - "qr_async_vsa", - "row", - "t", - "t", - "t", - "t", - logits, - mask, - "f", - ) - ) - else: - assert False - return pipelines - - -class CustomFactory(KernelComponentFactory): - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) - if dtype == "fp16" or dtype == "bf16": - if (128, 128) in result.keys(): - result[(128, 128)].insert( - 0, - FmhaFwdTileSize( - 64, - 128, - 64, - 128, - 64, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - CppConstraint( - "get_num_blocks(128) < num_cus * min_cu_util_rate" - ), - ), - ) - return result - - -def get_fwd_blobs( - kernel_filter: Optional[str], receipt, optdim_list, mask_impl -) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: - gen = list() - api_pool = FmhaFwdApiPool(mask_impl) - - factory = ( - CustomFactory - if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" - else KernelComponentFactory - ) - - # Only generate fp16/bf16 kernels for now. - # NOTE: VSA sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert) - for dtype in ["fp16", "bf16"]: - d = factory.get_hdim_tile_size_dict(dtype) - if d is None: - continue - for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): - for tile, pipeline in itertools.product( - tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) - ): - if pipeline.tag != "qr_async_vsa": - continue - k = FmhaFwdKernel( - F_idx=1, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl, - ) - if kernel_filter != "": - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= mode == "batch" - cond &= pipeline.F_logits == "f" - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ["fp16", "bf16"] - cond &= mode == "batch" - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # Aiter(mha_varlen_fwd) integration - elif receipt == 200: - cond = dtype in ["fp16", "bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # aiter::mha_fwd C++ api integration - elif receipt == 600: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - - api_pool.register_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - - -def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - update_file(autogen_dir / kernel.filename, kernel.template) - - -def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: - update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) - - -def write_blobs( - output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl -) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) - for kernel in kernels: - write_single_fwd_kernel(kernel, output_dir) - write_fwd_api(api_pool, output_dir) - - -def list_blobs( - file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl -) -> None: - with file_path.open("a") as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) - for kernel in kernels: - f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") - f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 25e3513d2fa..350d1803f66 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -277,13 +277,13 @@ struct fmha_jenga_fwd_traits float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); -// sparge jenga -float sparge_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); - template float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args); -float fmha_jenga_fwd(fmha_jenga_fwd_args, const ck_tile::stream_config&); +template +void fmha_jenga_fwd_oneshot_(const ck_tile::stream_config&, fmha_jenga_fwd_args); + +void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); // VSA uses the same traits structure as Jenga; aliases for clarity template float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args); -float fmha_vsa_fwd(fmha_vsa_fwd_args, const ck_tile::stream_config&); +template +void fmha_vsa_fwd_oneshot_(const ck_tile::stream_config&, fmha_vsa_fwd_args); + +void fmha_vsa_fwd_oneshot(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp deleted file mode 100644 index 88f3e08204e..00000000000 --- a/example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#include "jenga_sparge_attention.h" -#include "fmha_fwd_trek.hpp" -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" -#include "ck_tile/host/device_memory.hpp" -#include - -template -ck_tile::HostTensor -jenga_sparge_attention(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - const ck_tile::HostTensor& TV, - const ck_tile::HostTensor& Tblock_relation_onehot, - ck_tile::HostTensor& Y, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - bool i_perm, - bool o_perm, - int max_seqlen_q, - int max_seqlen_k, - int log_level) -{ - static_assert(std::is_same_v || - std::is_same_v, - "Jenga sparse attention supports fp16/bf16 only."); - std::string data_type = "fp16"; - if constexpr(std::is_same_v) - { - data_type = "bf16"; - } - - if(max_seqlen_q == 0) - max_seqlen_q = seqlen_q; - if(max_seqlen_k == 0) - max_seqlen_k = seqlen_k; - bool is_v_rowmajor = true; - float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); - std::string msk_str = "0"; - mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); - - const ck_tile::index_t shape_seqlen_q = seqlen_q; - const ck_tile::index_t shape_seqlen_k = seqlen_k; - - ck_tile::stream_config stream_config{nullptr, - false, // time_kernel - log_level, - 0, - 1, - false}; - - ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); - ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes()); - ck_tile::DeviceMem block_relation_buf(Tblock_relation_onehot.get_element_space_size_in_bytes()); - ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes()); - - q_buf.ToDevice(TQ.data()); - k_buf.ToDevice(TK.data()); - v_buf.ToDevice(TV.data()); - block_relation_buf.ToDevice(Tblock_relation_onehot.data()); - - const auto init_args = [&](auto& args) { - assert(nhead % nhead_k == 0); - const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); - const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_v = [&]() { - if(is_v_rowmajor) - return i_perm ? hdim_v : nhead_k * hdim_v; - else - return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); - }(); - const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); - const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; - const ck_tile::index_t nhead_stride_v = [&]() { - if(is_v_rowmajor) - return i_perm ? shape_seqlen_k * hdim_v : hdim_v; - else - return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; - }(); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; - const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); - - args.q_ptr = q_buf.GetDeviceBuffer(); - args.k_ptr = k_buf.GetDeviceBuffer(); - args.v_ptr = v_buf.GetDeviceBuffer(); - args.block_relation_onehot_ptr = block_relation_buf.GetDeviceBuffer(); - - args.batch = batch; - args.seqlen_q = shape_seqlen_q; - args.hdim_q = hdim_q; - args.hdim_v = hdim_v; - args.nhead_q = nhead; - args.nhead_k = nhead_k; - - args.stride_q = stride_q; - args.stride_k = stride_k; - args.stride_v = stride_v; - args.nhead_stride_q = nhead_stride_q; - args.nhead_stride_k = nhead_stride_k; - args.nhead_stride_v = nhead_stride_v; - args.batch_stride_q = batch_stride_q; - args.batch_stride_k = batch_stride_k; - args.batch_stride_v = batch_stride_v; - - args.o_ptr = o_buf.GetDeviceBuffer(); - - args.seqlen_k = shape_seqlen_k; - args.max_seqlen_q = max_seqlen_q; - - args.scale_s = scale_s; - - args.stride_o = stride_o; - args.nhead_stride_o = nhead_stride_o; - args.batch_stride_o = batch_stride_o; - - args.window_size_left = mask.left; - args.window_size_right = mask.right; - args.mask_type = static_cast(mask.type); - }; - - const auto init_traits = [&](auto& traits) { - traits.hdim_q = hdim_q; - traits.hdim_v = hdim_v; - traits.data_type = data_type; - traits.is_v_rowmajor = is_v_rowmajor; - traits.mask_type = mask.type; - }; - - fmha_jenga_fwd_traits fmha_traits; - init_traits(fmha_traits); - - fmha_jenga_fwd_args args; - init_args(args); - - sparge_jenga_fwd(fmha_traits, args, stream_config); - - o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes()); - - return Y; -} - -template ck_tile::HostTensor -jenga_sparge_attention(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - int, - bool, - bool, - int, - int, - int); - -template ck_tile::HostTensor -jenga_sparge_attention(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - int, - bool, - bool, - int, - int, - int); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparge_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.h deleted file mode 100644 index 6259fcc73cf..00000000000 --- a/example/ck_tile/50_sparse_attn/jenga_sparge_attention.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#pragma once -#include -#include -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" - -template -ck_tile::HostTensor -jenga_sparge_attention(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - const ck_tile::HostTensor& TV, - const ck_tile::HostTensor& Tblock_relation_onehot, - ck_tile::HostTensor& Y, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - bool i_perm, - bool o_perm, - int max_seqlen_q, - int max_seqlen_k, - int log_level = 0); diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp index fbd18b9ff24..a2df5bac569 100644 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp @@ -61,6 +61,57 @@ using bmap_fp16_problem = ck_tile::BlockFmhaPipelineProblem; using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel; +// ============================================================================ +// bf16: D=128, kM0=64, kN0=128 +// ============================================================================ + +using bmap_bf16_block_tile = ck_tile::sequence<64, 128, 128, 128, 128, 128>; + +using bmap_bf16_shape = + ck_tile::TileFmhaShape, + ck_tile::sequence<16, 16, 16>, + ck_tile::sequence<4, 1, 1>, + ck_tile::sequence<16, 16, 16>, + true>; + +using bmap_bf16_trait = ck_tile::TileFmhaTraits; + +using bmap_bf16_variant = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; +using bmap_bf16_mask = ck_tile::GenericAttentionMask; + +using bmap_bf16_problem = ck_tile::BlockFmhaPipelineProblem; + +using bmap_bf16_pipeline = ck_tile::SpargeBlockMapPipeline; +using bmap_bf16_kernel = ck_tile::SpargeBlockMapKernel; + // ============================================================================ // Dispatch // ============================================================================ @@ -81,8 +132,96 @@ float sparge_blockmap_fwd(sparge_blockmap_traits traits, s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); } + if(traits.data_type == "bf16" && traits.hdim_q == 128) + { + using k_ = bmap_bf16_kernel; + if(s.log_level_ > 0) + std::cout << ", sparge_blockmap_bf16_d128" << std::flush; + auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); + } + if(s.log_level_ > 0) std::cerr << "sparge_blockmap_fwd: unsupported config (data_type=" << traits.data_type << ", hdim_q=" << traits.hdim_q << ")" << std::endl; return -1.f; } + +// ============================================================================ +// Oneshot version: launches kernel without timing wrapper +// ============================================================================ + +void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& s) +{ + if(traits.data_type == "fp16" && traits.hdim_q == 128) + { + using k_ = bmap_fp16_kernel; + auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); + return; + } + + if(traits.data_type == "bf16" && traits.hdim_q == 128) + { + using k_ = bmap_bf16_kernel; + auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); + return; + } + + std::cerr << "sparge_blockmap_fwd_oneshot: unsupported config (data_type=" << traits.data_type + << ", hdim_q=" << traits.hdim_q << ")" << std::endl; +} + +// ============================================================================ +// Combined functions: blockmap + attention timed together via launch_kernel +// ============================================================================ + +float sparge_jenga_fwd(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a, + fmha_jenga_fwd_traits attn_t, fmha_jenga_fwd_args attn_a, + const ck_tile::stream_config& s) +{ + if(s.log_level_ > 0) + std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", fmha_jenga_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q + << std::flush; + + return ck_tile::launch_kernel( + s, + [=](const ck_tile::stream_config& s_) { + sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_); + }, + [=](const ck_tile::stream_config& s_) { + fmha_jenga_fwd_oneshot(attn_t, attn_a, s_); + }); +} + +float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a, + fmha_vsa_fwd_traits attn_t, fmha_vsa_fwd_args attn_a, + const ck_tile::stream_config& s) +{ + if(s.log_level_ > 0) + std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", fmha_vsa_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q + << std::flush; + + return ck_tile::launch_kernel( + s, + [=](const ck_tile::stream_config& s_) { + sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_); + }, + [=](const ck_tile::stream_config& s_) { + fmha_vsa_fwd_oneshot(attn_t, attn_a, s_); + }); +} diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp index 1e7e33248a2..6eaeb9ea77b 100644 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp @@ -91,3 +91,16 @@ auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args) float sparge_blockmap_fwd(sparge_blockmap_traits traits, sparge_blockmap_args args, const ck_tile::stream_config& stream_config); + +void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& stream_config); + +// Combined functions: blockmap + attention with unified timing +float sparge_jenga_fwd(sparge_blockmap_traits, sparge_blockmap_args, + fmha_jenga_fwd_traits, fmha_jenga_fwd_args, + const ck_tile::stream_config&); + +float sparge_vsa_fwd_combined(sparge_blockmap_traits, sparge_blockmap_args, + fmha_vsa_fwd_traits, fmha_vsa_fwd_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/50_sparse_attn/test_sparge.cpp b/example/ck_tile/50_sparse_attn/test_sparge.cpp new file mode 100644 index 00000000000..7c30a10b062 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/test_sparge.cpp @@ -0,0 +1,432 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// Unified test for Sparge pipeline: blockmap generation + sparse attention (Jenga/VSA). + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/reference/reference_blocked_attention.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" + +#include "fmha_fwd_trek.hpp" +#include "sparge_blockmap_trek.hpp" +#include "sparge_tool.hpp" + +// ============================================================================ +// Helpers +// ============================================================================ + +template +ck_tile::HostTensor +make_qkv_tensor(ck_tile::index_t batch, ck_tile::index_t nhead, ck_tile::index_t seqlen, ck_tile::index_t hdim, bool i_perm) +{ + if(i_perm) + return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); + return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); +} + +template +ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) +{ + auto lens = tensor.get_lengths(); + ck_tile::index_t batch = lens[0]; + ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; + ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; + ck_tile::index_t hdim = lens[3]; + + ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); + for(ck_tile::index_t b = 0; b < batch; ++b) + for(ck_tile::index_t h = 0; h < nhead; ++h) + for(ck_tile::index_t s = 0; s < seqlen; ++s) + for(ck_tile::index_t d = 0; d < hdim; ++d) + out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); + return out; +} + +template +auto get_error_tolerance() +{ + double rtol = 1e-2; + double atol = 4e-2; + if constexpr(std::is_same_v) + { + atol = 2e-1; + rtol = 2e-1; + } + return ck_tile::make_tuple(rtol, atol); +} + +template +float to_float_for_compare(T value) +{ + return static_cast(value); +} + +template <> +float to_float_for_compare(ck_tile::bf16_t value) +{ +#if CK_TILE_USE_CUSTOM_DATA_TYPE + return static_cast(value); +#else + return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); +#endif +} + +// ============================================================================ +// Arg parser +// ============================================================================ +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser + .insert("v", "1", "0:no validation, 1:cpu validation") + .insert("pipeline", "jenga", "attention pipeline: jenga / vsa") + .insert("b", "1", "batch size") + .insert("h", "4", "num of head for q") + .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") + .insert("s", "4096", "seqlen_q") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("topk", "0.3", "topk ratio for blockmap (fraction of K-blocks to keep)") + .insert("cdfthreshd", "-1", "CDF threshold for blockmap (overrides topk if >= 0)") + .insert("simthreshd1", "0.6", "similarity threshold for blockmap") + .insert("prec", "fp16", "data type: fp16/bf16") + .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") + .insert("operm", "1", "permute output") + .insert("seed", "42", "random seed") + .insert("warmup", "5", "warmup iterations") + .insert("repeat", "20", "benchmark iterations") + .insert("kname", "0", "print kernel name"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// ============================================================================ +// Main test +// ============================================================================ +template +bool run_test(const ck_tile::ArgParser& arg_parser) +{ + int do_validation = arg_parser.get_int("v"); + std::string pipeline = arg_parser.get_str("pipeline"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + float topk = arg_parser.get_float("topk"); + float cdfthreshd = arg_parser.get_float("cdfthreshd"); + float simthreshd1 = arg_parser.get_float("simthreshd1"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int kname = arg_parser.get_int("kname"); + + if(nhead_k < 0) nhead_k = nhead; + if(seqlen_k < 0) seqlen_k = seqlen_q; + if(hdim_v < 0) hdim_v = hdim_q; + + // If cdfthreshd >= 0, use CDF mode; otherwise use topk mode + if(cdfthreshd >= 0.0f) + topk = -1.0f; + + constexpr ck_tile::index_t BLKQ = 64; + constexpr ck_tile::index_t BLKK = 128; + + if(hdim_q != 128 || hdim_v != 128) + { + std::cout << "\n>>> TEST SKIPPED <<<\n" + << "Kernel instances are generated for hdim=128 only.\n"; + return true; + } + + ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; + ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; + + std::string prec_str = std::is_same_v ? "fp16" : "bf16"; + std::cout << "[" << pipeline << "|" << prec_str + << "] b=" << batch << " h=" << nhead << " s=" << seqlen_q + << " d=" << hdim_q << " topk=" << topk + << " sim1=" << simthreshd1 << std::flush; + + // ---- allocate host tensors ---- + auto q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); + auto k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); + auto v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); + auto output_host = o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) + : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); + + ck_tile::HostTensor block_map_host({batch, nhead, num_q_blocks, num_k_blocks}); + ck_tile::HostTensor lut_host({batch, nhead, num_q_blocks, num_k_blocks}); + ck_tile::HostTensor valid_block_num_host({batch, nhead, num_q_blocks}); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); + + // ---- device tensors ---- + ck_tile::DeviceMem q_dev(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_dev(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_dev(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_dev(output_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_map_dev(block_map_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lut_dev(lut_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem valid_bn_dev(valid_block_num_host.get_element_space_size_in_bytes()); + + q_dev.ToDevice(q_host.data()); + k_dev.ToDevice(k_host.data()); + v_dev.ToDevice(v_host.data()); + o_dev.SetZero(); + block_map_dev.SetZero(); + lut_dev.SetZero(); + valid_bn_dev.SetZero(); + + // ---- strides (BHSD when i_perm=true) ---- + auto q_strides = q_host.get_strides(); + auto k_strides = k_host.get_strides(); + auto v_strides = v_host.get_strides(); + auto o_strides = output_host.get_strides(); + + float scale_s = 1.0f / std::sqrt(static_cast(hdim_q)); + + // ---- build blockmap args ---- + sparge_blockmap_traits bmap_traits; + bmap_traits.data_type = std::is_same_v ? "fp16" : "bf16"; + bmap_traits.hdim_q = hdim_q; + + sparge_blockmap_args bmap_args; + bmap_args.q_ptr = q_dev.GetDeviceBuffer(); + bmap_args.k_ptr = k_dev.GetDeviceBuffer(); + bmap_args.batch = batch; + bmap_args.seqlen_q = seqlen_q; + bmap_args.seqlen_k = seqlen_k; + bmap_args.hdim_q = hdim_q; + bmap_args.nhead_q = nhead; + bmap_args.nhead_k = nhead_k; + bmap_args.stride_q = q_strides[i_perm ? 2 : 1]; + bmap_args.stride_k = k_strides[i_perm ? 2 : 1]; + bmap_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + bmap_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + bmap_args.batch_stride_q = q_strides[0]; + bmap_args.batch_stride_k = k_strides[0]; + bmap_args.simthreshd1 = simthreshd1; + bmap_args.cdfthreshd = (topk < 0.0f) ? cdfthreshd : -1.0f; + bmap_args.topk = topk; + bmap_args.scale = scale_s; + bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer(); + bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr; + bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr; + + // ---- build attention args ---- + ck_tile::stream_config stream_cfg; + stream_cfg.stream_id_ = nullptr; + stream_cfg.time_kernel_ = true; + stream_cfg.log_level_ = kname; + stream_cfg.cold_niters_ = warmup; + stream_cfg.nrepeat_ = repeat; + + float avg_ms = -1.0f; + + if(pipeline == "jenga") + { + fmha_jenga_fwd_traits attn_traits; + attn_traits.hdim_q = hdim_q; + attn_traits.hdim_v = hdim_v; + attn_traits.data_type = std::is_same_v ? "fp16" : "bf16"; + attn_traits.is_v_rowmajor = true; + attn_traits.mask_type = mask_enum::no_mask; + + fmha_jenga_fwd_args attn_args; + attn_args.q_ptr = q_dev.GetDeviceBuffer(); + attn_args.k_ptr = k_dev.GetDeviceBuffer(); + attn_args.v_ptr = v_dev.GetDeviceBuffer(); + attn_args.block_relation_onehot_ptr = block_map_dev.GetDeviceBuffer(); + attn_args.o_ptr = o_dev.GetDeviceBuffer(); + attn_args.seqlen_q = seqlen_q; + attn_args.seqlen_k = seqlen_k; + attn_args.batch = batch; + attn_args.max_seqlen_q = seqlen_q; + attn_args.hdim_q = hdim_q; + attn_args.hdim_v = hdim_v; + attn_args.nhead_q = nhead; + attn_args.nhead_k = nhead_k; + attn_args.scale_s = scale_s; + attn_args.stride_q = q_strides[i_perm ? 2 : 1]; + attn_args.stride_k = k_strides[i_perm ? 2 : 1]; + attn_args.stride_v = v_strides[i_perm ? 2 : 1]; + attn_args.stride_o = o_strides[o_perm ? 2 : 1]; + attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; + attn_args.batch_stride_q = q_strides[0]; + attn_args.batch_stride_k = k_strides[0]; + attn_args.batch_stride_v = v_strides[0]; + attn_args.batch_stride_o = o_strides[0]; + attn_args.window_size_left = -1; + attn_args.window_size_right = -1; + attn_args.mask_type = 0; + + avg_ms = sparge_jenga_fwd(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); + } + else if(pipeline == "vsa") + { + fmha_vsa_fwd_traits attn_traits; + attn_traits.hdim_q = hdim_q; + attn_traits.hdim_v = hdim_v; + attn_traits.data_type = std::is_same_v ? "fp16" : "bf16"; + attn_traits.is_v_rowmajor = true; + attn_traits.mask_type = mask_enum::no_mask; + + fmha_vsa_fwd_args attn_args; + attn_args.q_ptr = q_dev.GetDeviceBuffer(); + attn_args.k_ptr = k_dev.GetDeviceBuffer(); + attn_args.v_ptr = v_dev.GetDeviceBuffer(); + attn_args.lut_ptr = lut_dev.GetDeviceBuffer(); + attn_args.valid_block_num_ptr = valid_bn_dev.GetDeviceBuffer(); + attn_args.o_ptr = o_dev.GetDeviceBuffer(); + attn_args.seqlen_q = seqlen_q; + attn_args.seqlen_k = seqlen_k; + attn_args.batch = batch; + attn_args.max_seqlen_q = seqlen_q; + attn_args.hdim_q = hdim_q; + attn_args.hdim_v = hdim_v; + attn_args.nhead_q = nhead; + attn_args.nhead_k = nhead_k; + attn_args.scale_s = scale_s; + attn_args.stride_q = q_strides[i_perm ? 2 : 1]; + attn_args.stride_k = k_strides[i_perm ? 2 : 1]; + attn_args.stride_v = v_strides[i_perm ? 2 : 1]; + attn_args.stride_o = o_strides[o_perm ? 2 : 1]; + attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; + attn_args.batch_stride_q = q_strides[0]; + attn_args.batch_stride_k = k_strides[0]; + attn_args.batch_stride_v = v_strides[0]; + attn_args.batch_stride_o = o_strides[0]; + attn_args.window_size_left = -1; + attn_args.window_size_right = -1; + attn_args.mask_type = 0; + + avg_ms = sparge_vsa_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); + } + else + { + std::cerr << "Unknown pipeline: " << pipeline << " (use jenga or vsa)\n"; + return false; + } + + // ---- TFLOPS calculation (dense FMHA formula, so sparsity gains show as higher TFLOPS) ---- + std::size_t flop = static_cast(batch) * nhead * + (static_cast(2) * seqlen_q * seqlen_k * hdim_q + + static_cast(2) * seqlen_q * seqlen_k * hdim_v); + float tflops = (avg_ms > 0.f) ? static_cast(flop) / 1.E9f / avg_ms : 0.f; + + if(avg_ms > 0.f) + { + std::cout << std::fixed << ", " << std::setprecision(3) << avg_ms << " ms, " + << std::setprecision(2) << tflops << " TFlops" << std::flush; + } + + // ---- copy results back ---- + o_dev.FromDevice(output_host.data()); + block_map_dev.FromDevice(block_map_host.data()); + + // ---- count active blocks ---- + ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks; + ck_tile::index_t active_blocks = 0; + for(size_t i = 0; i < block_map_host.mData.size(); ++i) + if(block_map_host.mData[i]) + active_blocks++; + float actual_sparsity = 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << ", sparsity=" << std::setprecision(2) << actual_sparsity + << "(" << active_blocks << "/" << total_blocks << ")" << std::flush; + + // ---- validation ---- + bool pass = true; + if(do_validation) + { + auto q_ref = to_bhsd(q_host, i_perm); + auto k_ref = to_bhsd(k_host, i_perm); + auto v_ref = to_bhsd(v_host, i_perm); + + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); + ck_tile::reference_blocked_attention( + q_ref, k_ref, v_ref, block_map_host, output_ref, BLKQ, BLKK, scale_s); + + auto [rtol, atol] = get_error_tolerance(); + + float max_diff = 0.0f; + size_t num_errors = 0; + + auto output_host_bhsd = to_bhsd(output_host, o_perm); + for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i) + { + float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); + float ref_val = to_float_for_compare(output_ref.mData[i]); + float diff = std::abs(gpu_val - ref_val); + float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; + + max_diff = std::max(max_diff, diff); + + if(diff > atol && rel_diff > rtol) + num_errors++; + } + + pass = (num_errors == 0); + std::cout << ", " << (pass ? "PASS" : "FAIL") + << "(err=" << num_errors << "/" << output_host_bhsd.mData.size() + << " maxdiff=" << max_diff << ")"; + } + + std::cout << std::endl; + return pass; +} + +// ============================================================================ +// Main +// ============================================================================ +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + std::cerr << "Failed to parse arguments\n"; + return -1; + } + + std::string prec = arg_parser.get_str("prec"); + + bool test_result = false; + if(prec == "fp16") + { + test_result = run_test(arg_parser); + } + else if(prec == "bf16") + { + test_result = run_test(arg_parser); + } + else + { + std::cerr << "Unsupported precision: " << prec << "\n"; + return -1; + } + + return test_result ? 0 : -1; +} diff --git a/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp deleted file mode 100644 index 590e51db144..00000000000 --- a/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp +++ /dev/null @@ -1,422 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -// Demo: Sparge block-map -> Jenga sparse attention - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ck_tile/host.hpp" -#include "ck_tile/core.hpp" -#include "ck_tile/host/reference/reference_blocked_attention.hpp" -#include "ck_tile/core/utility/bit_cast.hpp" - -#include "jenga_sparge_attention.h" -#include "sparge_tool.hpp" - -// ============================================================================ -// Helper Functions -// ============================================================================ - -template -ck_tile::HostTensor make_qkv_tensor(ck_tile::index_t batch, - ck_tile::index_t nhead, - ck_tile::index_t seqlen, - ck_tile::index_t hdim, - bool i_perm) -{ - if(i_perm) - { - return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); - } - return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); -} - -template -ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) -{ - auto lens = tensor.get_lengths(); - ck_tile::index_t batch = lens[0]; - ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; - ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; - ck_tile::index_t hdim = lens[3]; - - ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t s = 0; s < seqlen; ++s) - { - for(ck_tile::index_t d = 0; d < hdim; ++d) - { - out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); - } - } - } - } - return out; -} - -template -auto get_error_tolerance() -{ - double rtol = 1e-2; - double atol = 4e-2; - if constexpr(std::is_same_v) - { - atol = 2e-1; - rtol = 2e-1; - } - return ck_tile::make_tuple(rtol, atol); -} - -template -float to_float_for_compare(T value) -{ - return static_cast(value); -} - -template <> -float to_float_for_compare(ck_tile::bf16_t value) -{ -#if CK_TILE_USE_CUSTOM_DATA_TYPE - return static_cast(value); -#else - return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); -#endif -} - -// ============================================================================ -// Command line argument parser -// ============================================================================ - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") - .insert("b", "1", "batch size") - .insert("h", "4", "num of head for q") - .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") - .insert("s", "4096", "seqlen_q") - .insert("s_k", "-1", "seqlen_k, -1 means equal to s") - .insert("d", "128", "head dim for q, k") - .insert("d_v", "-1", "head dim for v, -1 means equal to d") - .insert("prec", "fp16", "data type: fp16/bf16") - .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") - .insert("operm", "1", "permute output") - .insert("seed", "42", "random seed") - .insert("warmup", "5", "warmup iterations") - .insert("repeat", "20", "benchmark iterations") - .insert("kname", "0", "print kernel name") - // Sparge-specific - .insert("blkq", "64", "Sparge BLKQ") - .insert("blkk", "128", "Sparge BLKK") - .insert("simthreshd1", "0.6", "Sparge sim threshold") - .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") - .insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -// ============================================================================ -// Main Test Function -// ============================================================================ - -template -bool run_test(const ck_tile::ArgParser& arg_parser) -{ - int do_validation = arg_parser.get_int("v"); - ck_tile::index_t batch = arg_parser.get_int("b"); - ck_tile::index_t nhead = arg_parser.get_int("h"); - ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - bool i_perm = arg_parser.get_bool("iperm"); - bool o_perm = arg_parser.get_bool("operm"); - uint32_t seed = arg_parser.get_uint32("seed"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - int kname = arg_parser.get_int("kname"); - - // Sparge params - ck_tile::index_t blkq = arg_parser.get_int("blkq"); - ck_tile::index_t blkk = arg_parser.get_int("blkk"); - float simthreshd1 = arg_parser.get_float("simthreshd1"); - float cdfthreshd = arg_parser.get_float("cdfthreshd"); - float topk = arg_parser.get_float("topk"); - - if(nhead_k < 0) - nhead_k = nhead; - if(seqlen_k < 0) - seqlen_k = seqlen_q; - if(hdim_v < 0) - hdim_v = hdim_q; - - if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128) - { - std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; - std::cout << "Sparge Jenga kernel instances are generated for BLKQ=64, BLKK=128, " - "hdim_q=128, hdim_v=128 only." - << std::endl; - std::cout << "TEST SKIPPED" << std::endl; - return true; - } - - ck_tile::index_t BLKQ = blkq; - ck_tile::index_t BLKK = blkk; - - ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; - ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; - - std::cout << "============================================================" << std::endl; - std::cout << "[Sparge -> Jenga Sparse Attention Demo]" << std::endl; - std::cout << "============================================================" << std::endl; - std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k - << std::endl; - std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl; - std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl; - std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl; - std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks - << std::endl; - std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd - << ", topk=" << topk << ")" << std::endl; - std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; - - // Create host tensors - ck_tile::HostTensor q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); - ck_tile::HostTensor k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); - ck_tile::HostTensor v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); - ck_tile::HostTensor output_host = - o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) - : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); - ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); - - std::cout << "\nInitializing tensors..." << std::endl; - ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); - ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); - ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); - - // Build block map using Sparge tool - std::cout << "Building Sparge block map..." << std::endl; - sparge::SpargeParams p; - p.BLKQ = static_cast(BLKQ); - p.BLKK = static_cast(BLKK); - p.simthreshd1 = simthreshd1; - p.cdfthreshd = cdfthreshd; - p.topk = topk; - p.i_perm = i_perm; - - ck_tile::HostTensor block_relation_onehot = - sparge::build_block_map_meansim(q_host, k_host, p); - - // Print actual sparsity - std::size_t total_blocks = 0; - std::size_t active_blocks = 0; - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) - { - for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) - { - total_blocks++; - if(block_relation_onehot(b, h, qb, kb) != 0) - active_blocks++; - } - } - } - } - float actual_sparsity = - 1.0f - static_cast(active_blocks) / static_cast(total_blocks); - std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" - << total_blocks << " blocks active)" << std::endl; - - std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl; - - try - { - if(kname) - { - jenga_sparge_attention(q_host, - k_host, - v_host, - block_relation_onehot, - output_host, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - i_perm, - o_perm, - seqlen_q, - seqlen_k, - 1); - } - - for(int i = 0; i < warmup; ++i) - { - jenga_sparge_attention(q_host, - k_host, - v_host, - block_relation_onehot, - output_host, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - i_perm, - o_perm, - seqlen_q, - seqlen_k, - 0); - } - - [[maybe_unused]] auto sync_status1 = hipDeviceSynchronize(); - auto start = std::chrono::high_resolution_clock::now(); - - for(int i = 0; i < repeat; ++i) - { - jenga_sparge_attention(q_host, - k_host, - v_host, - block_relation_onehot, - output_host, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - i_perm, - o_perm, - seqlen_q, - seqlen_k, - 0); - } - - [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); - auto end = std::chrono::high_resolution_clock::now(); - double avg_time_ms = - std::chrono::duration(end - start).count() / repeat; - - std::cout << "\n>>>> Jenga sparse attention average time: " << avg_time_ms << " ms <<<<" - << std::endl; - } - catch(const std::exception& e) - { - std::cerr << "Error during kernel execution: " << e.what() << std::endl; - return false; - } - - bool pass = true; - if(do_validation) - { - std::cout << "\n--- Performing CPU validation ---" << std::endl; - float scale = 1.0f / std::sqrt(static_cast(hdim_q)); - - std::cout << "Computing reference output..." << std::endl; - auto q_ref = to_bhsd(q_host, i_perm); - auto k_ref = to_bhsd(k_host, i_perm); - auto v_ref = to_bhsd(v_host, i_perm); - - ck_tile::reference_blocked_attention( - q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale); - - auto [rtol, atol] = get_error_tolerance(); - - float max_diff = 0.0f; - float max_rel_diff = 0.0f; - std::size_t num_errors = 0; - - auto output_host_bhsd = to_bhsd(output_host, o_perm); - for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i) - { - float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); - float ref_val = to_float_for_compare(output_ref.mData[i]); - float diff = std::abs(gpu_val - ref_val); - float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; - - max_diff = std::max(max_diff, diff); - max_rel_diff = std::max(max_rel_diff, rel_diff); - - if(diff > atol && rel_diff > rtol) - { - num_errors++; - if(num_errors <= 5) - { - std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val - << ", Ref=" << ref_val << ", Diff=" << diff << std::endl; - } - } - } - - std::cout << "\nValidation results:" << std::endl; - std::cout << " Max absolute difference: " << max_diff << std::endl; - std::cout << " Max relative difference: " << max_rel_diff << std::endl; - std::cout << " Number of mismatches: " << num_errors << " / " - << output_host_bhsd.mData.size() << std::endl; - - if(num_errors == 0) - { - std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl; - } - else - { - std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl; - pass = false; - } - } - - std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl; - return pass; -} - -// ============================================================================ -// Main -// ============================================================================ - -int main(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - { - std::cerr << "Failed to parse arguments" << std::endl; - return -1; - } - - std::string prec = arg_parser.get_str("prec"); - - bool test_result = false; - if(prec == "fp16") - { - test_result = run_test(arg_parser); - } - else if(prec == "bf16") - { - test_result = run_test(arg_parser); - } - else - { - std::cerr << "Unsupported precision: " << prec << std::endl; - return -1; - } - - return test_result ? 0 : -1; -} diff --git a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp deleted file mode 100644 index 572b708f9ef..00000000000 --- a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp +++ /dev/null @@ -1,597 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention (all-in-device) - -#include -#include -#include -#include "ck_tile/host.hpp" -#include "ck_tile/core.hpp" -#include "ck_tile/host/reference/reference_blocked_attention.hpp" -#include "ck_tile/core/utility/bit_cast.hpp" - -#include "sparge_blockmap_trek.hpp" -#include "fmha_fwd_trek.hpp" -#include "sparge_tool.hpp" - -// ============================================================================ -// Helper Functions -// ============================================================================ - -template -ck_tile::HostTensor make_qkv_tensor(ck_tile::index_t batch, - ck_tile::index_t nhead, - ck_tile::index_t seqlen, - ck_tile::index_t hdim, - bool i_perm) -{ - if(i_perm) - { - return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); - } - return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); -} - -template -ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) -{ - auto lens = tensor.get_lengths(); - ck_tile::index_t batch = lens[0]; - ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; - ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; - ck_tile::index_t hdim = lens[3]; - - ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t s = 0; s < seqlen; ++s) - { - for(ck_tile::index_t d = 0; d < hdim; ++d) - { - out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); - } - } - } - } - return out; -} - -template -auto get_error_tolerance() -{ - double rtol = 1e-2; - double atol = 4e-2; - if constexpr(std::is_same_v) - { - atol = 2e-1; - rtol = 2e-1; - } - return ck_tile::make_tuple(rtol, atol); -} - -template -float to_float_for_compare(T value) -{ - return static_cast(value); -} - -template <> -float to_float_for_compare(ck_tile::bf16_t value) -{ -#if CK_TILE_USE_CUSTOM_DATA_TYPE - return static_cast(value); -#else - return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); -#endif -} - -// ============================================================================ -// Command line argument parser -// ============================================================================ - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") - .insert("b", "1", "batch size") - .insert("h", "4", "num of head for q") - .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") - .insert("s", "4096", "seqlen_q") - .insert("s_k", "-1", "seqlen_k, -1 means equal to s") - .insert("d", "128", "head dim for q, k") - .insert("d_v", "-1", "head dim for v, -1 means equal to d") - .insert("prec", "fp16", "data type: fp16/bf16") - .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") - .insert("operm", "1", "permute output") - .insert("seed", "42", "random seed") - .insert("warmup", "5", "warmup iterations") - .insert("repeat", "20", "benchmark iterations") - .insert("kname", "0", "print kernel name") - // Sparge-specific - .insert("blkq", "64", "Sparge BLKQ") - .insert("blkk", "128", "Sparge BLKK") - .insert("simthreshd1", "0.6", "Sparge sim threshold") - .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") - .insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -// ============================================================================ -// Main Test Function -// ============================================================================ - -template -bool run_test(const ck_tile::ArgParser& arg_parser) -{ - int do_validation = arg_parser.get_int("v"); - ck_tile::index_t batch = arg_parser.get_int("b"); - ck_tile::index_t nhead = arg_parser.get_int("h"); - ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - bool i_perm = arg_parser.get_bool("iperm"); - bool o_perm = arg_parser.get_bool("operm"); - uint32_t seed = arg_parser.get_uint32("seed"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - int kname = arg_parser.get_int("kname"); - - // Sparge params - ck_tile::index_t blkq = arg_parser.get_int("blkq"); - ck_tile::index_t blkk = arg_parser.get_int("blkk"); - float simthreshd1 = arg_parser.get_float("simthreshd1"); - float cdfthreshd = arg_parser.get_float("cdfthreshd"); - float topk = arg_parser.get_float("topk"); - - if(nhead_k < 0) - nhead_k = nhead; - if(seqlen_k < 0) - seqlen_k = seqlen_q; - if(hdim_v < 0) - hdim_v = hdim_q; - - if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128) - { - std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; - std::cout << "Sparge VSA kernel instances are generated for BLKQ=64, BLKK=128, " - "hdim_q=128, hdim_v=128 only." - << std::endl; - std::cout << "TEST SKIPPED" << std::endl; - return true; - } - - ck_tile::index_t BLKQ = blkq; - ck_tile::index_t BLKK = blkk; - - ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; - ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; - - std::cout << "============================================================" << std::endl; - std::cout << "[Sparge -> VSA Sparse Attention Demo]" << std::endl; - std::cout << "============================================================" << std::endl; - std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k - << std::endl; - std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl; - std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl; - std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl; - std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks - << std::endl; - std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd - << ", topk=" << topk << ")" << std::endl; - std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; - - // Create host tensors and fill with random data - ck_tile::HostTensor q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); - ck_tile::HostTensor k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); - ck_tile::HostTensor v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); - ck_tile::HostTensor output_host = - o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) - : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); - - std::cout << "\nInitializing tensors..." << std::endl; - ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); - ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); - ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); - - // ================================================================== - // Allocate device memory once, HtoD once - // ================================================================== - ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem o_buf(output_host.get_element_space_size_in_bytes()); - - q_buf.ToDevice(q_host.data()); - k_buf.ToDevice(k_host.data()); - v_buf.ToDevice(v_host.data()); - - const std::size_t bmap_bytes = - static_cast(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(uint8_t); - const std::size_t lut_bytes = - static_cast(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(int32_t); - const std::size_t valid_bytes = - static_cast(batch) * nhead * num_q_blocks * sizeof(int32_t); - - ck_tile::DeviceMem bmap_buf(bmap_bytes); - ck_tile::DeviceMem lut_buf(lut_bytes); - ck_tile::DeviceMem valid_buf(valid_bytes); - bmap_buf.SetZero(); - lut_buf.SetZero(); - valid_buf.SetZero(); - - // ================================================================== - // Common stride calculations - // ================================================================== - assert(nhead % nhead_k == 0); - const float scale_s = 1.0f / std::sqrt(static_cast(hdim_q)); - - const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead * hdim_q; - const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q; - const ck_tile::index_t stride_v = i_perm ? hdim_v : nhead_k * hdim_v; - const ck_tile::index_t stride_o = o_perm ? hdim_v : nhead * hdim_v; - const ck_tile::index_t nhead_stride_q = i_perm ? seqlen_q * hdim_q : hdim_q; - const ck_tile::index_t nhead_stride_k = i_perm ? seqlen_k * hdim_q : hdim_q; - const ck_tile::index_t nhead_stride_v = i_perm ? seqlen_k * hdim_v : hdim_v; - const ck_tile::index_t nhead_stride_o = o_perm ? seqlen_q * hdim_v : hdim_v; - const ck_tile::index_t batch_stride_q = nhead * seqlen_q * hdim_q; - const ck_tile::index_t batch_stride_k = nhead_k * seqlen_k * hdim_q; - const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * seqlen_k; - const ck_tile::index_t batch_stride_o = nhead * seqlen_q * hdim_v; - - std::string data_type = "fp16"; - if constexpr(std::is_same_v) - data_type = "bf16"; - - std::string msk_str = "0"; - mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); - - // ================================================================== - // GPU: Build block map + VSA LUT (always run, device-only) - // ================================================================== - std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl; - { - sparge_blockmap_args args; - args.q_ptr = q_buf.GetDeviceBuffer(); - args.k_ptr = k_buf.GetDeviceBuffer(); - args.batch = batch; - args.seqlen_q = seqlen_q; - args.seqlen_k = seqlen_k; - args.hdim_q = hdim_q; - args.nhead_q = nhead; - args.nhead_k = nhead_k; - args.stride_q = stride_q; - args.stride_k = stride_k; - args.nhead_stride_q = nhead_stride_q; - args.nhead_stride_k = nhead_stride_k; - args.batch_stride_q = batch_stride_q; - args.batch_stride_k = batch_stride_k; - args.simthreshd1 = simthreshd1; - args.cdfthreshd = cdfthreshd; - args.topk = topk; - args.scale = scale_s; - args.block_map_ptr = bmap_buf.GetDeviceBuffer(); - args.lut_ptr = lut_buf.GetDeviceBuffer(); - args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); - - sparge_blockmap_traits traits; - traits.data_type = data_type; - traits.hdim_q = hdim_q; - - sparge_blockmap_fwd(traits, args, ck_tile::stream_config{}); - } - - // ================================================================== - // VSA sparse attention kernel (always run, LUT stays on device) - // ================================================================== - std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; - - fmha_vsa_fwd_args fmha_args; - fmha_args.q_ptr = q_buf.GetDeviceBuffer(); - fmha_args.k_ptr = k_buf.GetDeviceBuffer(); - fmha_args.v_ptr = v_buf.GetDeviceBuffer(); - fmha_args.lut_ptr = lut_buf.GetDeviceBuffer(); - fmha_args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); - fmha_args.o_ptr = o_buf.GetDeviceBuffer(); - fmha_args.batch = batch; - fmha_args.seqlen_q = seqlen_q; - fmha_args.seqlen_k = seqlen_k; - fmha_args.max_seqlen_q = seqlen_q; - fmha_args.hdim_q = hdim_q; - fmha_args.hdim_v = hdim_v; - fmha_args.nhead_q = nhead; - fmha_args.nhead_k = nhead_k; - fmha_args.scale_s = scale_s; - fmha_args.stride_q = stride_q; - fmha_args.stride_k = stride_k; - fmha_args.stride_v = stride_v; - fmha_args.stride_o = stride_o; - fmha_args.nhead_stride_q = nhead_stride_q; - fmha_args.nhead_stride_k = nhead_stride_k; - fmha_args.nhead_stride_v = nhead_stride_v; - fmha_args.nhead_stride_o = nhead_stride_o; - fmha_args.batch_stride_q = batch_stride_q; - fmha_args.batch_stride_k = batch_stride_k; - fmha_args.batch_stride_v = batch_stride_v; - fmha_args.batch_stride_o = batch_stride_o; - fmha_args.window_size_left = mask.left; - fmha_args.window_size_right = mask.right; - fmha_args.mask_type = static_cast(mask.type); - - fmha_vsa_fwd_traits fmha_traits; - fmha_traits.hdim_q = hdim_q; - fmha_traits.hdim_v = hdim_v; - fmha_traits.data_type = data_type; - fmha_traits.is_v_rowmajor = true; - fmha_traits.mask_type = mask.type; - - ck_tile::stream_config stream_config{nullptr, - true, - /* log_level = */ kname ? 1 : 0, - warmup, - repeat, - false}; - - float avg_time_ms = sparge_vsa_fwd(fmha_traits, fmha_args, stream_config); - - std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<" - << std::endl; - - // DtoH: attention output (always needed) - o_buf.FromDevice(output_host.data(), output_host.get_element_space_size_in_bytes()); - - // DtoH: block_map (needed for sparsity stats and validation) - ck_tile::HostTensor block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks}); - bmap_buf.FromDevice(block_map_gpu.data(), bmap_bytes); - - // ================================================================== - // Sparsity statistics (pure CPU, reads block_map HostTensor) - // ================================================================== - std::size_t total_blocks = 0; - std::size_t active_blocks = 0; - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) - { - for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) - { - total_blocks++; - if(block_map_gpu(b, h, qb, kb) != 0) - active_blocks++; - } - } - } - } - float actual_sparsity = - 1.0f - static_cast(active_blocks) / static_cast(total_blocks); - std::cout << "\n Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" - << total_blocks << " blocks active)" << std::endl; - - // ================================================================== - // Validation (only when -v=1) - // ================================================================== - bool pass = true; - if(do_validation) - { - std::cout << "\n--- Performing CPU validation ---" << std::endl; - - // CPU golden: block map + VSA LUT - std::cout << "Building Sparge block map (CPU golden)..." << std::endl; - sparge::SpargeParams p; - p.BLKQ = static_cast(BLKQ); - p.BLKK = static_cast(BLKK); - p.simthreshd1 = simthreshd1; - p.cdfthreshd = cdfthreshd; - p.topk = topk; - p.i_perm = i_perm; - - ck_tile::HostTensor block_relation_onehot = - sparge::build_block_map_meansim(q_host, k_host, p); - - std::cout << "Converting block map to VSA LUT (delta, CPU)..." << std::endl; - auto vsa_lut_cpu = sparge::block_map_to_vsa_lut_delta(block_relation_onehot); - - // DtoH: LUT + valid_block_num (only for validation) - sparge::VSALut vsa_lut_gpu{ - ck_tile::HostTensor({batch, nhead, num_q_blocks, num_k_blocks}), - ck_tile::HostTensor({batch, nhead, num_q_blocks}), - }; - lut_buf.FromDevice(vsa_lut_gpu.lut.data(), lut_bytes); - valid_buf.FromDevice(vsa_lut_gpu.valid_block_num.data(), valid_bytes); - - // Validate block map - std::cout << "\n--- Validating GPU block map vs CPU golden ---" << std::endl; - { - std::size_t bmap_mismatches = 0; - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) - { - for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) - { - if(block_map_gpu(b, h, qb, kb) != block_relation_onehot(b, h, qb, kb)) - { - bmap_mismatches++; - if(bmap_mismatches <= 10) - { - std::cout - << " block_map mismatch at [" << b << "," << h << "," << qb - << "," << kb << "]: GPU=" - << static_cast(block_map_gpu(b, h, qb, kb)) << " CPU=" - << static_cast(block_relation_onehot(b, h, qb, kb)) - << std::endl; - } - } - } - } - } - } - std::cout << " Block map mismatches: " << bmap_mismatches << " / " - << (batch * nhead * num_q_blocks * num_k_blocks) << std::endl; - if(bmap_mismatches > 0) - { - std::cout << ">>> GPU BLOCK MAP VALIDATION FAILED <<<" << std::endl; - pass = false; - } - else - { - std::cout << ">>> GPU BLOCK MAP VALIDATION PASSED <<<" << std::endl; - } - } - - // Validate VSA LUT - std::cout << "\n--- Validating GPU VSA LUT vs CPU golden ---" << std::endl; - { - std::size_t lut_mismatches = 0; - std::size_t valid_mismatches = 0; - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) - { - if(vsa_lut_gpu.valid_block_num(b, h, qb) != - vsa_lut_cpu.valid_block_num(b, h, qb)) - { - valid_mismatches++; - if(valid_mismatches <= 5) - { - std::cout << " valid_block_num mismatch at [" << b << "," << h - << "," << qb - << "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb) - << " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb) - << std::endl; - } - } - for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) - { - if(vsa_lut_gpu.lut(b, h, qb, kb) != vsa_lut_cpu.lut(b, h, qb, kb)) - { - lut_mismatches++; - if(lut_mismatches <= 10) - { - std::cout - << " LUT mismatch at [" << b << "," << h << "," << qb - << "," << kb << "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb) - << " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb) << std::endl; - } - } - } - } - } - } - std::cout << " LUT mismatches: " << lut_mismatches << std::endl; - std::cout << " valid_block_num mismatches: " << valid_mismatches << std::endl; - if(lut_mismatches == 0 && valid_mismatches == 0) - { - std::cout << ">>> GPU VSA LUT VALIDATION PASSED <<<" << std::endl; - } - else - { - std::cout << ">>> GPU VSA LUT VALIDATION FAILED <<<" << std::endl; - pass = false; - } - } - - // Validate attention output - float scale = 1.0f / std::sqrt(static_cast(hdim_q)); - - std::cout << "\nComputing reference attention output..." << std::endl; - auto q_ref = to_bhsd(q_host, i_perm); - auto k_ref = to_bhsd(k_host, i_perm); - auto v_ref = to_bhsd(v_host, i_perm); - - ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); - ck_tile::reference_blocked_attention( - q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale); - - auto [rtol, atol] = get_error_tolerance(); - - float max_diff = 0.0f; - float max_rel_diff = 0.0f; - std::size_t num_errors = 0; - - auto output_host_bhsd = to_bhsd(output_host, o_perm); - for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i) - { - float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); - float ref_val = to_float_for_compare(output_ref.mData[i]); - float diff = std::abs(gpu_val - ref_val); - float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; - - max_diff = std::max(max_diff, diff); - max_rel_diff = std::max(max_rel_diff, rel_diff); - - if(diff > atol && rel_diff > rtol) - { - num_errors++; - if(num_errors <= 5) - { - std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val - << ", Ref=" << ref_val << ", Diff=" << diff << std::endl; - } - } - } - - std::cout << "\nAttention validation results:" << std::endl; - std::cout << " Max absolute difference: " << max_diff << std::endl; - std::cout << " Max relative difference: " << max_rel_diff << std::endl; - std::cout << " Number of mismatches: " << num_errors << " / " - << output_host_bhsd.mData.size() << std::endl; - - if(num_errors == 0) - { - std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl; - } - else - { - std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl; - pass = false; - } - } - - std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl; - return pass; -} - -// ============================================================================ -// Main -// ============================================================================ - -int main(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - { - std::cerr << "Failed to parse arguments" << std::endl; - return -1; - } - - std::string prec = arg_parser.get_str("prec"); - - bool test_result = false; - if(prec == "fp16") - { - test_result = run_test(arg_parser); - } - else if(prec == "bf16") - { - test_result = run_test(arg_parser); - } - else - { - std::cerr << "Unsupported precision: " << prec << std::endl; - return -1; - } - - return test_result ? 0 : -1; -} diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp index 67936c4353f..9fe8b365b00 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp @@ -318,26 +318,26 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga { if(!block_relation_onehot[i_total_loops]) { - i_total_loops++; - if(i_total_loops < num_total_loop) - { - // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); - k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); - - if(block_relation_onehot[i_total_loops]) - { - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), - k_dram_window, - number<-1>{}, - k_oob_ck, - k_pre_np); - } - move_tile_window(k_dram_window, {0, kK0}); - move_tile_window(v_dram_window, {0, kN0}); - continue; - } - break; + // scan-ahead: find the next active block in one shot + index_t next = i_total_loops + 1; + while(next < num_total_loop && !block_relation_onehot[next]) + next++; + if(next >= num_total_loop) + break; + const index_t delta = next - i_total_loops; + i_total_loops = next; + // jump K/V windows to the next active block + move_tile_window(k_dram_block_window, {kN0 * delta, 0}); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + move_tile_window(v_dram_window, {0, kN0 * delta}); + // immediately prefetch the active K tile + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + continue; } // STAGE 1, QK gemm From eca3cb3e0abdcb927a02f9b7b9ed786b9a9cdda2 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Fri, 24 Apr 2026 05:13:51 -0400 Subject: [PATCH 6/6] sparse_attn: add bm0 dispatch for sparge blockmap compatibility Add bm0 field to fmha_jenga_fwd_traits so callers can specify the preferred Q-tile size. Codegen now emits separate tile configs for bm0=64 (sparge blockmap) and bm0=128 (original), with CppConstraint guards to select the right kernel at runtime. End-to-end test passes for both jenga and vsa paths. Performance is known to be suboptimal at this stage; tile sizes and warp counts for the bm0=64 path have not been tuned. Co-Authored-By: Claude Opus 4.7 --- .../codegen/ops/fmha_fwd_jenga.py | 58 ++++++++++--------- .../codegen/ops/fmha_fwd_vsa.py | 58 ++++++++++--------- .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 2 +- .../ck_tile/50_sparse_attn/test_sparge.cpp | 2 + 4 files changed, 63 insertions(+), 57 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index 1f0a78048d9..fc4b8642ddd 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -690,12 +690,12 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128): [ - FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128 - 64, + FmhaFwdTileSize( # fmt: skip -- 128x128 tile (original, for old sparse attn test) 128, - 64, 128, - 64, + 32, + 128, + 32, 128, 4, 1, @@ -703,13 +703,36 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: 4, 1, 1, + 32, + 32, 16, + 32, + 32, 16, + -1, + CppConstraint("t.bm0 == 0 || t.bm0 == 128"), + ), + FmhaFwdTileSize( # fmt: skip -- 64x128 tile (for sparge blockmap kM0=64) + 64, + 128, + 32, + 128, + 32, + 128, + 2, + 1, + 1, + 2, + 1, + 1, + 32, + 32, 16, - 16, - 16, + 32, + 32, 16, -1, + CppConstraint("t.bm0 == 64"), ), FmhaFwdTileSize( # fmt: skip 16, @@ -774,27 +797,6 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: 16, -1, ), - FmhaFwdTileSize( # fmt: skip - 128, - 128, - 32, - 128, - 32, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 16, - 32, - 32, - 16, - -1, - ), ], # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], @@ -909,7 +911,7 @@ def get_fwd_blobs( for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if tile.F_bm0 != 64 or tile.F_bn0 != 128: + if tile.F_bm0 not in (64, 128) or tile.F_bn0 != 128: continue if pipeline.tag != "qr_async": continue diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index 217cfcfe2a4..208877037f1 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -690,12 +690,12 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128): [ - FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128 - 64, + FmhaFwdTileSize( # fmt: skip -- 128x128 tile (original, for old sparse attn test) 128, - 64, 128, - 64, + 32, + 128, + 32, 128, 4, 1, @@ -703,13 +703,36 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: 4, 1, 1, + 32, + 32, 16, + 32, + 32, 16, + -1, + CppConstraint("t.bm0 == 0 || t.bm0 == 128"), + ), + FmhaFwdTileSize( # fmt: skip -- 64x128 tile (for sparge blockmap kM0=64) + 64, + 128, + 32, + 128, + 32, + 128, + 2, + 1, + 1, + 2, + 1, + 1, + 32, + 32, 16, - 16, - 16, + 32, + 32, 16, -1, + CppConstraint("t.bm0 == 64"), ), FmhaFwdTileSize( # fmt: skip 16, @@ -774,27 +797,6 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: 16, -1, ), - FmhaFwdTileSize( # fmt: skip - 128, - 128, - 32, - 128, - 32, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 16, - 32, - 32, - 16, - -1, - ), ], # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], @@ -909,7 +911,7 @@ def get_fwd_blobs( for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if tile.F_bm0 != 64 or tile.F_bn0 != 128: + if tile.F_bm0 not in (64, 128) or tile.F_bn0 != 128: continue if pipeline.tag != "qr_async_vsa": continue diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 350d1803f66..62d40ffbe02 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -272,7 +272,7 @@ struct fmha_jenga_fwd_traits std::string data_type; bool is_v_rowmajor; mask_enum mask_type; - // TODO: padding check is inside this api + int bm0 = 0; // preferred Q-tile size; 0 = don't care (dispatch picks largest) }; float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/50_sparse_attn/test_sparge.cpp b/example/ck_tile/50_sparse_attn/test_sparge.cpp index 7c30a10b062..81a49ca006b 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge.cpp @@ -249,6 +249,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) attn_traits.data_type = std::is_same_v ? "fp16" : "bf16"; attn_traits.is_v_rowmajor = true; attn_traits.mask_type = mask_enum::no_mask; + attn_traits.bm0 = BLKQ; fmha_jenga_fwd_args attn_args; attn_args.q_ptr = q_dev.GetDeviceBuffer(); @@ -291,6 +292,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) attn_traits.data_type = std::is_same_v ? "fp16" : "bf16"; attn_traits.is_v_rowmajor = true; attn_traits.mask_type = mask_enum::no_mask; + attn_traits.bm0 = BLKQ; fmha_vsa_fwd_args attn_args; attn_args.q_ptr = q_dev.GetDeviceBuffer();