Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for Rotary Positional Embeddings #99

Merged
merged 2 commits into from
Nov 20, 2024

Conversation

alexkranias-amd
Copy link

@alexkranias-amd alexkranias-amd commented Nov 13, 2024

Motivation

Original Paper: RoFormer: Enhanced Transformer with Rotary Position Embedding

Rotary Positional Embeddings (RoPEs) are a common positional embedding type used in many transformer models today.

RoPEs work by applying a unique rotation transformation to the vectors that represent each token within our q and k tensors based on each token's respective position in the sequence $$m$$.

To compute attention, we must first compute $$\text{matmul(}Q \text{,} ~ K^T \text{)}$$. This effectively is taking the dot product between the vector embeddings of tokens in $$Q$$ and $$K^T$$. Given two tokens at positions $$i$$ and $$j$$, the closer $$i$$ and $$j$$ are to each other, then their vector embeddings will end up getting rotated roughly the same amount, and the dot product between these two token embedding vectors will be largely unchanged. However, the further away these tokens are from each other, the more the transformation applied to these two vector embeddings diverges, which causes the dot product to decay. As the dot product decays, so does the attention weighting applied between the two tokens, and likewise this effectively leads the model to learning that for a single token the tokens near it should be paid more attention to than the tokens much further away.

Dot Product Decay

A more detailed explanation

Fundamentally RoPEs work by dividing the embedding space of our q and k vectors (the $$\text{head}$$ _ $$\text{dim}$$) into many chunks of two. Each 2-dimensional chunk can be thought of as a vector subcomponent of q and k projected on a 2-dimensional plane that exists within the higher dimensional space of the q and k embedding. RoPE "rotates" the planar chunks of our q and k vectors uniquely based on the index of the token in the sequence. Each "chunk" is rotated some unique amount $$\theta_{m, d/2}$$ based on the index of the token in the sequence $$m$$, and the dimension $$d$$ of the subcomponents of q and k being rotated.

RoPE Implementation Details

@alexkranias-amd alexkranias-amd self-assigned this Nov 13, 2024
@alexkranias-amd alexkranias-amd changed the title Added Rotary Embedding (non-fused kernel) Added Rotary Embedding (both non-fused and fused kernel) Nov 14, 2024
@alexkranias-amd alexkranias-amd changed the title Added Rotary Embedding (both non-fused and fused kernel) Added Support for Rotary Positional Embeddings (both non-fused and fused kernel) Nov 14, 2024
tests/test_flash_attn_triton_amd.py Outdated Show resolved Hide resolved
tests/test_flash_attn_triton_amd.py Outdated Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/fwd_decode.py Outdated Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/fwd_decode.py Outdated Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/fwd_decode.py Outdated Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/fwd_decode.py Outdated Show resolved Hide resolved
@alexkranias-amd alexkranias-amd force-pushed the alexkranias/rotary_embedding_main_perf branch from 2a14c01 to e02ceee Compare November 19, 2024 22:56
@alexkranias-amd alexkranias-amd changed the title Added Support for Rotary Positional Embeddings (both non-fused and fused kernel) Added Support for Rotary Positional Embeddings Nov 19, 2024
@micmelesse micmelesse merged commit 1fcc51b into main_perf Nov 20, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants