Skip to content

Commit

Permalink
feat: added benchmark rotary script
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkranias-amd committed Nov 22, 2024
1 parent 5ac2b83 commit 76698a3
Show file tree
Hide file tree
Showing 2 changed files with 321 additions and 1 deletion.
292 changes: 292 additions & 0 deletions benchmarks/benchmark_rotary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
import argparse
import math
import torch
import triton
from flash_attn.flash_attn_triton_amd.utils import (
MetaData,
input_helper,
varlen_input_helper,
)
from flash_attn.flash_attn_triton_amd.interface_torch import attention_decode

ARGS_TO_TORCH_DTYPE = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}

FUNCTIONS = {
"decode": attention_decode
}

def get_benchmark_configs(args, varlen=False):
"""
Returns benchmark configurations based on whether variable-length sequences are used.
"""
if args.custom_config:
hk = args.hq if not args.hk else args.hk
sk = args.sq if not args.sk else args.sk
return [(args.b, args.hq, hk, args.sq, sk)]
elif varlen:
return [
(2, 16, 4, 1024, 1024),
(8, 16, 2, 2048, 2048),
(4, 16, 8, 4096, 4096),
(2, 16, 4, 8192, 8192),
(2, 16, 8, 16384, 16384),
(2, 48, 12, 1024, 1024),
(2, 48, 24, 2048, 2048),
(2, 48, 8, 4096, 4096),
(2, 48, 4, 8192, 8192),
(2, 48, 2, 16384, 16384),
(2, 64, 32, 1024, 1024),
(4, 64, 16, 2048, 2048),
(4, 64, 8, 4096, 4096),
(4, 64, 32, 8192, 8192),
(4, 128, 16, 16384, 16384),
]
else:
return [
(16, 16, 16, 1024, 1024),
(8, 16, 16, 2048, 2048),
(4, 16, 16, 4096, 4096),
(1, 8, 8, 8192, 8192),
(1, 2, 2, 16384, 16384),
(2, 48, 48, 1024, 1024),
(2, 48, 48, 2048, 1024),
(1, 8, 8, 4096, 8192),
(1, 8, 8, 8192, 4096),
(2, 4, 4, 16384, 8192),
(2, 8, 8, 1989, 15344),
(4, 16, 16, 4097, 163),
(2, 16, 16, 8122, 2159),
(1, 16, 16, 16281, 7),
(2, 48, 48, 1021, 1020),
(2, 48, 48, 2001, 2048),
(2, 8, 8, 3996, 9639),
(2, 8, 8, 8181, 1021),
]

def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, layout, causal, rotary_fraction=0.0, rotary_interleaved=False):
flops_per_matmul = 0

