Skip to content

Commit

Permalink
Merge main_perf
Browse files Browse the repository at this point in the history
  • Loading branch information
vgokhale committed May 14, 2024
2 parents 221ed7c + d2eeac6 commit 6754168
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,7 @@ def forward(ctx, q, k, v, o, metadata):
metadata.bias.stride(2), metadata.bias.stride(3))
else:
bias_strides = (0,0,0,0)

if metadata.alibi_slopes is not None:
alibi_strides = (metadata.alibi_slopes.stride(0), metadata.alibi_slopes.stride(1))
else:
Expand Down Expand Up @@ -1203,7 +1203,7 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to

scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale
if causal:
mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"),
mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"),
diagonal=N_CTX_K-N_CTX_Q)
scores[:, :, mask==0] = float("-inf")
if use_alibi:
Expand Down Expand Up @@ -1360,7 +1360,9 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16
@pytest.mark.parametrize('torch_sdpa_test', [False, True])
@pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('use_alibi', [False, True])
def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, dtype=torch.float16):
def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi,
dtype=torch.float16):
pytest.skip()
torch.manual_seed(20)
if qseqlen_not_equal_kseqlen is not None:
seqlen_q = qseqlen_not_equal_kseqlen
Expand Down Expand Up @@ -1407,10 +1409,10 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sd
M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if use_alibi:
p+= compute_alibi_tensor(alibi_slopes, N_CTX, N_CTX)
p+= compute_alibi_tensor(alibi_slopes, N_CTX, N_CTX)
if causal:
p[:, :, M == 0] = float("-inf")

p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
Expand Down

0 comments on commit 6754168

Please sign in to comment.