diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu index cf904605132..ca0f960f46d 100644 --- a/ggml/src/ggml-cuda/argmax.cu +++ b/ggml/src/ggml-cuda/argmax.cu @@ -237,8 +237,8 @@ static __global__ void topk_f32( // Per-thread top-K heap (min-heap: smallest score at index 0) // Max K=32, stored in registers - float heap_val[32]; - int32_t heap_idx[32]; + float heap_val[64]; + int32_t heap_idx[64]; for (int i = 0; i < K; i++) { heap_val[i] = -FLT_MAX; heap_idx[i] = -1; @@ -527,7 +527,7 @@ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int cub_mode = ggml_cuda_dflash_cub_topk_mode(); const bool use_cub_topk = K > 1 && !output_logprob && seed == 0 && - (K > 32 || cub_mode == 1 || (cub_mode < 0 && nrows > 32)); + (K > 64 || cub_mode == 1 || (cub_mode < 0 && nrows > 32)); if (ggml_cuda_dflash_argmax_profile_enabled() && K > 1) { GGML_LOG_INFO("%s: dflash argmax profile path=%s K=%d nrows=%" PRId64 " vocab=%" PRId64 " temp=%.3f seed=%" PRIu64 " output_logprob=%d cub_mode=%d\n", @@ -554,7 +554,7 @@ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (K == 1) { argmax_f32<<>>(src0_d, dst_d, ne00, nrows, inv_temp, seed, output_logprob); } else { - GGML_ASSERT(K <= 32); + GGML_ASSERT(K <= 64); // Shared memory: K * n_warps floats + K * n_warps ints + 2 * n_warps floats (softmax) const int n_warps = (int)(num_threads / WARP_SIZE); const size_t smem_size = K * n_warps * (sizeof(float) + sizeof(int32_t)) + 2 * n_warps * sizeof(float);