Skip to content

Commit

Permalink
Add square shapes to gemm problems (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhar authored Oct 10, 2024
1 parent 1aa0004 commit c5ad991
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions gemmbench/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,15 @@ def is_compute_bound(M: int, N: int, K: int, dtype: str) -> bool:
(4096, 5120, 640),
]

SQUARE = [
(128, 128, 128),
(256, 256, 256),
(512, 512, 512),
(1024, 1024, 1024),
(2048, 2048, 2048),
(4096, 4096, 4096),
(8192, 8192, 8192),
]

def llama13bmatvec(dtype: str) -> list[GemmConfig]:
configs = []
Expand Down Expand Up @@ -953,6 +962,24 @@ def unet(dtype: str) -> list[GemmConfig]:
return configs


def square(dtype: str) -> list[GemmConfig]:
configs = []
for m, n, k in SQUARE:
configs.append(
GemmConfig(
m,
n,
k,
"N",
"T",
dtype,
get_default_accumulator_element_type(dtype),
get_default_result_element_type(dtype),
)
)
return configs


def get_gemm_configs() -> list[tuple[str, GemmConfig]]:
llama13bmatvec_configs: list[GemmConfig] = []
llama13bmatvec_configs += llama13bmatvec("f16")
Expand Down Expand Up @@ -982,6 +1009,8 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]:
unet_configs += unet("f16")
unet_configs += unet("bf16")

square_configs: list[GemmConfig] = square("f16")

all_configs: list[tuple[str, GemmConfig]] = []
all_configs += [("llama13bmatvec", x) for x in llama13bmatvec_configs]
all_configs += [("llama70bmatvec", x) for x in llama70bmatvec_configs]
Expand All @@ -991,6 +1020,7 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]:
all_configs += [("llama70bmemory", x) for x in llama70bmemory_configs]
all_configs += [("compute", x) for x in compute_configs]
all_configs += [("unet", x) for x in unet_configs]
all_configs += [("square", x) for x in square_configs]
all_configs += [("tk", x) for x in tk_default_configs]

return all_configs
Expand Down

0 comments on commit c5ad991

Please sign in to comment.