diff --git a/python/perf-kernels/rmsnorm.py b/python/perf-kernels/rmsnorm.py index 4c6bbda135b9..127dce72c68b 100644 --- a/python/perf-kernels/rmsnorm.py +++ b/python/perf-kernels/rmsnorm.py @@ -37,28 +37,87 @@ def get_autotune_config(): return get_hip_autotune_config() +# accumulate sum of squares for a row in a blocked manner +@triton.jit +def accumulate_sum_squares(input_ptr, input_row_stride, n_cols, BLOCK_SIZE, row_idx): + col_offsets = tl.arange(0, BLOCK_SIZE) + sum_squares = tl.zeros([1], dtype=tl.float32) + row_input_ptr = input_ptr + row_idx * input_row_stride + + n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1 + for start in range(0, n_cols_blks * BLOCK_SIZE, BLOCK_SIZE): + cols = start + col_offsets + input_ptrs = row_input_ptr + cols + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + x = tl.load(input_ptrs) + sum_squares += tl.sum(x * x, axis=0) + + # loop peeling for mask + cols = n_cols_blks * BLOCK_SIZE + col_offsets + mask = cols < n_cols + input_ptrs = row_input_ptr + cols + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") + sum_squares += tl.sum(x * x, axis=0) + + return sum_squares + + +# apply normalization to each block of the row +@triton.jit +def apply_normalization(input_ptr, output_ptr, g_ptr, input_row_stride, output_row_stride, n_cols, norm_factor, + BLOCK_SIZE, row_idx): + col_offsets = tl.arange(0, BLOCK_SIZE) + row_input_ptr = input_ptr + row_idx * input_row_stride + row_output_ptr = output_ptr + row_idx * output_row_stride + + for start in range(0, n_cols, BLOCK_SIZE): + cols = start + col_offsets + mask = cols < n_cols + input_ptrs = row_input_ptr + cols + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + g_ptrs = g_ptr + cols + output_ptrs = row_output_ptr + cols + x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") + g = tl.load(g_ptrs, mask=mask, other=0.0) + rms_norm = x * norm_factor * g + tl.store(output_ptrs, rms_norm, mask=mask) + + +# Main kernel with both blocked and non-blocked versions based on BLOCK_SIZE @triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) @triton.jit def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon, - BLOCK_SIZE: tl.constexpr, NUM_PRGMS: tl.constexpr): - row_start = tl.program_id(0) + BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr): + row_idx = tl.program_id(0) # Each program instance handles one row col_offsets = tl.arange(0, BLOCK_SIZE) - mask = col_offsets < n_cols - tl.assume(input_row_stride >= 0) - tl.assume(output_row_stride >= 0) - for row_idx in tl.range(row_start, n_rows, NUM_PRGMS): + + if USE_BLOCKED: + # Blocked Approach: Accumulate sum of squares and normalize in chunks + sum_squares = accumulate_sum_squares(input_ptr, input_row_stride, n_cols, BLOCK_SIZE, row_idx) + mean_square = sum_squares / n_cols + norm_factor = tl.rsqrt(mean_square + epsilon) + + # Apply normalization + apply_normalization(input_ptr, output_ptr, g_ptr, input_row_stride, output_row_stride, n_cols, norm_factor, + BLOCK_SIZE, row_idx) + + else: + mask = col_offsets < n_cols + tl.assume(input_row_stride >= 0) + tl.assume(output_row_stride >= 0) row_start_ptr = input_ptr + row_idx * input_row_stride input_ptrs = row_start_ptr + col_offsets input_ptrs = tl.multiple_of(input_ptrs, (16, )) row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0) - row_norm = row * row #square each value - row_norm = tl.sum(row_norm, axis=-1) #sum across columns(axis=-1) - row_norm = row_norm / n_cols #divide by n_cols - row_norm = row_norm + epsilon #add epsilon - row_norm = tl.rsqrt(row_norm) #take rsqrt, this is normalization value - rms_norm = row * row_norm #multiply each x by normalization value - rms_norm = rms_norm * g #element wise multiplication with g + row_norm = row * row + row_norm = tl.sum(row_norm, axis=-1) + row_norm = row_norm / n_cols + row_norm = row_norm + epsilon + row_norm = tl.rsqrt(row_norm) + rms_norm = row * row_norm + rms_norm = rms_norm * g output_row_start_ptr = output_ptr + row_idx * output_row_stride output_ptrs = output_row_start_ptr + col_offsets @@ -68,10 +127,12 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride def triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, epsilon=1e-6): BLOCK_SIZE = blk_size + # Use blocked approach if BLOCK_SIZE larger than 65536 // x.element_size() + USE_BLOCKED = n_cols > BLOCK_SIZE NUM_PRGMS = n_rows grid = lambda meta: (NUM_PRGMS, ) - rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE, NUM_PRGMS) + rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE, USE_BLOCKED, NUM_PRGMS) return y @@ -93,6 +154,8 @@ def torch_rmsnorm(x, g): (8192, 4096), (4096, 8192), (1, 8192), + (1, 31744), + (3, 65536), (873, 1245), ]) def test_rmsnorm(M, N): @@ -100,7 +163,8 @@ def test_rmsnorm(M, N): x = torch.randn(M, N, device='cuda') y = torch.zeros_like(x, device='cuda') n_rows, n_cols = x.shape - blk_size = triton.next_power_of_2(n_cols) + MAX_FUSED_SIZE = 65536 // x.element_size() + blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols)) g = torch.ones((1, N), device='cuda') y_triton = triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size) @@ -153,7 +217,8 @@ def benchmark(M, N, provider): x = torch.randn(M, N, device='cuda', dtype=dtype) y = torch.zeros_like(x, device='cuda') n_rows, n_cols = x.shape - blk_size = triton.next_power_of_2(n_cols) + MAX_FUSED_SIZE = 65536 // x.element_size() + blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols)) stream = torch.cuda.Stream() torch.cuda.set_stream(stream) g = torch.ones((1, N), device='cuda') @@ -199,7 +264,8 @@ def main(): x = torch.randn(args.M_start, args.N_start, device='cuda') y = torch.zeros_like(x, device='cuda') n_rows, n_cols = x.shape - blk_size = triton.next_power_of_2(n_cols) + MAX_FUSED_SIZE = 65536 // x.element_size() + blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols)) g = torch.ones((1, args.N_start), device='cuda') triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size) else: