Skip to content

Commit

Permalink
add blocked version to address performance issue of when N is large (#…
Browse files Browse the repository at this point in the history
…672)

* add blocked version to address performance issue of when N is large

* use_blocked should be based on size of N

* block size should be size of total elements

* correct comments

* remove fake persistent loop
  • Loading branch information
xiaohuguo2023 authored Dec 6, 2024
1 parent 27a1b5b commit 736071f
Showing 1 changed file with 83 additions and 17 deletions.
100 changes: 83 additions & 17 deletions python/perf-kernels/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,28 +37,87 @@ def get_autotune_config():
return get_hip_autotune_config()


# accumulate sum of squares for a row in a blocked manner
@triton.jit
def accumulate_sum_squares(input_ptr, input_row_stride, n_cols, BLOCK_SIZE, row_idx):
col_offsets = tl.arange(0, BLOCK_SIZE)
sum_squares = tl.zeros([1], dtype=tl.float32)
row_input_ptr = input_ptr + row_idx * input_row_stride

n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1
for start in range(0, n_cols_blks * BLOCK_SIZE, BLOCK_SIZE):
cols = start + col_offsets
input_ptrs = row_input_ptr + cols
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
x = tl.load(input_ptrs)
sum_squares += tl.sum(x * x, axis=0)

# loop peeling for mask
cols = n_cols_blks * BLOCK_SIZE + col_offsets
mask = cols < n_cols
input_ptrs = row_input_ptr + cols
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
sum_squares += tl.sum(x * x, axis=0)

return sum_squares


# apply normalization to each block of the row
@triton.jit
def apply_normalization(input_ptr, output_ptr, g_ptr, input_row_stride, output_row_stride, n_cols, norm_factor,
BLOCK_SIZE, row_idx):
col_offsets = tl.arange(0, BLOCK_SIZE)
row_input_ptr = input_ptr + row_idx * input_row_stride
row_output_ptr = output_ptr + row_idx * output_row_stride

for start in range(0, n_cols, BLOCK_SIZE):
cols = start + col_offsets
mask = cols < n_cols
input_ptrs = row_input_ptr + cols
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
g_ptrs = g_ptr + cols
output_ptrs = row_output_ptr + cols
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
g = tl.load(g_ptrs, mask=mask, other=0.0)
rms_norm = x * norm_factor * g
tl.store(output_ptrs, rms_norm, mask=mask)


# Main kernel with both blocked and non-blocked versions based on BLOCK_SIZE
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
@triton.jit
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon,
BLOCK_SIZE: tl.constexpr, NUM_PRGMS: tl.constexpr):
row_start = tl.program_id(0)
BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr):
row_idx = tl.program_id(0) # Each program instance handles one row
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
tl.assume(input_row_stride >= 0)
tl.assume(output_row_stride >= 0)
for row_idx in tl.range(row_start, n_rows, NUM_PRGMS):

if USE_BLOCKED:
# Blocked Approach: Accumulate sum of squares and normalize in chunks
sum_squares = accumulate_sum_squares(input_ptr, input_row_stride, n_cols, BLOCK_SIZE, row_idx)
mean_square = sum_squares / n_cols
norm_factor = tl.rsqrt(mean_square + epsilon)

# Apply normalization
apply_normalization(input_ptr, output_ptr, g_ptr, input_row_stride, output_row_stride, n_cols, norm_factor,
BLOCK_SIZE, row_idx)

else:
mask = col_offsets < n_cols
tl.assume(input_row_stride >= 0)
tl.assume(output_row_stride >= 0)
row_start_ptr = input_ptr + row_idx * input_row_stride
input_ptrs = row_start_ptr + col_offsets
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0)
row_norm = row * row #square each value
row_norm = tl.sum(row_norm, axis=-1) #sum across columns(axis=-1)
row_norm = row_norm / n_cols #divide by n_cols
row_norm = row_norm + epsilon #add epsilon
row_norm = tl.rsqrt(row_norm) #take rsqrt, this is normalization value
rms_norm = row * row_norm #multiply each x by normalization value
rms_norm = rms_norm * g #element wise multiplication with g
row_norm = row * row
row_norm = tl.sum(row_norm, axis=-1)
row_norm = row_norm / n_cols
row_norm = row_norm + epsilon
row_norm = tl.rsqrt(row_norm)
rms_norm = row * row_norm
rms_norm = rms_norm * g

output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
Expand All @@ -68,10 +127,12 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride

def triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, epsilon=1e-6):
BLOCK_SIZE = blk_size
# Use blocked approach if BLOCK_SIZE larger than 65536 // x.element_size()
USE_BLOCKED = n_cols > BLOCK_SIZE

NUM_PRGMS = n_rows
grid = lambda meta: (NUM_PRGMS, )
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE, NUM_PRGMS)
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE, USE_BLOCKED, NUM_PRGMS)

return y

Expand All @@ -93,14 +154,17 @@ def torch_rmsnorm(x, g):
(8192, 4096),
(4096, 8192),
(1, 8192),
(1, 31744),
(3, 65536),
(873, 1245),
])
def test_rmsnorm(M, N):
torch.manual_seed(0)
x = torch.randn(M, N, device='cuda')
y = torch.zeros_like(x, device='cuda')
n_rows, n_cols = x.shape
blk_size = triton.next_power_of_2(n_cols)
MAX_FUSED_SIZE = 65536 // x.element_size()
blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
g = torch.ones((1, N), device='cuda')
y_triton = triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size)

Expand Down Expand Up @@ -153,7 +217,8 @@ def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=dtype)
y = torch.zeros_like(x, device='cuda')
n_rows, n_cols = x.shape
blk_size = triton.next_power_of_2(n_cols)
MAX_FUSED_SIZE = 65536 // x.element_size()
blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
g = torch.ones((1, N), device='cuda')
Expand Down Expand Up @@ -199,7 +264,8 @@ def main():
x = torch.randn(args.M_start, args.N_start, device='cuda')
y = torch.zeros_like(x, device='cuda')
n_rows, n_cols = x.shape
blk_size = triton.next_power_of_2(n_cols)
MAX_FUSED_SIZE = 65536 // x.element_size()
blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
g = torch.ones((1, args.N_start), device='cuda')
triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size)
else:
Expand Down

0 comments on commit 736071f

Please sign in to comment.