Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable MQA/GQA in backward #100

Merged
merged 21 commits into from
Nov 15, 2024
Merged
108 changes: 55 additions & 53 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,8 @@ def _bwd_kernel_one_col_block(
stride_deltaz,
stride_deltah,
stride_deltam,
Z,
H,
N_CTX_Q,
N_CTX_K,
off_h,
off_z,
off_hz,
start_n,
num_block_m,
num_block_n,
Expand All @@ -129,6 +124,7 @@ def _bwd_kernel_one_col_block(
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
USE_EXP2: tl.constexpr,
GROUP_SIZE: tl.constexpr,
):
if CAUSAL:
# TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M
Expand All @@ -153,8 +149,8 @@ def _bwd_kernel_one_col_block(
# load k and v once per column block
k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
k = tl.load(k_ptrs, mask=kv_mask, other=0.0).to(tl.float32)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0).to(tl.float32)

# loop over rows
for start_m in range(lo, num_block_m * BLOCK_M, BLOCK_M):
Expand All @@ -168,8 +164,8 @@ def _bwd_kernel_one_col_block(
q_mask = mask_m[:, None] & mask_d[None, :]

# load q, k, v, do on-chip
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
do = tl.load(do_ptrs, mask=q_mask, other=0.0)
q = tl.load(q_ptrs, mask=q_mask, other=0.0).to(tl.float32)
do = tl.load(do_ptrs, mask=q_mask, other=0.0).to(tl.float32)

# recompute p = softmax(qk, dim=-1).T
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
Expand All @@ -196,9 +192,10 @@ def _bwd_kernel_one_col_block(
# mask block in the cases where the data is smaller the block size
p_mask = mask_m[:, None] & mask_n[None, :]
p = tl.where(p_mask, p, 0.0)
p = p.to(tl.float32)

# compute dv
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
dv += tl.dot(tl.trans(p), do)

# compute dp
dp = tl.dot(do, tl.trans(v))
Expand All @@ -207,7 +204,7 @@ def _bwd_kernel_one_col_block(
d_ptrs = d_offset + offs_m * stride_deltam
Di = tl.load(d_ptrs, mask=mask_m)
ds = (p * (dp - Di[:, None])) * sm_scale
ds = tl.where(p_mask, ds, 0.0).to(Q.dtype.element_ty)
ds = tl.where(p_mask, ds, 0.0)

# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds), q)
Expand All @@ -225,8 +222,13 @@ def _bwd_kernel_one_col_block(
dv_ptrs = dv_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk

# write-back
tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask)
tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask)
if GROUP_SIZE != 1:
# use atomic_add to properly accumulate gradients from multiple query heads
tl.atomic_add(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask)
tl.atomic_add(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask)
else:
tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask)
tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask)

@triton.jit
def _bwd_kernel(
Expand Down Expand Up @@ -258,7 +260,8 @@ def _bwd_kernel(
stride_deltah,
stride_deltam,
Z,
H,
HQ,
HK,
num_block_m,
num_block_n,
cu_seqlens_q,
Expand All @@ -275,11 +278,17 @@ def _bwd_kernel(
IS_VARLEN: tl.constexpr,
):
# program ids
off_hz = tl.program_id(0)
off_zh = tl.program_id(0)
if SEQUENCE_PARALLEL:
start_n = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
off_z = off_zh // HQ
off_hq = off_zh % HQ

GROUP_SIZE = HQ // HK
if GROUP_SIZE != 1:
off_hk = off_hq // GROUP_SIZE
else:
off_hk = off_hq

if IS_VARLEN:
# Compute sequence lengths for the current batch
Expand All @@ -299,20 +308,20 @@ def _bwd_kernel(


# input tensor offsets
q_offset = Q + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
k_offset = K + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
v_offset = V + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
do_offset = DO + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
l_offset = L + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
d_offset = D + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
q_offset = Q + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm
k_offset = K + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn
v_offset = V + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn
do_offset = DO + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm
l_offset = L + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam
d_offset = D + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam

# output tensor offsets
dk_offset = DK + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
dv_offset = DV + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
dk_offset = DK + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn
dv_offset = DV + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn
if SEQUENCE_PARALLEL:
dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm
else:
dq_offset = DQ + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
dq_offset = DQ + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm

# inner loop
if SEQUENCE_PARALLEL:
Expand Down Expand Up @@ -353,13 +362,8 @@ def _bwd_kernel(
stride_deltaz,
stride_deltah,
stride_deltam,
Z,
H,
N_CTX_Q,
N_CTX_K,
off_h,
off_z,
off_hz,
start_n,
num_block_m,
num_block_n,
Expand All @@ -370,6 +374,7 @@ def _bwd_kernel(
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
USE_EXP2=USE_EXP2,
GROUP_SIZE=GROUP_SIZE
)
else:
for start_n in range(0, num_block_n):
Expand Down Expand Up @@ -410,13 +415,8 @@ def _bwd_kernel(
stride_deltaz,
stride_deltah,
stride_deltam,
Z,
H,
N_CTX_Q,
N_CTX_K,
off_h,
off_z,
off_hz,
start_n,
num_block_m,
num_block_n,
Expand All @@ -427,6 +427,7 @@ def _bwd_kernel(
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
USE_EXP2=USE_EXP2,
GROUP_SIZE=GROUP_SIZE
)


Expand Down Expand Up @@ -454,7 +455,7 @@ def attention_prefill_backward_triton_impl(
):
if DEBUG:
print()
print("attention_prefill_backward_triton_new_impl")
print("attention_prefill_backward_triton_impl")
print("do:", do, do.shape)
print("q:", q, q.shape)
print("k:", k, k.shape)
Expand Down Expand Up @@ -488,7 +489,6 @@ def attention_prefill_backward_triton_impl(
stride_kz, stride_kh, stride_kn, stride_kk = k_strides
stride_vz, stride_vh, stride_vn, stride_vk = v_strides
stride_oz, stride_oh, stride_om, stride_ok = o_strides
batch_headsize = batch * nheads_q
is_varlen = layout == "thd"

# FIXME: some configs lead to oom for some reason when using 64 x 64 blocks
Expand Down Expand Up @@ -538,22 +538,30 @@ def attention_prefill_backward_triton_impl(

# deal with dk, dv
if (dk is None) or (dv is None):
dk = torch.empty_like(k)
dv = torch.empty_like(v)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
else:
# store og
dk_og = dk
dv_og = dv


if (not dk.is_contiguous()):
dk_og = dk
dk = dk.contiguous()
copy_back["dk"] = True

if (not dv.is_contiguous()):
dv_og = dv
dv = dv.contiguous()
copy_back["dv"] = True

if DEBUG:
print("copy_back:", copy_back)

# zero out
dq.zero_()
dk.zero_()
dv.zero_()

# assert contigious
assert do.is_contiguous()
assert q.is_contiguous()
Expand All @@ -570,7 +578,7 @@ def attention_prefill_backward_triton_impl(
else:
stride_deltaz, stride_deltah, stride_deltam = delta.stride()

_bwd_preprocess_use_o[(num_blocks_m, batch_headsize)](
_bwd_preprocess_use_o[(num_blocks_m, batch * nheads_q)](
o,
do,
delta,
Expand Down Expand Up @@ -622,7 +630,7 @@ def attention_prefill_backward_triton_impl(
print("num_blocks_m:", num_blocks_m)
print("num_blocks_n:", num_blocks_n)

_bwd_kernel[(batch_headsize, num_blocks_n if sequence_parallel else 1)](
_bwd_kernel[(batch * nheads_q, num_blocks_n if sequence_parallel else 1)](
q,
k,
v,
Expand All @@ -641,6 +649,7 @@ def attention_prefill_backward_triton_impl(
stride_deltaz, stride_deltah, stride_deltam,
batch,
nheads_q,
nheads_k,
num_blocks_m,
num_blocks_n,
cu_seqlens_q,
Expand All @@ -660,18 +669,11 @@ def attention_prefill_backward_triton_impl(
IS_VARLEN=is_varlen
)

if DEBUG:
print("_bwd_kernel outputs")
print("dq:", dq, dq.shape)
print("dk:", dk, dk.shape)
print("dv:", dv, dv.shape)
print("delta:", delta, delta.shape)

if sequence_parallel:
dq = dq.sum(dim=0)

if DEBUG:
print("attention_prefill_backward_triton_new_impl outputs")
print("attention_prefill_backward_triton_impl outputs")
print("dq:", dq, dq.shape)
print("dk:", dk, dk.shape)
print("dv:", dv, dv.shape)
Expand Down
Loading
Loading