Skip to content

Commit

Permalink
Merge pull request #649 from ROCm/ravil/main_perf
Browse files Browse the repository at this point in the history
Added instr.sched options to `tune_gemm.py`
  • Loading branch information
ravil-mobile authored Nov 26, 2024
2 parents 94961d9 + c0ff468 commit 6e7ad94
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 45 deletions.
38 changes: 20 additions & 18 deletions python/perf-kernels/streamk/tune_streamk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from datetime import datetime
import multiprocessing
import pandas as pd
import itertools

from utils.file_generator import (
gen_configStr,
Expand Down Expand Up @@ -63,22 +64,17 @@ def get_full_tuning_space():
kpack_range = [1, 2]
num_sms_range = [304]

for block_m in block_mn_range:
for block_n in block_mn_range:
for block_k in block_k_range:
for num_warps in num_warps_range:
for group_m in group_m_range:
for num_sms in num_sms_range:
for num_stages in num_stage_range:
for waves_per_eu in waves_per_eu_range:
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
for kpack in kpack_range:
configs.append({
'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K':
block_k, 'GROUP_SIZE_M': group_m, 'NUM_SMS': num_sms, 'num_warps':
num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu,
'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack
})
space = itertools.product(block_mn_range, block_mn_range, block_k_range, num_warps_range, group_m_range,
num_sms_range, num_stage_range, waves_per_eu_range, matrix_instr_nonkdim_range,
kpack_range)

for instance in space:
block_m, block_n, block_k, num_warps, group_m, num_sms, num_stages, waves_per_eu, matrix_instr_nonkdim, kpack = instance
configs.append({
'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m,
'NUM_SMS': num_sms, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu,
'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack
})

return configs

Expand Down Expand Up @@ -139,8 +135,14 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b):
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
LDS = LDS if not num_stages else LDS * (num_stages - 1)
LDSA = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
LDSB = BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
if num_stages <= 1:
# No pipeline, buffer A and buffer B can re-use each other
LDS = max(LDSA, LDSB)
else:
# Pipeline, we need (num_stages - 1) buffers for both A and B at the same time
LDS = (LDSA + LDSB) * (num_stages - 1)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
Expand Down
1 change: 1 addition & 0 deletions python/perf-kernels/tools/tune_gemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ will be added later.
- Switched back to rocprofv1. Check [ticket#228](https://github.com/ROCm/triton-internal/issues/228) for more details.
- Improved the post-procesing logic to filter out the "spikes" in the profiling results.
- Reduced the number of iterations in both tuning and benchmark mode (120 and 200).
- Appended the parameters tuning space with instruction scheduling variants for the main gemm-loop (k-loop).


# One config running script
Expand Down
12 changes: 12 additions & 0 deletions python/perf-kernels/tools/tune_gemm/matmul_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak,
stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BIAS: tl.constexpr,
EVEN_K: tl.constexpr, GRID_MN: tl.constexpr, NUM_XCDS: tl.constexpr):

tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)
tl.assume(stride_bias > 0)

pid = tl.program_id(axis=0)
pid_z = tl.program_id(1)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
Expand All @@ -33,6 +42,9 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak,
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

tl.assume(pid_m > 0)
tl.assume(pid_n > 0)

if SPLIT_K == 1:
offs_k = tl.arange(0, BLOCK_SIZE_K)
else:
Expand Down
1 change: 1 addition & 0 deletions python/perf-kernels/tools/tune_gemm/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def teardown_class(self):
},
], ids=lambda val: f"Config: {val}")
def test_matmul_performance_regression(self, config, record_property):
config.setdefault('instruction_sched_variant', 'none')

M, N, K, col_a, col_b, runConfig = tune_gemm.process_item(deepcopy(config))

Expand Down
42 changes: 20 additions & 22 deletions python/perf-kernels/tools/tune_gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from datetime import datetime
import multiprocessing
import pandas as pd
import itertools

