Skip to content

Commit

Permalink
Add llama3 405b attention shapes. (#29)
Browse files Browse the repository at this point in the history
This PR adds the llama3 405b attention shapes that we see in the
sharktank export
(https://gist.github.com/KyleHerndon/a9c60ce93264d6ba7ec9e878c879f218).
We make sure the dynamic sequence length is always a multiple of 16
  • Loading branch information
saienduri authored Oct 28, 2024
1 parent 3f3c514 commit 982eb72
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions attentionbench/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def bert_attn_sweep(dtype: str) -> list[AttentionConfig]:
configs.append(AttentionConfig(B, M, N, K1, K2, dtype))
return configs

def llama3_405b_attn_sweep(dtype: str) -> list[AttentionConfig]:
configs = []
for M in [1024, 2048, 3072, 4096, 5120, 6144, 7168, 8192]:
K2 = M
configs.append(AttentionConfig(512, M, 128, 128, K2, dtype))
M += 128
return configs

def get_attention_configs() -> list[tuple[str, AttentionConfig]]:
configs: list[tuple[str, AttentionConfig]] = []
Expand All @@ -55,9 +62,12 @@ def get_attention_configs() -> list[tuple[str, AttentionConfig]]:
sdxl_configs += sdxl_unet_sweep("f8E4M3FNUZ")
bert_configs = bert_attn_sweep("f16")
bert_configs += bert_attn_sweep("f8E4M3FNUZ")
llama3_configs = llama3_405b_attn_sweep("f16")
llama3_configs += llama3_405b_attn_sweep("f8E4M3FNUZ")

configs += [("llm_sweep", x) for x in llm_configs]
configs += [("sdxl_unet_sweep", x) for x in sdxl_configs]
configs += [("bert_attn_sweep", x) for x in bert_configs]
configs += [("llm", x) for x in llm_configs]
configs += [("sdxl_unet", x) for x in sdxl_configs]
configs += [("bert", x) for x in bert_configs]
configs += [("llama3_405b", x) for x in llama3_configs]

return configs

0 comments on commit 982eb72

Please sign in to comment.