diff --git a/gemmbench/gemm_bench.py b/gemmbench/gemm_bench.py index 8778c7c..6ce3066 100644 --- a/gemmbench/gemm_bench.py +++ b/gemmbench/gemm_bench.py @@ -38,7 +38,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, help="Set the logging level", ) - parser.add_argument("--target", help="The IREE hip target to compile for", type=str, default="gfx942") + parser.add_argument("--target", help="The IREE hip target to compile for. The special value host_cpu results in a llvm-cpu benchmark instead of HIP, compiled for the host CPU.", type=str, default="gfx942") parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip") parser.add_argument( "--Xiree_compile", @@ -76,10 +76,15 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, default=None, help="Directory to which executable files will be dumped." ) + parser.add_argument( + "--raw_accumulators", + action='store_true', + help="If true, benchmark matmuls returning the raw accumulator type with no truncation. If false (default), the results are truncated and cast to the input element type." + ) args = parser.parse_args() # Handle default values here, since list args are not compatible with defaulted lists. - requested_dtypes = ["f16", "bf16"] if not args.dtypes else list(args.dtypes) + requested_dtypes = ["f16", "bf16", "i8"] if not args.dtypes else list(args.dtypes) requested_variants = ["NN", "NT", "TN", "TT"] if not args.variants else list(args.variants) logging.basicConfig(level=args.log_level) @@ -91,7 +96,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tk = args.tk configs = get_tk_gemm_configs() if tk else get_gemm_configs() - configs = get_matching_configs(configs, requested_dtypes, requested_variants, args.tag_regex) + configs = get_matching_configs(configs, requested_dtypes, requested_variants, args.tag_regex, args.raw_accumulators) print(f"Generated {len(configs)} gemm configs.") num_cpus = max(1, max(cpu_count() // 2, 1)) @@ -108,7 +113,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, target = args.target extra_compiler_args = ['--' + x for x in list(args.Xiree_compile)] dump_dir = args.dump_dir - device = args.device + device = "local-task" if args.target == "host_cpu" else args.device compile_args = itertools.starmap( lambda tag, config: (tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tk, dump_dir), configs @@ -130,9 +135,12 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, results = [] index = 0 - output_csv = "results/iree_gemm.csv" + output_csv_base = "iree_gemm" + if args.raw_accumulators: + output_csv_base += "_raw_accumulators" if tk: - output_csv = "results/iree_gemm_tk.csv" + output_csv_base += "_tk" + output_csv = f"results/{output_csv_base}.csv" csv_dir = os.path.dirname(output_csv) if not os.path.exists(csv_dir): os.makedirs(csv_dir) diff --git a/gemmbench/gemm_utils.py b/gemmbench/gemm_utils.py index 86d8aec..a196574 100644 --- a/gemmbench/gemm_utils.py +++ b/gemmbench/gemm_utils.py @@ -8,6 +8,7 @@ from iree.turbine.kernel.lang.global_symbols import * import torch + @dataclass class GemmConfig: M: int @@ -48,13 +49,15 @@ def get_byte_count(self) -> int: } operand_bytes_per_element = dtype_to_bytes[self.operand_element_type] result_bytes_per_element = dtype_to_bytes[self.result_element_type] - byte_count = (self.M * self.K + self.N * self.K) * operand_bytes_per_element + (self.M * self.N) * result_bytes_per_element - return byte_count + byte_count_input = (self.M + self.N) * self.K * operand_bytes_per_element + byte_count_output = (self.M * self.N) * result_bytes_per_element + return byte_count_input + byte_count_output def get_flops(self) -> int: flops = 2 * self.M * self.N * self.K return flops + def generate_mlir(config: GemmConfig): K = config.K M = config.M @@ -62,59 +65,63 @@ def generate_mlir(config: GemmConfig): operand_element_type = config.operand_element_type acc_element_type = config.accumulator_element_type result_element_type = config.result_element_type - assert not operand_element_type.startswith('i'), "Integer types not supported yet" + is_integer = operand_element_type.startswith('i') + literal_zero = "0" if is_integer else "0.0" + trunc_op = "arith.trunci" if is_integer else "arith.truncf" tA = config.tA tB = config.tB - mlir_template_A = f""" + mlir_template_matmul_transpose_a = f""" module {{ func.func @main(%arg0: tensor<{K}x{M}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{ - %cst = arith.constant 0.000000e+00 : {acc_element_type} + %cst = arith.constant {literal_zero} : {acc_element_type} %0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}> %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> %2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>) outs(%1 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> - %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> - return %3 : tensor<{M}x{N}x{result_element_type}> - }} -}} """ - mlir_template_B = f""" + mlir_template_matmul_transpose_b = f""" module {{ func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{N}x{K}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{ - %cst = arith.constant 0.000000e+00 : {acc_element_type} + %cst = arith.constant {literal_zero} : {acc_element_type} %0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}> %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{N}x{K}x{operand_element_type}>) outs(%1 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> - %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> - return %3 : tensor<{M}x{N}x{result_element_type}> - }} -}} """ - mlir_template = f""" + mlir_template_matmul_normal = f""" module {{ func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{ - %cst = arith.constant 0.000000e+00 : {acc_element_type} + %cst = arith.constant {literal_zero} : {acc_element_type} %0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}> %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> %2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>) outs(%1 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> - %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> +""" + mlir_template_matmul = mlir_template_matmul_transpose_a if tA == "T" else mlir_template_matmul_transpose_b if tB == "T" else mlir_template_matmul_normal + + mlir_template_return_truncated = f""" + %3 = {trunc_op} %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> return %3 : tensor<{M}x{N}x{result_element_type}> }} }} """ - if tA == "T": - return mlir_template_A - if tB == "T": - return mlir_template_B - return mlir_template + + mlir_template_return_untruncated = f""" + return %2 : tensor<{M}x{N}x{result_element_type}> + }} +}} +""" + + mlir_template_return = mlir_template_return_untruncated if (acc_element_type == result_element_type) else mlir_template_return_truncated + + return mlir_template_matmul + mlir_template_return + @dataclass class TkTunedConfig: @@ -131,6 +138,7 @@ class TkTunedConfig: DELAY_SHARED: int DELAY_GLOBAL: int + def get_tk_tuned_config(config: GemmConfig) -> TkTunedConfig: if config.M == 2048 and config.N == 10240 and config.K == 1280: return TkTunedConfig(128, 320, 32, 2, 2, 2, 2, 2, 2, 1, 1, 2) @@ -145,6 +153,7 @@ def get_tk_tuned_config(config: GemmConfig) -> TkTunedConfig: # Default config return TkTunedConfig(64, 64, 32, 2, 2, 1, 2, 2, 2, 1, 1, 2) + def generate_tk_mlir(config: GemmConfig): # TODO: Enable waves_per_eu # TODO: Use scheduling barriers with LLVM patch @@ -166,14 +175,16 @@ def generate_tk_mlir(config: GemmConfig): STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD # Expose user-constraints - constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints: list[tkw.Constraint] = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] constraints += [tkw.TilingConstraint(K, BLOCK_K)] constraints += [tkw.WaveConstraint(M, BLOCK_M / tc.RATIO_M)] constraints += [tkw.WaveConstraint(N, BLOCK_N / tc.RATIO_N)] constraints += [ - tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(tc.RATIO_M, tc.RATIO_N, 1)) + tkw.HardwareConstraint(threads_per_wave=64, + waves_per_block=(tc.RATIO_M, tc.RATIO_N, 1)) ] # Wave-level micro-kernel. @@ -266,13 +277,22 @@ def compile_gemm_config( exec_args = [ "iree-compile", f"{mlir_file}", - "--iree-hal-target-backends=rocm", - f"--iree-hip-target={target}", - "--iree-llvmgpu-enable-prefetch=true", "-o", f"{vmfb_file}", ] + extra_compiler_args + if target == "host_cpu": + exec_args += [ + "--iree-hal-target-backends=llvm-cpu", + "--iree-llvmcpu-target-cpu=host" + ] + else: + exec_args += [ + "--iree-hal-target-backends=rocm", + f"--iree-hip-target={target}", + "--iree-llvmgpu-enable-prefetch=true", + ] + print(" ".join(exec_args)) ret_value, stdout, stderr = run_iree_command(exec_args) diff --git a/gemmbench/problems.py b/gemmbench/problems.py index ab2dc7d..6b23c6c 100644 --- a/gemmbench/problems.py +++ b/gemmbench/problems.py @@ -19,16 +19,17 @@ def get_default_accumulator_element_type(operand_element_type: str) -> str: ] -def get_default_result_element_type(operand_element_type: str) -> str: - return operand_element_type +def get_default_result_element_type(operand_element_type: str, raw_accumulators: bool) -> str: + return get_default_accumulator_element_type(operand_element_type) if raw_accumulators else operand_element_type -def is_compute_bound(M: int, N: int, K: int, dtype: str) -> bool: +def is_compute_bound(M: int, N: int, K: int, dtype: str, raw_accumulators: bool) -> bool: """Is this GEMM compute (or memory) bound?""" magic_ratio = 64 flops = 2 * M * N * K elem_type_bytes = num_bytes(dtype) - result_bytes = num_bytes(get_default_result_element_type(dtype)) + result_bytes = num_bytes( + get_default_result_element_type(dtype, raw_accumulators)) bytes = elem_type_bytes * (M * K + K * N) + result_bytes * (M * N) return flops > magic_ratio * bytes @@ -680,23 +681,26 @@ def is_compute_bound(M: int, N: int, K: int, dtype: str) -> bool: (8192, 8192, 8192), ] + def llama13bmatvec(dtype: str) -> list[GemmConfig]: configs = [] """LLAMA 13b, single batch, FP16.""" for m, n, k, model, gcount in LLAMA: if n == 1 and model == "13b": - configs.append( - GemmConfig( - m, - n, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), + for raw_accumulators in [False, True]: + configs.append( + GemmConfig( + m, + n, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type( + dtype, raw_accumulators), + ) ) - ) return configs @@ -705,16 +709,17 @@ def llama13bmatvecbf16(dtype: str) -> list[GemmConfig]: """LLAMA 13b, single batch, BF16.""" for m, n, k, model, gcount in LLAMA: if n == 1 and model == "13b": - configs.append(GemmConfig( - m, - n, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - )) + for raw_accumulators in [False, True]: + configs.append(GemmConfig( + m, + n, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + )) return configs @@ -723,16 +728,17 @@ def llama70bmatvec(dtype: str) -> list[GemmConfig]: configs = [] for m, n, k, model, gcount in LLAMA: if n == 1 and model == "70b": - configs.append(GemmConfig( - m, - n, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - )) + for raw_accumulators in [False, True]: + configs.append(GemmConfig( + m, + n, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + )) return configs @@ -741,16 +747,17 @@ def llama70bmatvecbf16(dtype: str) -> list[GemmConfig]: configs = [] for m, n, k, model, gcount in LLAMA: if n == 1 and model == "70b": - configs.append(GemmConfig( - m, - n, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - )) + for raw_accumulators in [False, True]: + configs.append(GemmConfig( + m, + n, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + )) return configs @@ -760,16 +767,18 @@ def llama13bskinny(dtype: str) -> list[GemmConfig]: for m, n, k, model, gcount in LLAMA: if n == 1 and model == "13b": for batch in [2, 4, 8, 16, 32]: - configs.append(GemmConfig( - m, - batch, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - )) + for raw_accumulators in [False, True]: + configs.append(GemmConfig( + m, + batch, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type( + dtype, raw_accumulators), + )) return configs @@ -779,16 +788,18 @@ def llama13bskinnybf16(dtype: str) -> list[GemmConfig]: for m, n, k, model, gcount in LLAMA: if n == 1 and model == "13b": for batch in [2, 4, 8, 16, 32]: - configs.append(GemmConfig( - m, - batch, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - )) + for raw_accumulators in [False, True]: + configs.append(GemmConfig( + m, + batch, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type( + dtype, raw_accumulators), + )) return configs @@ -798,16 +809,18 @@ def llama70bskinny(dtype: str) -> list[GemmConfig]: for m, n, k, model, gcount in LLAMA: if n == 1 and model == "70b": for batch in [2, 4, 8, 16, 32]: - configs.append(GemmConfig( - m, - batch, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - )) + for raw_accumulators in [False, True]: + configs.append(GemmConfig( + m, + batch, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type( + dtype, raw_accumulators), + )) return configs @@ -817,16 +830,18 @@ def llama70bskinnybf16(dtype: str) -> list[GemmConfig]: for m, n, k, model, gcount in LLAMA: if n == 1 and model == "70b": for batch in [2, 4, 8, 16, 32]: - configs.append(GemmConfig( - m, - batch, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - )) + for raw_accumulators in [False, True]: + configs.append(GemmConfig( + m, + batch, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type( + dtype, raw_accumulators), + )) return configs @@ -834,18 +849,19 @@ def gpt4memory(dtype: str) -> list[GemmConfig]: """GPT4 memory bound GEMMs; FP16.""" configs = [] for m, n, k in GPT4: - hgemm = GemmConfig( - m, - n, - k, - "N", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - ) - if not is_compute_bound(m, n, k, dtype): - configs.append(hgemm) + for raw_accumulators in [False, True]: + hgemm = GemmConfig( + m, + n, + k, + "N", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + ) + if not is_compute_bound(m, n, k, dtype, raw_accumulators): + configs.append(hgemm) return configs @@ -853,36 +869,43 @@ def gpt4compute(dtype: str) -> list[GemmConfig]: """GPT4 compute bound GEMMs; FP16.""" configs = [] for m, n, k in GPT4: - hgemm = GemmConfig( - m, - n, - k, - "N", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - ) - if is_compute_bound(m, n, k, dtype): - configs.append(hgemm) + for raw_accumulators in [False, True]: + hgemm = GemmConfig( + m, + n, + k, + "N", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + ) + if is_compute_bound(m, n, k, dtype, raw_accumulators): + configs.append(hgemm) return configs def tk_default(dtype: str) -> list[GemmConfig]: """TK Shapes.""" - acc_type = get_default_accumulator_element_type(dtype) - res_type = get_default_result_element_type(dtype) - configs = [] - M, N, K = 2048, 10240, 1280 - configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) - M, N, K = 2048, 1280, 1280 - configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) - M, N, K = 2048, 1280, 5120 - configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) - M, N, K = 128, 1280, 2048 - configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) - M, N, K = 8192, 5120, 640 - configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) + for raw_accumulators in [False, True]: + acc_type = get_default_accumulator_element_type(dtype) + res_type = get_default_result_element_type(dtype, raw_accumulators) + configs = [] + M, N, K = 2048, 10240, 1280 + configs.append(GemmConfig(M, N, K, "N", "T", + dtype, acc_type, res_type)) + M, N, K = 2048, 1280, 1280 + configs.append(GemmConfig(M, N, K, "N", "T", + dtype, acc_type, res_type)) + M, N, K = 2048, 1280, 5120 + configs.append(GemmConfig(M, N, K, "N", "T", + dtype, acc_type, res_type)) + M, N, K = 128, 1280, 2048 + configs.append(GemmConfig(M, N, K, "N", "T", + dtype, acc_type, res_type)) + M, N, K = 8192, 5120, 640 + configs.append(GemmConfig(M, N, K, "N", "T", + dtype, acc_type, res_type)) return configs @@ -890,18 +913,19 @@ def tk_unet(dtype: str) -> list[GemmConfig]: """UNET Shapes for TK.""" configs = [] for m, n, k in UNET: - configs.append( - GemmConfig( - m, - n, - k, - "N", - "T", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), + for raw_accumulators in [False, True]: + configs.append( + GemmConfig( + m, + n, + k, + "N", + "T", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + ) ) - ) return configs @@ -909,18 +933,19 @@ def llama70bmemory(dtype: str) -> list[GemmConfig]: """LLAMA 70b memory bound GEMMs; NT; BF16.""" configs = [] for n in [1280, 3584, 7168]: - configs.append( - GemmConfig( - 2, - n, - 8192, - "N", - "T", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), + for raw_accumulators in [False, True]: + configs.append( + GemmConfig( + 2, + n, + 8192, + "N", + "T", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + ) ) - ) return configs @@ -928,18 +953,19 @@ def compute(dtype: str) -> list[GemmConfig]: """Compute bound GEMMs.""" configs = [] for tA, tB in [("N", "N"), ("N", "T"), ("T", "N")]: - configs.append( - GemmConfig( - 4096, - 4096, - 8192, - tA, - tB, - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), + for raw_accumulators in [False, True]: + configs.append( + GemmConfig( + 4096, + 4096, + 8192, + tA, + tB, + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + ) ) - ) return configs @@ -947,39 +973,42 @@ def unet(dtype: str) -> list[GemmConfig]: configs = [] for tA, tB in [("N", "N"), ("N", "T")]: for m, n, k in UNET: + for raw_accumulators in [False, True]: + configs.append( + GemmConfig( + m, + n, + k, + tA, + tB, + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type( + dtype, raw_accumulators), + ) + ) + return configs + + +def square(dtype: str) -> list[GemmConfig]: + configs = [] + for m, n, k in SQUARE: + for raw_accumulators in [False, True]: configs.append( GemmConfig( m, n, k, - tA, - tB, + "N", + "T", dtype, get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), ) ) return configs -def square(dtype: str) -> list[GemmConfig]: - configs = [] - for m, n, k in SQUARE: - configs.append( - GemmConfig( - m, - n, - k, - "N", - "T", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - ) - ) - return configs - - def get_gemm_configs() -> list[tuple[str, GemmConfig]]: llama13bmatvec_configs: list[GemmConfig] = [] llama13bmatvec_configs += llama13bmatvec("f16") @@ -1009,7 +1038,7 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]: unet_configs += unet("f16") unet_configs += unet("bf16") - square_configs: list[GemmConfig] = square("f16") + square_configs: list[GemmConfig] = square("f16") + square("bf16") + square("i8") all_configs: list[tuple[str, GemmConfig]] = [] all_configs += [("llama13bmatvec", x) for x in llama13bmatvec_configs] @@ -1041,6 +1070,7 @@ def get_matching_configs( dtypes: list[str], variants: list[str], tag_regex: str, + raw_accumulators: bool ) -> list[tuple[str, GemmConfig]]: tag_re = re.compile(tag_regex) matching_configs: list[tuple[str, GemmConfig]] = [] @@ -1051,6 +1081,14 @@ def get_matching_configs( continue if not tag_re.match(tag): continue + # The raw_accumulators arg means "test configs where the result element + # type is different from what it would be in the default mode". + # We can't just test for (result_element_type == accumulator_element_type), + # as that would cause e.g. f32 matmuls to be omitted in the default mode. + default_result_element_type = get_default_result_element_type(config.operand_element_type, False) + is_raw_accumulators_config = (config.result_element_type != default_result_element_type) + if raw_accumulators != is_raw_accumulators_config: + continue matching_configs.append((tag, config)) return matching_configs