from utils.file_generator import (
gen_configStr,
Expand Down Expand Up @@ -64,23 +65,19 @@ def get_full_tuning_space():
waves_per_eu_range = [0]
matrix_instr_nonkdim_range = [16, 32]
kpack_range = [1, 2]
sched_variants = ["none"]

for block_m in block_mn_range:
for block_n in block_mn_range:
for block_k in block_k_range:
for num_warps in num_warps_range:
for group_m in group_m_range:
for split_k in split_k_range:
for num_stages in num_stage_range:
for waves_per_eu in waves_per_eu_range:
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
for kpack in kpack_range:
configs.append({
'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K':
block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps':
num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu,
'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack
})
space = itertools.product(block_mn_range, block_mn_range, block_k_range, num_warps_range, group_m_range,
split_k_range, num_stage_range, waves_per_eu_range, matrix_instr_nonkdim_range,
sched_variants, kpack_range)

for instance in space:
block_m, block_n, block_k, num_warps, group_m, split_k, num_stages, waves_per_eu, matrix_instr_nonkdim, sched_variant, kpack = instance
configs.append({
'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m,
'SPLIT_K': split_k, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu,
'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack, 'instruction_sched_variant': sched_variant
})

return configs

Expand Down Expand Up @@ -355,7 +352,7 @@ def gen_rotating_tensors(M, N, K, dtype_a, need_Trans_a, dtype_b, need_Trans_b,


def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu,
mfmaInstrSize, kpack, use_bias):
mfmaInstrSize, kpack, use_bias, sched_variant):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
#assert a.is_contiguous(), "Matrix A must be contiguous"
Expand All @@ -372,12 +369,13 @@ def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps
c.stride(1), stride_bias=stride_bias, BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n,
BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m, SPLIT_K=split_k, num_warps=num_warps,
num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize,
kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K, GRID_MN=grid[0], NUM_XCDS=num_xcds)
kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K, GRID_MN=grid[0], NUM_XCDS=num_xcds,
instruction_sched_variant=sched_variant)
return c


def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, bias_vector, verbose):
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, sched_variant = read_config(
config)
use_bias = bias_vector
torch.manual_seed(0)
Expand All @@ -393,7 +391,7 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type
# Allocates output.
c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]])
triton_output = matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages,
waves_per_eu, mfmaInstrSize, kpack, use_bias)
waves_per_eu, mfmaInstrSize, kpack, use_bias, sched_variant)
torch_output = torch.matmul(a_fp16, b_fp16)
if use_bias:
torch_output += bias_fp16[:, None]
Expand Down Expand Up @@ -658,11 +656,11 @@ def main():
formatted_tflops = format_output(tri_tflops)
minTime = format_output(minTime)
if not run_bench:
print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True)
print(f'\nTFLOPS: {formatted_tflops}; time(us): {minTime}', end=" ", flush=True)

bestConfig_compact_str = gen_configStr(bestConfig)
if not run_bench:
print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True)
print(f'\nbest_config: {bestConfig_compact_str}', end=" ", flush=True)

# write best config to tuning_results.yaml
if run_bench:
Expand Down
14 changes: 9 additions & 5 deletions python/perf-kernels/tools/tune_gemm/utils/file_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ def read_config(config):
waves_per_eu = config.get('waves_per_eu')
mfma_instr_size = config.get('matrix_instr_nonkdim')
kpack = config.get('kpack')
return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack
sched_variant = config.get('instruction_sched_variant')
return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack, sched_variant


def gen_configStr(config):
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, sched_variant = read_config(
config)

## {M}_{N}_{K} is removed since the same kernel can be used for differen gemm sizes
configStr = f"BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}"
sched_variant = sched_variant.upper().replace('-', '_')
configStr = f"BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}_sched{sched_variant}"

return configStr

Expand Down Expand Up @@ -69,7 +71,7 @@ def generate_matmul_kernels(configs):
## construct the configStr and generate the wrapper function matmul_{configStr}()
## If `warmup` is set, the generated kernel will be **compiled**
def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, warmup):
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, sched_variant = read_config(
config)

configStr = gen_configStr(config)
Expand Down Expand Up @@ -112,6 +114,7 @@ def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn):
EVEN_K = {EVEN_K},
GRID_MN = grid_mn,
NUM_XCDS = {num_xcds},
instruction_sched_variant = \"{sched_variant}\",
grid=(1,),
)
return None
Expand Down Expand Up @@ -145,7 +148,8 @@ def matmul_{configStr}(a, b, c, bias, M, N, K, am, ak, bk, bn, cm, cn, biasn):
BIAS = {use_bias},
EVEN_K = {EVEN_K},
GRID_MN = grid[0],
NUM_XCDS = {num_xcds}
NUM_XCDS = {num_xcds},
instruction_sched_variant = \"{sched_variant}\",
)
return c
"""
Expand Down

0 comments on commit 6e7ad94

Please sign in to comment.