diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index ed4403a17..efbf3f31e 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -82,14 +82,11 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri k_offs_n = None k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) - if IS_FP8: - k = (k.to(tl.float16) / k_scale.to(tl.float16)).to(k.type.element_ty) if PRE_LOAD_V: # We can use the same offsets as k, just with dims transposed. v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) - if IS_FP8: - v = (v.to(tl.float16) / v_scale.to(tl.float16)).to(v.type.element_ty) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n @@ -107,7 +104,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri qk = tl.where(mask, qk, float("-inf")) # -- compute qk ---- - qk += tl.dot(q, k) + qk += tl.dot(q.to(tl.float16), k.to(tl.float16)) qk_scaled = qk * SM_SCALE if IS_FP8: qk_scaled = qk_scaled * q_scale * k_scale # descale qk after matmul if quantized @@ -173,17 +170,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri acc = acc * alpha[:, None] if not PRE_LOAD_V: v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) - if IS_FP8: - v = (v.to(tl.float16) / v_scale.to(tl.float16)).to(v.type.element_ty) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij if IS_FP8: - p_scale = 1 # NOTE: for proper scaling set this = tl.max(p) (increases error) - p_scaled = (p / p_scale) - acc += tl.dot(p_scaled.to(v.type.element_ty), v.to(v.type.element_ty)).to(tl.float32) * v_scale * p_scale # if you want to use p_scaled: tl.dot(p_scaled.to(v.type.element_ty), v.to(v.type.element_ty)) * v_scale * p_scale + acc += tl.dot(p.to(v.type.element_ty), v.to(v.type.element_ty)).to(tl.float32) * v_scale # if you want to use p_scaled: tl.dot(p_scaled.to(v.type.element_ty), v.to(v.type.element_ty)) * v_scale * p_scale else: # NOTE: if you make the below operation tl.float16 + set FLASH_ATTENTION_TRITON_AMD_REMOVE_QUANT_SCALE=1. It passes. --> acc += tl.dot(p.to(tl.float16), v.to(tl.float16)) PASSES acc += tl.dot(p.to(v.type.element_ty), v).to(tl.float32) @@ -416,8 +409,6 @@ def attn_fwd(Q, K, V, bias, Q_SCALE, K_SCALE, V_SCALE, stride_qscale_z, stride_k # if IS FP8 get q_scale and quantize if IS_FP8: q_scale = tl.load(Q_SCALE + off_z*stride_qscale_z + off_h_q) - q = (q.to(tl.float16) / q_scale.to(tl.float16)).to(q.type.element_ty) # scale q by q_scale - k_scale = tl.load(K_SCALE + off_z*stride_kvscale_z + off_h_k) v_scale = tl.load(V_SCALE + off_z*stride_kvscale_z + off_h_k) else: @@ -570,12 +561,16 @@ def attention_prefill_forward_triton_impl( is_fp8 = check_is_fp8(q) - # if qkv are fp8, then find scaling factor for quantization - q_scale, k_scale, v_scale = create_scale_tensors(q, k, v, SCALE_PER_HEAD=True, layout=layout) # TODO: if SCALE_PER_HEAD: within the kernel itself just compute qkv_scale = tl.max(q or k or v) - q_scale_stride_z = q_scale.stride(0) - kv_scale_stride_z = k_scale.stride(0) - - # import pdb; pdb.set_trace() + if is_fp8: + # if qkv are fp8, then find scaling factor for quantization + q_scale, k_scale, v_scale = create_scale_tensors(q, k, v, SCALE_PER_HEAD=True, layout=layout) # TODO: if SCALE_PER_HEAD: within the kernel itself just compute qkv_scale = tl.max(q or k or v) + q_scale_stride_z = q_scale.stride(0) + kv_scale_stride_z = k_scale.stride(0) + q = (q.to(torch.float32) / q_scale.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, q.shape[-2], q.shape[-1])).to(q.dtype) + k = (k.to(torch.float32) / k_scale.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, k.shape[-2], k.shape[-1])).to(k.dtype) + v = (v.to(torch.float32) / v_scale.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, v.shape[-2], v.shape[-1])).to(v.dtype) + else: + q_scale = k_scale = v_scale = 1 if DEBUG: print() @@ -661,8 +656,6 @@ def attention_prefill_forward_triton_impl( else: alibi_strides = (0, 0) - # import pdb; pdb.set_trace() - attn_fwd[grid](q, k, v, bias, q_scale, k_scale, v_scale, q_scale_stride_z, kv_scale_stride_z, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, diff --git a/flash_attn/flash_attn_triton_amd/fwd_ref.py b/flash_attn/flash_attn_triton_amd/fwd_ref.py index 771b3551a..13465347a 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/fwd_ref.py @@ -4,20 +4,8 @@ DEBUG_CORE = False -def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2): - is_fp8 = check_is_fp8(q) - if is_fp8: - # if qkv are fp8, then find scaling factor for quantization - q_scale, k_scale, v_scale = create_scale_tensors(q, k, v, SCALE_PER_HEAD=True, layout=layout) # TODO: if SCALE_PER_HEAD: within the kernel itself just compute qkv_scale = tl.max(q or k or v) - q_scale_stride_z = q_scale.stride(0) - kv_scale_stride_z = k_scale.stride(0) - - # scale qkv tensors if FP8 - q = q / q_scale - k = k / k_scale - v = v / v_scale - else: - q_scale = k_scale = v_scale = 1 +def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2, is_fp8): + if DEBUG_CORE: print() print("attention_forward_core_ref_impl") @@ -32,9 +20,6 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p print("use_exp2:", use_exp2) print('layout:', layout) print('is_fp8:', is_fp8) - print('q_scale:', q_scale) - print('k_scale:', k_scale) - print('v_scale:', v_scale) # cast to float32 q = q.to(torch.float32) @@ -42,10 +27,7 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p v = v.to(torch.float32) # Compute attention scores - if is_fp8: - attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32)) * q_scale * v_scale - else: - attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32)) + attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32)) if DEBUG_CORE: print("attention_scores:", attention_scores, attention_scores.shape) @@ -150,10 +132,7 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p print("softmax_lse:", softmax_lse, softmax_lse.shape) # Compute output - if is_fp8: - o = torch.matmul(p, v.to(torch.float32)) * v_scale - else: - o = torch.matmul(p, v) + o = torch.matmul(p, v) if DEBUG_CORE: print("o:", o, o.shape) @@ -164,7 +143,7 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p return o, softmax_lse, sd_mask -def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2): +def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2, is_fp8): """Compute reference output and softmax_lse using PyTorch's built-in function""" # Ensure the layout is 'bhsd' @@ -200,7 +179,7 @@ def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout # Call the core attention function o, softmax_lse, sd_mask = attention_forward_core_ref_impl( - q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2 + q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2, is_fp8 ) if group_size != 1: @@ -238,7 +217,8 @@ def attention_varlen_forward_pytorch_ref_impl( dropout_p, philox_seed, philox_offset, - use_exp2 + use_exp2, + is_fp8 ): # Ensure the layout is 'thd' if layout != 'thd': @@ -302,7 +282,7 @@ def attention_varlen_forward_pytorch_ref_impl( v_i = v_i.reshape(nheads_k, seqlen_k, head_dim) # Call the core attention function for this sequence - o_i, softmax_lse_i, sd_mask_i = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2) + o_i, softmax_lse_i, sd_mask_i = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2, is_fp8) # Reshape outputs back to original dimensions if group_size != 1: @@ -365,6 +345,12 @@ def attention_forward_pytorch_ref_impl( print("philox_offset:", philox_offset) print("use_exp2:", use_exp2) + is_fp8 = check_is_fp8(q) + + # if is fp8 upcast to fp32 for torch ops to be supported + if is_fp8: + q, k, v = q.to(torch.float32), k.to(torch.float32), v.to(torch.float32) + # compute reference if layout == "thd": o_ref, softmax_lse_ref, sd_mask_ref = attention_varlen_forward_pytorch_ref_impl( @@ -382,6 +368,7 @@ def attention_forward_pytorch_ref_impl( philox_seed, philox_offset, use_exp2, + is_fp8 ) else: o_ref, softmax_lse_ref, sd_mask_ref = attention_vanilla_forward_pytorch_ref_impl(q.clone(), @@ -393,7 +380,8 @@ def attention_forward_pytorch_ref_impl( dropout_p, philox_seed, philox_offset, - use_exp2) + use_exp2, + is_fp8) if DEBUG: print() diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index ecc9a458d..2e9920d5c 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -390,7 +390,7 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropou if layout == "thd": q, k, v, metadata = varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) else: - q, k, v, q_fp32, k_fp32, v_fp32, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT) + q, k, v, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT) if DEBUG_INPUT: output_triton = torch.zeros_like(q).contiguous() else: @@ -436,7 +436,7 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropou metadata.use_exp2) output_ref, softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( - q_fp32, k_fp32, v_fp32, + q, k, v, metadata.sm_scale, causal, layout, diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 6e50a634f..a148192bc 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -167,7 +167,6 @@ def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device="cud v = torch.randn(k_tensor_shape, dtype=torch.float32, device=device, requires_grad=True) q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) - q_fp32, k_fp32, v_fp32 = q.to(torch.float32), k.to(torch.float32), v.to(torch.float32) if DEBUG_INPUT: sm_scale = 1 @@ -177,7 +176,7 @@ def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device="cud input_metadata.max_seqlens_q = N_CTX_Q input_metadata.max_seqlens_k = N_CTX_K input_metadata.layout = layout - return q, k, v, q_fp32, k_fp32, v_fp32, input_metadata + return q, k, v, input_metadata def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device="cuda", equal_seqlens=False, DEBUG_INPUT=False):