Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
juuso-oskari committed Dec 18, 2024
1 parent 6fc40bb commit f24a8ea
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):

def set_persistent(self, persistent):
self.persistent = persistent

def set_int8_params(self, q_descale, k_descale, v_descale, p_scale, p_descale):
self.int8 = True
self.q_descale = q_descale
Expand Down Expand Up @@ -441,12 +441,14 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr, INT8: tl.constexpr,
USE_P_SCALE: tl.constexpr, INT8_KV: tl.constexpr):

start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)

if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
Expand Down Expand Up @@ -2220,19 +2222,14 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i]
# x2 for 2 GEMMs
flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2
else:
q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, args.layout)
flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD
if causal:
input_metadata.need_causal()
<<<<<<< HEAD

input_metadata.set_persistent(args.persistent)

=======
if int8:
q, k, v = quantize_input(q, k, v, input_metadata, quantize_p=quantize_p, int8_kv=int8_kv)
>>>>>>> main_perf
o = torch.empty_like(q)
fn = lambda: attention(q, k, v, o, input_metadata)
if mode == 'bwd':
Expand All @@ -2243,6 +2240,8 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
total_flops = 2 * flops_per_matmul
if causal:
# total_flops *= 0.5 # normally, but we have to take into account the unequal seqlen_q/k
seqlen_q= N_CTX_Q
seqlen_k = N_CTX_K
if seqlen_q > seqlen_k:
total_flops *= seqlen_k / (2 * seqlen_q)
else:
Expand Down

0 comments on commit f24a8ea

Please sign in to comment.