diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 90e6021dcb..e1b0a3b50d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1450,6 +1450,29 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } + else if constexpr(ActivationOperation == + Activation::swiglustep_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + gate = gate < 7.0f ? gate : 7.0f; + up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f; + c_thread_buf_fp32(cidx) = gate * up; + } } else { @@ -1511,6 +1534,21 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } + else if constexpr(ActivationOperation == + Activation::swiglustep_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + gate = gate < 7.0f ? gate : 7.0f; + up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f; + c_thread_buf_fp32(cidx) = gate * up; + } } else { @@ -1923,6 +1961,29 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } + else if constexpr(ActivationOperation == + Activation::swiglustep_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + gate = gate < 7.0f ? gate : 7.0f; + up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f; + c_thread_buf_fp32(cidx) = gate * up; + } } else { @@ -1984,6 +2045,21 @@ struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; } + else if constexpr(ActivationOperation == + Activation::swiglustep_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + gate = gate < 7.0f ? gate : 7.0f; + up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f; + c_thread_buf_fp32(cidx) = gate * up; + } } else {