From d0eb0177f6421a4e78c5eba21df63805dbfe3620 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Fri, 10 May 2024 21:24:12 +0000 Subject: [PATCH] Change all block pointers to regular tensor pointers --- python/perf-kernels/flash-attention.py | 1100 +++++++++++++----------- 1 file changed, 596 insertions(+), 504 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 6fc861b281fa..6cd29493e0ba 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -2,36 +2,34 @@ Fused Attention =============== -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) -Credits: OpenAI kernel team, AMD ML Frameworks Triton team +This is a Triton implementation of the Flash Attention v2 algorithm +See https://tridao.me/publications/flash2/flash2.pdf -Features supported: +Credits: +AMD Triton kernels team +OpenAI kernel team -1) Fwd with causal masking -2) Any sequence lengths without padding (currently fwd kernel only) -3) Support for different sequence lengths for q and k -4) Nested tensor API currently does not support dropout or bias. - -Not currently supported: +Currently only the forward kernel is supported, and contains these features: -1) Non power of two head dims +1) Fwd with causal masking +2) Arbitrary Q and KV sequence lengths +3) Arbitrary head sizes +4) Multi and grouped query attention +5) Variable sequence lengths +6) ALiBi and matrix bias """ import argparse import pytest +import random import sys import torch import triton import triton.language as tl -torch_dtype: tl.constexpr = torch.float16 - -TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') -if TORCH_HAS_FP8E5: - torch_dtype: tl.constexpr = torch.float8_e5m2fnuz - +torch_dtype:tl.constexpr = torch.float16 class MetaData(): cu_seqlens_q = None @@ -56,9 +54,9 @@ def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): assert len(cu_seqlens_q) >= 2 assert len(cu_seqlens_q) == len(cu_seqlens_k) self.num_contexts = len(cu_seqlens_q) - 1 - for i in range(0, self.num_contexts): - self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q) - self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k) + for i in range (0, self.num_contexts): + self.max_seqlens_q = max(cu_seqlens_q[i+1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q) + self.max_seqlens_k = max(cu_seqlens_k[i+1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k) def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): assert bias.is_cuda @@ -77,7 +75,7 @@ def need_alibi(self, alibi_slopes, batch, nheads): def need_causal(self): self.causal = True - def need_dropout(self, dropout_p, return_encoded_softmax): + def need_dropout(dropout_p, return_encoded_softmax): self.dropout_p = dropout_p self.return_encoded_softmax = return_encoded_softmax @@ -91,7 +89,7 @@ def check_args(self, q, k, v, o): assert self.cu_seqlens_k is not None assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) # TODO: Remove once bias is supported with varlen - assert self.bias is None + assert self.bias == None # TODO:Remove once dropout is supported with varlen assert self.dropout_p == 0.0 assert not self.return_encoded_softmax @@ -109,51 +107,50 @@ def check_args(self, q, k, v, o): assert o.shape == q.shape assert (nheads_q % nheads_k) == 0 - @triton.jit -def cdiv_fn(x, y): +def cdiv_fn(x,y): return (x + y - 1) // y - @triton.jit def max_fn(x, y): return tl.math.max(x, y) - @triton.jit def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): ms = tl.arange(0, m) ns = tl.arange(0, n) return philox_offset + ms[:, None] * stride + ns[None, :] - @triton.jit def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) # TODO: use tl.randint for better performance return tl.rand(philox_seed, rng_offsets) - @triton.jit def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) rng_keep = rng_output > dropout_p return rng_keep - +# Convenience function to load with optional boundary checks. +# "First" is the major dim, "second" is the minor dim. @triton.jit -def load_fn(block_ptr, first, second, pad): - if first and second: - tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) - elif first: - tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) - elif second: - tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) +def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) else: - tensor = tl.load(block_ptr) + tensor = tl.load(ptrs) return tensor - @triton.jit def print_gpu(prefix, val=None): if (tl.program_id(0) == 0) and ((tl.program_id(1) == 0) and (tl.program_id(2) == 0)): @@ -162,9 +159,8 @@ def print_gpu(prefix, val=None): else: tl.device_print(prefix) - @triton.jit -def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): +def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose = False): # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix # for casual mask we want something like this where (1 is kept and 0 is masked) # seqlen_q = 2 and seqlen_k = 5 @@ -188,35 +184,62 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo # [4], [ 4, 3, 2, 1, 0]] # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], # [ -4, -3, -2, -1, 0]], - relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + relative_pos_block = offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) if transpose: return alibi_block.T else: return alibi_block - def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): - q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) - k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) - relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) - return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) - + q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) + k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) + relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) + return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) @triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, - philox_seed, batch_philox_offset, encoded_softmax_block_ptr, block_min, block_max, offs_n_causal, - masked_blocks, n_extra_tokens, bias_ptr, alibi_slope, IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, - OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr): +def _attn_fwd_inner( + acc, l_i, m_i, q, + k_ptrs, v_ptrs, bias_ptrs, + stride_kn, stride_vk, stride_bn, + start_m, + actual_seqlen_k, + actual_seqlen_q, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_sm_ptrs, + block_min, block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + alibi_slope, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr +): # loop over k, v, and update accumulator - for start_n in range(block_min, block_max, BLOCK_N): + for start_n in range (block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - k = load_fn(K_block_ptr, PADDED_HEAD, MASK_STEPS and (n_extra_tokens != 0), "zero") + if MASK_STEPS: + k_offs_n = start_n + tl.arange(0, BLOCK_N) + else: + k_offs_n = None + k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) if PRE_LOAD_V: - v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + # We can use the same offsets as k, just with dims transposed. + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n @@ -229,8 +252,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_ # that case. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) - size_n = start_n + OFFS_N[None, :] - mask = size_n < boundary_m[:, None] + size_n = start_n + OFFS_N[None,:] + mask = size_n < boundary_m[:,None] qk = tl.where(mask, qk, float("-inf")) if IS_CAUSAL: causal_boundary = start_n + offs_n_causal @@ -238,8 +261,9 @@ def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_ qk = tl.where(causal_mask, qk, float("-inf")) # -- compute qk ---- qk += tl.dot(q, k) - if bias_ptr is not None: - bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero") + if bias_ptrs is not None: + bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None + bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) # While bias is added after multiplying qk with sm_scale, # our optimization to use 2^x instead of e^x results in an additional # scale factor of log2(e) which we must also multiply the bias with. @@ -247,16 +271,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_ if alibi_slope is not None: # Compute the global position of each token within the sequence - global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_m_positions = start_m*BLOCK_M + tl.arange(0, BLOCK_M) global_n_positions = start_n + tl.arange(0, BLOCK_N) - - alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, - global_n_positions) - - qk += (alibi_block * 1.44269504089) # scale factor of log2(e) + alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, global_n_positions) + qk += (alibi_block * 1.44269504089) # scale factor of log2(e) # softmax - m_ij = tl.maximum(m_i, tl.max(qk, 1)) + m_ij = tl.maximum(m_i, tl.max(qk,1)) qk = qk - m_ij[:, None] p = tl.math.exp2(qk) @@ -266,105 +287,65 @@ def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_ philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) if RETURN_ENCODED_SOFTMAX: - tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty)) + tl.store(encoded_sm_ptrs, tl.where(keep, p, -p).to(encoded_sm_ptrs.type.element_ty)) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: - tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty)) + tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + acc += tl.dot(p.to(v.type.element_ty), v) + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if bias_ptrs is not None: + bias_ptrs += BLOCK_N * stride_bn if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) + encoded_sm_ptrs += BLOCK_N return acc, l_i, m_i - @triton.autotune( - configs=[ - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), - # TODO: This config fails with head_size not pow2 with data mismatches. Check why. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - ], - key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], - use_cuda_graph=True, + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + # TODO: This config fails with head_size not pow2 with data mismatches. Check why. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + ], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + use_cuda_graph=True, ) @triton.jit def attn_fwd( - Q, - K, - V, - bias, - sm_scale, - L, - Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, - stride_az, - stride_ah, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset_base, - encoded_softmax, + Q, K, V, bias, sm_scale, L, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + stride_bz, stride_bh, stride_bm, stride_bn, + stride_az, stride_ah, + cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, - HQ: tl.constexpr, - HK: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, + HQ: tl.constexpr, HK:tl.constexpr, + ACTUAL_BLOCK_DMODEL:tl.constexpr, + MAX_SEQLENS_Q:tl.constexpr, MAX_SEQLENS_K:tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, - BIAS_TYPE: tl.constexpr, + USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr, @@ -375,6 +356,7 @@ def attn_fwd( off_z = tl.program_id(2) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) @@ -405,25 +387,30 @@ def attn_fwd( # the causal mask boundary is bottom right aligned, and ends at either # the top edge (seqlen_q < seqlen_k) or left edge. # This captures the decrease in n_blocks if we have a rectangular attn matrix - n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, + BLOCK_N + ) # This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks n_blocks = min(n_blocks, n_blocks_seqlen) # If we have no blocks after adjusting for seqlen deltas, this WG is part of # the blocks that are all 0. We exit early. if n_blocks <= 0: - o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh - O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + o_ptrs_mask = offs_m[:, None] < seqlen_q # We still need to write 0s to the result - tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0, 1)) + tl.store(o_ptrs, acc, mask=o_ptrs_mask) + # The tensor allocated for L is based on MAX_SEQLENS_Q as that is + # statically known. l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m # We store inf to LSE, not -inf because in the bwd pass, we subtract this # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - tl.store(l_ptrs, l) + l_ptrs_mask = offs_m < MAX_SEQLENS_Q + tl.store(l_ptrs, l, mask=l_ptrs_mask) # TODO: Should dropout and return encoded softmax be handled here too? return @@ -434,63 +421,47 @@ def attn_fwd( else: off_h_k = off_h_q - # need_padding = False + need_padding = False n_extra_tokens = 0 if seqlen_k < BLOCK_N: - # need_padding = True + need_padding = True n_extra_tokens = BLOCK_N - seqlen_k elif seqlen_k % BLOCK_N: - # need_padding = True + need_padding = True n_extra_tokens = seqlen_k % BLOCK_N - PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + PADDED_HEAD:tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) # Compute pointers for all the tensors used in this kernel. - q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm - Q_block_ptr = tl.make_block_ptr(base=Q + q_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) - k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn - K_block_ptr = tl.make_block_ptr(base=K + k_offset, shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), - strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1)) - v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk - V_block_ptr = tl.make_block_ptr(base=V + v_offset, shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), - strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0)) - if BIAS_TYPE != 0: - b_offset = off_h_q * stride_bh # Note: this might get large enough to overflow on some configs - bias_ptr = tl.make_block_ptr( - base=bias + b_offset, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) + q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + if USE_BIAS: + # Note: this might get large enough to overflow on some configs + bias_offset = off_h_q * stride_bh + bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn else: - bias_ptr = None + bias_ptrs = None if USE_ALIBI: - a_offset = off_z * stride_az + off_h_q * stride_ah + a_offset = off_z * stride_az + off_h_q * stride_ah alibi_slope = tl.load(alibi_slopes + a_offset) else: alibi_slope = None if ENABLE_DROPOUT: - off_hz = off_z * HQ + off_h_q batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k else: batch_philox_offset = 0 # We can ask to return the dropout mask without actually doing any dropout. In # this case, we return an invalid pointer so indicate the mask is not valid. - # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.make_block_ptr(base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, - shape=(seqlen_q, seqlen_k), strides=(seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0)) + encoded_sm_base = encoded_softmax + off_h_q * seqlen_q * seqlen_k + encoded_sm_ptrs = encoded_sm_base + offs_m[:, None] * seqlen_k + offs_n[None, :] else: - encoded_softmax_block_ptr = 0 + encoded_sm_ptrs = None # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) @@ -499,8 +470,11 @@ def attn_fwd( # have native e^x support in HW. qk_scale = sm_scale * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. - q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + q_ptrs_mask = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + q = (q * qk_scale).to(q.type.element_ty) # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -522,14 +496,17 @@ def attn_fwd( # value because there is no masking. Similarly we do not need padding. if n_full_blocks > 0: block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, seqlen_k, seqlen_q, - dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, block_max, 0, 0, 0, bias_ptr, alibi_slope, - # IS_CAUSAL, .... - False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD) + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, + start_m, seqlen_k, seqlen_q, + dropout_p, philox_seed, batch_philox_offset, encoded_sm_ptrs, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, block_max, 0, 0, 0, alibi_slope, + # IS_CAUSAL, .... + False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, ACTUAL_BLOCK_DMODEL + ) block_min = block_max block_max = n_blocks * BLOCK_N @@ -540,18 +517,21 @@ def attn_fwd( offs_n_causal = offs_n + (seqlen_q - seqlen_k) else: offs_n_causal = 0 - K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vk + if USE_BIAS: + bias_ptrs += n_full_blocks * BLOCK_N * stride_bn if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, n_full_blocks)) - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, seqlen_k, seqlen_q, - dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, - block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, bias_ptr, - alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD) + encoded_sm_ptrs += n_full_blocks * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, + start_m, seqlen_k, seqlen_q, + dropout_p, philox_seed, batch_philox_offset, encoded_sm_ptrs, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, ACTUAL_BLOCK_DMODEL + ) # epilogue acc = acc / l_i[:, None] if ENABLE_DROPOUT: @@ -566,7 +546,7 @@ def attn_fwd( acc = acc.to(Out.type.element_ty) if IS_CAUSAL: if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + out_mask_boundary = tl.full((BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] z = 0.0 @@ -577,37 +557,28 @@ def attn_fwd( # This is only true for the last M block. For others, overflow_size will be -ve overflow_size = end_m_idx - seqlen_q if overflow_size > 0: - boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) - # This is a > check because mask being 0 blocks the store. - l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) else: tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O - o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh - O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) - # Need boundary check on this to make sure the padding from the - # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. - tl.store(O_block_ptr, acc, boundary_check=(0, 1)) - + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + if overflow_size > 0: + o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + if PADDED_HEAD: + o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) @triton.jit def _attn_bwd_preprocess( - Out, - DO, + Out, DO, Delta, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_doz, - stride_doh, - stride_dom, - stride_don, + stride_oz, stride_oh, stride_om, stride_on, + stride_doz, stride_doh, stride_dom, stride_don, seqlen_q, head_dim, BLOCK_M: tl.constexpr, @@ -616,20 +587,32 @@ def _attn_bwd_preprocess( # off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) # off_n = tl.arange(0, D_HEAD) off_m = tl.program_id(0) * BLOCK_M - off_h = tl.program_id(1) # head index - off_z = tl.program_id(2) # batch index + off_h = tl.program_id(1) # head index + off_z = tl.program_id(2) # batch index num_h = tl.num_programs(1) o_offset = off_h * stride_oh + off_z * stride_oz - O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, head_dim), strides=(stride_om, stride_on), - offsets=(off_m, 0), block_shape=(BLOCK_M, D_HEAD), order=(1, 0)) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, head_dim), + strides=(stride_om, stride_on), + offsets=(off_m, 0), + block_shape=(BLOCK_M, D_HEAD), + order=(1, 0) + ) do_offset = off_h * stride_doh + off_z * stride_doz - DO_block_ptr = tl.make_block_ptr(base=DO + do_offset, shape=(seqlen_q, head_dim), strides=(stride_dom, stride_don), - offsets=(off_m, 0), block_shape=(BLOCK_M, D_HEAD), order=(1, 0)) + DO_block_ptr = tl.make_block_ptr( + base=DO + do_offset, + shape=(seqlen_q, head_dim), + strides=(stride_dom, stride_don), + offsets=(off_m, 0), + block_shape=(BLOCK_M, D_HEAD), + order=(1, 0) + ) # load # o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) # do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) - o = tl.load(O_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) - do = tl.load(DO_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + o = tl.load(O_block_ptr, boundary_check=(0,1), padding_option="zero").to(tl.float32) + do = tl.load(DO_block_ptr, boundary_check=(0,1), padding_option="zero").to(tl.float32) # compute delta = tl.sum(o * do, axis=1) # write-back, shape (q.shape[0] * q.shape[1], q.shape[2]) @@ -644,21 +627,39 @@ def _attn_bwd_preprocess( else: tl.store(delta_ptrs, delta) - @triton.jit -def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, - # shared by Q/K/V/DO. - stride_tok, stride_d, H, N_CTX, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - # Filled in by the wrapper. - start_n, start_m, num_steps, MASK: tl.constexpr): +def _bwd_kernel_dk_dv( + dk, dv, + Q, k, v, sm_scale, alibi_slope, + DO, + M, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, + H, N_CTX, BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_n, start_m, num_steps, + MASK: tl.constexpr): offs_m = start_m + tl.arange(0, BLOCK_M1) offs_n = start_n + tl.arange(0, BLOCK_N1) - # offs_k = tl.arange(0, BLOCK_DMODEL) - QT_block_ptr = tl.make_block_ptr(base=Q, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), - offsets=(0, start_m), block_shape=(BLOCK_DMODEL, BLOCK_M1), order=(0, 1)) - DO_block_ptr = tl.make_block_ptr(base=DO, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), - offsets=(start_m, 0), block_shape=(BLOCK_M1, BLOCK_DMODEL), order=(1, 0)) + offs_k = tl.arange(0, BLOCK_DMODEL) + QT_block_ptr = tl.make_block_ptr( + base=Q, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_d, stride_tok), + offsets=(0, start_m), + block_shape=(BLOCK_DMODEL, BLOCK_M1), + order=(0,1) + ) + DO_block_ptr = tl.make_block_ptr( + base=DO, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M1, BLOCK_DMODEL), + order=(1,0) + ) # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) curr_m = start_m @@ -696,21 +697,37 @@ def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0)) return dk, dv - @triton.jit -def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, - # shared by Q/K/V/DO. - stride_tok, stride_d, H, N_CTX, BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - # Filled in by the wrapper. - start_m, start_n, num_steps, MASK: tl.constexpr): +def _bwd_kernel_dq(dq, q, K, V, + do, m, D, alibi_slope, + # shared by Q/K/V/DO. + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, + MASK: tl.constexpr): offs_m = start_m + tl.arange(0, BLOCK_M2) offs_n = start_n + tl.arange(0, BLOCK_N2) - # offs_k = tl.arange(0, BLOCK_DMODEL) - KT_block_ptr = tl.make_block_ptr(base=K, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), - offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1)) - VT_block_ptr = tl.make_block_ptr(base=V, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), - offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1)) + offs_k = tl.arange(0, BLOCK_DMODEL) + KT_block_ptr = tl.make_block_ptr( + base=K, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_d, stride_tok), + offsets=(0, start_n), + block_shape=(BLOCK_DMODEL, BLOCK_N2), + order=(0, 1) + ) + VT_block_ptr = tl.make_block_ptr( + base=V, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_d, stride_tok), + offsets=(0, start_n), + block_shape=(BLOCK_DMODEL, BLOCK_N2), + order=(0, 1) + ) # D (= delta) is pre-divided by ds_scale. Di = tl.load(D + offs_m) # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. @@ -744,14 +761,22 @@ def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) return dq - @triton.jit -def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, +def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, + DO, + DQ, DK, DV, + M, D, # shared by Q/K/V/DO. stride_z, stride_h, stride_tok, stride_d, # H = 16, N_CTX = 1024 - H, N_CTX, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, - BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLK_SLICE_FACTOR: tl.constexpr, USE_ALIBI: tl.constexpr): + H, N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + USE_ALIBI: tl.constexpr): LN2: tl.constexpr = 0.6931471824645996 # = ln(2) bhid = tl.program_id(2) @@ -770,7 +795,7 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, M += off_chz D += off_chz - # offs_k = tl.arange(0, BLOCK_DMODEL) + offs_k = tl.arange(0, BLOCK_DMODEL) start_n = pid * BLOCK_N1 # This assignment is important. It is what allows us to pick the diagonal @@ -779,7 +804,7 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, start_m = start_n MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR - # offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_n = start_n + tl.arange(0, BLOCK_N1) dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) @@ -813,24 +838,54 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, # compute dK and dV for blocks close to the diagonal that need to be masked num_steps = BLOCK_N1 // MASK_BLOCK_M1 - dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, - MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=True) + dk, dv = _bwd_kernel_dk_dv( + dk, dv, + Q, k, v, sm_scale, alibi_slope, + DO, + M, D, + stride_tok, stride_d, + H, N_CTX, + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, + start_n, start_m, num_steps, + MASK=True + ) # compute dK and dV for blocks that don't need masking further from the diagonal start_m += num_steps * MASK_BLOCK_M1 num_steps = (N_CTX - start_m) // BLOCK_M1 - dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, - BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=False) + dk, dv = _bwd_kernel_dk_dv( + dk, dv, + Q, k, v, sm_scale, alibi_slope, + DO, + M, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, + start_n, start_m, num_steps, + MASK=False + ) - DV_block_ptrs = tl.make_block_ptr(base=DV, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), - offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0)) + DV_block_ptrs = tl.make_block_ptr( + base=DV, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1,0) + ) tl.store(DV_block_ptrs, dv.to(v.dtype)) # Write back dK. dk *= sm_scale - DK_block_ptrs = tl.make_block_ptr(base=DK, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), - offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0)) + DK_block_ptrs = tl.make_block_ptr( + base=DK, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1,0) + ) tl.store(DK_block_ptrs, dk.to(k.dtype)) # THIS BLOCK DOES DQ: @@ -840,11 +895,23 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR offs_m = start_m + tl.arange(0, BLOCK_M2) - Q_block_ptr = tl.make_block_ptr(base=Q, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), - offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + Q_block_ptr = tl.make_block_ptr( + base=Q, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M2, BLOCK_DMODEL), + order=(1, 0) + ) - DO_block_ptr = tl.make_block_ptr(base=DO, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), - offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + DO_block_ptr = tl.make_block_ptr( + base=DO, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M2, BLOCK_DMODEL), + order=(1, 0) + ) q = tl.load(Q_block_ptr) do = tl.load(DO_block_ptr) dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) @@ -858,30 +925,46 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, # not due to anything important. I just wanted to reuse the loop # structure for dK & dV above as much as possible. num_steps = BLOCK_M2 // MASK_BLOCK_N2 - dq = _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, stride_tok, stride_d, H, N_CTX, BLOCK_M2, MASK_BLOCK_N2, - BLOCK_DMODEL, start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, MASK=True) + dq = _bwd_kernel_dq(dq, q, K, V, + do, m, D, alibi_slope, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, + MASK=True + ) end_n -= num_steps * MASK_BLOCK_N2 # stage 2 num_steps = end_n // BLOCK_N2 - dq = _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, stride_tok, stride_d, H, N_CTX, BLOCK_M2, BLOCK_N2, - BLOCK_DMODEL, start_m, end_n - num_steps * BLOCK_N2, num_steps, MASK=False) + dq = _bwd_kernel_dq(dq, q, K, V, + do, m, D, alibi_slope, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, + start_m, end_n - num_steps * BLOCK_N2, num_steps, + MASK=False + ) # Write back dQ. - DQ_block_ptr = tl.make_block_ptr(base=DQ, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), - offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M2, BLOCK_DMODEL), + order=(1, 0) + ) dq *= LN2 tl.store(DQ_block_ptr, dq.to(q.dtype)) - empty = torch.empty(128, device="cuda") class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, o, metadata): # NOTE: a large bias tensor leads to overflow during pointer arithmetic if (metadata.bias is not None): - assert (metadata.bias.numel() < 2**31) + assert(metadata.bias.numel() < 2 ** 31) if o is None: o = torch.empty_like(q, dtype=v.dtype) @@ -906,15 +989,18 @@ def forward(ctx, q, k, v, o, metadata): padded_d_model = 1 << (head_size - 1).bit_length() padded_d_model = max(padded_d_model, 16) - grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']), nheads_q, batch) + grid = lambda META: ( + triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']), + nheads_q, + batch + ) # encoded_softmax is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing # only. This return holds no useful output aside from debugging. if metadata.return_encoded_softmax: - encoded_softmax = torch.zeros((q.shape[0], q.shape[1], q.shape[2], k.shape[2]), device=q.device, - dtype=torch.float32) + encoded_softmax = torch.zeros((q.shape[0], q.shape[1], q.shape[2], k.shape[2]), device=q.device, dtype=torch.float32) else: encoded_softmax = None @@ -925,25 +1011,38 @@ def forward(ctx, q, k, v, o, metadata): philox_offset = 0x1D4B42 if metadata.bias is not None: - bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), metadata.bias.stride(2), - metadata.bias.stride(3)) + bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), + metadata.bias.stride(2), metadata.bias.stride(3)) else: - bias_strides = (0, 0, 0, 0) - + bias_strides = (0,0,0,0) + if metadata.alibi_slopes is not None: alibi_strides = (metadata.alibi_slopes.stride(0), metadata.alibi_slopes.stride(1)) else: alibi_strides = (0, 0) - attn_fwd[grid](q, k, v, metadata.bias, metadata.sm_scale, M, o, *q_strides, *k_strides, *v_strides, *o_strides, - *bias_strides, *alibi_strides, metadata.cu_seqlens_q, metadata.cu_seqlens_k, - dropout_p=metadata.dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, - encoded_softmax=encoded_softmax, alibi_slopes=metadata.alibi_slopes, HQ=nheads_q, HK=nheads_k, - ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=metadata.max_seqlens_q, - MAX_SEQLENS_K=metadata.max_seqlens_k, IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen, - BLOCK_DMODEL=padded_d_model, BIAS_TYPE=0 if metadata.bias is None else 1, - USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p - > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, BATCH_SIZE=q.shape[0]) + attn_fwd[grid]( + q, k, v, metadata.bias, metadata.sm_scale, M, o, + *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, *alibi_strides, + metadata.cu_seqlens_q, metadata.cu_seqlens_k, + dropout_p=metadata.dropout_p, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, + alibi_slopes = metadata.alibi_slopes, + HQ=nheads_q, HK=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=metadata.max_seqlens_q, + MAX_SEQLENS_K=metadata.max_seqlens_k, + IS_CAUSAL=metadata.causal, + VARLEN=metadata.varlen, + BLOCK_DMODEL=padded_d_model, + USE_BIAS=False if metadata.bias is None else True, + USE_ALIBI=False if metadata.alibi_slopes is None else True, + ENABLE_DROPOUT=metadata.dropout_p > 0.0, + RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, + BATCH_SIZE= q.shape[0] + ) ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid @@ -973,7 +1072,7 @@ def backward(ctx, do, _): dv = torch.empty_like(v) BATCH, N_HEAD, N_CTX = q.shape[:3] PRE_BLOCK = 128 - # NUM_WARPS, NUM_STAGES = 4, 1 + NUM_WARPS, NUM_STAGES = 4, 1 BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32 BLK_SLICE_FACTOR = 2 RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) @@ -981,60 +1080,37 @@ def backward(ctx, do, _): arg_k = arg_k * (ctx.sm_scale * RCP_LN2) assert N_CTX % PRE_BLOCK == 0 delta = torch.empty_like(M) - _, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1] - # padded_head = (Lk != ctx.BLOCK_DMODEL) + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + padded_head = (Lk != ctx.BLOCK_DMODEL) grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0]) _attn_bwd_preprocess[grid_preprocess]( - o, - do, - delta, - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), - do.stride(0), - do.stride(1), - do.stride(2), - do.stride(3), + o, do, delta, + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + do.stride(0), do.stride(1), do.stride(2), do.stride(3), seqlen_q, head_dim=Lk, - BLOCK_M=BLOCK, - D_HEAD=ctx.BLOCK_DMODEL, + BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, + ) + grid = lambda META: ( + triton.cdiv(N_CTX, META['BLOCK_N1']), + 1, + BATCH * N_HEAD ) - grid = lambda META: (triton.cdiv(N_CTX, META['BLOCK_N1']), 1, BATCH * N_HEAD) _attn_bwd[grid]( - q, - arg_k, - v, - ctx.sm_scale, - ctx.alibi_slopes, - do, - dq, - dk, - dv, - M, - delta, - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), - N_HEAD, - N_CTX, + q, arg_k, v, ctx.sm_scale, ctx.alibi_slopes, do, dq, dk, dv, + M, delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + N_HEAD, N_CTX, BLOCK_DMODEL=ctx.BLOCK_DMODEL, - BLOCK_M1=BLOCK_M1, - BLOCK_N1=BLOCK_N1, - BLOCK_M2=BLOCK_M2, - BLOCK_N2=BLOCK_N2, + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - USE_ALIBI=False if ctx.alibi_slopes is None else True, + USE_ALIBI= False if ctx.alibi_slopes is None else True, ) return dq, dk, dv, None, None - attention = _attention.apply - def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): torch.manual_seed(20) @@ -1048,15 +1124,14 @@ def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): input_metadata.max_seqlens_k = N_CTX_K return q, k, v, input_metadata - def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): torch.manual_seed(20) # Random sequence lengths. Using N_CTX as kind of max of sum of individual seqs max_seqlens_q = N_CTX_Q // Z max_seqlens_k = N_CTX_K // Z - seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32) - seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32) + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32) max_seqlens_q = torch.max(seqlens_q).item() max_seqlens_k = torch.max(seqlens_k).item() @@ -1066,7 +1141,7 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): cu_seqlens_q = cu_seqlens_q.to(device="cuda") cu_seqlens_k = cu_seqlens_k.to(device="cuda") # -1 because the last entry of cu_seqlens_q specifies the end of the last seq - # num_ctxs = len(cu_seqlens_q) - 1 + num_ctxs = len(cu_seqlens_q) - 1 # Initialize q, k, v with variable lengths total_q = cu_seqlens_q[-1].item() @@ -1079,25 +1154,24 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) return q, k, v, input_metadata - -@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 24, 1024, 1024, 64), - (1, 24, 6, 8192, 8192, 64), - (1, 4, 2, 16384, 16384, 128), - (2, 16, 4, 1020, 987, 128), - (2, 16, 4, 15498, 2, 128), - (2, 16, 2, 7, 16219, 64), - (4, 48, 12, 1, 1, 64), - (4, 48, 48, 1, 1, 128), - (4, 48, 24, 3, 3, 128), - (4, 48, 48, 1001, 990, 64), - (1, 8, 8, 8081, 7099, 64), - (1, 4, 4, 16330, 15989, 128), - (4, 4, 1, 1024, 1024, 33), - (4, 4, 2, 65, 1018, 65), - (4, 4, 4, 128, 128, 65), - (4, 4, 4, 113, 123, 1), -]) +@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', + [(4, 48, 24, 1024, 1024, 64), + (1, 24, 6, 8192, 8192, 64), + (1, 4, 2, 16384, 16384, 128), + (2, 16, 4, 1020, 987, 128), + (2, 16, 4, 15498, 2, 128), + (2, 16, 2, 7, 16219, 64), + (4, 48, 12, 1, 1, 64), + (4, 48, 48, 1, 1, 128), + (4, 48, 24, 3, 3, 128), + (4, 48, 48, 1001, 990, 64), + (1, 8, 8, 8081, 7099, 64), + (1, 4, 4, 16330, 15989, 128), + (4, 4, 1, 1024, 1024, 33), + (4, 4, 2, 65, 1018, 65), + (4, 4, 4, 128, 128, 65), + (4, 4, 4, 113, 123, 1), + ]) @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('use_alibi', [True, False]) def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=torch.float16): @@ -1108,15 +1182,11 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to if use_alibi: # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) - alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, - device="cuda").repeat(Z, 1) + alibi_slopes = torch.tensor([2**(-8/HQ*i) for i in range(1, HQ+1)], dtype=torch.float32, device="cuda").repeat(Z, 1) input_metadata.need_alibi(alibi_slopes, Z, HQ) else: - alibi_slopes = None + alibi = None - if TORCH_HAS_FP8E5: - q = q.to(torch_dtype) - k = k.to(torch_dtype) o = torch.empty_like(q) # triton implementation @@ -1124,17 +1194,20 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to # Replicate K and V if using MQA/GQA if HQ != HK: - k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], - k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3]) - v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], - v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3]) + k = k.view( + k.shape[0], k.shape[1], -1, k.shape[2], k.shape[3]).expand( + -1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3]) + v = v.view( + v.shape[0], v.shape[1], -1, v.shape[2], v.shape[3]).expand( + -1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3]) scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale if causal: - mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) - scores[:, :, mask == 0] = float("-inf") + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), + diagonal=N_CTX_K-N_CTX_Q) + scores[:, :, mask==0] = float("-inf") if use_alibi: - scores += compute_alibi_tensor(alibi_slopes, N_CTX_Q, N_CTX_K) + scores+= compute_alibi_tensor(alibi_slopes, N_CTX_Q, N_CTX_K) p = torch.softmax(scores, dim=-1) if causal: @@ -1142,36 +1215,35 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix # this by converting the NaNs to 0s, which is what they should be out of the softmax. nan_mask = torch.isnan(p) - p[nan_mask == 1] = 0 + p[nan_mask==1] = 0 ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) # compare torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) - -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 1024, 1024, 64), - (4, 24, 8192, 8192, 64), - (2, 4, 16384, 16384, 128), - (2, 16, 1020, 987, 128), - (2, 16, 15498, 2, 128), - (2, 16, 7, 16219, 64), - (4, 48, 1, 1, 64), - (4, 48, 1, 1, 128), - (4, 48, 3, 3, 128), - (4, 48, 1001, 990, 64), - (1, 8, 8081, 7099, 64), - (1, 8, 16330, 15989, 128), - (4, 4, 1024, 1024, 33), - (4, 4, 65, 1019, 65), - (4, 4, 128, 128, 65), - (4, 4, 113, 123, 1), -]) -@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', + [(4, 48, 1024, 1024, 64), + (4, 12, 8192, 8192, 64), + (2, 4, 16384, 16384, 128), + (2, 16, 1020, 987, 128), + (2, 16, 15498, 2, 128), + (2, 4, 7, 16219, 64), + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 48, 1001, 990, 64), + (1, 8, 8081, 7099, 64), + (1, 8, 16330, 15989, 128), + (4, 4, 1024, 1024, 33), + (4, 4, 65, 1019, 65), + (4, 4, 128, 128, 65), + # TODO: This config fails. Disabled until triaged and fixed. + # (4, 4, 113, 123, 1), + ]) +@pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('use_bias', [True]) def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16): - pytest.skip() torch.manual_seed(20) - sm_scale = D_HEAD**-0.5 + sm_scale = D_HEAD ** -0.5 input_metadata = MetaData(sm_scale=sm_scale) input_metadata.max_seqlens_q = N_CTX_Q input_metadata.max_seqlens_k = N_CTX_K @@ -1185,9 +1257,6 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor q = torch.randn((Z, H, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() k = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() v = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - if TORCH_HAS_FP8E5: - q = q.to(torch_dtype) - k = k.to(torch_dtype) o = torch.empty_like(q) # triton implementation @@ -1196,8 +1265,9 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * sm_scale if causal: - mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) - scores[:, :, mask == 0] = float("-inf") + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), + diagonal=N_CTX_K-N_CTX_Q) + scores[:, :, mask==0] = float("-inf") if use_bias: scores += input_metadata.bias p = torch.softmax(scores, dim=-1) @@ -1206,39 +1276,55 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix # this by converting the NaNs to 0s, which is what they should be out of the softmax. nan_mask = torch.isnan(p) - p[nan_mask == 1] = 0 + p[nan_mask==1] = 0 ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) # compare torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) - -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 8192, 64), (4, 48, 256, 64), (4, 48, 512, 64), - (4, 48, 1024, 64), (8, 48, 4096, 64), (4, 48, 8192, 64), - (4, 48, 128, 128), (4, 48, 4096, 128), (4, 48, 16384, 128), - (4, 16, 1024, 128), (4, 16, 8192, 128), (32, 48, 8192, 128)]) +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', + [(4, 48, 8192, 64), + (4, 48, 256, 64), + (4, 48, 512, 64), + (4, 48, 1024, 64), + (8, 48, 4096, 64), + (4, 48, 8192, 64), + (4, 48, 128, 128), + (4, 48, 4096, 128), + (4, 48, 16384, 128), + (4, 16, 1024, 128), + (4, 16, 8192, 128), + (32, 48, 8192, 128) + ]) @pytest.mark.parametrize('causal', [True, False]) def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): - pytest.skip() + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) - q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, D_HEAD, dtype) tri_out = torch.empty_like(q) ref_out = torch.empty_like(q) for i in range(0, input_metadata.num_contexts): start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] - end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] + end_q, end_k = input_metadata.cu_seqlens_q[i+1], input_metadata.cu_seqlens_k[i+1] scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k[start_k:end_k]).float() p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v[start_k:end_k]) attention(q, k, v, tri_out, input_metadata) torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) - -@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), (4, 48, 12, 256, 64), (4, 48, 4, 512, 64), - (4, 48, 2, 1024, 64), (8, 48, 6, 4096, 64), (4, 48, 8, 16384, 64), - (4, 64, 16, 128, 128), (4, 64, 4, 4096, 128), - (4, 64, 8, 16384, 128), (4, 16, 4, 1024, 128), - (4, 16, 2, 8192, 128), (32, 128, 32, 8192, 128)]) +@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', + [(2, 48, 24, 128, 64), + (4, 48, 12, 256, 64), + (4, 48, 4, 512, 64), + (4, 48, 2, 1024, 64), + (8, 48, 6, 4096, 64), + (4, 48, 8, 16384, 64), + (4, 64, 16, 128, 128), + (4, 64, 4, 4096, 128), + (4, 64, 8, 16384, 128), + (4, 16, 4, 1024, 128), + (4, 16, 2, 8192, 128), + (32, 128, 32, 8192, 128) + ]) @pytest.mark.parametrize('causal', [False]) def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16): q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype) @@ -1250,7 +1336,7 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand(-1, -1, HQ // HK, -1) for i in range(0, input_metadata.num_contexts): start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] - end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] + end_q, end_k = input_metadata.cu_seqlens_q[i+1], input_metadata.cu_seqlens_k[i+1] k_curr = k_ref[start_k:end_k] k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) v_curr = v_ref[start_k:end_k] @@ -1261,22 +1347,20 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 attention(q, k, v, tri_out, input_metadata) torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) - -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ - (4, 48, 1024, 64), - (4, 48, 2048, 64), - (2, 48, 4096, 64), - (1, 16, 1024, 64), - (1, 16, 1024, 128), - #(1, 16, 8192, 63), - #(1, 16, 1022, 64), -]) +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', + [(4, 48, 1024, 64), + (4, 48, 2048, 64), + (2, 48, 4096, 64), + (1, 16, 1024, 64), + (1, 16, 1024, 128), + #(1, 16, 8192, 63), + #(1, 16, 1022, 64), + ]) @pytest.mark.parametrize('qseqlen_not_equal_kseqlen', [None]) @pytest.mark.parametrize('torch_sdpa_test', [False, True]) @pytest.mark.parametrize('causal', [True]) @pytest.mark.parametrize('use_alibi', [False, True]) -def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, - dtype=torch.float16): +def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, dtype=torch.float16): torch.manual_seed(20) if qseqlen_not_equal_kseqlen is not None: seqlen_q = qseqlen_not_equal_kseqlen @@ -1289,7 +1373,7 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sd if causal and seqlen_q != seqlen_k: pytest.skip() - sm_scale = D_HEAD**-0.5 + sm_scale = D_HEAD ** -0.5 input_metadata = MetaData(sm_scale=sm_scale) input_metadata.max_seqlens_q = seqlen_q input_metadata.max_seqlens_k = seqlen_k @@ -1305,14 +1389,15 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sd if use_alibi and not torch_sdpa_test: # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) - alibi_slopes = torch.tensor([2**(-8 / H * i) for i in range(1, H + 1)], dtype=torch.float32, - device="cuda").repeat(Z, 1) + alibi_slopes = torch.tensor([2**(-8/H*i) for i in range(1, H+1)], dtype=torch.float32, device="cuda").repeat(Z, 1) input_metadata.need_alibi(alibi_slopes, Z, H) dout = torch.randn_like(q) # reference implementation if torch_sdpa_test: - ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p, - is_causal=causal, scale=sm_scale, + ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, + dropout_p=dropout_p, + is_causal=causal, + scale=sm_scale, dropout_mask=None) ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype)) ref_dv, v.grad = v.grad.clone(), None @@ -1322,10 +1407,10 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sd M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda")) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale if use_alibi: - p += compute_alibi_tensor(alibi_slopes, N_CTX, N_CTX) + p+= compute_alibi_tensor(alibi_slopes, N_CTX, N_CTX) if causal: p[:, :, M == 0] = float("-inf") - + p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype) ref_out = torch.matmul(p, v) ref_out.backward(dout) @@ -1362,66 +1447,61 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sd torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL) torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL) - def nonvarlen_benchmark_configs(): - configs = [ - (16, 16, 16, 1024, 1024), - (8, 16, 16, 2048, 2048), - (4, 16, 16, 4096, 4096), - (2, 16, 16, 8192, 8192), - (1, 16, 16, 16384, 16384), - (2, 48, 48, 1024, 1024), - (2, 48, 48, 2048, 1024), - (2, 48, 48, 4096, 8192), - (2, 48, 48, 8192, 4096), - (2, 48, 48, 16384, 8192), - (8, 16, 16, 1989, 15344), - (4, 16, 16, 4097, 163), - (2, 16, 16, 8122, 2159), - (1, 16, 16, 16281, 7), - (2, 48, 48, 1021, 1020), - (2, 48, 48, 2001, 2048), - (2, 48, 48, 3996, 9639), - (2, 48, 48, 8181, 1021), - ] + configs=[(16, 16, 16, 1024, 1024), + (8, 16, 16, 2048, 2048), + (4, 16, 16, 4096, 4096), + (2, 16, 16, 8192, 8192), + (1, 16, 16, 16384, 16384), + (2, 48, 48, 1024, 1024), + (2, 48, 48, 2048, 1024), + (2, 48, 48, 4096, 8192), + (2, 48, 48, 8192, 4096), + (2, 48, 48, 16384, 8192), + (8, 16, 16, 1989, 15344), + (4, 16, 16, 4097, 163), + (2, 16, 16, 8122, 2159), + (1, 16, 16, 16281, 7), + (2, 48, 48, 1021, 1020), + (2, 48, 48, 2001, 2048), + (2, 48, 48, 3996, 9639), + (2, 48, 48, 8181, 1021), + ] return configs - def varlen_benchmark_configs(): - configs = [ - (2, 16, 4, 1024, 1024), - (8, 16, 2, 2048, 2048), - (4, 16, 8, 4096, 4096), - (2, 16, 4, 8192, 8192), - (2, 16, 8, 16384, 16384), - (2, 48, 12, 1024, 1024), - (2, 48, 24, 2048, 2048), - (2, 48, 8, 4096, 4096), - (2, 48, 4, 8192, 8192), - (2, 48, 2, 16384, 16384), - (2, 64, 32, 1024, 1024), - (4, 64, 16, 2048, 2048), - (4, 64, 8, 4096, 4096), - (4, 64, 32, 8192, 8192), - (4, 128, 16, 16384, 16384), - ] + configs=[(2, 16, 4, 1024, 1024), + (8, 16, 2, 2048, 2048), + (4, 16, 8, 4096, 4096), + (2, 16, 4, 8192, 8192), + (2, 16, 8, 16384, 16384), + (2, 48, 12, 1024, 1024), + (2, 48, 24, 2048, 2048), + (2, 48, 8, 4096, 4096), + (2, 48, 4, 8192, 8192), + (2, 48, 2, 16384, 16384), + (2, 64, 32, 1024, 1024), + (4, 64, 16, 2048, 2048), + (4, 64, 8, 4096, 4096), + (4, 64, 32, 8192, 8192), + (4, 128, 16, 16384, 16384), + ] return configs - def run_benchmark(custom): args = parse_args() dtype = arg_to_torch_dtype[args.dtype] - # hk = args.hq if not args.hk else args.hk - # sk = args.sq if not args.sk else args.sk + hk = args.hq if not args.hk else args.hk + sk = args.sq if not args.sk else args.sk head_size = 128 if not args.d else args.d mode = 'fwd' - x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] + x_names=['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] causal = args.causal varlen = args.varlen configs = [] if custom: - x_vals_list = [(args.b, args.hq, args.hk, args.sq, args.sk)] + x_vals_list=[(args.b, args.hq, args.hk, args.sq, args.sk)] else: if varlen: x_vals_list = varlen_benchmark_configs() @@ -1429,14 +1509,26 @@ def run_benchmark(custom): x_vals_list = nonvarlen_benchmark_configs() print_time = args.return_time line_names = 'Time (ms)' if print_time else 'TFLOPS' - configs.append( - triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'], - line_names=[line_names], styles=[('red', '-')], ylabel='ms', - plot_name=f'fused-attention-{mode}-d{head_size}{"-varlen" if varlen else ""}', - args={'D_HEAD': head_size, 'dtype': dtype, 'causal': causal, 'mode': mode})) + configs.append(triton.testing.Benchmark( + x_names=x_names, + x_vals=x_vals_list, + line_arg='provider', + line_vals=['triton'], + line_names=[line_names], + styles=[('red', '-')], + ylabel='ms', + plot_name=f'fused-attention-{mode}-d{head_size}{"-varlen" if varlen else ""}', + args={ + 'D_HEAD': head_size, + 'dtype': dtype, + 'causal': causal, + 'mode': mode}) + ) @triton.testing.perf_report(configs) - def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda"): + def bench_flash_attention( + BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda" + ): assert mode in ["fwd", "bwd"] warmup = 25 rep = 100 @@ -1446,7 +1538,7 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal # input_metadata.need_bias(bias, BATCH, H, N_CTX, N_CTX) # else: # bias = None - # bias = None + bias = None # Bwd pass only supports causal=True right now if mode == 'bwd': @@ -1455,9 +1547,9 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal flops_per_matmul = 0 if varlen: q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) - for i in range(0, input_metadata.num_contexts): - seqlen_q = input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i] - seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i] + for i in range (0, input_metadata.num_contexts): + seqlen_q = input_metadata.cu_seqlens_q[i+1] - input_metadata.cu_seqlens_q[i] + seqlen_k = input_metadata.cu_seqlens_k[i+1] - input_metadata.cu_seqlens_k[i] # x2 for 2 GEMMs flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2 else: @@ -1485,7 +1577,6 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal bench_flash_attention.run(save_path=".", print_data=True) - def parse_args(): parser = argparse.ArgumentParser( prog="Benchmark FlashAttention", @@ -1503,9 +1594,11 @@ def parse_args(): parser.add_argument("-return_time", action='store_true', default=False) return parser.parse_args() - -arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32} - +arg_to_torch_dtype = { + 'fp16': torch.float16, + 'bf16': torch.bfloat16, + 'fp32': torch.float32 +} def main(): args = parse_args() @@ -1522,6 +1615,5 @@ def main(): run_benchmark(custom_config) - if __name__ == '__main__': sys.exit(main())