diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 6fc861b281fa..fe94d287efef 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -1081,7 +1081,7 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): @pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 24, 1024, 1024, 64), + (3, 48, 24, 1024, 1024, 64), (1, 24, 6, 8192, 8192, 64), (1, 4, 2, 16384, 16384, 128), (2, 16, 4, 1020, 987, 128),