Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ggml/src/ggml-cuda/argmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ static __global__ void topk_f32(

// Per-thread top-K heap (min-heap: smallest score at index 0)
// Max K=32, stored in registers
float heap_val[32];
int32_t heap_idx[32];
float heap_val[64];
int32_t heap_idx[64];
for (int i = 0; i < K; i++) {
heap_val[i] = -FLT_MAX;
heap_idx[i] = -1;
Expand Down Expand Up @@ -527,7 +527,7 @@ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int cub_mode = ggml_cuda_dflash_cub_topk_mode();
const bool use_cub_topk =
K > 1 && !output_logprob && seed == 0 &&
(K > 32 || cub_mode == 1 || (cub_mode < 0 && nrows > 32));
(K > 64 || cub_mode == 1 || (cub_mode < 0 && nrows > 32));
if (ggml_cuda_dflash_argmax_profile_enabled() && K > 1) {
GGML_LOG_INFO("%s: dflash argmax profile path=%s K=%d nrows=%" PRId64
" vocab=%" PRId64 " temp=%.3f seed=%" PRIu64 " output_logprob=%d cub_mode=%d\n",
Expand All @@ -554,7 +554,7 @@ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (K == 1) {
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows, inv_temp, seed, output_logprob);
} else {
GGML_ASSERT(K <= 32);
GGML_ASSERT(K <= 64);
// Shared memory: K * n_warps floats + K * n_warps ints + 2 * n_warps floats (softmax)
const int n_warps = (int)(num_threads / WARP_SIZE);
const size_t smem_size = K * n_warps * (sizeof(float) + sizeof(int32_t)) + 2 * n_warps * sizeof(float);
Expand Down
Loading