diff --git a/python/perf-kernels/streamk/tune_streamk.py b/python/perf-kernels/streamk/tune_streamk.py index 1215d963463d..1c6cbabd6921 100755 --- a/python/perf-kernels/streamk/tune_streamk.py +++ b/python/perf-kernels/streamk/tune_streamk.py @@ -15,6 +15,7 @@ from datetime import datetime import multiprocessing import pandas as pd +import itertools from utils.file_generator import ( gen_configStr, @@ -63,22 +64,17 @@ def get_full_tuning_space(): kpack_range = [1, 2] num_sms_range = [304] - for block_m in block_mn_range: - for block_n in block_mn_range: - for block_k in block_k_range: - for num_warps in num_warps_range: - for group_m in group_m_range: - for num_sms in num_sms_range: - for num_stages in num_stage_range: - for waves_per_eu in waves_per_eu_range: - for matrix_instr_nonkdim in matrix_instr_nonkdim_range: - for kpack in kpack_range: - configs.append({ - 'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': - block_k, 'GROUP_SIZE_M': group_m, 'NUM_SMS': num_sms, 'num_warps': - num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu, - 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack - }) + space = itertools.product(block_mn_range, block_mn_range, block_k_range, num_warps_range, group_m_range, + num_sms_range, num_stage_range, waves_per_eu_range, matrix_instr_nonkdim_range, + kpack_range) + + for instance in space: + block_m, block_n, block_k, num_warps, group_m, num_sms, num_stages, waves_per_eu, matrix_instr_nonkdim, kpack = instance + configs.append({ + 'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, + 'NUM_SMS': num_sms, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu, + 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack + }) return configs @@ -139,8 +135,14 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): continue # out of shared memory resource # TODO (zhanglx): This does not consider the LDS usage in the epilogue - LDS = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b - LDS = LDS if not num_stages else LDS * (num_stages - 1) + LDSA = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + LDSB = BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b + if num_stages <= 1: + # No pipeline, buffer A and buffer B can re-use each other + LDS = max(LDSA, LDSB) + else: + # Pipeline, we need (num_stages - 1) buffers for both A and B at the same time + LDS = (LDSA + LDSB) * (num_stages - 1) if LDS > 65536: continue # Skip small block sizes and num_warps for large gemm diff --git a/python/perf-kernels/tools/tune_gemm/README.md b/python/perf-kernels/tools/tune_gemm/README.md index c22382143544..c187942ed894 100644 --- a/python/perf-kernels/tools/tune_gemm/README.md +++ b/python/perf-kernels/tools/tune_gemm/README.md @@ -323,6 +323,7 @@ will be added later. - Switched back to rocprofv1. Check [ticket#228](https://github.com/ROCm/triton-internal/issues/228) for more details. - Improved the post-procesing logic to filter out the "spikes" in the profiling results. - Reduced the number of iterations in both tuning and benchmark mode (120 and 200). +- Appended the parameters tuning space with instruction scheduling variants for the main gemm-loop (k-loop). # One config running script diff --git a/python/perf-kernels/tools/tune_gemm/matmul_kernel.py b/python/perf-kernels/tools/tune_gemm/matmul_kernel.py index 1d9902bc2de6..18cd8edc13c2 100644 --- a/python/perf-kernels/tools/tune_gemm/matmul_kernel.py +++ b/python/perf-kernels/tools/tune_gemm/matmul_kernel.py @@ -7,6 +7,15 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BIAS: tl.constexpr, EVEN_K: tl.constexpr, GRID_MN: tl.constexpr, NUM_XCDS: tl.constexpr): + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + tl.assume(stride_bias > 0) + pid = tl.program_id(axis=0) pid_z = tl.program_id(1) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -33,6 +42,9 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m + tl.assume(pid_m > 0) + tl.assume(pid_n > 0) + if SPLIT_K == 1: offs_k = tl.arange(0, BLOCK_SIZE_K) else: diff --git a/python/perf-kernels/tools/tune_gemm/test_regression.py b/python/perf-kernels/tools/tune_gemm/test_regression.py index 2ca69e15b5b0..ce69e4c1be06 100644 --- a/python/perf-kernels/tools/tune_gemm/test_regression.py +++ b/python/perf-kernels/tools/tune_gemm/test_regression.py @@ -101,6 +101,7 @@ def teardown_class(self): }, ], ids=lambda val: f"Config: {val}") def test_matmul_performance_regression(self, config, record_property): + config.setdefault('instruction_sched_variant', 'none') M, N, K, col_a, col_b, runConfig = tune_gemm.process_item(deepcopy(config)) diff --git a/python/perf-kernels/tools/tune_gemm/tune_gemm.py b/python/perf-kernels/tools/tune_gemm/tune_gemm.py index 740ce88d8946..6f26acb3b6fc 100755 --- a/python/perf-kernels/tools/tune_gemm/tune_gemm.py +++ b/python/perf-kernels/tools/tune_gemm/tune_gemm.py @@ -16,6 +16,7 @@ from datetime import datetime import multiprocessing import pandas as pd +import itertools from utils.file_generator import ( gen_configStr, @@ -64,23 +65,19 @@ def get_full_tuning_space(): waves_per_eu_range = [0] matrix_instr_nonkdim_range = [16, 32] kpack_range = [1, 2] + sched_variants = ["none"] - for block_m in block_mn_range: - for block_n in block_mn_range: - for block_k in block_k_range: - for num_warps in num_warps_range: - for group_m in group_m_range: - for split_k in split_k_range: - for num_stages in num_stage_range: - for waves_per_eu in waves_per_eu_range: - for matrix_instr_nonkdim in matrix_instr_nonkdim_range: - for kpack in kpack_range: - configs.append({ - 'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': - block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps': - num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu, - 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack - }) + space = itertools.product(block_mn_range, block_mn_range, block_k_range, num_warps_range, group_m_range, + split_k_range, num_stage_range, waves_per_eu_range, matrix_instr_nonkdim_range, + sched_variants, kpack_range) + + for instance in space: + block_m, block_n, block_k, num_warps, group_m, split_k, num_stages, waves_per_eu, matrix_instr_nonkdim, sched_variant, kpack = instance + configs.append({ + 'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, + 'SPLIT_K': split_k, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu, + 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack, 'instruction_sched_variant': sched_variant + }) return configs @@ -355,7 +352,7 @@ def gen_rotating_tensors(M, N, K, dtype_a, need_Trans_a, dtype_b, need_Trans_b, def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, - mfmaInstrSize, kpack, use_bias): + mfmaInstrSize, kpack, use_bias, sched_variant): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" #assert a.is_contiguous(), "Matrix A must be contiguous" @@ -372,12 +369,13 @@ def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps c.stride(1), stride_bias=stride_bias, BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m, SPLIT_K=split_k, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, - kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K, GRID_MN=grid[0], NUM_XCDS=num_xcds) + kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K, GRID_MN=grid[0], NUM_XCDS=num_xcds, + instruction_sched_variant=sched_variant) return c def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, bias_vector, verbose): - block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config( + block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, sched_variant = read_config( config) use_bias = bias_vector torch.manual_seed(0) @@ -393,7 +391,7 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type # Allocates output. c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]]) triton_output = matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, - waves_per_eu, mfmaInstrSize, kpack, use_bias) + waves_per_eu, mfmaInstrSize, kpack, use_bias, sched_variant) torch_output = torch.matmul(a_fp16, b_fp16) if use_bias: torch_output += bias_fp16[:, None] @@ -658,11 +656,11 @@ def main(): formatted_tflops = format_output(tri_tflops) minTime = format_output(minTime) if not run_bench: - print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True) + print(f'\nTFLOPS: {formatted_tflops}; time(us): {minTime}', end=" ", flush=True) bestConfig_compact_str = gen_configStr(bestConfig) if not run_bench: - print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True) + print(f'\nbest_config: {bestConfig_compact_str}', end=" ", flush=True) # write best config to tuning_results.yaml if run_bench: diff --git a/python/perf-kernels/tools/tune_gemm/utils/file_generator.py b/python/perf-kernels/tools/tune_gemm/utils/file_generator.py index d92079dab9a0..6bbc51c387d9 100644 --- a/python/perf-kernels/tools/tune_gemm/utils/file_generator.py +++ b/python/perf-kernels/tools/tune_gemm/utils/file_generator.py @@ -21,15 +21,17 @@ def read_config(config): waves_per_eu = config.get('waves_per_eu') mfma_instr_size = config.get('matrix_instr_nonkdim') kpack = config.get('kpack') - return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack + sched_variant = config.get('instruction_sched_variant') + return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack, sched_variant def gen_configStr(config): - block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config( + block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, sched_variant = read_config( config) ## {M}_{N}_{K} is removed since the same kernel can be used for differen gemm sizes - configStr = f"BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}" + sched_variant = sched_variant.upper().replace('-', '_') + configStr = f"BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}_sched{sched_variant}" return configStr @@ -69,7 +71,7 @@ def generate_matmul_kernels(configs): ## construct the configStr and generate the wrapper function matmul_{configStr}() ## If `warmup` is set, the generated kernel will be **compiled** def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, warmup): - block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config( + block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, sched_variant = read_config( config) configStr = gen_configStr(config) @@ -112,6 +114,7 @@ def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn): EVEN_K = {EVEN_K}, GRID_MN = grid_mn, NUM_XCDS = {num_xcds}, + instruction_sched_variant = \"{sched_variant}\", grid=(1,), ) return None @@ -145,7 +148,8 @@ def matmul_{configStr}(a, b, c, bias, M, N, K, am, ak, bk, bn, cm, cn, biasn): BIAS = {use_bias}, EVEN_K = {EVEN_K}, GRID_MN = grid[0], - NUM_XCDS = {num_xcds} + NUM_XCDS = {num_xcds}, + instruction_sched_variant = \"{sched_variant}\", ) return c """