Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Benchmark for Rotary Decode Kernel + Performance Speed Up for Rotary Kernel #102

Open
wants to merge 4 commits into
base: main_perf
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 292 additions & 0 deletions benchmarks/benchmark_rotary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
import argparse
import math
import torch
import triton
from flash_attn.flash_attn_triton_amd.utils import (
MetaData,
input_helper,
varlen_input_helper,
)
from flash_attn.flash_attn_triton_amd.interface_torch import attention_decode

ARGS_TO_TORCH_DTYPE = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}

FUNCTIONS = {
"decode": attention_decode
}

def get_benchmark_configs(args, varlen=False):
"""
Returns benchmark configurations based on whether variable-length sequences are used.
"""
if args.custom_config:
hk = args.hq if not args.hk else args.hk
sk = args.sq if not args.sk else args.sk
return [(args.b, args.hq, hk, args.sq, sk)]
elif varlen:
return [
(2, 16, 4, 1024, 1024),
(8, 16, 2, 2048, 2048),
(4, 16, 8, 4096, 4096),
(2, 16, 4, 8192, 8192),
(2, 16, 8, 16384, 16384),
(2, 48, 12, 1024, 1024),
(2, 48, 24, 2048, 2048),
(2, 48, 8, 4096, 4096),
(2, 48, 4, 8192, 8192),
(2, 48, 2, 16384, 16384),
(2, 64, 32, 1024, 1024),
(4, 64, 16, 2048, 2048),
(4, 64, 8, 4096, 4096),
(4, 64, 32, 8192, 8192),
(4, 128, 16, 16384, 16384),
]
else:
return [
(16, 16, 16, 1024, 1024),
(8, 16, 16, 2048, 2048),
(4, 16, 16, 4096, 4096),
(1, 8, 8, 8192, 8192),
(1, 2, 2, 16384, 16384),
(2, 48, 48, 1024, 1024),
(2, 48, 48, 2048, 1024),
(1, 8, 8, 4096, 8192),
(1, 8, 8, 8192, 4096),
(2, 4, 4, 16384, 8192),
(2, 8, 8, 1989, 15344),
(4, 16, 16, 4097, 163),
(2, 16, 16, 8122, 2159),
(1, 16, 16, 16281, 7),
(2, 48, 48, 1021, 1020),
(2, 48, 48, 2001, 2048),
(2, 8, 8, 3996, 9639),
(2, 8, 8, 8181, 1021),
]

def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, layout, causal, rotary_fraction=0.0, rotary_interleaved=False):
flops_per_matmul = 0

