Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for i8 dtype, add --raw_accumulators flag, add --target=host_cpu for easy local testing. #22

Merged
merged 4 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions gemmbench/gemm_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
help="Set the logging level",
)

parser.add_argument("--target", help="The IREE hip target to compile for", type=str, default="gfx942")
parser.add_argument("--target", help="The IREE hip target to compile for. The special value host_cpu results in a llvm-cpu benchmark instead of HIP, compiled for the host CPU.", type=str, default="gfx942")
kuhar marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip")
parser.add_argument(
"--Xiree_compile",
Expand Down Expand Up @@ -76,10 +76,15 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
default=None,
help="Directory to which executable files will be dumped."
)
parser.add_argument(
"--raw_accumulators",
action='store_true',
kuhar marked this conversation as resolved.
Show resolved Hide resolved
help="If true, benchmark matmuls returning the raw accumulator type with no truncation. If false (default), the results are truncated and cast to the input element type."
)

args = parser.parse_args()
# Handle default values here, since list args are not compatible with defaulted lists.
requested_dtypes = ["f16", "bf16"] if not args.dtypes else list(args.dtypes)
requested_dtypes = ["f16", "bf16", "i8"] if not args.dtypes else list(args.dtypes)
requested_variants = ["NN", "NT", "TN", "TT"] if not args.variants else list(args.variants)

logging.basicConfig(level=args.log_level)
Expand All @@ -91,7 +96,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,

tk = args.tk
configs = get_tk_gemm_configs() if tk else get_gemm_configs()
configs = get_matching_configs(configs, requested_dtypes, requested_variants, args.tag_regex)
configs = get_matching_configs(configs, requested_dtypes, requested_variants, args.tag_regex, args.raw_accumulators)
print(f"Generated {len(configs)} gemm configs.")

