diff --git a/attentionbench/problems.py b/attentionbench/problems.py index f88dc9f..faca8e4 100644 --- a/attentionbench/problems.py +++ b/attentionbench/problems.py @@ -46,6 +46,13 @@ def bert_attn_sweep(dtype: str) -> list[AttentionConfig]: configs.append(AttentionConfig(B, M, N, K1, K2, dtype)) return configs +def llama3_405b_attn_sweep(dtype: str) -> list[AttentionConfig]: + configs = [] + for M in [1024, 2048, 3072, 4096, 5120, 6144, 7168, 8192]: + K2 = M + configs.append(AttentionConfig(512, M, 128, 128, K2, dtype)) + M += 128 + return configs def get_attention_configs() -> list[tuple[str, AttentionConfig]]: configs: list[tuple[str, AttentionConfig]] = [] @@ -55,9 +62,12 @@ def get_attention_configs() -> list[tuple[str, AttentionConfig]]: sdxl_configs += sdxl_unet_sweep("f8E4M3FNUZ") bert_configs = bert_attn_sweep("f16") bert_configs += bert_attn_sweep("f8E4M3FNUZ") + llama3_configs = llama3_405b_attn_sweep("f16") + llama3_configs += llama3_405b_attn_sweep("f8E4M3FNUZ") - configs += [("llm_sweep", x) for x in llm_configs] - configs += [("sdxl_unet_sweep", x) for x in sdxl_configs] - configs += [("bert_attn_sweep", x) for x in bert_configs] + configs += [("llm", x) for x in llm_configs] + configs += [("sdxl_unet", x) for x in sdxl_configs] + configs += [("bert", x) for x in bert_configs] + configs += [("llama3_405b", x) for x in llama3_configs] return configs