q = torch.randn(
[BATCH, N_CTX_Q, HK, HQ // HK, D_HEAD],
device=device,
dtype=dtype,
requires_grad=False,
)
k = torch.randn(
[BATCH, N_CTX_K, HK, 1, D_HEAD],
device=device,
dtype=dtype,
requires_grad=False,
).expand(-1, -1, -1, HQ // HK, -1)
v = torch.randn(
[BATCH, N_CTX_K, HK, 1, D_HEAD],
device=device,
dtype=dtype,
requires_grad=False,
).expand(-1, -1, -1, HQ // HK, -1)
input_metadata = MetaData(sm_scale=1.3)
input_metadata.layout = "bsghd"

rotary_dim = math.floor(int(rotary_fraction * D_HEAD) / 16) * 16
if rotary_dim > 0:
angle = (
torch.rand(
max(N_CTX_K, N_CTX_Q), # NOTE: must be the max otherwise segfaults when the longer one accesses the shorter one
rotary_dim // 2,
device=device,
)
* 2
* math.pi
)
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)

# add rotary
input_metadata.need_rotary(sin, cos, rotary_interleaved=rotary_interleaved)

# Adjust flops calculation if needed
flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD

input_data = (q, k, v, input_metadata)

return input_data, flops_per_matmul

def run_benchmark(args, fn_name, fn, mode):
"""
Runs the benchmark for the provided function based on the provided arguments.
"""
print(f"Benchmarking {fn_name} in {mode} mode...")

dtype = ARGS_TO_TORCH_DTYPE[args.dtype]
head_size = args.d if args.d else 128
causal = args.causal
rotary_fraction = args.rotary_fraction
rotary_interleaved = args.rotary_interleaved
varlen = args.layout == "thd"
return_tflops = args.return_tflops
line_names = "TFLOPS" if return_tflops else "Time (ms)"

# Determine configurations
x_vals_list = get_benchmark_configs(args, varlen=varlen)

# Setup benchmark configurations
configs = [
triton.testing.Benchmark(
x_names=["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K"],
x_vals=x_vals_list,
line_arg="provider",
line_vals=["triton"],
line_names=[line_names],
styles=[("red", "-")],
ylabel="ms",
plot_name=f"benchmark-{fn_name}-d{head_size}-layout{args.layout}-mode{mode}",
args={
"D_HEAD": head_size,
"dtype": dtype,
"causal": causal,
"rotary_fraction": rotary_fraction,
"rotary_interleaved": rotary_interleaved,
"mode": mode,
},
)
]

@triton.testing.perf_report(configs)
def bench_function(
BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, rotary_fraction, rotary_interleaved, mode, provider, device="cuda"
):
warmup = 25
rep = 100
flops_per_matmul = 0

# generate function inputs
fn_inputs, flops_per_matmul = gen_fn_inputs(
fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, args.layout, causal, rotary_fraction, rotary_interleaved
)

# define the function to benchmark
if mode == "fwd":
benchmark_fn = lambda: fn(*fn_inputs)
total_flops = 2 * flops_per_matmul
elif mode == "bwd":
outputs = fn(*fn_inputs)
output = outputs[0]
grad_output = torch.randn_like(output)
benchmark_fn = lambda: output.backward(grad_output, retain_graph=True)
total_flops = 2 * flops_per_matmul * 2.5
else:
raise ValueError("Unsupported mode. Choose 'fwd' or 'bwd'.")

if causal:
total_flops *= 0.5

# Run the benchmark
ms = triton.testing.do_bench(benchmark_fn, warmup=warmup, rep=rep, return_mode="median")

if return_tflops:
return total_flops / ms * 1e-9
else:
return ms

bench_function.run(save_path=".", print_data=True)

def supported_layouts():
"""
Returns a string describing the supported layouts.
"""
return (
"bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]\n"
"bshd: Q, K, V are individual tensors of [batch, seqlen_q/k, num_heads, head_size]\n"
"thd: Q, K, V are individual tensors of [total_q/k, num_heads, head_size]\n"
'This layout is sometimes called "varlen" or "grouped" layout.'
)

def parse_args():
"""
Parses command-line arguments.
"""
parser = argparse.ArgumentParser(
prog="Benchmark FlashAttention",
allow_abbrev=False,
)
parser.add_argument("-b", type=int, default=0)
parser.add_argument("-hq", type=int, default=0)
parser.add_argument("-hk", type=int, default=0)
parser.add_argument("-sq", type=int, default=0)
parser.add_argument("-sk", type=int, default=0)
parser.add_argument(
"-equal_seqlens",
action="store_true",
default=False,
help="If specified, each context within the thd layout has same seqlen as sq and sk",
)
parser.add_argument("-d", type=int, default=0)
parser.add_argument("-causal", action="store_true", default=False)
parser.add_argument("-rotary_fraction", type=float, default=0.0)
parser.add_argument("-rotary_interleaved", action="store_true", default=False)
parser.add_argument("-dtype", default="fp16")
parser.add_argument("-return_tflops", action="store_true", default=False)
parser.add_argument(
"-layout",
type=str,
default="bhsd",
help=supported_layouts(),
)
parser.add_argument(
"-benchmark_fn",
type=str,
nargs="*",
choices=FUNCTIONS.keys(),
help="Function(s) to benchmark: prefill, decode, or both",
)
parser.add_argument(
"-mode",
type=str,
nargs='*',
default=["fwd", "bwd"],
choices=["fwd", "bwd"],
help="Mode(s) to run: 'fwd' for forward pass, 'bwd' for backward pass",
)
return parser.parse_args()

def main():
"""
Main function to run benchmarks.
"""
args = parse_args()

# Validate arguments
assert (
args.layout == "thd" or not args.equal_seqlens
), "Equal sequence lengths arg must be used with the thd layout."
args.custom_config = False
if args.b or args.hq or args.hk or args.sq or args.sk or args.d:
args.custom_config = True
assert args.b and args.hq and args.sq and args.d, (
"If custom config is specified, please provide all of batch, "
"number of Q heads, Q sequence length, and head size."
)
assert args.dtype in ARGS_TO_TORCH_DTYPE, "Only fp16, bf16 and fp32 types currently supported."

# determine the functions to benchmark
if args.benchmark_fn is None or len(args.benchmark_fn) == 0:
bench_fn_list = FUNCTIONS.keys()
else:
bench_fn_list = args.benchmark_fn

# benchmark functions
for fn_name in bench_fn_list:
if fn_name not in FUNCTIONS:
raise ValueError(f"Invalid benchmark function specified: {fn_name}")
for mode in args.mode:
if fn_name == "decode" and mode == "bwd":
print(f"Decode kernel doesnot have a backward pass")
continue
run_benchmark(args, fn_name, FUNCTIONS[fn_name], mode)

if __name__ == "__main__":
main()
30 changes: 29 additions & 1 deletion flash_attn/flash_attn_triton_amd/interface_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from .fwd_prefill import attention_prefill_forward_triton_impl
from .bwd_prefill import attention_prefill_backward_triton_impl
from .fwd_decode import attention_decode_forward_triton_impl

from einops import rearrange, repeat, parse_shape
from flash_attn.layers.rotary import apply_rotary_emb

class _attention_prefill(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -78,6 +79,33 @@ def backward(ctx, do, *args):
class _attention_decode(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, metadata):
if metadata.rotary_cos is not None:
q_original_shape = parse_shape(q, 'b s g h d')
if metadata.causal: # NOTE: when local support is added. Add `or metadata.local`
q_ro = apply_rotary_emb(
q,
metadata.rotary_cos,
metadata.rotary_sin,
seqlen_offsets=metadata.cache_seqlens if metadata.cache_seqlens else 0,
interleaved=metadata.rotary_interleaved,
)
else:
q_ro = rearrange(
apply_rotary_emb(
rearrange(q, "b s g h d -> b 1 (s g h) d"),
metadata.rotary_cos,
metadata.rotary_sin,
seqlen_offsets=metadata.cache_seqlens if metadata.cache_seqlens else 0,
interleaved=metadata.rotary_interleaved,
),
"b 1 (s g h) d -> b s g h d",
s=q_original_shape['s'],
g=q_original_shape['g'],
h=q_original_shape['h']
)

q, metadata.k_new = q_ro.to(q.dtype), None

output, softmax_lse = attention_decode_forward_triton_impl(
q,
k,
Expand Down

0 comments on commit 76698a3

Please sign in to comment.