From 5e2ffc6e261acbee87d65abff8fe94c928c8cfc2 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Mon, 8 Jul 2024 22:25:33 +0000 Subject: [PATCH] Reduce some autotune configs and keys to reduce runtime --- python/perf-kernels/flash-attention.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index e5e71e2889f9..6184d6eab60c 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -313,19 +313,14 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), - # TODO: This config fails with head_size not pow2 with data mismatches. Check why. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # Fall-back config. triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), ], key=[ - 'IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', + 'IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' ], use_cuda_graph=True,