From aa3aaf6d2ed0359ceb2f4428afe38705b69f7400 Mon Sep 17 00:00:00 2001 From: byshiue Date: Mon, 17 Oct 2022 14:04:45 +0800 Subject: [PATCH] Update stop_criteria_kernels.cu fix: fix bug of length_criterion --- .../kernels/stop_criteria_kernels.cu | 48 +++++++++---------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/src/fastertransformer/kernels/stop_criteria_kernels.cu b/src/fastertransformer/kernels/stop_criteria_kernels.cu index 5420e90e3..2c038c0f0 100644 --- a/src/fastertransformer/kernels/stop_criteria_kernels.cu +++ b/src/fastertransformer/kernels/stop_criteria_kernels.cu @@ -14,22 +14,13 @@ * limitations under the License. */ -#ifndef CUDART_VERSION -#error CUDART_VERSION Undefined! -#elif (CUDART_VERSION >= 11050) -#include -#else -#include "3rdparty/cub/cub.cuh" -#endif - #include "src/fastertransformer/kernels/stop_criteria_kernels.h" #include "src/fastertransformer/utils/cuda_utils.h" #include "src/fastertransformer/utils/memory_utils.h" +#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh" namespace fastertransformer { -constexpr int LENGTH_CRITERION_BLOCKSIZE = 256; - __global__ void stop_words_criterion(const int* output_ids, const int* parent_ids, const int* stop_words, @@ -100,6 +91,7 @@ void invokeStopWordsCriterion(const int* output_ids, int step, cudaStream_t stream) { + FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); // Check if we have sampled a word from the stop_words list. If so, stop the sequence. dim3 block, grid; block.x = min(((stop_words_len + 32 - 1) / 32) * 32, 256UL); @@ -119,22 +111,25 @@ __global__ void length_criterion(bool* finished, int beam_width, int step) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - const int index = blockIdx.x * blockDim.x + threadIdx.x; + int thread_finished_count = 0; + for (int index = threadIdx.x; index < batch_size * beam_width; index += blockDim.x) { + const int batch_idx = index / beam_width; - // const int beam_idx = index % beam_width; - const int batch_idx = index / beam_width; - - if (index >= batch_size * beam_width) { - return; + finished[index] |= step >= sequence_limit_length[batch_idx]; + thread_finished_count += finished[index] ? 1 : 0; } + int block_finished_count = 0; + if (blockDim.x <= 32) { + block_finished_count = warpReduceSum(thread_finished_count); + } + else { + block_finished_count = blockReduceSum(thread_finished_count); + } + __syncthreads(); - finished[index] |= step >= sequence_limit_length[batch_idx]; - - int agg = BlockReduce(temp_storage).Sum((int)finished[index]); - atomicAdd(finished_sum, agg); + if (threadIdx.x == 0) { + finished_sum[0] = block_finished_count; + } } void invokeLengthCriterion(bool* finished, @@ -148,15 +143,16 @@ void invokeLengthCriterion(bool* finished, { // Check if we have attained the sequence length limit. If so, stop the sequence. // In addition, check if all sequences are stopped and return the result in should_stop - dim3 block{LENGTH_CRITERION_BLOCKSIZE}; - dim3 grid{(batch_size * beam_width + block.x - 1) / block.x}; + FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + dim3 block{min(512, uint32_t(batch_size * beam_width))}; + dim3 grid{1}; length_criterion<<>>( finished, should_stop, finished_sum, sequence_limit_length, batch_size, beam_width, step); + sync_check_cuda_error(); int h_finished_sum = 0; cudaD2Hcpy(&h_finished_sum, finished_sum, 1); - *should_stop = h_finished_sum == batch_size * beam_width; }