num_cpus = max(1, max(cpu_count() // 2, 1))
Expand All @@ -108,7 +113,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
target = args.target
extra_compiler_args = ['--' + x for x in list(args.Xiree_compile)]
dump_dir = args.dump_dir
device = args.device
device = "local-task" if args.target == "host_cpu" else args.device

compile_args = itertools.starmap(
lambda tag, config: (tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tk, dump_dir), configs
Expand All @@ -130,9 +135,12 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,

results = []
index = 0
output_csv = "results/iree_gemm.csv"
output_csv_base = "iree_gemm"
if args.raw_accumulators:
output_csv_base += "_raw_accumulators"
if tk:
output_csv = "results/iree_gemm_tk.csv"
output_csv_base += "_tk"
output_csv = f"results/{output_csv_base}.csv"
csv_dir = os.path.dirname(output_csv)
if not os.path.exists(csv_dir):
os.makedirs(csv_dir)
Expand Down
76 changes: 48 additions & 28 deletions gemmbench/gemm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from iree.turbine.kernel.lang.global_symbols import *
import torch


@dataclass
class GemmConfig:
M: int
Expand Down Expand Up @@ -48,73 +49,79 @@ def get_byte_count(self) -> int:
}
operand_bytes_per_element = dtype_to_bytes[self.operand_element_type]
result_bytes_per_element = dtype_to_bytes[self.result_element_type]
byte_count = (self.M * self.K + self.N * self.K) * operand_bytes_per_element + (self.M * self.N) * result_bytes_per_element
return byte_count
byte_count_input = (self.M + self.N) * self.K * operand_bytes_per_element
byte_count_output = (self.M * self.N) * result_bytes_per_element
return byte_count_input + byte_count_output

def get_flops(self) -> int:
flops = 2 * self.M * self.N * self.K
return flops


def generate_mlir(config: GemmConfig):
K = config.K
M = config.M
N = config.N
operand_element_type = config.operand_element_type
acc_element_type = config.accumulator_element_type
result_element_type = config.result_element_type
assert not operand_element_type.startswith('i'), "Integer types not supported yet"
is_integer = operand_element_type.startswith('i')
literal_zero = "0" if is_integer else "0.0"
trunc_op = "arith.trunci" if is_integer else "arith.truncf"

tA = config.tA
tB = config.tB
mlir_template_A = f"""
mlir_template_matmul_transpose_a = f"""
module {{
func.func @main(%arg0: tensor<{K}x{M}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{
%cst = arith.constant 0.000000e+00 : {acc_element_type}
%cst = arith.constant {literal_zero} : {acc_element_type}
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
-> tensor<{M}x{N}x{acc_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
return %3 : tensor<{M}x{N}x{result_element_type}>
}}
}}
"""

mlir_template_B = f"""
mlir_template_matmul_transpose_b = f"""
module {{
func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{N}x{K}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{
%cst = arith.constant 0.000000e+00 : {acc_element_type}
%cst = arith.constant {literal_zero} : {acc_element_type}
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{N}x{K}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
-> tensor<{M}x{N}x{acc_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
return %3 : tensor<{M}x{N}x{result_element_type}>
}}
}}
"""

mlir_template = f"""
mlir_template_matmul_normal = f"""
module {{
func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{
%cst = arith.constant 0.000000e+00 : {acc_element_type}
%cst = arith.constant {literal_zero} : {acc_element_type}
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
-> tensor<{M}x{N}x{acc_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
"""
mlir_template_matmul = mlir_template_matmul_transpose_a if tA == "T" else mlir_template_matmul_transpose_b if tB == "T" else mlir_template_matmul_normal

mlir_template_return_truncated = f"""
%3 = {trunc_op} %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
return %3 : tensor<{M}x{N}x{result_element_type}>
}}
}}
"""
if tA == "T":
return mlir_template_A
if tB == "T":
return mlir_template_B
return mlir_template

mlir_template_return_untruncated = f"""
return %2 : tensor<{M}x{N}x{result_element_type}>
}}
}}
"""

mlir_template_return = mlir_template_return_untruncated if (acc_element_type == result_element_type) else mlir_template_return_truncated

return mlir_template_matmul + mlir_template_return


@dataclass
class TkTunedConfig:
Expand All @@ -131,6 +138,7 @@ class TkTunedConfig:
DELAY_SHARED: int
DELAY_GLOBAL: int


def get_tk_tuned_config(config: GemmConfig) -> TkTunedConfig:
if config.M == 2048 and config.N == 10240 and config.K == 1280:
return TkTunedConfig(128, 320, 32, 2, 2, 2, 2, 2, 2, 1, 1, 2)
Expand All @@ -145,6 +153,7 @@ def get_tk_tuned_config(config: GemmConfig) -> TkTunedConfig:
# Default config
return TkTunedConfig(64, 64, 32, 2, 2, 1, 2, 2, 2, 1, 1, 2)


def generate_tk_mlir(config: GemmConfig):
# TODO: Enable waves_per_eu
# TODO: Use scheduling barriers with LLVM patch
Expand All @@ -166,14 +175,16 @@ def generate_tk_mlir(config: GemmConfig):
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD

# Expose user-constraints
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints: list[tkw.Constraint] = [
tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / tc.RATIO_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / tc.RATIO_N)]

constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(tc.RATIO_M, tc.RATIO_N, 1))
tkw.HardwareConstraint(threads_per_wave=64,
waves_per_block=(tc.RATIO_M, tc.RATIO_N, 1))
]

# Wave-level micro-kernel.
Expand Down Expand Up @@ -266,13 +277,22 @@ def compile_gemm_config(
exec_args = [
"iree-compile",
f"{mlir_file}",
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={target}",
"--iree-llvmgpu-enable-prefetch=true",
"-o",
f"{vmfb_file}",
] + extra_compiler_args

if target == "host_cpu":
exec_args += [
"--iree-hal-target-backends=llvm-cpu",
"--iree-llvmcpu-target-cpu=host"
]
else:
exec_args += [
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={target}",
"--iree-llvmgpu-enable-prefetch=true",
]

print(" ".join(exec_args))

ret_value, stdout, stderr = run_iree_command(exec_args)
Expand Down
Loading