From f24a8eacd8454c031c3f6400b355938e2ead13b3 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Wed, 18 Dec 2024 13:12:20 +0000 Subject: [PATCH] fix --- python/perf-kernels/flash-attention.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index f013dc54ab8b..7de377560521 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -63,7 +63,7 @@ def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): def set_persistent(self, persistent): self.persistent = persistent - + def set_int8_params(self, q_descale, k_descale, v_descale, p_scale, p_descale): self.int8 = True self.q_descale = q_descale @@ -441,12 +441,14 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr, INT8: tl.constexpr, USE_P_SCALE: tl.constexpr, INT8_KV: tl.constexpr): + start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) + if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) @@ -2220,19 +2222,14 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i] # x2 for 2 GEMMs flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2 - else: q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, args.layout) flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD - if causal: input_metadata.need_causal() -<<<<<<< HEAD input_metadata.set_persistent(args.persistent) -======= if int8: q, k, v = quantize_input(q, k, v, input_metadata, quantize_p=quantize_p, int8_kv=int8_kv) ->>>>>>> main_perf o = torch.empty_like(q) fn = lambda: attention(q, k, v, o, input_metadata) if mode == 'bwd': @@ -2243,6 +2240,8 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal total_flops = 2 * flops_per_matmul if causal: # total_flops *= 0.5 # normally, but we have to take into account the unequal seqlen_q/k + seqlen_q= N_CTX_Q + seqlen_k = N_CTX_K if seqlen_q > seqlen_k: total_flops *= seqlen_k / (2 * seqlen_q) else: