From 4dd086f3ee1df95a4a6fba309dffaf94fc33d019 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Tue, 14 May 2024 20:03:48 +0000 Subject: [PATCH] Formatting --- python/perf-kernels/flash-attention.py | 232 +++++++++++++------------ 1 file changed, 122 insertions(+), 110 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 70197689e44a..aef1306ad1c2 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -26,6 +26,7 @@ import triton import triton.language as tl + class MetaData(): cu_seqlens_q = None cu_seqlens_k = None @@ -106,32 +107,38 @@ def check_args(self, q, k, v, o): assert self.layout is not None assert self.layout == 'thd' or not self.varlen + @triton.jit -def cdiv_fn(x,y): +def cdiv_fn(x, y): return (x + y - 1) // y + @triton.jit def max_fn(x, y): return tl.math.max(x, y) + @triton.jit def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): ms = tl.arange(0, m) ns = tl.arange(0, n) return philox_offset + ms[:, None] * stride + ns[None, :] + @triton.jit def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) # TODO: use tl.randint for better performance return tl.rand(philox_seed, rng_offsets) + @triton.jit def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) rng_keep = rng_output > dropout_p return rng_keep + @triton.jit def load_fn(block_ptr, first, second, pad): if first and second: @@ -144,6 +151,7 @@ def load_fn(block_ptr, first, second, pad): tensor = tl.load(block_ptr) return tensor + @triton.jit def print_gpu(prefix, val=None): if (tl.program_id(0) == 0) and ((tl.program_id(1) == 0) and (tl.program_id(2) == 0)): @@ -152,6 +160,7 @@ def print_gpu(prefix, val=None): else: tl.device_print(prefix) + @triton.jit def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix @@ -184,6 +193,7 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo else: return alibi_block + def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) @@ -276,47 +286,42 @@ def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_ encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) return acc, l_i, m_i + @triton.autotune( - configs=[ - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), - # TODO: This config fails with head_size not pow2 with data mismatches. Check why. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - ], - key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], - use_cuda_graph=True, + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + # TODO: This config fails with head_size not pow2 with data mismatches. Check why. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + ], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + use_cuda_graph=True, ) @triton.jit -def attn_fwd( - Q, K, V, bias, sm_scale, L, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - stride_oz, stride_oh, stride_om, stride_on, - stride_bz, stride_bh, stride_bm, stride_bn, - stride_az, stride_ah, - cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, encoded_softmax, - alibi_slopes, - HQ: tl.constexpr, HK:tl.constexpr, - ACTUAL_BLOCK_DMODEL:tl.constexpr, - MAX_SEQLENS_Q:tl.constexpr, MAX_SEQLENS_K:tl.constexpr, - VARLEN: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - BIAS_TYPE: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - USE_ALIBI: tl.constexpr -): +def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, + stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, + stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, + HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) @@ -591,6 +596,7 @@ def _attn_bwd_preprocess( else: tl.store(delta_ptrs, delta) + @triton.jit def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, # shared by Q/K/V/DO. @@ -642,6 +648,7 @@ def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0)) return dk, dv + @triton.jit def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, # shared by Q/K/V/DO. @@ -689,6 +696,7 @@ def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) return dq + @triton.jit def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, # shared by Q/K/V/DO. @@ -818,6 +826,7 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, empty = torch.empty(128, device="cuda") + # TODO: This can probably optimized to have fewer lines of code. def get_strides_from_layout(metadata, q, k, v, o): if metadata.layout == 'thd': @@ -845,7 +854,9 @@ def get_strides_from_layout(metadata, q, k, v, o): assert False, 'Got unsupported layout.' return batch, nheads_q, nheads_k, q_strides, k_strides, v_strides, o_strides + class _attention(torch.autograd.Function): + @staticmethod def forward(ctx, q, k, v, o, metadata): # NOTE: a large bias tensor leads to overflow during pointer arithmetic @@ -902,9 +913,8 @@ def forward(ctx, q, k, v, o, metadata): ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k, IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen, BLOCK_DMODEL=padded_d_model, BIAS_TYPE=0 if metadata.bias is None else 1, - USE_ALIBI=False if metadata.alibi_slopes is None else True, - ENABLE_DROPOUT=metadata.dropout_p > 0.0, - RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax) + USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p + > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax) ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid @@ -992,8 +1002,10 @@ def backward(ctx, do, _): return dq, dk, dv, None, None + attention = _attention.apply + def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout): torch.manual_seed(20) @@ -1016,6 +1028,7 @@ def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout): input_metadata.layout = layout return q, k, v, input_metadata + def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlens=False): torch.manual_seed(20) @@ -1023,11 +1036,11 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlen if not equal_seqlens: max_seqlens_q = N_CTX_Q // Z max_seqlens_k = N_CTX_K // Z - seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32) - seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32) + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32) else: - seqlens_q = torch.full((Z,), N_CTX_Q // Z) - seqlens_k = torch.full((Z,), N_CTX_K // Z) + seqlens_q = torch.full((Z, ), N_CTX_Q // Z) + seqlens_k = torch.full((Z, ), N_CTX_K // Z) # Calculate cumulative sequence lengths cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0, dtype=torch.int32)]) @@ -1046,6 +1059,7 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlen input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) return q, k, v, input_metadata + @pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ (4, 48, 24, 1024, 1024, 64), (1, 24, 6, 8192, 8192, 64), @@ -1112,6 +1126,7 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to # compare torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + @pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ (4, 48, 1024, 1024, 64), (4, 24, 8192, 8192, 64), @@ -1193,6 +1208,7 @@ def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): attention(q, k, v, tri_out, input_metadata) torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) + @pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), (4, 48, 12, 256, 64), (4, 48, 4, 512, 64), (4, 48, 2, 1024, 64), (8, 48, 6, 4096, 64), (4, 48, 8, 16384, 64), (4, 64, 16, 128, 128), (4, 64, 4, 4096, 128), @@ -1220,6 +1236,7 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 attention(q, k, v, tri_out, input_metadata) torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) + @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ (4, 48, 1024, 64), (4, 48, 2048, 64), @@ -1320,47 +1337,52 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sd torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL) torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL) + def nonvarlen_benchmark_configs(): - configs=[(16, 16, 16, 1024, 1024), - (8, 16, 16, 2048, 2048), - (4, 16, 16, 4096, 4096), - (2, 16, 16, 8192, 8192), - (1, 16, 16, 16384, 16384), - (2, 48, 48, 1024, 1024), - (2, 48, 48, 2048, 1024), - (2, 48, 48, 4096, 8192), - (2, 48, 48, 8192, 4096), - (2, 48, 48, 16384, 8192), - (8, 16, 16, 1989, 15344), - (4, 16, 16, 4097, 163), - (2, 16, 16, 8122, 2159), - (1, 16, 16, 16281, 7), - (2, 48, 48, 1021, 1020), - (2, 48, 48, 2001, 2048), - (2, 48, 48, 3996, 9639), - (2, 48, 48, 8181, 1021), - ] + configs = [ + (16, 16, 16, 1024, 1024), + (8, 16, 16, 2048, 2048), + (4, 16, 16, 4096, 4096), + (2, 16, 16, 8192, 8192), + (1, 16, 16, 16384, 16384), + (2, 48, 48, 1024, 1024), + (2, 48, 48, 2048, 1024), + (2, 48, 48, 4096, 8192), + (2, 48, 48, 8192, 4096), + (2, 48, 48, 16384, 8192), + (8, 16, 16, 1989, 15344), + (4, 16, 16, 4097, 163), + (2, 16, 16, 8122, 2159), + (1, 16, 16, 16281, 7), + (2, 48, 48, 1021, 1020), + (2, 48, 48, 2001, 2048), + (2, 48, 48, 3996, 9639), + (2, 48, 48, 8181, 1021), + ] return configs + def varlen_benchmark_configs(): - configs=[(2, 16, 4, 1024, 1024), - (8, 16, 2, 2048, 2048), - (4, 16, 8, 4096, 4096), - (2, 16, 4, 8192, 8192), - (2, 16, 8, 16384, 16384), - (2, 48, 12, 1024, 1024), - (2, 48, 24, 2048, 2048), - (2, 48, 8, 4096, 4096), - (2, 48, 4, 8192, 8192), - (2, 48, 2, 16384, 16384), - (2, 64, 32, 1024, 1024), - (4, 64, 16, 2048, 2048), - (4, 64, 8, 4096, 4096), - (4, 64, 32, 8192, 8192), - (4, 128, 16, 16384, 16384), - ] + configs = [ + (2, 16, 4, 1024, 1024), + (8, 16, 2, 2048, 2048), + (4, 16, 8, 4096, 4096), + (2, 16, 4, 8192, 8192), + (2, 16, 8, 16384, 16384), + (2, 48, 12, 1024, 1024), + (2, 48, 24, 2048, 2048), + (2, 48, 8, 4096, 4096), + (2, 48, 4, 8192, 8192), + (2, 48, 2, 16384, 16384), + (2, 64, 32, 1024, 1024), + (4, 64, 16, 2048, 2048), + (4, 64, 8, 4096, 4096), + (4, 64, 32, 8192, 8192), + (4, 128, 16, 16384, 16384), + ] return configs + def run_benchmark(custom, args): dtype = arg_to_torch_dtype[args.dtype] @@ -1368,12 +1390,12 @@ def run_benchmark(custom, args): sk = args.sq if not args.sk else args.sk head_size = 128 if not args.d else args.d mode = 'fwd' - x_names=['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] + x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] causal = args.causal varlen = args.layout == 'thd' configs = [] if custom: - x_vals_list=[(args.b, args.hq, hk, args.sq, sk)] + x_vals_list = [(args.b, args.hq, hk, args.sq, sk)] else: if varlen: x_vals_list = varlen_benchmark_configs() @@ -1381,26 +1403,14 @@ def run_benchmark(custom, args): x_vals_list = nonvarlen_benchmark_configs() print_time = args.return_time line_names = 'Time (ms)' if print_time else 'TFLOPS' - configs.append(triton.testing.Benchmark( - x_names=x_names, - x_vals=x_vals_list, - line_arg='provider', - line_vals=['triton'], - line_names=[line_names], - styles=[('red', '-')], - ylabel='ms', - plot_name=f'fused-attention-{mode}-d{head_size}-layout{args.layout}', - args={ - 'D_HEAD': head_size, - 'dtype': dtype, - 'causal': causal, - 'mode': mode}) - ) + configs.append( + triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'], + line_names=[line_names], styles=[('red', '-')], ylabel='ms', + plot_name=f'fused-attention-{mode}-d{head_size}-layout{args.layout}', + args={'D_HEAD': head_size, 'dtype': dtype, 'causal': causal, 'mode': mode})) @triton.testing.perf_report(configs) - def bench_flash_attention( - BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda" - ): + def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda"): assert mode in ["fwd", "bwd"] warmup = 25 rep = 100 @@ -1418,10 +1428,11 @@ def bench_flash_attention( flops_per_matmul = 0 if varlen: - q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, args.equal_seqlens) - for i in range (0, input_metadata.num_contexts): - seqlen_q = input_metadata.cu_seqlens_q[i+1] - input_metadata.cu_seqlens_q[i] - seqlen_k = input_metadata.cu_seqlens_k[i+1] - input_metadata.cu_seqlens_k[i] + q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, + args.equal_seqlens) + for i in range(0, input_metadata.num_contexts): + seqlen_q = input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i] + seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i] # x2 for 2 GEMMs flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2 else: @@ -1449,6 +1460,7 @@ def bench_flash_attention( bench_flash_attention.run(save_path=".", print_data=True) + def supported_layouts(): layouts = \ 'bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]' \ @@ -1457,6 +1469,7 @@ def supported_layouts(): 'This layout is sometimes called "varlen" or "grouped" layout.' return layouts + def parse_args(): parser = argparse.ArgumentParser( prog="Benchmark FlashAttention", @@ -1477,11 +1490,9 @@ def parse_args(): parser.add_argument("-layout", type=str, default='bhsd', help=supported_layouts()) return parser.parse_args() -arg_to_torch_dtype = { - 'fp16': torch.float16, - 'bf16': torch.bfloat16, - 'fp32': torch.float32 -} + +arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32} + def main(): args = parse_args() @@ -1500,5 +1511,6 @@ def main(): run_benchmark(custom_config, args) + if __name__ == '__main__': sys.exit(main())