diff --git a/src/turbomind/kernels/sampling_penalty_kernels.cu b/src/turbomind/kernels/sampling_penalty_kernels.cu index 4877bdb1a0..f7ebfeff03 100644 --- a/src/turbomind/kernels/sampling_penalty_kernels.cu +++ b/src/turbomind/kernels/sampling_penalty_kernels.cu @@ -446,10 +446,16 @@ void invokeBatchApplyRepetitionPenalty(T* logits, dim3 grid(local_batch_size); size_t smem_size = step * (sizeof(float) + sizeof(int)); if (penalty_type == RepetitionPenaltyType::Additive) { + check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); batchApplyRepetitionPenalty<<>>( logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); } else if (penalty_type == RepetitionPenaltyType::Multiplicative) { + check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); batchApplyRepetitionPenalty<<>>( logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); }