q = torch.randn(
[BATCH, N_CTX_Q, HK, HQ // HK, D_HEAD],
device=device,
dtype=dtype,
requires_grad=False,
)
k = torch.randn(
[BATCH, N_CTX_K, HK, 1, D_HEAD],
device=device,
dtype=dtype,
requires_grad=False,
).expand(-1, -1, -1, HQ // HK, -1)
v = torch.randn(
[BATCH, N_CTX_K, HK, 1, D_HEAD],
device=device,
dtype=dtype,
requires_grad=False,
).expand(-1, -1, -1, HQ // HK, -1)
input_metadata = MetaData(sm_scale=1.3)
input_metadata.layout = "bsghd"

rotary_dim = math.floor(int(rotary_fraction * D_HEAD) / 16) * 16
if rotary_dim > 0:
angle = (
torch.rand(
max(N_CTX_K, N_CTX_Q), # NOTE: must be the max otherwise segfaults when the longer one accesses the shorter one
rotary_dim // 2,
device=device,
)
* 2
* math.pi
)
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)

# add rotary
input_metadata.need_rotary(sin, cos, rotary_interleaved=rotary_interleaved)

# Adjust flops calculation if needed
flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD

input_data = (q, k, v, input_metadata)

return input_data, flops_per_matmul

def run_benchmark(args, fn_name, fn, mode):
"""
Runs the benchmark for the provided function based on the provided arguments.
"""
print(f"Benchmarking {fn_name} in {mode} mode...")

dtype = ARGS_TO_TORCH_DTYPE[args.dtype]
head_size = args.d if args.d else 128
causal = args.causal
rotary_fraction = args.rotary_fraction
rotary_interleaved = args.rotary_interleaved
varlen = args.layout == "thd"
return_tflops = args.return_tflops
line_names = "TFLOPS" if return_tflops else "Time (ms)"

# Determine configurations
x_vals_list = get_benchmark_configs(args, varlen=varlen)

# Setup benchmark configurations
configs = [
triton.testing.Benchmark(
x_names=["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K"],
x_vals=x_vals_list,
line_arg="provider",
line_vals=["triton"],
line_names=[line_names],
styles=[("red", "-")],
ylabel="ms",
plot_name=f"benchmark-{fn_name}-d{head_size}-layout{args.layout}-mode{mode}",
args={
"D_HEAD": head_size,
"dtype": dtype,
"causal": causal,
"rotary_fraction": rotary_fraction,
"rotary_interleaved": rotary_interleaved,
"mode": mode,
},
)
]

@triton.testing.perf_report(configs)
def bench_function(
BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, rotary_fraction, rotary_interleaved, mode, provider, device="cuda"
):
warmup = 25
rep = 100
flops_per_matmul = 0

# generate function inputs
fn_inputs, flops_per_matmul = gen_fn_inputs(
fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, args.layout, causal, rotary_fraction, rotary_interleaved
)

# define the function to benchmark
if mode == "fwd":
benchmark_fn = lambda: fn(*fn_inputs)
total_flops = 2 * flops_per_matmul
elif mode == "bwd":
outputs = fn(*fn_inputs)
output = outputs[0]
grad_output = torch.randn_like(output)
benchmark_fn = lambda: output.backward(grad_output, retain_graph=True)
total_flops = 2 * flops_per_matmul * 2.5
else:
raise ValueError("Unsupported mode. Choose 'fwd' or 'bwd'.")

if causal:
total_flops *= 0.5

# Run the benchmark
ms = triton.testing.do_bench(benchmark_fn, warmup=warmup, rep=rep, return_mode="median")

if return_tflops:
return total_flops / ms * 1e-9
else:
return ms

bench_function.run(save_path=".", print_data=True)

def supported_layouts():
"""
Returns a string describing the supported layouts.
"""
return (
"bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]\n"
"bshd: Q, K, V are individual tensors of [batch, seqlen_q/k, num_heads, head_size]\n"
"thd: Q, K, V are individual tensors of [total_q/k, num_heads, head_size]\n"
'This layout is sometimes called "varlen" or "grouped" layout.'
)

def parse_args():
"""
Parses command-line arguments.
"""
parser = argparse.ArgumentParser(
prog="Benchmark FlashAttention",
allow_abbrev=False,
)
parser.add_argument("-b", type=int, default=0)
parser.add_argument("-hq", type=int, default=0)
parser.add_argument("-hk", type=int, default=0)
parser.add_argument("-sq", type=int, default=0)
parser.add_argument("-sk", type=int, default=0)
parser.add_argument(
"-equal_seqlens",
action="store_true",
default=False,
help="If specified, each context within the thd layout has same seqlen as sq and sk",
)
parser.add_argument("-d", type=int, default=0)
parser.add_argument("-causal", action="store_true", default=False)
parser.add_argument("-rotary_fraction", type=float, default=0.0)
parser.add_argument("-rotary_interleaved", action="store_true", default=False)
parser.add_argument("-dtype", default="fp16")
parser.add_argument("-return_tflops", action="store_true", default=False)
parser.add_argument(
"-layout",
type=str,
default="bhsd",
help=supported_layouts(),
)
parser.add_argument(
"-benchmark_fn",
type=str,
nargs="*",
choices=FUNCTIONS.keys(),
help="Function(s) to benchmark: prefill, decode, or both",
)
parser.add_argument(
"-mode",
type=str,
nargs='*',
default=["fwd", "bwd"],
choices=["fwd", "bwd"],
help="Mode(s) to run: 'fwd' for forward pass, 'bwd' for backward pass",
)
return parser.parse_args()

def main():
"""
Main function to run benchmarks.
"""
args = parse_args()

# Validate arguments
assert (
args.layout == "thd" or not args.equal_seqlens
), "Equal sequence lengths arg must be used with the thd layout."
args.custom_config = False
if args.b or args.hq or args.hk or args.sq or args.sk or args.d:
args.custom_config = True
assert args.b and args.hq and args.sq and args.d, (
"If custom config is specified, please provide all of batch, "
"number of Q heads, Q sequence length, and head size."
)
assert args.dtype in ARGS_TO_TORCH_DTYPE, "Only fp16, bf16 and fp32 types currently supported."

# determine the functions to benchmark
if args.benchmark_fn is None or len(args.benchmark_fn) == 0:
bench_fn_list = FUNCTIONS.keys()
else:
bench_fn_list = args.benchmark_fn

# benchmark functions
for fn_name in bench_fn_list:
if fn_name not in FUNCTIONS:
raise ValueError(f"Invalid benchmark function specified: {fn_name}")
for mode in args.mode:
if fn_name == "decode" and mode == "bwd":
print(f"Decode kernel doesnot have a backward pass")
continue
run_benchmark(args, fn_name, FUNCTIONS[fn_name], mode)

if __name__ == "__main__":
main()
39 changes: 39 additions & 0 deletions flash_attn/flash_attn_triton_amd/interface_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from .fwd_ref import attention_forward_pytorch_ref_impl
from .bwd_ref import attention_backward_pytorch_ref_impl
from .utils import MetaData, get_shape_from_layout, DEBUG
from einops import rearrange, repeat
from flash_attn.layers.rotary import apply_rotary_emb

USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes')

Expand Down Expand Up @@ -516,6 +518,43 @@ def fwd_kvcache(
batch, _ , nheads_q, _= q.shape
metadata.need_alibi(alibi_slopes, batch, nheads_q)

# rotary boolean
apply_rotary = torch.is_tensor(rotary_cos) and torch.is_tensor(rotary_sin)
if apply_rotary:
metadata.need_rotary(rotary_sin, rotary_cos, rotary_interleaved)

# Rotary Embedding Implementation
if apply_rotary:
if metadata.causal: # NOTE: when support is added. Add `or metadata.local`
q_ro = apply_rotary_emb(
q,
metadata.rotary_cos,
metadata.rotary_sin,
seqlen_offsets=metadata.cache_seqlens,
interleaved=metadata.rotary_interleaved,
)
else:
q_ro = rearrange(
apply_rotary_emb(
rearrange(q, "b s h d -> b 1 (s h) d"),
metadata.rotary_cos,
metadata.rotary_sin,
seqlen_offsets=metadata.cache_seqlens,
interleaved=metadata.rotary_interleaved,
),
"b 1 (s h) d -> b s h d",
s=metadata.max_seqlens_q,
)
k_ro = apply_rotary_emb(
metadata.k_new,
metadata.rotary_cos,
metadata.rotary_sin,
seqlen_offsets=metadata.cache_seqlens,
interleaved=metadata.rotary_interleaved,
)

q, metadata.k_new = q_ro.to(q.dtype), k_ro.to(q.dtype)

# launch kernel
# TODO: pass output as an arg. Maybe we are copying output which is causing slow down
output, softmax_lse = attention_decode_forward_triton_impl(
Expand Down
30 changes: 29 additions & 1 deletion flash_attn/flash_attn_triton_amd/interface_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from .fwd_prefill import attention_prefill_forward_triton_impl
from .bwd_prefill import attention_prefill_backward_triton_impl
from .fwd_decode import attention_decode_forward_triton_impl

from einops import rearrange, repeat, parse_shape
from flash_attn.layers.rotary import apply_rotary_emb

class _attention_prefill(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -78,6 +79,33 @@ def backward(ctx, do, *args):
class _attention_decode(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, metadata):
if metadata.rotary_cos is not None:
q_original_shape = parse_shape(q, 'b s g h d')
if metadata.causal: # NOTE: when local support is added. Add `or metadata.local`
q_ro = apply_rotary_emb(
q,
metadata.rotary_cos,
metadata.rotary_sin,
seqlen_offsets=metadata.cache_seqlens if metadata.cache_seqlens else 0,
interleaved=metadata.rotary_interleaved,
)
else:
q_ro = rearrange(
apply_rotary_emb(
rearrange(q, "b s g h d -> b 1 (s g h) d"),
metadata.rotary_cos,
metadata.rotary_sin,
seqlen_offsets=metadata.cache_seqlens if metadata.cache_seqlens else 0,
interleaved=metadata.rotary_interleaved,
),
"b 1 (s g h) d -> b s g h d",
s=q_original_shape['s'],
g=q_original_shape['g'],
h=q_original_shape['h']
)

q, metadata.k_new = q_ro.to(q.dtype), None

output, softmax_lse = attention_decode_forward_triton_impl(
q,
k,
Expand Down
Loading
Loading