diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 6184d6eab60c..3a8bfea059f2 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -319,10 +319,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), ], - key=[ - 'IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', - 'HQ', 'HK' - ], + key=['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'], use_cuda_graph=True, ) @triton.jit