Skip to content

Commit

Permalink
Revert "Cleanup to pass CI"
Browse files Browse the repository at this point in the history
This reverts commit 713846d.
  • Loading branch information
vgokhale committed May 13, 2024
1 parent 713846d commit 221ed7c
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def need_alibi(self, alibi_slopes, batch, nheads):
def need_causal(self):
self.causal = True

def need_dropout(self, dropout_p, return_encoded_softmax):
def need_dropout(dropout_p, return_encoded_softmax):
self.dropout_p = dropout_p
self.return_encoded_softmax = return_encoded_softmax

Expand All @@ -89,7 +89,7 @@ def check_args(self, q, k, v, o):
assert self.cu_seqlens_k is not None
assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)
# TODO: Remove once bias is supported with varlen
assert self.bias is None
assert self.bias == None
# TODO:Remove once dropout is supported with varlen
assert self.dropout_p == 0.0
assert not self.return_encoded_softmax
Expand Down Expand Up @@ -421,10 +421,13 @@ def attn_fwd(
else:
off_h_k = off_h_q

need_padding = False
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
need_padding = True
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
need_padding = True
n_extra_tokens = seqlen_k % BLOCK_N
PADDED_HEAD:tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL)

Expand All @@ -449,7 +452,7 @@ def attn_fwd(
alibi_slope = None

if ENABLE_DROPOUT:
batch_philox_offset = philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k
batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k
else:
batch_philox_offset = 0
# We can ask to return the dropout mask without actually doing any dropout. In
Expand Down Expand Up @@ -640,6 +643,7 @@ def _bwd_kernel_dk_dv(
MASK: tl.constexpr):
offs_m = start_m + tl.arange(0, BLOCK_M1)
offs_n = start_n + tl.arange(0, BLOCK_N1)
offs_k = tl.arange(0, BLOCK_DMODEL)
QT_block_ptr = tl.make_block_ptr(
base=Q,
shape=(BLOCK_DMODEL, N_CTX),
Expand Down Expand Up @@ -707,6 +711,7 @@ def _bwd_kernel_dq(dq, q, K, V,
MASK: tl.constexpr):
offs_m = start_m + tl.arange(0, BLOCK_M2)
offs_n = start_n + tl.arange(0, BLOCK_N2)
offs_k = tl.arange(0, BLOCK_DMODEL)
KT_block_ptr = tl.make_block_ptr(
base=K,
shape=(BLOCK_DMODEL, N_CTX),
Expand Down Expand Up @@ -790,13 +795,16 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes,
M += off_chz
D += off_chz

offs_k = tl.arange(0, BLOCK_DMODEL)

start_n = pid * BLOCK_N1
# This assignment is important. It is what allows us to pick the diagonal
# blocks. Later, when we want to do the lower triangular, we update start_m
# after the first dkdv call.
start_m = start_n

MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
offs_n = start_n + tl.arange(0, BLOCK_N1)

dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
Expand Down Expand Up @@ -1033,8 +1041,7 @@ def forward(ctx, q, k, v, o, metadata):
USE_ALIBI=False if metadata.alibi_slopes is None else True,
ENABLE_DROPOUT=metadata.dropout_p > 0.0,
RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax,
BATCH_SIZE= q.shape[0],
allow_flush_denorm=True
BATCH_SIZE= q.shape[0]
)

ctx.save_for_backward(q, k, v, o, M)
Expand Down Expand Up @@ -1065,14 +1072,16 @@ def backward(ctx, do, _):
dv = torch.empty_like(v)
BATCH, N_HEAD, N_CTX = q.shape[:3]
PRE_BLOCK = 128
NUM_WARPS, NUM_STAGES = 4, 1
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
assert N_CTX % PRE_BLOCK == 0
delta = torch.empty_like(M)
Lk = q.shape[-1], k.shape[-1], v.shape[-1]
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
padded_head = (Lk != ctx.BLOCK_DMODEL)
grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0])
_attn_bwd_preprocess[grid_preprocess](
o, do, delta,
Expand Down Expand Up @@ -1131,6 +1140,8 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype):
cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0, dtype=torch.int32)])
cu_seqlens_q = cu_seqlens_q.to(device="cuda")
cu_seqlens_k = cu_seqlens_k.to(device="cuda")
# -1 because the last entry of cu_seqlens_q specifies the end of the last seq
num_ctxs = len(cu_seqlens_q) - 1

# Initialize q, k, v with variable lengths
total_q = cu_seqlens_q[-1].item()
Expand Down Expand Up @@ -1173,6 +1184,8 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to
# for n heads the set of slopes is the geometric sequence that starts 2^(-8/n)
alibi_slopes = torch.tensor([2**(-8/HQ*i) for i in range(1, HQ+1)], dtype=torch.float32, device="cuda").repeat(Z, 1)
input_metadata.need_alibi(alibi_slopes, Z, HQ)
else:
alibi = None

o = torch.empty_like(q)

Expand Down Expand Up @@ -1488,7 +1501,7 @@ def run_benchmark(custom):
varlen = args.varlen
configs = []
if custom:
x_vals_list=[(args.b, args.hq, hk, args.sq, sk)]
x_vals_list=[(args.b, args.hq, args.hk, args.sq, args.sk)]
else:
if varlen:
x_vals_list = varlen_benchmark_configs()
Expand Down Expand Up @@ -1519,6 +1532,13 @@ def bench_flash_attention(
assert mode in ["fwd", "bwd"]
warmup = 25
rep = 100
# TODO: Enable bias after testing.
# if use_bias:
# bias = torch.randn((1, H, N_CTX, N_CTX), dtype=torch.float32, device="cuda")
# input_metadata.need_bias(bias, BATCH, H, N_CTX, N_CTX)
# else:
# bias = None
bias = None

# Bwd pass only supports causal=True right now
if mode == 'bwd':
Expand Down

0 comments on commit 221ed7c

Please sign in to comment.