Skip to content

Commit

Permalink
Update stop_criteria_kernels.cu
Browse files Browse the repository at this point in the history
fix: fix bug of length_criterion
  • Loading branch information
byshiue authored Oct 17, 2022
1 parent bc077a9 commit aa3aaf6
Showing 1 changed file with 22 additions and 26 deletions.
48 changes: 22 additions & 26 deletions src/fastertransformer/kernels/stop_criteria_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,13 @@
* limitations under the License.
*/

#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif

#include "src/fastertransformer/kernels/stop_criteria_kernels.h"
#include "src/fastertransformer/utils/cuda_utils.h"
#include "src/fastertransformer/utils/memory_utils.h"
#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh"

namespace fastertransformer {

constexpr int LENGTH_CRITERION_BLOCKSIZE = 256;

__global__ void stop_words_criterion(const int* output_ids,
const int* parent_ids,
const int* stop_words,
Expand Down Expand Up @@ -100,6 +91,7 @@ void invokeStopWordsCriterion(const int* output_ids,
int step,
cudaStream_t stream)
{
FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
// Check if we have sampled a word from the stop_words list. If so, stop the sequence.
dim3 block, grid;
block.x = min(((stop_words_len + 32 - 1) / 32) * 32, 256UL);
Expand All @@ -119,22 +111,25 @@ __global__ void length_criterion(bool* finished,
int beam_width,
int step)
{
typedef cub::BlockReduce<int, LENGTH_CRITERION_BLOCKSIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

const int index = blockIdx.x * blockDim.x + threadIdx.x;
int thread_finished_count = 0;
for (int index = threadIdx.x; index < batch_size * beam_width; index += blockDim.x) {
const int batch_idx = index / beam_width;

// const int beam_idx = index % beam_width;
const int batch_idx = index / beam_width;

if (index >= batch_size * beam_width) {
return;
finished[index] |= step >= sequence_limit_length[batch_idx];
thread_finished_count += finished[index] ? 1 : 0;
}
int block_finished_count = 0;
if (blockDim.x <= 32) {
block_finished_count = warpReduceSum(thread_finished_count);
}
else {
block_finished_count = blockReduceSum(thread_finished_count);
}
__syncthreads();

finished[index] |= step >= sequence_limit_length[batch_idx];

int agg = BlockReduce(temp_storage).Sum((int)finished[index]);
atomicAdd(finished_sum, agg);
if (threadIdx.x == 0) {
finished_sum[0] = block_finished_count;
}
}

void invokeLengthCriterion(bool* finished,
Expand All @@ -148,15 +143,16 @@ void invokeLengthCriterion(bool* finished,
{
// Check if we have attained the sequence length limit. If so, stop the sequence.
// In addition, check if all sequences are stopped and return the result in should_stop
dim3 block{LENGTH_CRITERION_BLOCKSIZE};
dim3 grid{(batch_size * beam_width + block.x - 1) / block.x};
FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
dim3 block{min(512, uint32_t(batch_size * beam_width))};
dim3 grid{1};

length_criterion<<<grid, block, 0, stream>>>(
finished, should_stop, finished_sum, sequence_limit_length, batch_size, beam_width, step);
sync_check_cuda_error();

int h_finished_sum = 0;
cudaD2Hcpy(&h_finished_sum, finished_sum, 1);

*should_stop = h_finished_sum == batch_size * beam_width;
}

Expand Down

0 comments on commit aa3aaf6

Please sign in to comment.