From dc1271ad750a1291dd4e14f3ed05304598f4778c Mon Sep 17 00:00:00 2001 From: Alex Kranias Date: Wed, 13 Nov 2024 15:51:40 -0600 Subject: [PATCH 1/2] feat: added rotary support in kvcache --- .../flash_attn_triton_amd/interface_fa.py | 39 +++++++++++++++++++ flash_attn/flash_attn_triton_amd/utils.py | 10 +++++ tests/test_flash_attn_triton_amd.py | 8 ++-- 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index 59a306d5d..f93f1ee69 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -6,6 +6,8 @@ from .fwd_ref import attention_forward_pytorch_ref_impl from .bwd_ref import attention_backward_pytorch_ref_impl from .utils import MetaData, get_shape_from_layout, DEBUG +from einops import rearrange, repeat +from flash_attn.layers.rotary import apply_rotary_emb USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes') @@ -516,6 +518,43 @@ def fwd_kvcache( batch, _ , nheads_q, _= q.shape metadata.need_alibi(alibi_slopes, batch, nheads_q) + # rotary boolean + apply_rotary = torch.is_tensor(rotary_cos) and torch.is_tensor(rotary_sin) + if apply_rotary: + metadata.need_rotary(rotary_sin, rotary_cos, rotary_interleaved) + + # Rotary Embedding Implementation + if apply_rotary: + if metadata.causal: # NOTE: when support is addede. Add `or metadata.local` + q_ro = apply_rotary_emb( + q, + metadata.rotary_cos, + metadata.rotary_sin, + seqlen_offsets=metadata.cache_seqlens, + interleaved=metadata.rotary_interleaved, + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + metadata.rotary_cos, + metadata.rotary_sin, + seqlen_offsets=metadata.cache_seqlens, + interleaved=metadata.rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=metadata.max_seqlens_q, + ) + k_ro = apply_rotary_emb( + metadata.k_new, + metadata.rotary_cos, + metadata.rotary_sin, + seqlen_offsets=metadata.cache_seqlens, + interleaved=metadata.rotary_interleaved, + ) + + q, metadata.k_new = q_ro.to(q.dtype), k_ro.to(q.dtype) + # launch kernel # TODO: pass output as an arg. Maybe we are copying output which is causing slow down output, softmax_lse = attention_decode_forward_triton_impl( diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 530455063..7d4321818 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -27,6 +27,10 @@ class MetaData(): dropout_p, return_scores= 0.0, False # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW. use_exp2 = False + rotary_sin = None + rotary_cos = None + rotary_interleaved = False + rotary_conjunction = False def __repr__(self) -> str: @@ -85,6 +89,12 @@ def need_alibi(self, alibi_slopes, batch, nheads): def need_causal(self): self.causal = True + def need_rotary(self, sin, cos, rotary_interleaved, rotary_conjunction=False): + self.rotary_sin = sin + self.rotary_cos = cos + self.rotary_interleaved = rotary_interleaved + self.rotary_conjunction = rotary_conjunction + def need_dropout(self, dropout_p, return_scores): self.dropout_p = dropout_p self.return_scores = return_scores diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index d64246f95..fc29533e2 100644 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -1851,9 +1851,10 @@ def test_flash_attn_varlen_causal( @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) # @pytest.mark.parametrize("rotary_interleaved", [False, True]) -@pytest.mark.parametrize("rotary_interleaved", [False]) +@pytest.mark.parametrize("rotary_interleaved", [True]) # @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) -@pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize("rotary_fraction", [0.5, 1.0]) +# @pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) # @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) @pytest.mark.parametrize("paged_kv_block_size", [None]) @@ -1907,9 +1908,6 @@ def test_flash_attn_kvcache( if local == True: pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") - - if rotary_interleaved == True or rotary_fraction > 0.0: - pytest.skip("rotary embedding not supported on AMD's Triton Backend yet") if has_leftpad == True: pytest.skip("cache_leftpad not supported on AMD's Triton Backend yet") From e02ceeeeeb4027cfbfda5108a2b2b3063d812fcd Mon Sep 17 00:00:00 2001 From: Alex Kranias Date: Wed, 13 Nov 2024 16:11:01 -0600 Subject: [PATCH 2/2] confirmed non-fused rotary passes all tests --- flash_attn/flash_attn_triton_amd/interface_fa.py | 2 +- tests/test_flash_attn_triton_amd.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index f93f1ee69..f2aacc963 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -525,7 +525,7 @@ def fwd_kvcache( # Rotary Embedding Implementation if apply_rotary: - if metadata.causal: # NOTE: when support is addede. Add `or metadata.local` + if metadata.causal: # NOTE: when support is added. Add `or metadata.local` q_ro = apply_rotary_emb( q, metadata.rotary_cos, diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index fc29533e2..f7d0f1728 100644 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -1850,10 +1850,9 @@ def test_flash_attn_varlen_causal( # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) -# @pytest.mark.parametrize("rotary_interleaved", [False, True]) -@pytest.mark.parametrize("rotary_interleaved", [True]) -# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) -@pytest.mark.parametrize("rotary_fraction", [0.5, 1.0]) +@pytest.mark.parametrize("rotary_interleaved", [False, True]) +# @pytest.mark.parametrize("rotary_interleaved", [False]) +@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) # @pytest.mark.parametrize("paged_kv_block_size", [256, 512])