Skip to content

Commit

Permalink
perf: improved block tiling config for rotary
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkranias-amd committed Nov 22, 2024
1 parent e02ceee commit 5ac2b83
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion flash_attn/ops/triton/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def apply_rotary(
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
)
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 128 else 4)
BLOCK_M = 2 if interleaved else (4 if rotary_dim <= 128 else 2)

# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
Expand Down

0 comments on commit 5ac2b83

Please sign in to comment.