From 78134ef0e8a00706753ea0030b232042a755b8b3 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Mon, 11 Dec 2023 08:36:29 +0000 Subject: [PATCH] set smem size for repetition penalty kernel --- src/turbomind/kernels/sampling_penalty_kernels.cu | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/turbomind/kernels/sampling_penalty_kernels.cu b/src/turbomind/kernels/sampling_penalty_kernels.cu index 4877bdb1a0..f7ebfeff03 100644 --- a/src/turbomind/kernels/sampling_penalty_kernels.cu +++ b/src/turbomind/kernels/sampling_penalty_kernels.cu @@ -446,10 +446,16 @@ void invokeBatchApplyRepetitionPenalty(T* logits, dim3 grid(local_batch_size); size_t smem_size = step * (sizeof(float) + sizeof(int)); if (penalty_type == RepetitionPenaltyType::Additive) { + check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); batchApplyRepetitionPenalty<<>>( logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); } else if (penalty_type == RepetitionPenaltyType::Multiplicative) { + check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); batchApplyRepetitionPenalty<<>>( logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); }