From e23a3c2cba15df591f3c2341f1648dcc6079ce6a Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Thu, 19 Dec 2024 17:10:17 -0600 Subject: [PATCH 01/11] Move preamble code into tikzplot.tex --- .../tools/plot-layout/plot_layout.py | 65 ------------------- .../tools/plot-layout/tikzplot.tex | 60 +++++++++++++++++ 2 files changed, 60 insertions(+), 65 deletions(-) diff --git a/python/perf-kernels/tools/plot-layout/plot_layout.py b/python/perf-kernels/tools/plot-layout/plot_layout.py index 599f92c790e4..a394fe87e099 100644 --- a/python/perf-kernels/tools/plot-layout/plot_layout.py +++ b/python/perf-kernels/tools/plot-layout/plot_layout.py @@ -4,68 +4,6 @@ import subprocess -def draw_preamble_cmd(): - return '''\\documentclass[tikz, border=1mm, dvipsnames]{standalone} -\\usepackage{ifthen} -\\usepackage{tikz} -\\usetikzlibrary{arrows.meta,arrows} -\\usetikzlibrary{intersections} -\\usetikzlibrary{calc, quotes} -\\usetikzlibrary{patterns} -\\usepackage{xparse} - -\\ExplSyntaxOn -\\NewExpandableDocumentCommand{\\bitwiseXor}{mm} - { - \\recuenco_bitwise_xor:nn { #1 } { #2 } - } - -\\cs_new:Nn \\recuenco_bitwise_xor:nn - { - \\int_from_bin:e - { - \\__recuenco_bitwise_xor:ee { \\int_to_bin:n { #1 } } { \\int_to_bin:n { #2 } } - } - } -\\cs_generate_variant:Nn \\int_from_bin:n { e } - -\\cs_new:Nn \\__recuenco_bitwise_xor:nn - { - \\__recuenco_bitwise_xor_binary:ee - { - \\prg_replicate:nn - { - \\int_max:nn { \\tl_count:n { #1 } } { \\tl_count:n { #2 } } - \\tl_count:n { #1 } - } - { 0 } - #1 - } - { - \\prg_replicate:nn - { - \\int_max:nn { \\tl_count:n { #1 } } { \\tl_count:n { #2 } } - \\tl_count:n { #2 } - } - { 0 } - #2 - } - } -\\cs_generate_variant:Nn \\__recuenco_bitwise_xor:nn { ee } - -\\cs_new:Nn \\__recuenco_bitwise_xor_binary:nn - { - \\__recuenco_bitwise_xor_binary:w #1;#2; - } -\\cs_generate_variant:Nn \\__recuenco_bitwise_xor_binary:nn { ee } - -\\cs_new:Npn \\__recuenco_bitwise_xor_binary:w #1#2;#3#4; - { - \\int_abs:n { #1-#3 } - \\tl_if_empty:nF { #2 } { \\__recuenco_bitwise_xor_binary:w #2;#4; } - } - -\\ExplSyntaxOff''' - - def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack): return f'''\\begin{{document}} \\begin{{tikzpicture}} @@ -254,8 +192,6 @@ def main(): with open("tikzplot.tex") as file: tikz_code = file.read() - preamble_str = draw_preamble_cmd() - draw_blockedLayout_str = draw_blocked_layout_cmd(M, K, sizePerThread, threadsPerWarp, warpsPerCTA, order) draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack) @@ -264,7 +200,6 @@ def main(): draw_wmma_str = draw_wmma_instr_cmd(waveSize) - f_plot.write(preamble_str + "\n") f_plot.write(tikz_code) if plot_mode == 'blocked': f_plot.write(draw_blockedLayout_str) diff --git a/python/perf-kernels/tools/plot-layout/tikzplot.tex b/python/perf-kernels/tools/plot-layout/tikzplot.tex index d8441b042f02..a9d08b16e5db 100644 --- a/python/perf-kernels/tools/plot-layout/tikzplot.tex +++ b/python/perf-kernels/tools/plot-layout/tikzplot.tex @@ -1,3 +1,63 @@ +\documentclass[tikz, border=1mm, dvipsnames]{standalone} +\usepackage{ifthen} +\usepackage{tikz} +\usetikzlibrary{arrows.meta,arrows} +\usetikzlibrary{intersections} +\usetikzlibrary{calc, quotes} +\usetikzlibrary{patterns} +\usepackage{xparse} + +\ExplSyntaxOn +\NewExpandableDocumentCommand{\bitwiseXor}{mm} + { + \recuenco_bitwise_xor:nn { #1 } { #2 } + } + +\cs_new:Nn \recuenco_bitwise_xor:nn + { + \int_from_bin:e + { + \__recuenco_bitwise_xor:ee { \int_to_bin:n { #1 } } { \int_to_bin:n { #2 } } + } + } +\cs_generate_variant:Nn \int_from_bin:n { e } + +\cs_new:Nn \__recuenco_bitwise_xor:nn + { + \__recuenco_bitwise_xor_binary:ee + { + \prg_replicate:nn + { + \int_max:nn { \tl_count:n { #1 } } { \tl_count:n { #2 } } - \tl_count:n { #1 } + } + { 0 } + #1 + } + { + \prg_replicate:nn + { + \int_max:nn { \tl_count:n { #1 } } { \tl_count:n { #2 } } - \tl_count:n { #2 } + } + { 0 } + #2 + } + } +\cs_generate_variant:Nn \__recuenco_bitwise_xor:nn { ee } + +\cs_new:Nn \__recuenco_bitwise_xor_binary:nn + { + \__recuenco_bitwise_xor_binary:w #1;#2; + } +\cs_generate_variant:Nn \__recuenco_bitwise_xor_binary:nn { ee } + +\cs_new:Npn \__recuenco_bitwise_xor_binary:w #1#2;#3#4; + { + \int_abs:n { #1-#3 } + \tl_if_empty:nF { #2 } { \__recuenco_bitwise_xor_binary:w #2;#4; } + } + +\ExplSyntaxOff + \newcommand{\drawBlockedWave}[5]{ %% %% Draw a wave coverage with blocked layout From a1b286a22b7cb9e9ca3d05dcb02dea65220e82b2 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Thu, 19 Dec 2024 17:12:28 -0600 Subject: [PATCH 02/11] Rename kpack to kWidth and allow kWidth = 32 --- .../tools/plot-layout/plot_layout.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/python/perf-kernels/tools/plot-layout/plot_layout.py b/python/perf-kernels/tools/plot-layout/plot_layout.py index a394fe87e099..648f8f86ff17 100644 --- a/python/perf-kernels/tools/plot-layout/plot_layout.py +++ b/python/perf-kernels/tools/plot-layout/plot_layout.py @@ -4,7 +4,7 @@ import subprocess -def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack): +def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth): return f'''\\begin{{document}} \\begin{{tikzpicture}} \\def\\scale{{1}} @@ -14,7 +14,7 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack): \\def\\opColorAR{{cyan}} \\def\\opColorBL{{Maroon}} \\def\\opColorBR{{BlueGreen}} - \\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}}{{{kpack}}} + \\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}}{{{kWidth}}} \\coordinate (C TL) at ($(C TL)+({N}*\elem+32*\elem, 0)$); \\def\\mfmaTrans{{{trans}}} @@ -24,8 +24,8 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack): \\pgfmathsetmacro{{\\gap}}{{\\elem*5}} \\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}} \\pgfmathsetmacro{{\\groups}}{{64/{mfmaNonKDim}}} - \\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kpack}*\\elem, 0)$); - \\drawMFMAInstr{{{mfmaNonKDim}}}{{{kpack}}}{{\\mfmaTrans}} + \\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kWidth}*\\elem, 0)$); + \\drawMFMAInstr{{{mfmaNonKDim}}}{{{kWidth}}}{{\\mfmaTrans}} \\end{{tikzpicture}} \\end{{document}}''' @@ -42,7 +42,7 @@ def draw_blocked_layout_cmd(M, K, sizePerThread, threadsPerWarp, warpsPerCTA, or \\end{{document}}''' -def draw_lds_access_cmd(M, K, kpack, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp): +def draw_lds_access_cmd(M, K, kWidth, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp): if ldsLayout == 'swizzle': hasSwizzle = 1 elif ldsLayout == 'padding': @@ -62,7 +62,7 @@ def draw_lds_access_cmd(M, K, kpack, ldsLayout, ldsAccess, sizePerThread, thread \\def\\scale{{1}} \\def\\M{{{M}}} \\def\\K{{{K}}} - \\def\\vec{{{kpack}}} + \\def\\vec{{{kWidth}}} \\def\\hasSwizzle{{{hasSwizzle}}} \\def\\accessMode{{{accessMode}}} @@ -112,7 +112,7 @@ def parse_args(): parser.add_argument("-warpsPerCTA", type=int, nargs=2, default=(1, 4)) parser.add_argument("-order", type=int, nargs=2, default=(1, 0)) ## LDS access parameters - parser.add_argument("-kWidth", type=int, default=4, choices=[4, 8, 16], help='number of elements per thread') + parser.add_argument("-kWidth", type=int, default=4, choices=[4, 8, 16, 32], help='number of contiguous elements per thread') parser.add_argument("-lds_layout", type=str, default="none", choices=['swizzle', 'padding', 'none'], help='choose the LDS data layout') parser.add_argument("-lds_access", type=str, default="none", choices=['read', 'write', 'none'], @@ -138,7 +138,7 @@ def main(): K = shape[2] plot_mode = args.plot mfmaNonKDim = args.nonKDim - kpack = args.kWidth + kWidth = args.kWidth trans = 1 if args.mfmaTrans else 0 ofilename = args.o keepSrc = args.keep @@ -167,7 +167,7 @@ def main(): mfma_inst_str = "mfma_32x32" if mfmaNonKDim == 32 else "mfma_16x16" mfma_trans_str = ".trans" if trans else "" print(f"Plotting dot operation with shapes M={M},N={N},K={K}") - print("MFMA: " + mfma_inst_str + mfma_trans_str + f" kWidth = {kpack}", end=" ") + print("MFMA: " + mfma_inst_str + mfma_trans_str + f" kWidth = {kWidth}", end=" ") print(f"warpsPerCTA={warpsPerCTA}", end=" ") CTAShape.append(mfmaNonKDim * warpsPerCTA[0]) CTAShape.append(mfmaNonKDim * warpsPerCTA[1]) @@ -181,10 +181,10 @@ def main(): if plot_mode == 'dot': assert N != 0 and CTAShape[1] <= N and N % CTAShape[1] == 0, "bad tensor dimension N" - assert K != 0 and K % (2 * kpack) == 0, "bad tensor dimension K" + assert K != 0 and K % (2 * kWidth) == 0, "bad tensor dimension K" if plot_mode == 'lds': - print(f"Plotting LDS access for tensor M={M},K={K} with vec={kpack}") + print(f"Plotting LDS access for tensor M={M},K={K} with vec={kWidth}") if ldsAccess == 'write': print(f"sizePerThread={sizePerThread}, threadsPerWarp={threadsPerWarp}") @@ -194,9 +194,9 @@ def main(): draw_blockedLayout_str = draw_blocked_layout_cmd(M, K, sizePerThread, threadsPerWarp, warpsPerCTA, order) - draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack) + draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth) - draw_lds_str = draw_lds_access_cmd(M, K, kpack, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp) + draw_lds_str = draw_lds_access_cmd(M, K, kWidth, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp) draw_wmma_str = draw_wmma_instr_cmd(waveSize) From ef61330dfd5c3f6ee7f77ce71daaa2ee59253fcc Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Sun, 22 Dec 2024 10:26:29 -0600 Subject: [PATCH 03/11] [API change] Take user input to set dim names API change: - For blocked layout, use -tensorShape, which only takes two dims as dim0,dim1 - For dot layout, use -dotShape, which takes three dims as M,N,K --- .../perf-kernels/tools/plot-layout/README.md | 26 +++++----- .../tools/plot-layout/plot_layout.py | 52 ++++++++++++------- .../tools/plot-layout/tikzplot.tex | 6 ++- 3 files changed, 49 insertions(+), 35 deletions(-) diff --git a/python/perf-kernels/tools/plot-layout/README.md b/python/perf-kernels/tools/plot-layout/README.md index 40de35bdb3aa..563e42c697d9 100644 --- a/python/perf-kernels/tools/plot-layout/README.md +++ b/python/perf-kernels/tools/plot-layout/README.md @@ -41,36 +41,34 @@ sudo apt install texlive-full Examples: ```bash -python3 plot_layout.py -plot blocked -shape 128 128 64 -sizePerThread 1 8 -threadsPerWarp 8 8 -warpsPerCTA 4 1 -python3 plot_layout.py -plot blocked -shape 16 128 64 -sizePerThread 1 8 -threadsPerWarp 16 4 -warpsPerCTA 1 2 -python3 plot_layout.py -plot blocked -shape 32 128 64 -sizePerThread 8 1 -threadsPerWarp 4 16 -warpsPerCTA 1 2 -order 0 1 +python3 plot_layout.py -plot blocked -tensorShape 128 64 -sizePerThread 1 8 -threadsPerWarp 8 8 -warpsPerCTA 4 1 +python3 plot_layout.py -plot blocked -tensorShape 16 64 -sizePerThread 1 8 -threadsPerWarp 16 4 -warpsPerCTA 1 2 +python3 plot_layout.py -plot blocked -tensorShape 32 64 -sizePerThread 8 1 -threadsPerWarp 4 16 -warpsPerCTA 1 2 -order 0 1 ``` Blocked layouts are used during global load. It is used to describe the layout of the tensor for pointers and results. -We can provide tensor shape (`-shape M N K`) and blocked layout parameters ( +We can provide tensor shape (`-tensorShape dim0 dim1`) and blocked layout parameters ( `-sizePerThread x y`, `-threadsPerWarp x y`, and `-warpsPerCTA x y`). We can also provide the order of the tensor as `-order x y` to control which dim is the fastest changing dimension. Notes -- All of the gemm dims (M, N, and K) are needed when providing the shape. But only - M and K will be used to plot the layout of the tensor. - The script does not support the case when threads are loading elements that are out of the boundary of the tensor dimensions. This means - - For M: sizePerThread[0] * threadsPerWarps[0] * warpsPerCTA[0] <= M - - For K: sizePerThread[1] * threadsPerWarps[1] * warpsPerCTA[1] <= K + - For dim0: sizePerThread[0] * threadsPerWarps[0] * warpsPerCTA[0] <= dim0 + - For dim1: sizePerThread[1] * threadsPerWarps[1] * warpsPerCTA[1] <= dim1 ## Draw mfma operand and result layouts (`-plot dot`) Examples: ```bash -python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 4 -python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 -python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 -mfmaTrans -python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 8 -python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 16 +python3 plot_layout.py -plot dot -dotShape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 4 +python3 plot_layout.py -plot dot -dotShape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 +python3 plot_layout.py -plot dot -dotShape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 -mfmaTrans +python3 plot_layout.py -plot dot -dotShape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 8 +python3 plot_layout.py -plot dot -dotShape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 16 ``` This mode draws two graphs: @@ -86,7 +84,7 @@ Knobs Notes - The layout shows the mapping from the threads/wave to the elements in the - original tensor. It does not care if the elements are arranged in LDS, like + original tensor. It does not care if the elements are re-arranged in LDS, like swizzling to avoid bank conflicts. - The script does not allow settings for data type or k dim of the mfma instruction. This can be controled by the `-kWidth` flag. diff --git a/python/perf-kernels/tools/plot-layout/plot_layout.py b/python/perf-kernels/tools/plot-layout/plot_layout.py index 648f8f86ff17..785f674bb386 100644 --- a/python/perf-kernels/tools/plot-layout/plot_layout.py +++ b/python/perf-kernels/tools/plot-layout/plot_layout.py @@ -31,13 +31,15 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth): \\end{{document}}''' -def draw_blocked_layout_cmd(M, K, sizePerThread, threadsPerWarp, warpsPerCTA, order): +def draw_blocked_layout_cmd(dim0, dim1, dim0Name, dim1Name, sizePerThread, threadsPerWarp, warpsPerCTA, order): return f'''\\begin{{document}} \\begin{{tikzpicture}} \\def\\scale{{1}} \\def\\elem{{0.06}} \\coordinate (TL) at (0,0); - \\drawBlockedTensor{{{M}}}{{{K}}}{{{sizePerThread[0]}}}{{{sizePerThread[1]}}}{{{threadsPerWarp[0]}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{order[0]}}} + \\def\\dimColName{{{dim0Name}}} + \\def\\dimRowName{{{dim1Name}}} + \\drawBlockedTensor{{{dim0}}}{{{dim1}}}{{{sizePerThread[0]}}}{{{sizePerThread[1]}}}{{{threadsPerWarp[0]}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{order[0]}}} \\end{{tikzpicture}} \\end{{document}}''' @@ -102,17 +104,22 @@ def parse_args(): allow_abbrev=False, ) ## tensor shapes - parser.add_argument("-shape", type=int, nargs=3, default=(32, 128, 64), help='Tensor shape in the form of M,N,K') + parser.add_argument("-tensorShape", type=int, nargs=2, default=(128, 64), + help='2D tensor shape in the form of dim0,dim1') + parser.add_argument("-dotShape", type=int, nargs=3, default=(32, 128, 64), help='Dot op shape in the form of M,N,K') parser.add_argument("-plot", type=str, default="blocked", choices=['blocked', 'dot', 'wmma', 'lds'], help='choose plot mode') parser.add_argument("-nonKDim", type=int, default=32, choices=[16, 32], help='mfma instruction dim') + parser.add_argument("-dim0", type=str, default="M", help='tensor dim0 name') + parser.add_argument("-dim1", type=str, default="K", help='tensor dim1 name') ## blocked layout parameters parser.add_argument("-sizePerThread", type=int, nargs=2, default=(1, 4)) parser.add_argument("-threadsPerWarp", type=int, nargs=2, default=(16, 4)) parser.add_argument("-warpsPerCTA", type=int, nargs=2, default=(1, 4)) parser.add_argument("-order", type=int, nargs=2, default=(1, 0)) ## LDS access parameters - parser.add_argument("-kWidth", type=int, default=4, choices=[4, 8, 16, 32], help='number of contiguous elements per thread') + parser.add_argument("-kWidth", type=int, default=4, choices=[4, 8, 16, 32], + help='number of contiguous elements per thread') parser.add_argument("-lds_layout", type=str, default="none", choices=['swizzle', 'padding', 'none'], help='choose the LDS data layout') parser.add_argument("-lds_access", type=str, default="none", choices=['read', 'write', 'none'], @@ -132,10 +139,15 @@ def parse_args(): def main(): args = parse_args() - shape = args.shape - M = shape[0] - N = shape[1] - K = shape[2] + dotShape = args.dotShape + M = dotShape[0] + N = dotShape[1] + K = dotShape[2] + tShape = args.tensorShape + dim0 = tShape[0] + dim1 = tShape[1] + dim0Name = args.dim0 + dim1Name = args.dim1 plot_mode = args.plot mfmaNonKDim = args.nonKDim kWidth = args.kWidth @@ -155,31 +167,32 @@ def main(): CTAShape = [] if plot_mode == 'blocked': - print(f"Plotting tensor M={M},K={K} with blocked layout:") - print(f"sizePerThread={sizePerThread}", end=" ") - print(f"threadsPerWarp={threadsPerWarp}", end=" ") - print(f"warpsPerCTA={warpsPerCTA}", end=" ") - print(f"order={order}", end=" ") + print(f"Plotting tensor {dim0Name}={dim0},{dim1Name}={dim1} with blocked layout:") + print(f"{sizePerThread=}", end=" ") + print(f"{threadsPerWarp=}", end=" ") + print(f"{warpsPerCTA=}", end=" ") + print(f"{order=}", end=" ") CTAShape.append(sizePerThread[0] * threadsPerWarp[0] * warpsPerCTA[0]) CTAShape.append(sizePerThread[1] * threadsPerWarp[1] * warpsPerCTA[1]) if plot_mode == 'dot': mfma_inst_str = "mfma_32x32" if mfmaNonKDim == 32 else "mfma_16x16" mfma_trans_str = ".trans" if trans else "" - print(f"Plotting dot operation with shapes M={M},N={N},K={K}") - print("MFMA: " + mfma_inst_str + mfma_trans_str + f" kWidth = {kWidth}", end=" ") - print(f"warpsPerCTA={warpsPerCTA}", end=" ") + print(f"Plotting dot operation with shapes {M=},{N=},{K=}") + print("MFMA: " + mfma_inst_str + mfma_trans_str + f" {kWidth=}", end=" ") + print(f"{warpsPerCTA=}", end=" ") CTAShape.append(mfmaNonKDim * warpsPerCTA[0]) CTAShape.append(mfmaNonKDim * warpsPerCTA[1]) if plot_mode == 'blocked' or plot_mode == 'dot': print(f"CTAShape={CTAShape}") - assert M != 0 and CTAShape[0] <= M and M % CTAShape[0] == 0, "bad tensor dimension M" if plot_mode == 'blocked': - assert K != 0 and CTAShape[1] <= K and K % CTAShape[1] == 0, "bad tensor dimension K" + assert dim0 != 0 and CTAShape[0] <= dim0 and dim0 % CTAShape[0] == 0, "bad tensor dimension " + dim0Name + assert dim1 != 0 and CTAShape[1] <= dim1 and dim1 % CTAShape[1] == 0, "bad tensor dimension " + dim1Name if plot_mode == 'dot': + assert M != 0 and CTAShape[0] <= M and M % CTAShape[0] == 0, "bad tensor dimension M" assert N != 0 and CTAShape[1] <= N and N % CTAShape[1] == 0, "bad tensor dimension N" assert K != 0 and K % (2 * kWidth) == 0, "bad tensor dimension K" @@ -192,7 +205,8 @@ def main(): with open("tikzplot.tex") as file: tikz_code = file.read() - draw_blockedLayout_str = draw_blocked_layout_cmd(M, K, sizePerThread, threadsPerWarp, warpsPerCTA, order) + draw_blockedLayout_str = draw_blocked_layout_cmd(dim0, dim1, dim0Name, dim1Name, sizePerThread, threadsPerWarp, + warpsPerCTA, order) draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth) diff --git a/python/perf-kernels/tools/plot-layout/tikzplot.tex b/python/perf-kernels/tools/plot-layout/tikzplot.tex index a9d08b16e5db..6af05c923597 100644 --- a/python/perf-kernels/tools/plot-layout/tikzplot.tex +++ b/python/perf-kernels/tools/plot-layout/tikzplot.tex @@ -159,6 +159,8 @@ %% %% TL: pre defined top-left coordinate of the tensor %% \elem: pre defined variable + %% \dimColName: dim0Name + %% \dimRowName: dim1Name %% %% #1: tensorShape[0] --> M %% #2: tensorShape[1] --> N @@ -193,8 +195,8 @@ \drawBlockedCTA{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\warpsPerCTAM}{\warpsPerCTAN}{\order} } - \node [scale=.7*\scale, above, rotate=90] at ($(TL)+(0, -.5*\M*\elem)$) {M=\M}; - \node [scale=.7*\scale, above] at ($(TL)+(.5*\N*\elem, 0)$) {K=\N}; + \node [scale=.7*\scale, above, rotate=90] at ($(TL)+(0, -.5*\M*\elem)$) {\dimColName=\M}; + \node [scale=.7*\scale, above] at ($(TL)+(.5*\N*\elem, 0)$) {\dimRowName=\N}; \def\zoomR{1.5} \coordinate (zoomin BL) at ($(TL)+(0, .3)$); From 48ab75842808e496a27a1d9f1b1c9ae84cd68afc Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Sun, 22 Dec 2024 21:37:36 -0600 Subject: [PATCH 04/11] Re-structure files Separate each layout's code into their own files --- .../tools/plot-layout/blockedLayout.tex | 157 +++ .../tools/plot-layout/dotLayout.tex | 359 +++++++ .../tools/plot-layout/ldsLayout.tex | 274 +++++ .../tools/plot-layout/plot_layout.py | 14 +- .../tools/plot-layout/preamble.tex | 26 + .../tools/plot-layout/tikzplot.tex | 942 ------------------ .../tools/plot-layout/wmmaLayout.tex | 121 +++ 7 files changed, 944 insertions(+), 949 deletions(-) create mode 100644 python/perf-kernels/tools/plot-layout/blockedLayout.tex create mode 100644 python/perf-kernels/tools/plot-layout/dotLayout.tex create mode 100644 python/perf-kernels/tools/plot-layout/ldsLayout.tex create mode 100644 python/perf-kernels/tools/plot-layout/preamble.tex delete mode 100644 python/perf-kernels/tools/plot-layout/tikzplot.tex create mode 100644 python/perf-kernels/tools/plot-layout/wmmaLayout.tex diff --git a/python/perf-kernels/tools/plot-layout/blockedLayout.tex b/python/perf-kernels/tools/plot-layout/blockedLayout.tex new file mode 100644 index 000000000000..37aba60f5bf0 --- /dev/null +++ b/python/perf-kernels/tools/plot-layout/blockedLayout.tex @@ -0,0 +1,157 @@ +\newcommand{\drawBlockedWave}[5]{ + %% + %% Draw a wave coverage with blocked layout + %% + %% Wave TL: pre defined top-left coordinate of the wave + %% \elem: pre defined variable + %% + %% #1: sizePerThread[0] --> sizePerThreadM + %% #2: sizePerThread[1] --> sizePerThreadN + %% #3: threadsPerWarp[0] --> threadsPerWarpM + %% #4: threadsPerWarp[1] --> threadsPerWarpN + %% #5: fastest changing dim --> order + + \pgfmathsetmacro{\sizePerThreadM}{#1} + \pgfmathsetmacro{\sizePerThreadN}{#2} + \pgfmathsetmacro{\threadsPerWarpM}{#3} + \pgfmathsetmacro{\threadsPerWarpN}{#4} + \pgfmathsetmacro{\order}{#5} + + \pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM} + \pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN} + + \foreach \tid in {0,...,63}{ + \pgfmathsetmacro{\tidM}{int(\tid/\threadsPerWarpN)} + \pgfmathsetmacro{\tidN}{mod(\tid,\threadsPerWarpN)} + \coordinate (Thread TL) at ($(Wave TL)+(\tidN*\sizePerThreadN*\elem, -\tidM*\sizePerThreadM*\elem)$); + \pgfmathsetmacro{\ratio}{\tidM*10} + + \ifthenelse{\tid = 0}{ + \draw [line width = 0.01mm, fill=red] (Thread TL) + rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); + }{ + \draw [line width = 0.01mm, fill=blue!\ratio!white] (Thread TL) + rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); + } + } + \draw (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem); +} + +\newcommand{\drawBlockedCTA}[7]{ + %% + %% Draw a CTA coverage with blocked layout + %% + %% CTA TL: pre defined top-left coordinate of the CTA + %% \elem: pre defined variable + %% + %% #1: sizePerThread[0] --> sizePerThreadM + %% #2: sizePerThread[1] --> sizePerThreadN + %% #3: threadsPerWarp[0] --> threadsPerWarpM + %% #4: threadsPerWarp[1] --> threadsPerWarpN + %% #5: warpsPerCTA[0] --> warpsPerCTAM + %% #6: warpsPerCTA[1] --> warpsPerCTAN + %% #7: fastest changing dim --> order + + \pgfmathsetmacro{\sizePerThreadM}{#1} + \pgfmathsetmacro{\sizePerThreadN}{#2} + \pgfmathsetmacro{\threadsPerWarpM}{#3} + \pgfmathsetmacro{\threadsPerWarpN}{#4} + \pgfmathsetmacro{\warpsPerCTAM}{#5} + \pgfmathsetmacro{\warpsPerCTAN}{#6} + \pgfmathsetmacro{\order}{#7} + + \pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM} + \pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN} + \pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM} + \pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN} + + \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM*\warpsPerCTAN-1} + + \coordinate (Wave TL) at (CTA TL); + \drawBlockedWave{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\order} + \foreach \waveId in {0,...,\maxWaveId}{ + \ifthenelse{\order=1} + { + \pgfmathsetmacro{\waveCoordM}{int(\waveId/\warpsPerCTAN)} + \pgfmathsetmacro{\waveCoordN}{mod(\waveId,\warpsPerCTAN)} + \pgfmathsetmacro{\rot}{0} + }{ + \pgfmathsetmacro{\waveCoordM}{mod(\waveId,\warpsPerCTAM)} + \pgfmathsetmacro{\waveCoordN}{int(\waveId/\warpsPerCTAM)} + \pgfmathsetmacro{\rot}{90} + } + + \coordinate (Wave TL) at ($(CTA TL)+(\waveCoordN*\waveSizeN*\elem, -\waveCoordM*\waveSizeM*\elem)$); + \draw [ultra thin] (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem) + node [pos=.5, scale=.6*\scale, inner sep=0, fill=white, rotate=\rot] {wave\waveId}; + } + + \draw [thick] (CTA TL) rectangle ++(\CTASizeN*\elem, -\CTASizeM*\elem); +} + +\newcommand{\drawBlockedTensor}[8]{ + %% + %% Draw a tensor with blocked layout of the following parameters + %% sizePerThread[2] + %% threadsPerWarp[2] + %% warpsPerCTA[2] + %% order[2] + %% + %% TL: pre defined top-left coordinate of the tensor + %% \elem: pre defined variable + %% \dimColName: dim0Name + %% \dimRowName: dim1Name + %% + %% #1: tensorShape[0] --> M + %% #2: tensorShape[1] --> N + %% #3: sizePerThread[0] --> sizePerThreadM + %% #4: sizePerThread[1] --> sizePerThreadN + %% #5: threadsPerWarp[0] --> threadsPerWarpM + %% Note that threadsPerWarp[1] is calculated by 64/threadsPerWarp[0] + %% #6: warpsPerCTA[0] --> warpsPerCTAM + %% #7: warpsPerCTA[1] --> warpsPerCTAN + %% #8: fastest changing dim --> order + + \pgfmathsetmacro{\M}{#1} + \pgfmathsetmacro{\N}{#2} + \pgfmathsetmacro{\sizePerThreadM}{#3} + \pgfmathsetmacro{\sizePerThreadN}{#4} + \pgfmathsetmacro{\threadsPerWarpM}{#5} + \pgfmathsetmacro{\warpsPerCTAM}{#6} + \pgfmathsetmacro{\warpsPerCTAN}{#7} + \pgfmathsetmacro{\order}{#8} + + \pgfmathsetmacro{\threadsPerWarpN}{64/\threadsPerWarpM} + \pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM} + \pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN} + \pgfmathsetmacro{\CTARepM}{\M/\CTASizeM} + \pgfmathsetmacro{\CTARepN}{\N/\CTASizeN} + \pgfmathsetmacro{\maxCTAId}{\CTARepM*\CTARepN-1} + + \foreach \ctaId in {0,...,\maxCTAId}{ + \pgfmathsetmacro{\ctaCoordM}{int(\ctaId/\CTARepN)} + \pgfmathsetmacro{\ctaCoordN}{mod(\ctaId,\CTARepN)} + \coordinate (CTA TL) at ($(TL)+(\ctaCoordN*\CTASizeN*\elem, -\ctaCoordM*\CTASizeM*\elem)$); + \drawBlockedCTA{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\warpsPerCTAM}{\warpsPerCTAN}{\order} + } + + \node [scale=.7*\scale, above, rotate=90] at ($(TL)+(0, -.5*\M*\elem)$) {\dimColName=\M}; + \node [scale=.7*\scale, above] at ($(TL)+(.5*\N*\elem, 0)$) {\dimRowName=\N}; + + \def\zoomR{1.5} + \coordinate (zoomin BL) at ($(TL)+(0, .3)$); + + \foreach \hl in {0,...,\sizePerThreadM}{ + \draw ($(zoomin BL)+(0, \hl*\elem*\zoomR)$) -- ++(\sizePerThreadN*\elem*\zoomR,0); + } + \foreach \vl in {0,...,\sizePerThreadN}{ + \draw ($(zoomin BL)+(\vl*\elem*\zoomR, 0)$) -- ++(0, \sizePerThreadM*\elem*\zoomR); + } + + \node [scale=.6*\scale, left] at ($(zoomin BL)+(0, .5*\sizePerThreadM*\elem*\zoomR)$) {$t_0$}; + \node [scale=.6*\scale, right] at ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, .5*\sizePerThreadM*\elem*\zoomR)$) {\sizePerThreadM$\times$\sizePerThreadN}; + + \draw [densely dotted] (TL) -- (zoomin BL); + \draw [densely dotted] ($(TL)+(\sizePerThreadN*\elem, 0)$) -- ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, 0)$); + \draw [fill=red] (TL) rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); +} diff --git a/python/perf-kernels/tools/plot-layout/dotLayout.tex b/python/perf-kernels/tools/plot-layout/dotLayout.tex new file mode 100644 index 000000000000..f9ef6df60aff --- /dev/null +++ b/python/perf-kernels/tools/plot-layout/dotLayout.tex @@ -0,0 +1,359 @@ +\newcommand{\drawBlockMFMALayoutLarge}[3]{ + %% + %% Draw a single block of MFMA_32x32x8xf16 or MFMA_16x16x16xf16 + %% + %% block TL: pre-defined top-left coordinate of the block + %% \elem: pre defined variable + %% + %% #1: 1 for mfma.trans, 0 for normal mfma + %% #2: mfmaNonKDim + %% #3: verbose. 1 means draw tid in each vec; 0 means draw nothing + + \pgfmathsetmacro{\trans}{#1} + \pgfmathsetmacro{\nonTrans}{1-#1} + \pgfmathsetmacro{\nonKDim}{#2} + \pgfmathsetmacro{\maxTID}{\nonKDim-1} + \pgfmathsetmacro{\groups}{64/\nonKDim} + \pgfmathsetmacro{\maxGID}{\groups-1} + \pgfmathsetmacro{\maxIVec}{\nonKDim*\nonKDim/256-1} + \pgfmathsetmacro{\verbose}{#3} + \foreach \iVec in {0,...,\maxIVec} { + \coordinate (wave TL) at ($(block TL)+(\trans*\iVec*\groups*4*\elem, -\nonTrans*\iVec*\groups*4*\elem)$); + \foreach \tg in {0,...,\maxGID}{ + \pgfmathsetmacro{\colID}{\tg+4} + \pgfmathsetmacro{\col}{\Colors[\colID]} + \foreach \tid in {0,...,\maxTID} { + \pgfmathsetmacro{\ratio}{\tid*2.5*\groups+15} + \ifthenelse{\verbose=0}{ + \draw [line width=0.005mm, fill=\col!\ratio!white] + ($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$) + rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem); + }{ + \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} + \draw [line width=0.005mm, fill=\col!\ratio!white] + ($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$) + rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem) + node [pos=.5, scale=.35*\scale, rotate=90*\nonTrans] {t\drawTid}; + } + } + } + } + \draw [thick] (block TL) rectangle ++(\nonKDim*\elem, -\nonKDim*\elem); +} + + +\newcommand{\drawTensorMFMALayout}[6]{ + %% + %% Draw a tensor with mfma layout. + %% + %% C TL: pre defined top-left coordinates of the tensor + %% + %% #1: M + %% #2: N + %% #3: MFMA nonKDim + %% #4: warpsPerCTA[0] + %% #5: warpsPerCTA[1] + %% #6: 1 for mfma.trans, 0 for normal mfma + + \pgfmathsetmacro{\tensorShapeH}{#1} + \pgfmathsetmacro{\tensorShapeW}{#2} + \pgfmathsetmacro{\mfmaNonKDim}{#3} + \pgfmathsetmacro{\warpsPerCTAH}{#4} + \pgfmathsetmacro{\warpsPerCTAW}{#5} + \pgfmathsetmacro{\mfmaTrans}{#6} + + \coordinate (old TL) at (TL); + \coordinate (TL) at (C TL); + + + \pgfmathsetmacro{\CTARepH}{\tensorShapeH/\mfmaNonKDim/\warpsPerCTAH} + \pgfmathsetmacro{\CTARepW}{\tensorShapeW/\mfmaNonKDim/\warpsPerCTAW} + \pgfmathsetmacro{\maxCTAId}{\CTARepH*\CTARepW-1} + \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAH*\warpsPerCTAW-1} + \pgfmathsetmacro{\CTASizeH}{\warpsPerCTAH*\mfmaNonKDim} + \pgfmathsetmacro{\CTASizeW}{\warpsPerCTAW*\mfmaNonKDim} + + + \foreach \ctaId in {0,...,\maxCTAId}{ + \pgfmathsetmacro{\ctaCoordH}{int(\ctaId/\CTARepW)} + \pgfmathsetmacro{\ctaCoordW}{mod(\ctaId,\CTARepW)} + \coordinate (CTA TL) at ($(TL)+(\ctaCoordW*\CTASizeW*\elem, -\ctaCoordH*\CTASizeH*\elem)$); + %% Draw a detailed view of wave0 in each CTA + \coordinate (block TL) at (CTA TL); + \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{0} + + \foreach \waveId in {0,...,\maxWaveId}{ + \pgfmathsetmacro{\waveCoordH}{int(\waveId/\warpsPerCTAW)} + \pgfmathsetmacro{\waveCoordW}{mod(\waveId,\warpsPerCTAW)} + \coordinate (block TL) at ($(CTA TL)+(\waveCoordW*\mfmaNonKDim*\elem, -\waveCoordH*\mfmaNonKDim*\elem)$); + %% Inside the loop, only draw a rectangle + \draw [ultra thin] (block TL) rectangle ++(\mfmaNonKDim*\elem, -\mfmaNonKDim*\elem) + node [scale=.7*\mfmaNonKDim/32*\scale, pos=.5, fill=white, inner sep=0] {wave\waveId}; + } + + %% Draw the outline of each CTA rep + \draw [ultra thick] (CTA TL) rectangle ++(\CTASizeW*\elem, -\CTASizeH*\elem); + } + + \coordinate (TL) at (old TL); +} + +\newcommand{\drawMFMAOperand}[4]{ + %% + %% Draw one mfma operand + %% + %% mfma op TL: pre defined coordinates of the top-left + %% \elem: pre defined variable + %% + %% #1: mfmNonKDim + %% #2: kpack + %% #3: 0 for opA and 1 for opB + %% #4: verbose. 1 means draw tid in each vec; 0 means draw nothing + + \pgfmathsetmacro{\nonKDim}{#1} + \pgfmathsetmacro{\maxGID}{64/\nonKDim-1} + \pgfmathsetmacro{\maxTID}{\nonKDim-1} + \pgfmathsetmacro{\kpack}{#2} + \pgfmathsetmacro{\opIdxA}{#3} + \pgfmathsetmacro{\opIdxB}{1-\opIdxA} + \pgfmathsetmacro{\verbose}{#4} + + \foreach \col/\tg in {0,...,\maxGID}{ + \pgfmathsetmacro{\col}{\Colors[\tg]} + \foreach \tid in {0,...,\maxTID} { + % \pgfmathsetmacro{\ratio}{\tid*2.5+15} + \ifthenelse{\verbose=0}{ + \draw [line width=0.005mm, fill=\col] + ($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$) + rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA); + }{ + \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} + \draw [line width=0.005mm, fill=\col] + ($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$) + rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA) + node [pos=.5, scale=.35*\scale, rotate=90*\opIdxA] {t\drawTid}; + } + } + } +} + +\newcommand{\drawWaveOperand}[4]{ + %% + %% Draw the part of the tensor that is one operand of the wave + %% + %% Op TL: pre defined coordinates of the top-left of the operand + %% \elem: pre defined variable + %% + %% #1: K + %% #2: mfmNonKDim + %% #3: kpack + %% #4: 0 for opA and 1 for opB + + \pgfmathsetmacro{\K}{#1} + \pgfmathsetmacro{\nonKDim}{#2} + \pgfmathsetmacro{\groups}{64/\nonKDim} + \pgfmathsetmacro{\kpack}{#3} + \pgfmathsetmacro{\opIdx}{#4} + \pgfmathsetmacro{\opIdxOther}{1-\opIdx} + + \coordinate (TL) at (Op TL); + + \pgfmathsetmacro{\numKRep}{\K/\kpack/\groups} + \pgfmathsetmacro{\maxKRepId}{\numKRep-1} + + \foreach \repId in {0,...,\maxKRepId}{ + \coordinate (mfma op TL) at ($(TL)+(\repId*\groups*\kpack*\elem*\opIdxOther, -\repId*\groups*\kpack*\elem*\opIdx)$); + \drawMFMAOperand{\nonKDim}{\kpack}{\opIdx}{0} + \draw [thick] (mfma op TL) rectangle + ++(\groups*\kpack*\elem*\opIdxOther+\nonKDim*\opIdx*\elem, -\nonKDim*\opIdxOther*\elem-\groups*\kpack*\elem*\opIdx); + } +} + +\newcommand{\drawDotOperands}[7]{ + %% + %% Draw operand tensors of dot + %% + %% A TL and B TL: pre defined top-left coordinates of A and B tensor + %% \elem: pre defined variable + %% + %% #1: M + %% #2: N + %% #3: K + %% #4: MFMA nonKDim + %% #5: warpsPerCTA[0] + %% #6: warpsPerCTA[1] + %% #7: kpack + + \pgfmathsetmacro{\M}{#1} + \pgfmathsetmacro{\N}{#2} + \pgfmathsetmacro{\K}{#3} + \pgfmathsetmacro{\mfmaNonKDim}{#4} + \pgfmathsetmacro{\warpsPerCTAM}{#5} + \pgfmathsetmacro{\warpsPerCTAN}{#6} + \pgfmathsetmacro{\kpack}{#7} + + %% operand A + \pgfmathsetmacro{\CTARepM}{\M/\warpsPerCTAM/\mfmaNonKDim} + \pgfmathsetmacro{\maxCTAIdM}{\CTARepM-1} + \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM-1} + \foreach \ctaId in {0,...,\maxCTAIdM}{ + \coordinate (CTA TL) at ($(A TL)+(0, -\ctaId*\warpsPerCTAM*\mfmaNonKDim*\elem)$); + \foreach \waveId in {0,...,\maxWaveId}{ + \coordinate (wave TL) at ($(CTA TL)+(0, -\waveId*\mfmaNonKDim*\elem)$); + \draw [ultra thin] (wave TL) rectangle ++(\K*\elem, -\mfmaNonKDim*\elem); + } + %% Only draw the detailed view of the first wave in CTA + \coordinate (Op TL) at (CTA TL); + \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{0} + + %% Draw the outline of each CTA rep + \draw [ultra thick] (CTA TL) rectangle ++(\K*\elem, -\warpsPerCTAM*\mfmaNonKDim*\elem); + } + \draw [ultra thin] (A TL) rectangle ++(\K*\elem, -\M*\elem); + + + %% operand B + \pgfmathsetmacro{\CTARepN}{\N/\warpsPerCTAN/\mfmaNonKDim} + \pgfmathsetmacro{\maxCTAIdN}{\CTARepN-1} + \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAN-1} + \foreach \ctaId in {0,...,\maxCTAIdN}{ + \coordinate (CTA TL) at ($(B TL)+(\ctaId*\warpsPerCTAN*\mfmaNonKDim*\elem, 0)$); + \foreach \waveId in {0,...,\maxWaveId}{ + \coordinate (wave TL) at ($(CTA TL)+(\waveId*\mfmaNonKDim*\elem ,0)$); + \draw [ultra thin] (wave TL) rectangle ++(\mfmaNonKDim*\elem, -\K*\elem); + } + %% Only draw the detailed view of the first wave in CTA + \coordinate (Op TL) at (CTA TL); + \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{1} + + %% Draw the outline of each CTA rep + \draw [ultra thick] (CTA TL) rectangle ++(\warpsPerCTAN*\mfmaNonKDim*\elem, -\K*\elem); + } + \draw [ultra thin] (B TL) rectangle ++(\N*\elem, -\K*\elem); +} + + +\newcommand{\drawDot}[8]{ + %% + %% Draw C = dot A, B + %% + %% C TL: pre defined top-left coordinates of the result tensor + %% \elem: pre defined variable + %% + %% #1: M + %% #2: N + %% #3: K + %% #4: MFMA nonKDim + %% #5: warpsPerCTA[0] + %% #6: warpsPerCTA[1] + %% #7: 1 for mfma.trans, 0 for normal mfma + %% #8: kpack + + \pgfmathsetmacro{\M}{#1} + \pgfmathsetmacro{\N}{#2} + \pgfmathsetmacro{\K}{#3} + \pgfmathsetmacro{\mfmaNonKDim}{#4} + \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} + \pgfmathsetmacro{\warpsPerCTAM}{#5} + \pgfmathsetmacro{\warpsPerCTAN}{#6} + \pgfmathsetmacro{\mfmaTrans}{#7} + \pgfmathsetmacro{\kpack}{#8} + \pgfmathsetmacro{\kdim}{int(\groups*\kpack)} + + \pgfmathsetmacro{\gap}{\elem*20} + \coordinate (A TL) at ($(C TL)+(-\gap-\K*\elem, 0)$); + \coordinate (B TL) at ($(C TL)+(0, \gap+\K*\elem)$); + + \drawDotOperands{\M}{\N}{\K}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\kpack} + + \drawTensorMFMALayout{\M}{\N}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\mfmaTrans} + + %% Draw labels + \node [scale=\scale, above] at ($(A TL)+(.5*\K*\elem, 0)$) {K=\K}; + \node [scale=\scale, above, rotate=90] at ($(A TL)+(0, -.5*\M*\elem)$) {M=\M}; + + \node [scale=\scale, above, rotate=90] at ($(B TL)+(0, -.5*\K*\elem)$) {K=\K}; + \node [scale=\scale, above] at ($(B TL)+(.5*\N*\elem, 0)$) {N=\N}; + + \node [scale=\scale, above left] at (A TL) {A}; + \node [scale=\scale, above left] at (B TL) {B}; + \node [scale=\scale, above left] at (C TL) {C}; + + %% label nonKDim + \node [scale=.8*\scale, left] at ($(A TL)+(0, -.5*\mfmaNonKDim*\elem)$) {\mfmaNonKDim}; + \node [scale=.8*\scale, above] at ($(B TL)+(.5*\mfmaNonKDim*\elem, 0)$) {\mfmaNonKDim}; + %% label kpack + \node [scale=.8*\scale, above] at ($(A TL)+(0.5*\groups*\kpack*\elem, 0)$) {\kdim}; + \node [scale=.8*\scale, left] at ($(B TL)+(0, -0.5*\groups\kpack*\elem)$) {\kdim}; +} + +\newcommand{\drawMFMAInstr}[3]{ + %% + %% Draw layout of mfma instructions with tid labeled + %% + %% C TL: pre defined top-left coordinates of the output matrix + %% \elem: pre defined variable + %% + %% #1: mfmaNonKDim + %% #2: kpack + %% #3: mfmaTrans + \pgfmathsetmacro{\mfmaNonKDim}{#1} + \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} + \pgfmathsetmacro{\kpack}{#2} + \pgfmathsetmacro{\mfmaTrans}{#3} + \pgfmathsetmacro{\nonTrans}{1-#3} + + \pgfmathsetmacro{\gap}{\elem*5} + \coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-\groups*\kpack*\elem, 0)$); + \coordinate (mfma op TL) at (mfma opA TL); + \drawMFMAOperand{\mfmaNonKDim}{\kpack}{0}{1} + \coordinate (mfma op TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+\groups*\kpack*\elem)$); + \drawMFMAOperand{\mfmaNonKDim}{\kpack}{1}{1} + + \coordinate (block TL) at (C TL); + \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{1} + + %% Draw labels + \def\vecR{1.5} + \coordinate (vec TL) at ($(mfma opA TL)+(-.25*\kpack*\elem, 3*\elem*\vecR)$); + \pgfmathsetmacro{\maxVec}{\kpack-1} + \foreach \vecId in {0,...,\maxVec}{ + \draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR); + } + \draw [densely dotted] (mfma opA TL) -- ($(vec TL)+(0, -\elem*\vecR)$); + \draw [densely dotted] ($(mfma opA TL)+(\kpack*\elem, 0)$) -- ($(vec TL)+(\kpack*\elem*\vecR, -\elem*\vecR)$); + \node [scale=.8*\scale, above] at ($(vec TL)+(.5*\kpack*\elem*\vecR, 0)$) {vec=\kpack}; + + \coordinate (vec TL) at ($(mfma op TL)+(-3*\elem*\vecR, .25*\kpack*\elem)$); + \foreach \vecId in {0,...,\maxVec}{ + \draw ($(vec TL)+(0, -\vecId*\elem*\vecR)$) rectangle ++(\elem*\vecR, -\elem*\vecR); + } + \draw [densely dotted] (mfma op TL) -- ($(vec TL)+(\elem*\vecR,0)$); + \draw [densely dotted] ($(mfma op TL)+(0, -\kpack*\elem)$) -- ($(vec TL)+(\elem*\vecR, -\kpack*\elem*\vecR)$); + \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*\kpack*\elem*\vecR)$) {vec=\kpack}; + + \node [scale=\scale, below] at ($(block TL)+(.5*\mfmaNonKDim*\elem,-\mfmaNonKDim*\elem)$) {outC}; + \ifthenelse{\mfmaTrans=0}{ + \node [scale=\scale, below] at ($(mfma opA TL)+(\kpack*\elem, -\mfmaNonKDim*\elem)$) {opA}; + \node [scale=\scale, above] at (mfma op TL) {opB}; + \coordinate (vec TL) at ($(block TL)+(-3*\elem-\elem*\vecR, .25*4*\elem)$); + \foreach \vecId in {0,1,2,3}{ + \draw ($(vec TL)+(0, -\vecId*\elem*\vecR)$) rectangle ++(\elem*\vecR, -\elem*\vecR); + } + \draw [densely dotted] (block TL) -- ++(-3*\elem, .25*4*\elem); + \draw [densely dotted] ($(block TL)+(0, -4*\elem)$) -- ++(-3*\elem, -.25*4*\elem); + \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*4*\elem*\vecR)$) {vec=4}; + \node [scale=.8*\scale, above, align=center] at ($(block TL)+(.5*\mfmaNonKDim*\elem, 0)$) {mfmaLayout\\trans=False}; + }{ + \node [scale=\scale, below] at ($(mfma opA TL)+(\kpack*\elem, -\mfmaNonKDim*\elem)$) {opB}; + \node [scale=\scale, above] at (mfma op TL) {opA}; + \coordinate (vec TL) at ($(block TL)+(-.25*4*\elem, 3*\elem+\elem*\vecR)$); + \foreach \vecId in {0,1,2,3}{ + \draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR); + } + \draw [densely dotted] (block TL) -- ++(-.25*4*\elem, 3*\elem); + \draw [densely dotted] ($(block TL)+(4*\elem, 0)$) -- ++(.25*4*\elem, 3*\elem); + \node [scale=.8*\scale, above] at ($(vec TL)+(.5*4*\elem*\vecR, 0)$) {vec=4}; + \node [scale=.8*\scale, above, align=center] at ($(block TL)+(16*\elem, 0)$) {mfmaLayout\\trans=True}; + } +} \ No newline at end of file diff --git a/python/perf-kernels/tools/plot-layout/ldsLayout.tex b/python/perf-kernels/tools/plot-layout/ldsLayout.tex new file mode 100644 index 000000000000..393c709a29f5 --- /dev/null +++ b/python/perf-kernels/tools/plot-layout/ldsLayout.tex @@ -0,0 +1,274 @@ +\ExplSyntaxOn +\NewExpandableDocumentCommand{\bitwiseXor}{mm} + { + \recuenco_bitwise_xor:nn { #1 } { #2 } + } + +\cs_new:Nn \recuenco_bitwise_xor:nn + { + \int_from_bin:e + { + \__recuenco_bitwise_xor:ee { \int_to_bin:n { #1 } } { \int_to_bin:n { #2 } } + } + } +\cs_generate_variant:Nn \int_from_bin:n { e } + +\cs_new:Nn \__recuenco_bitwise_xor:nn + { + \__recuenco_bitwise_xor_binary:ee + { + \prg_replicate:nn + { + \int_max:nn { \tl_count:n { #1 } } { \tl_count:n { #2 } } - \tl_count:n { #1 } + } + { 0 } + #1 + } + { + \prg_replicate:nn + { + \int_max:nn { \tl_count:n { #1 } } { \tl_count:n { #2 } } - \tl_count:n { #2 } + } + { 0 } + #2 + } + } +\cs_generate_variant:Nn \__recuenco_bitwise_xor:nn { ee } + +\cs_new:Nn \__recuenco_bitwise_xor_binary:nn + { + \__recuenco_bitwise_xor_binary:w #1;#2; + } +\cs_generate_variant:Nn \__recuenco_bitwise_xor_binary:nn { ee } + +\cs_new:Npn \__recuenco_bitwise_xor_binary:w #1#2;#3#4; + { + \int_abs:n { #1-#3 } + \tl_if_empty:nF { #2 } { \__recuenco_bitwise_xor_binary:w #2;#4; } + } + +\ExplSyntaxOff + +\newcommand{\drawTensorLayoutGlobalMem}{ + %% + %% Draw tensor layout in global memory without any swizzling + %% + %% TL: pre defined top-left coordinates of the tensor in global memory + %% \elem: per defined variable + %% \Colors: a pre defined array of 16 colors + %% + %% The following arguments are also expected to be pre defined + %% #1: M + %% #2: K + %% #3: vec: number of elements in a group + + \pgfmathsetmacro{\numVecK}{\K/\vec} + \pgfmathsetmacro{\maxVecId}{16*\numVecK-1} + \pgfmathsetmacro{\drawM}{20} + + %% Draw the tensor, but only draw 32 rows + \draw (TL) rectangle ++(\K*\elem, -\drawM*\elem); + %% Draw detailed vec view of the tensor + \foreach \vecId in {0,...,\maxVecId}{ + + \pgfmathsetmacro{\vecCoordM}{int(\vecId/\numVecK)} + \pgfmathsetmacro{\vecCoordK}{mod(\vecId,\numVecK)} + \coordinate (vec TL) at ($(TL)+(\vecCoordK*\vec*\elem, -\vecCoordM*\elem)$); + + \pgfmathsetmacro{\colorIdxK}{int(mod(\vecCoordK,16))} + \pgfmathsetmacro{\colorIdxM}{mod(\vecCoordM,16)} + \pgfmathsetmacro{\vecColor}{\Colors[\colorIdxK]} + \pgfmathsetmacro{\ratio}{100-floor(\vecCoordK/16)*40} + + \draw [ultra thin, fill=\vecColor!\ratio!white] (vec TL) rectangle ++(\vec*\elem, -\elem) + node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; + + } + %% M and K dim + \node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM*\elem-8*\elem)$) {M=\M}; + \node [scale=.8*\scale, left] at ($(TL)+(0, -.5*16*\elem)$) {16}; + \node [scale=\scale, above] at ($(TL)+(.5*\K*\elem, 0)$) {K=\K}; + %% label for vecSize + \def\vecR{1.5} + \coordinate (vec TL) at ($(TL)+(-.25*\vec*\elem, 3*\elem*\vecR)$); + \pgfmathsetmacro{\maxVec}{\vec-1} + \foreach \vecId in {0,...,\maxVec}{ + \draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR); + } + \draw [densely dotted] (TL) -- ($(vec TL)+(0, -\elem*\vecR)$); + \draw [densely dotted] ($(TL)+(\vec*\elem, 0)$) -- ($(vec TL)+(\vec*\elem*\vecR, -\elem*\vecR)$); + \node [scale=.8*\scale, above] at ($(vec TL)+(.5*\vec*\elem*\vecR, 0)$) {vec=\vec}; +} + + + +\newcommand{\drawLDSLayoutTritonSwizzling}[2]{ + %% + %% Draw tensor layout in LDS with swizzling + %% + %% TL: pre defined top-left coordinates of the tensor in global memory + %% \elem: per defined variable + %% \Colors: a pre defined array of 16 colors + %% + %% The following three arguments are expected to be pre defined + %% #1: M + %% #2: K + %% #3: vec: number of elements in a group + %% + %% #1: hasSwizzle, 0 means no swizzling and no padding, + %% 1 means optimal swizzling + %% 2 means padding + %% #2: access mode, 0 means draw nothing, 1 means ds_read, 2 means ds_write + %% For ds_write access, the following variables are assumed to be pre defined + %% \sizePerThreadK + %% \sizePerThreadM + %% \threadsPerWarpK + + \pgfmathsetmacro{\hasSwizzle}{#1} + \pgfmathsetmacro{\accessMode}{#2} + \pgfmathsetmacro{\numVecK}{\K/\vec} + + %% Assuming fp16 data type + \pgfmathsetmacro{\LDSK}{64} + \pgfmathsetmacro{\numLDSVec}{\LDSK/\vec} + \pgfmathsetmacro{\swizzleK}{max(\LDSK, \K)} + \pgfmathsetmacro{\LDSM}{int(\M/\LDSK*\K)} + + \ifthenelse{\accessMode = 2}{ + %% \accessMode == 2, draw 8 rows + \pgfmathsetmacro{\maxVecId}{8*\numVecK-1} + \pgfmathsetmacro{\drawM}{8*\K/\LDSK+4} + }{ + %% \accessMode == 0 or 1, draw 16 rows + \pgfmathsetmacro{\maxVecId}{16*\numVecK-1} + \pgfmathsetmacro{\drawM}{16*\K/\LDSK+4} + } + + %% Parameters used for swizzling + \pgfmathsetmacro{\numVecSwizzleK}{\swizzleK/\vec} + %% perPhase = ceil(LDSK / K) + %% The number of the rows of the tensor that can share the same swizzling pattern + \pgfmathsetmacro{\perPhase}{ceil(\LDSK/\K)} + %% maxPhase: the total number of different swizzling patterns + \ifthenelse{\hasSwizzle=0}{ + %% When swizzling is disabled + \pgfmathsetmacro{\maxPhase}{1} + }{ + %% When vec is small enough, we want 16/perPhase different swizzling patterns + %% When vec is large, we can only have 64 / \vec different swizzling pattern at most + \pgfmathsetmacro{\maxPhase}{min(16/\perPhase,64/\vec)} + } + + %% Draw the LDS + \draw (TL) rectangle ++(\LDSK*\elem, -\drawM*\elem); + + %% Draw detailed vec view of LDS + \foreach \vecId in {0,...,\maxVecId}{ + \pgfmathsetmacro{\vecCoordM}{int(\vecId/\numVecK)} + \pgfmathsetmacro{\vecCoordK}{int(mod(\vecId,\numVecK))} + \pgfmathsetmacro{\rawPhase}{floor(\vecId/\numVecSwizzleK)} + %% vec color + \pgfmathsetmacro{\colorIdxK}{int(mod(\vecCoordK,16))} + \pgfmathsetmacro{\colorIdxM}{mod(\vecCoordM,16)} + \pgfmathsetmacro{\ratio}{100-floor(\vecCoordK/16)*40} + \pgfmathsetmacro{\vecColor}{\Colors[\colorIdxK]} + + %% old vec coordinates + \coordinate (vec TL) at ($(TL)+(\vecCoordK*\vec*\elem, -\vecCoordM*\elem)$); + + %% new vec coordinates in LDS by swizzling + %% The following two conditions correspond to the relation between \LDSK and \K + \ifthenelse{\LDSK < \K}{ + \pgfmathsetmacro{\vecLDSM}{\vecCoordM*\K/\LDSK+floor(\vecCoordK*\vec/\LDSK)} + \pgfmathsetmacro{\vecLDSK}{int(mod(\vecCoordK, \LDSK/\vec))} + }{ + \pgfmathsetmacro{\vecLDSM}{floor(\vecCoordM/\perPhase)} + \pgfmathsetmacro{\vecLDSK}{int(\vecCoordK+mod(\vecCoordM,\perPhase)*\numVecK)} + } + %% + \pgfmathsetmacro{\phase}{int(mod(\rawPhase, \maxPhase))} + %% Compute the swizzled col id + \pgfmathsetmacro{\vecLDSKSwizzled}{\bitwiseXor{\vecLDSK}{\phase}} + + %% new vec coordinates in LDS by padding + \pgfmathsetmacro{\numPads}{floor(\vecId/\numLDSVec)} + \pgfmathsetmacro{\bankId}{\vec/2*\vecId+\numPads} + \pgfmathsetmacro{\vecPadM}{int(\bankId/32)} + \pgfmathsetmacro{\vecPadK}{int(mod(\bankId,32))} + + \ifthenelse{\hasSwizzle = 2}{ + %% vec coordinates by padding + \coordinate (new vec TL) at ($(TL)+(\vecPadK*2*\elem, -\vecPadM*\elem)$); + \pgfmathsetmacro{\tailBankId}{int(\vecPadK+\vec/2-1)} + }{ + %% vec coordinates by swizzling + \coordinate (new vec TL) at ($(TL)+(\vecLDSKSwizzled*\vec*\elem, -\vecLDSM*\elem)$); + \pgfmathsetmacro{\tailBankId}{0} + } + + \ifthenelse{\hasSwizzle = 2 \AND \tailBankId > 31}{ + \pgfmathsetmacro{\nextBanks}{\tailBankId-31} + \pgfmathsetmacro{\leftBanks}{\vec/2 - \nextBanks} + \draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\leftBanks*2*\elem, -\elem) + node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; + \draw [ultra thin, fill=\vecColor!\ratio!white] ($(TL)+(0, -\vecPadM*\elem-\elem)$) + rectangle ++(\nextBanks*2*\elem, -\elem) node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; + }{ + \draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\vec*\elem, -\elem) + node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; + } + + %% ds_read + %% Highlight the elements the first 16 threads access in the first cycle + %% This is used to visualize bank conflicts + \ifthenelse{\accessMode = 1}{ + \ifthenelse{\vecCoordK = 0}{ + \draw [fill=white] (new vec TL) rectangle ++(\elem, -\elem); + \draw (new vec TL) -- ++(\elem, -\elem); + \draw ($(new vec TL)+(0, -\elem)$) -- ++(\elem, \elem); + }{} + }{} + + %% Draw ds_write pattern + \ifthenelse{\accessMode = 2}{ + %% First compute the coverage of the first 16 threads + \pgfmathsetmacro{\covK}{min(16, \threadsPerWarpK)*\sizePerThreadK/\vec} + \pgfmathsetmacro{\covM}{ceil(16/\threadsPerWarpK)*\sizePerThreadM} + %% Check conditions for the first 16 threads + \pgfmathsetmacro{\vecInThread}{int(mod(\vecCoordK, \sizePerThreadK/\vec))} + \ifthenelse{\vecInThread=0}{ + \ifthenelse{\vecCoordK<\covK \AND \vecCoordM<\covM}{ + \draw [fill=white] (new vec TL) rectangle ++(\elem, -\elem); + \draw (new vec TL) -- ++(\elem, -\elem); + \draw ($(new vec TL)+(0, -\elem)$) -- ++(\elem, \elem); + }{} + }{} + }{} + + %% Label the phase of each line if swizzling is used + \ifthenelse{\hasSwizzle = 2}{}{ + \pgfmathsetmacro{\lastVecId}{int(64/\vec)-1} + \ifthenelse{\vecLDSKSwizzled = \lastVecId}{ + \draw [ultra thin] ($(new vec TL)+(\vec*\elem, -.5*\elem)$) -- ++(\elem, 0) + node [scale=.6*\scale, right] {\phase}; + }{} + } + } + + %% Draw boundary of 32 banks + %% Assume fp16 data type + \foreach \bank in {0,...,31}{ + \draw [ultra thin, gray] ($(TL)+(\bank*2*\elem, 0)$) -- ++(0, 2*\elem) + node [scale=.6*\scale, right, black] {\bank}; + } + \draw [ultra thin, gray] ($(TL)+(32*2*\elem, 0)$) -- ++(0, 2*\elem); + \node [scale=.6*\scale, left, black] at ($(TL)+(0, 2*\elem)$) {bank id}; + + \node [scale=\scale, above] at ($(TL)+(.5*\LDSK*\elem, 3*\elem)$) {LDS 32 banks}; + \node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM*\elem)$) {LDSM=\LDSM}; + + %% label phase if swizzling is used + \ifthenelse{\hasSwizzle = 2}{}{ + \node [scale=.6*\scale, above right] at($(TL)+(32*2*\elem, 0)$) {phase}; + } +} diff --git a/python/perf-kernels/tools/plot-layout/plot_layout.py b/python/perf-kernels/tools/plot-layout/plot_layout.py index 785f674bb386..298eab79c5a2 100644 --- a/python/perf-kernels/tools/plot-layout/plot_layout.py +++ b/python/perf-kernels/tools/plot-layout/plot_layout.py @@ -10,10 +10,6 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth): \\def\\scale{{1}} \\def\\elem{{0.04}} \\coordinate (C TL) at (0,0); - \\def\\opColorAL{{magenta}} - \\def\\opColorAR{{cyan}} - \\def\\opColorBL{{Maroon}} - \\def\\opColorBR{{BlueGreen}} \\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}}{{{kWidth}}} \\coordinate (C TL) at ($(C TL)+({N}*\elem+32*\elem, 0)$); @@ -202,8 +198,8 @@ def main(): print(f"sizePerThread={sizePerThread}, threadsPerWarp={threadsPerWarp}") with open("myplot.tex", 'w') as f_plot: - with open("tikzplot.tex") as file: - tikz_code = file.read() + with open("preamble.tex") as file: + preamble = file.read() draw_blockedLayout_str = draw_blocked_layout_cmd(dim0, dim1, dim0Name, dim1Name, sizePerThread, threadsPerWarp, warpsPerCTA, order) @@ -214,14 +210,18 @@ def main(): draw_wmma_str = draw_wmma_instr_cmd(waveSize) - f_plot.write(tikz_code) + f_plot.write(preamble) if plot_mode == 'blocked': + f_plot.write("\input{blockedLayout}\n") f_plot.write(draw_blockedLayout_str) elif plot_mode == 'dot': + f_plot.write("\input{dotLayout}\n") f_plot.write(draw_dotLayout_str) elif plot_mode == 'lds': + f_plot.write("\input{ldsLayout}\n") f_plot.write(draw_lds_str) elif plot_mode == 'wmma': + f_plot.write("\input{wmmaLayout}\n") f_plot.write(draw_wmma_str) run_bash_command(f"pdflatex -jobname {ofilename} myplot.tex") diff --git a/python/perf-kernels/tools/plot-layout/preamble.tex b/python/perf-kernels/tools/plot-layout/preamble.tex new file mode 100644 index 000000000000..1c9d7e480446 --- /dev/null +++ b/python/perf-kernels/tools/plot-layout/preamble.tex @@ -0,0 +1,26 @@ +\documentclass[tikz, border=1mm, dvipsnames]{standalone} +\usepackage{ifthen} +\usepackage{tikz} +\usetikzlibrary{arrows.meta,arrows} +\usetikzlibrary{intersections} +\usetikzlibrary{calc, quotes} +\usetikzlibrary{patterns} +\usepackage{xparse} +\newcommand{\Colors}{{ + "red", + "YellowGreen", + "blue", + "Maroon", + "orange", + "cyan", + "magenta", + "brown", + "teal", + "purple", + "gray", + "Green", + "BlueGreen", + "violet", + "olive", + "darkgray", + }} diff --git a/python/perf-kernels/tools/plot-layout/tikzplot.tex b/python/perf-kernels/tools/plot-layout/tikzplot.tex deleted file mode 100644 index 6af05c923597..000000000000 --- a/python/perf-kernels/tools/plot-layout/tikzplot.tex +++ /dev/null @@ -1,942 +0,0 @@ -\documentclass[tikz, border=1mm, dvipsnames]{standalone} -\usepackage{ifthen} -\usepackage{tikz} -\usetikzlibrary{arrows.meta,arrows} -\usetikzlibrary{intersections} -\usetikzlibrary{calc, quotes} -\usetikzlibrary{patterns} -\usepackage{xparse} - -\ExplSyntaxOn -\NewExpandableDocumentCommand{\bitwiseXor}{mm} - { - \recuenco_bitwise_xor:nn { #1 } { #2 } - } - -\cs_new:Nn \recuenco_bitwise_xor:nn - { - \int_from_bin:e - { - \__recuenco_bitwise_xor:ee { \int_to_bin:n { #1 } } { \int_to_bin:n { #2 } } - } - } -\cs_generate_variant:Nn \int_from_bin:n { e } - -\cs_new:Nn \__recuenco_bitwise_xor:nn - { - \__recuenco_bitwise_xor_binary:ee - { - \prg_replicate:nn - { - \int_max:nn { \tl_count:n { #1 } } { \tl_count:n { #2 } } - \tl_count:n { #1 } - } - { 0 } - #1 - } - { - \prg_replicate:nn - { - \int_max:nn { \tl_count:n { #1 } } { \tl_count:n { #2 } } - \tl_count:n { #2 } - } - { 0 } - #2 - } - } -\cs_generate_variant:Nn \__recuenco_bitwise_xor:nn { ee } - -\cs_new:Nn \__recuenco_bitwise_xor_binary:nn - { - \__recuenco_bitwise_xor_binary:w #1;#2; - } -\cs_generate_variant:Nn \__recuenco_bitwise_xor_binary:nn { ee } - -\cs_new:Npn \__recuenco_bitwise_xor_binary:w #1#2;#3#4; - { - \int_abs:n { #1-#3 } - \tl_if_empty:nF { #2 } { \__recuenco_bitwise_xor_binary:w #2;#4; } - } - -\ExplSyntaxOff - -\newcommand{\drawBlockedWave}[5]{ - %% - %% Draw a wave coverage with blocked layout - %% - %% Wave TL: pre defined top-left coordinate of the wave - %% \elem: pre defined variable - %% - %% #1: sizePerThread[0] --> sizePerThreadM - %% #2: sizePerThread[1] --> sizePerThreadN - %% #3: threadsPerWarp[0] --> threadsPerWarpM - %% #4: threadsPerWarp[1] --> threadsPerWarpN - %% #5: fastest changing dim --> order - - \pgfmathsetmacro{\sizePerThreadM}{#1} - \pgfmathsetmacro{\sizePerThreadN}{#2} - \pgfmathsetmacro{\threadsPerWarpM}{#3} - \pgfmathsetmacro{\threadsPerWarpN}{#4} - \pgfmathsetmacro{\order}{#5} - - \pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM} - \pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN} - - \foreach \tid in {0,...,63}{ - \pgfmathsetmacro{\tidM}{int(\tid/\threadsPerWarpN)} - \pgfmathsetmacro{\tidN}{mod(\tid,\threadsPerWarpN)} - \coordinate (Thread TL) at ($(Wave TL)+(\tidN*\sizePerThreadN*\elem, -\tidM*\sizePerThreadM*\elem)$); - \pgfmathsetmacro{\ratio}{\tidM*10} - - \ifthenelse{\tid = 0}{ - \draw [line width = 0.01mm, fill=red] (Thread TL) - rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); - }{ - \draw [line width = 0.01mm, fill=blue!\ratio!white] (Thread TL) - rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); - } - } - \draw (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem); -} - -\newcommand{\drawBlockedCTA}[7]{ - %% - %% Draw a CTA coverage with blocked layout - %% - %% CTA TL: pre defined top-left coordinate of the CTA - %% \elem: pre defined variable - %% - %% #1: sizePerThread[0] --> sizePerThreadM - %% #2: sizePerThread[1] --> sizePerThreadN - %% #3: threadsPerWarp[0] --> threadsPerWarpM - %% #4: threadsPerWarp[1] --> threadsPerWarpN - %% #5: warpsPerCTA[0] --> warpsPerCTAM - %% #6: warpsPerCTA[1] --> warpsPerCTAN - %% #7: fastest changing dim --> order - - \pgfmathsetmacro{\sizePerThreadM}{#1} - \pgfmathsetmacro{\sizePerThreadN}{#2} - \pgfmathsetmacro{\threadsPerWarpM}{#3} - \pgfmathsetmacro{\threadsPerWarpN}{#4} - \pgfmathsetmacro{\warpsPerCTAM}{#5} - \pgfmathsetmacro{\warpsPerCTAN}{#6} - \pgfmathsetmacro{\order}{#7} - - \pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM} - \pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN} - \pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM} - \pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN} - - \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM*\warpsPerCTAN-1} - - \coordinate (Wave TL) at (CTA TL); - \drawBlockedWave{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\order} - \foreach \waveId in {0,...,\maxWaveId}{ - \ifthenelse{\order=1} - { - \pgfmathsetmacro{\waveCoordM}{int(\waveId/\warpsPerCTAN)} - \pgfmathsetmacro{\waveCoordN}{mod(\waveId,\warpsPerCTAN)} - \pgfmathsetmacro{\rot}{0} - }{ - \pgfmathsetmacro{\waveCoordM}{mod(\waveId,\warpsPerCTAM)} - \pgfmathsetmacro{\waveCoordN}{int(\waveId/\warpsPerCTAM)} - \pgfmathsetmacro{\rot}{90} - } - - \coordinate (Wave TL) at ($(CTA TL)+(\waveCoordN*\waveSizeN*\elem, -\waveCoordM*\waveSizeM*\elem)$); - \draw [ultra thin] (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem) - node [pos=.5, scale=.6*\scale, inner sep=0, fill=white, rotate=\rot] {wave\waveId}; - } - - \draw [thick] (CTA TL) rectangle ++(\CTASizeN*\elem, -\CTASizeM*\elem); -} - -\newcommand{\drawBlockedTensor}[8]{ - %% - %% Draw a tensor with blocked layout of the following parameters - %% sizePerThread[2] - %% threadsPerWarp[2] - %% warpsPerCTA[2] - %% order[2] - %% - %% TL: pre defined top-left coordinate of the tensor - %% \elem: pre defined variable - %% \dimColName: dim0Name - %% \dimRowName: dim1Name - %% - %% #1: tensorShape[0] --> M - %% #2: tensorShape[1] --> N - %% #3: sizePerThread[0] --> sizePerThreadM - %% #4: sizePerThread[1] --> sizePerThreadN - %% #5: threadsPerWarp[0] --> threadsPerWarpM - %% Note that threadsPerWarp[1] is calculated by 64/threadsPerWarp[0] - %% #6: warpsPerCTA[0] --> warpsPerCTAM - %% #7: warpsPerCTA[1] --> warpsPerCTAN - %% #8: fastest changing dim --> order - - \pgfmathsetmacro{\M}{#1} - \pgfmathsetmacro{\N}{#2} - \pgfmathsetmacro{\sizePerThreadM}{#3} - \pgfmathsetmacro{\sizePerThreadN}{#4} - \pgfmathsetmacro{\threadsPerWarpM}{#5} - \pgfmathsetmacro{\warpsPerCTAM}{#6} - \pgfmathsetmacro{\warpsPerCTAN}{#7} - \pgfmathsetmacro{\order}{#8} - - \pgfmathsetmacro{\threadsPerWarpN}{64/\threadsPerWarpM} - \pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM} - \pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN} - \pgfmathsetmacro{\CTARepM}{\M/\CTASizeM} - \pgfmathsetmacro{\CTARepN}{\N/\CTASizeN} - \pgfmathsetmacro{\maxCTAId}{\CTARepM*\CTARepN-1} - - \foreach \ctaId in {0,...,\maxCTAId}{ - \pgfmathsetmacro{\ctaCoordM}{int(\ctaId/\CTARepN)} - \pgfmathsetmacro{\ctaCoordN}{mod(\ctaId,\CTARepN)} - \coordinate (CTA TL) at ($(TL)+(\ctaCoordN*\CTASizeN*\elem, -\ctaCoordM*\CTASizeM*\elem)$); - \drawBlockedCTA{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\warpsPerCTAM}{\warpsPerCTAN}{\order} - } - - \node [scale=.7*\scale, above, rotate=90] at ($(TL)+(0, -.5*\M*\elem)$) {\dimColName=\M}; - \node [scale=.7*\scale, above] at ($(TL)+(.5*\N*\elem, 0)$) {\dimRowName=\N}; - - \def\zoomR{1.5} - \coordinate (zoomin BL) at ($(TL)+(0, .3)$); - - \foreach \hl in {0,...,\sizePerThreadM}{ - \draw ($(zoomin BL)+(0, \hl*\elem*\zoomR)$) -- ++(\sizePerThreadN*\elem*\zoomR,0); - } - \foreach \vl in {0,...,\sizePerThreadN}{ - \draw ($(zoomin BL)+(\vl*\elem*\zoomR, 0)$) -- ++(0, \sizePerThreadM*\elem*\zoomR); - } - - \node [scale=.6*\scale, left] at ($(zoomin BL)+(0, .5*\sizePerThreadM*\elem*\zoomR)$) {$t_0$}; - \node [scale=.6*\scale, right] at ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, .5*\sizePerThreadM*\elem*\zoomR)$) {\sizePerThreadM$\times$\sizePerThreadN}; - - \draw [densely dotted] (TL) -- (zoomin BL); - \draw [densely dotted] ($(TL)+(\sizePerThreadN*\elem, 0)$) -- ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, 0)$); - \draw [fill=red] (TL) rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); -} - -\newcommand{\drawBlockMFMALayoutLarge}[3]{ - %% - %% Draw a single block of MFMA_32x32x8xf16 or MFMA_16x16x16xf16 - %% - %% block TL: pre-defined top-left coordinate of the block - %% \elem: pre defined variable - %% - %% #1: 1 for mfma.trans, 0 for normal mfma - %% #2: mfmaNonKDim - %% #3: verbose. 1 means draw tid in each vec; 0 means draw nothing - - \pgfmathsetmacro{\trans}{#1} - \pgfmathsetmacro{\nonTrans}{1-#1} - \pgfmathsetmacro{\nonKDim}{#2} - \pgfmathsetmacro{\maxTID}{\nonKDim-1} - \pgfmathsetmacro{\groups}{64/\nonKDim} - \pgfmathsetmacro{\maxGID}{\groups-1} - \pgfmathsetmacro{\maxIVec}{\nonKDim*\nonKDim/256-1} - \pgfmathsetmacro{\verbose}{#3} - \foreach \iVec in {0,...,\maxIVec} { - \coordinate (wave TL) at ($(block TL)+(\trans*\iVec*\groups*4*\elem, -\nonTrans*\iVec*\groups*4*\elem)$); - \foreach \tg in {0,...,\maxGID}{ - \pgfmathsetmacro{\colID}{\tg+4} - \pgfmathsetmacro{\col}{\Colors[\colID]} - \foreach \tid in {0,...,\maxTID} { - \pgfmathsetmacro{\ratio}{\tid*2.5*\groups+15} - \ifthenelse{\verbose=0}{ - \draw [line width=0.005mm, fill=\col!\ratio!white] - ($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$) - rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem); - }{ - \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} - \draw [line width=0.005mm, fill=\col!\ratio!white] - ($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$) - rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem) - node [pos=.5, scale=.35*\scale, rotate=90*\nonTrans] {t\drawTid}; - } - } - } - } - \draw [thick] (block TL) rectangle ++(\nonKDim*\elem, -\nonKDim*\elem); -} - - -\newcommand{\drawTensorMFMALayout}[6]{ - %% - %% Draw a tensor with mfma layout. - %% - %% C TL: pre defined top-left coordinates of the tensor - %% - %% #1: M - %% #2: N - %% #3: MFMA nonKDim - %% #4: warpsPerCTA[0] - %% #5: warpsPerCTA[1] - %% #6: 1 for mfma.trans, 0 for normal mfma - - \pgfmathsetmacro{\tensorShapeH}{#1} - \pgfmathsetmacro{\tensorShapeW}{#2} - \pgfmathsetmacro{\mfmaNonKDim}{#3} - \pgfmathsetmacro{\warpsPerCTAH}{#4} - \pgfmathsetmacro{\warpsPerCTAW}{#5} - \pgfmathsetmacro{\mfmaTrans}{#6} - - \coordinate (old TL) at (TL); - \coordinate (TL) at (C TL); - - - \pgfmathsetmacro{\CTARepH}{\tensorShapeH/\mfmaNonKDim/\warpsPerCTAH} - \pgfmathsetmacro{\CTARepW}{\tensorShapeW/\mfmaNonKDim/\warpsPerCTAW} - \pgfmathsetmacro{\maxCTAId}{\CTARepH*\CTARepW-1} - \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAH*\warpsPerCTAW-1} - \pgfmathsetmacro{\CTASizeH}{\warpsPerCTAH*\mfmaNonKDim} - \pgfmathsetmacro{\CTASizeW}{\warpsPerCTAW*\mfmaNonKDim} - - - \foreach \ctaId in {0,...,\maxCTAId}{ - \pgfmathsetmacro{\ctaCoordH}{int(\ctaId/\CTARepW)} - \pgfmathsetmacro{\ctaCoordW}{mod(\ctaId,\CTARepW)} - \coordinate (CTA TL) at ($(TL)+(\ctaCoordW*\CTASizeW*\elem, -\ctaCoordH*\CTASizeH*\elem)$); - %% Draw a detailed view of wave0 in each CTA - \coordinate (block TL) at (CTA TL); - \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{0} - - \foreach \waveId in {0,...,\maxWaveId}{ - \pgfmathsetmacro{\waveCoordH}{int(\waveId/\warpsPerCTAW)} - \pgfmathsetmacro{\waveCoordW}{mod(\waveId,\warpsPerCTAW)} - \coordinate (block TL) at ($(CTA TL)+(\waveCoordW*\mfmaNonKDim*\elem, -\waveCoordH*\mfmaNonKDim*\elem)$); - %% Inside the loop, only draw a rectangle - \draw [ultra thin] (block TL) rectangle ++(\mfmaNonKDim*\elem, -\mfmaNonKDim*\elem) - node [scale=.7*\mfmaNonKDim/32*\scale, pos=.5, fill=white, inner sep=0] {wave\waveId}; - } - - %% Draw the outline of each CTA rep - \draw [ultra thick] (CTA TL) rectangle ++(\CTASizeW*\elem, -\CTASizeH*\elem); - } - - \coordinate (TL) at (old TL); -} - -\newcommand{\drawMFMAOperand}[4]{ - %% - %% Draw one mfma operand - %% - %% mfma op TL: pre defined coordinates of the top-left - %% \elem: pre defined variable - %% - %% #1: mfmNonKDim - %% #2: kpack - %% #3: 0 for opA and 1 for opB - %% #4: verbose. 1 means draw tid in each vec; 0 means draw nothing - - \pgfmathsetmacro{\nonKDim}{#1} - \pgfmathsetmacro{\maxGID}{64/\nonKDim-1} - \pgfmathsetmacro{\maxTID}{\nonKDim-1} - \pgfmathsetmacro{\kpack}{#2} - \pgfmathsetmacro{\opIdxA}{#3} - \pgfmathsetmacro{\opIdxB}{1-\opIdxA} - \pgfmathsetmacro{\verbose}{#4} - - \foreach \col/\tg in {0,...,\maxGID}{ - \pgfmathsetmacro{\col}{\Colors[\tg]} - \foreach \tid in {0,...,\maxTID} { - % \pgfmathsetmacro{\ratio}{\tid*2.5+15} - \ifthenelse{\verbose=0}{ - \draw [line width=0.005mm, fill=\col] - ($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$) - rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA); - }{ - \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} - \draw [line width=0.005mm, fill=\col] - ($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$) - rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA) - node [pos=.5, scale=.35*\scale, rotate=90*\opIdxA] {t\drawTid}; - } - } - } -} - -\newcommand{\drawWaveOperand}[4]{ - %% - %% Draw the part of the tensor that is one operand of the wave - %% - %% Op TL: pre defined coordinates of the top-left of the operand - %% \elem: pre defined variable - %% - %% #1: K - %% #2: mfmNonKDim - %% #3: kpack - %% #4: 0 for opA and 1 for opB - - \pgfmathsetmacro{\K}{#1} - \pgfmathsetmacro{\nonKDim}{#2} - \pgfmathsetmacro{\groups}{64/\nonKDim} - \pgfmathsetmacro{\kpack}{#3} - \pgfmathsetmacro{\opIdx}{#4} - \pgfmathsetmacro{\opIdxOther}{1-\opIdx} - - \coordinate (TL) at (Op TL); - - \pgfmathsetmacro{\numKRep}{\K/\kpack/\groups} - \pgfmathsetmacro{\maxKRepId}{\numKRep-1} - - \foreach \repId in {0,...,\maxKRepId}{ - \coordinate (mfma op TL) at ($(TL)+(\repId*\groups*\kpack*\elem*\opIdxOther, -\repId*\groups*\kpack*\elem*\opIdx)$); - \drawMFMAOperand{\nonKDim}{\kpack}{\opIdx}{0} - \draw [thick] (mfma op TL) rectangle - ++(\groups*\kpack*\elem*\opIdxOther+\nonKDim*\opIdx*\elem, -\nonKDim*\opIdxOther*\elem-\groups*\kpack*\elem*\opIdx); - } -} - -\newcommand{\drawDotOperands}[7]{ - %% - %% Draw operand tensors of dot - %% - %% A TL and B TL: pre defined top-left coordinates of A and B tensor - %% \elem: pre defined variable - %% - %% #1: M - %% #2: N - %% #3: K - %% #4: MFMA nonKDim - %% #5: warpsPerCTA[0] - %% #6: warpsPerCTA[1] - %% #7: kpack - - \pgfmathsetmacro{\M}{#1} - \pgfmathsetmacro{\N}{#2} - \pgfmathsetmacro{\K}{#3} - \pgfmathsetmacro{\mfmaNonKDim}{#4} - \pgfmathsetmacro{\warpsPerCTAM}{#5} - \pgfmathsetmacro{\warpsPerCTAN}{#6} - \pgfmathsetmacro{\kpack}{#7} - - %% operand A - \pgfmathsetmacro{\CTARepM}{\M/\warpsPerCTAM/\mfmaNonKDim} - \pgfmathsetmacro{\maxCTAIdM}{\CTARepM-1} - \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM-1} - \foreach \ctaId in {0,...,\maxCTAIdM}{ - \coordinate (CTA TL) at ($(A TL)+(0, -\ctaId*\warpsPerCTAM*\mfmaNonKDim*\elem)$); - \foreach \waveId in {0,...,\maxWaveId}{ - \coordinate (wave TL) at ($(CTA TL)+(0, -\waveId*\mfmaNonKDim*\elem)$); - \draw [ultra thin] (wave TL) rectangle ++(\K*\elem, -\mfmaNonKDim*\elem); - } - %% Only draw the detailed view of the first wave in CTA - \coordinate (Op TL) at (CTA TL); - \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{0} - - %% Draw the outline of each CTA rep - \draw [ultra thick] (CTA TL) rectangle ++(\K*\elem, -\warpsPerCTAM*\mfmaNonKDim*\elem); - } - \draw [ultra thin] (A TL) rectangle ++(\K*\elem, -\M*\elem); - - - %% operand B - \pgfmathsetmacro{\CTARepN}{\N/\warpsPerCTAN/\mfmaNonKDim} - \pgfmathsetmacro{\maxCTAIdN}{\CTARepN-1} - \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAN-1} - \foreach \ctaId in {0,...,\maxCTAIdN}{ - \coordinate (CTA TL) at ($(B TL)+(\ctaId*\warpsPerCTAN*\mfmaNonKDim*\elem, 0)$); - \foreach \waveId in {0,...,\maxWaveId}{ - \coordinate (wave TL) at ($(CTA TL)+(\waveId*\mfmaNonKDim*\elem ,0)$); - \draw [ultra thin] (wave TL) rectangle ++(\mfmaNonKDim*\elem, -\K*\elem); - } - %% Only draw the detailed view of the first wave in CTA - \coordinate (Op TL) at (CTA TL); - \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{1} - - %% Draw the outline of each CTA rep - \draw [ultra thick] (CTA TL) rectangle ++(\warpsPerCTAN*\mfmaNonKDim*\elem, -\K*\elem); - } - \draw [ultra thin] (B TL) rectangle ++(\N*\elem, -\K*\elem); -} - - -\newcommand{\drawDot}[8]{ - %% - %% Draw C = dot A, B - %% - %% C TL: pre defined top-left coordinates of the result tensor - %% \elem: pre defined variable - %% - %% #1: M - %% #2: N - %% #3: K - %% #4: MFMA nonKDim - %% #5: warpsPerCTA[0] - %% #6: warpsPerCTA[1] - %% #7: 1 for mfma.trans, 0 for normal mfma - %% #8: kpack - - \pgfmathsetmacro{\M}{#1} - \pgfmathsetmacro{\N}{#2} - \pgfmathsetmacro{\K}{#3} - \pgfmathsetmacro{\mfmaNonKDim}{#4} - \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} - \pgfmathsetmacro{\warpsPerCTAM}{#5} - \pgfmathsetmacro{\warpsPerCTAN}{#6} - \pgfmathsetmacro{\mfmaTrans}{#7} - \pgfmathsetmacro{\kpack}{#8} - \pgfmathsetmacro{\kdim}{int(\groups*\kpack)} - - \pgfmathsetmacro{\gap}{\elem*20} - \coordinate (A TL) at ($(C TL)+(-\gap-\K*\elem, 0)$); - \coordinate (B TL) at ($(C TL)+(0, \gap+\K*\elem)$); - - \drawDotOperands{\M}{\N}{\K}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\kpack} - - \drawTensorMFMALayout{\M}{\N}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\mfmaTrans} - - %% Draw labels - \node [scale=\scale, above] at ($(A TL)+(.5*\K*\elem, 0)$) {K=\K}; - \node [scale=\scale, above, rotate=90] at ($(A TL)+(0, -.5*\M*\elem)$) {M=\M}; - - \node [scale=\scale, above, rotate=90] at ($(B TL)+(0, -.5*\K*\elem)$) {K=\K}; - \node [scale=\scale, above] at ($(B TL)+(.5*\N*\elem, 0)$) {N=\N}; - - \node [scale=\scale, above left] at (A TL) {A}; - \node [scale=\scale, above left] at (B TL) {B}; - \node [scale=\scale, above left] at (C TL) {C}; - - %% label nonKDim - \node [scale=.8*\scale, left] at ($(A TL)+(0, -.5*\mfmaNonKDim*\elem)$) {\mfmaNonKDim}; - \node [scale=.8*\scale, above] at ($(B TL)+(.5*\mfmaNonKDim*\elem, 0)$) {\mfmaNonKDim}; - %% label kpack - \node [scale=.8*\scale, above] at ($(A TL)+(0.5*\groups*\kpack*\elem, 0)$) {\kdim}; - \node [scale=.8*\scale, left] at ($(B TL)+(0, -0.5*\groups\kpack*\elem)$) {\kdim}; -} - -\newcommand{\Colors}{{ - "red", - "YellowGreen", - "blue", - "Maroon", - "orange", - "cyan", - "magenta", - "brown", - "teal", - "purple", - "gray", - "Green", - "BlueGreen", - "violet", - "olive", - "darkgray", - }} - -\newcommand{\drawTensorLayoutGlobalMem}{ - %% - %% Draw tensor layout in global memory without any swizzling - %% - %% TL: pre defined top-left coordinates of the tensor in global memory - %% \elem: per defined variable - %% \Colors: a pre defined array of 16 colors - %% - %% The following arguments are also expected to be pre defined - %% #1: M - %% #2: K - %% #3: vec: number of elements in a group - - \pgfmathsetmacro{\numVecK}{\K/\vec} - \pgfmathsetmacro{\maxVecId}{16*\numVecK-1} - \pgfmathsetmacro{\drawM}{20} - - %% Draw the tensor, but only draw 32 rows - \draw (TL) rectangle ++(\K*\elem, -\drawM*\elem); - %% Draw detailed vec view of the tensor - \foreach \vecId in {0,...,\maxVecId}{ - - \pgfmathsetmacro{\vecCoordM}{int(\vecId/\numVecK)} - \pgfmathsetmacro{\vecCoordK}{mod(\vecId,\numVecK)} - \coordinate (vec TL) at ($(TL)+(\vecCoordK*\vec*\elem, -\vecCoordM*\elem)$); - - \pgfmathsetmacro{\colorIdxK}{int(mod(\vecCoordK,16))} - \pgfmathsetmacro{\colorIdxM}{mod(\vecCoordM,16)} - \pgfmathsetmacro{\vecColor}{\Colors[\colorIdxK]} - \pgfmathsetmacro{\ratio}{100-floor(\vecCoordK/16)*40} - - \draw [ultra thin, fill=\vecColor!\ratio!white] (vec TL) rectangle ++(\vec*\elem, -\elem) - node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; - - } - %% M and K dim - \node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM*\elem-8*\elem)$) {M=\M}; - \node [scale=.8*\scale, left] at ($(TL)+(0, -.5*16*\elem)$) {16}; - \node [scale=\scale, above] at ($(TL)+(.5*\K*\elem, 0)$) {K=\K}; - %% label for vecSize - \def\vecR{1.5} - \coordinate (vec TL) at ($(TL)+(-.25*\vec*\elem, 3*\elem*\vecR)$); - \pgfmathsetmacro{\maxVec}{\vec-1} - \foreach \vecId in {0,...,\maxVec}{ - \draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR); - } - \draw [densely dotted] (TL) -- ($(vec TL)+(0, -\elem*\vecR)$); - \draw [densely dotted] ($(TL)+(\vec*\elem, 0)$) -- ($(vec TL)+(\vec*\elem*\vecR, -\elem*\vecR)$); - \node [scale=.8*\scale, above] at ($(vec TL)+(.5*\vec*\elem*\vecR, 0)$) {vec=\vec}; -} - - - -\newcommand{\drawLDSLayoutTritonSwizzling}[2]{ - %% - %% Draw tensor layout in LDS with swizzling - %% - %% TL: pre defined top-left coordinates of the tensor in global memory - %% \elem: per defined variable - %% \Colors: a pre defined array of 16 colors - %% - %% The following three arguments are expected to be pre defined - %% #1: M - %% #2: K - %% #3: vec: number of elements in a group - %% - %% #1: hasSwizzle, 0 means no swizzling and no padding, - %% 1 means optimal swizzling - %% 2 means padding - %% #2: access mode, 0 means draw nothing, 1 means ds_read, 2 means ds_write - %% For ds_write access, the following variables are assumed to be pre defined - %% \sizePerThreadK - %% \sizePerThreadM - %% \threadsPerWarpK - - \pgfmathsetmacro{\hasSwizzle}{#1} - \pgfmathsetmacro{\accessMode}{#2} - \pgfmathsetmacro{\numVecK}{\K/\vec} - - %% Assuming fp16 data type - \pgfmathsetmacro{\LDSK}{64} - \pgfmathsetmacro{\numLDSVec}{\LDSK/\vec} - \pgfmathsetmacro{\swizzleK}{max(\LDSK, \K)} - \pgfmathsetmacro{\LDSM}{int(\M/\LDSK*\K)} - - \ifthenelse{\accessMode = 2}{ - %% \accessMode == 2, draw 8 rows - \pgfmathsetmacro{\maxVecId}{8*\numVecK-1} - \pgfmathsetmacro{\drawM}{8*\K/\LDSK+4} - }{ - %% \accessMode == 0 or 1, draw 16 rows - \pgfmathsetmacro{\maxVecId}{16*\numVecK-1} - \pgfmathsetmacro{\drawM}{16*\K/\LDSK+4} - } - - %% Parameters used for swizzling - \pgfmathsetmacro{\numVecSwizzleK}{\swizzleK/\vec} - %% perPhase = ceil(LDSK / K) - %% The number of the rows of the tensor that can share the same swizzling pattern - \pgfmathsetmacro{\perPhase}{ceil(\LDSK/\K)} - %% maxPhase: the total number of different swizzling patterns - \ifthenelse{\hasSwizzle=0}{ - %% When swizzling is disabled - \pgfmathsetmacro{\maxPhase}{1} - }{ - %% When vec is small enough, we want 16/perPhase different swizzling patterns - %% When vec is large, we can only have 64 / \vec different swizzling pattern at most - \pgfmathsetmacro{\maxPhase}{min(16/\perPhase,64/\vec)} - } - - %% Draw the LDS - \draw (TL) rectangle ++(\LDSK*\elem, -\drawM*\elem); - - %% Draw detailed vec view of LDS - \foreach \vecId in {0,...,\maxVecId}{ - \pgfmathsetmacro{\vecCoordM}{int(\vecId/\numVecK)} - \pgfmathsetmacro{\vecCoordK}{int(mod(\vecId,\numVecK))} - \pgfmathsetmacro{\rawPhase}{floor(\vecId/\numVecSwizzleK)} - %% vec color - \pgfmathsetmacro{\colorIdxK}{int(mod(\vecCoordK,16))} - \pgfmathsetmacro{\colorIdxM}{mod(\vecCoordM,16)} - \pgfmathsetmacro{\ratio}{100-floor(\vecCoordK/16)*40} - \pgfmathsetmacro{\vecColor}{\Colors[\colorIdxK]} - - %% old vec coordinates - \coordinate (vec TL) at ($(TL)+(\vecCoordK*\vec*\elem, -\vecCoordM*\elem)$); - - %% new vec coordinates in LDS by swizzling - %% The following two conditions correspond to the relation between \LDSK and \K - \ifthenelse{\LDSK < \K}{ - \pgfmathsetmacro{\vecLDSM}{\vecCoordM*\K/\LDSK+floor(\vecCoordK*\vec/\LDSK)} - \pgfmathsetmacro{\vecLDSK}{int(mod(\vecCoordK, \LDSK/\vec))} - }{ - \pgfmathsetmacro{\vecLDSM}{floor(\vecCoordM/\perPhase)} - \pgfmathsetmacro{\vecLDSK}{int(\vecCoordK+mod(\vecCoordM,\perPhase)*\numVecK)} - } - %% - \pgfmathsetmacro{\phase}{int(mod(\rawPhase, \maxPhase))} - %% Compute the swizzled col id - \pgfmathsetmacro{\vecLDSKSwizzled}{\bitwiseXor{\vecLDSK}{\phase}} - - %% new vec coordinates in LDS by padding - \pgfmathsetmacro{\numPads}{floor(\vecId/\numLDSVec)} - \pgfmathsetmacro{\bankId}{\vec/2*\vecId+\numPads} - \pgfmathsetmacro{\vecPadM}{int(\bankId/32)} - \pgfmathsetmacro{\vecPadK}{int(mod(\bankId,32))} - - \ifthenelse{\hasSwizzle = 2}{ - %% vec coordinates by padding - \coordinate (new vec TL) at ($(TL)+(\vecPadK*2*\elem, -\vecPadM*\elem)$); - \pgfmathsetmacro{\tailBankId}{int(\vecPadK+\vec/2-1)} - }{ - %% vec coordinates by swizzling - \coordinate (new vec TL) at ($(TL)+(\vecLDSKSwizzled*\vec*\elem, -\vecLDSM*\elem)$); - \pgfmathsetmacro{\tailBankId}{0} - } - - \ifthenelse{\hasSwizzle = 2 \AND \tailBankId > 31}{ - \pgfmathsetmacro{\nextBanks}{\tailBankId-31} - \pgfmathsetmacro{\leftBanks}{\vec/2 - \nextBanks} - \draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\leftBanks*2*\elem, -\elem) - node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; - \draw [ultra thin, fill=\vecColor!\ratio!white] ($(TL)+(0, -\vecPadM*\elem-\elem)$) - rectangle ++(\nextBanks*2*\elem, -\elem) node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; - }{ - \draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\vec*\elem, -\elem) - node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; - } - - %% ds_read - %% Highlight the elements the first 16 threads access in the first cycle - %% This is used to visualize bank conflicts - \ifthenelse{\accessMode = 1}{ - \ifthenelse{\vecCoordK = 0}{ - \draw [fill=white] (new vec TL) rectangle ++(\elem, -\elem); - \draw (new vec TL) -- ++(\elem, -\elem); - \draw ($(new vec TL)+(0, -\elem)$) -- ++(\elem, \elem); - }{} - }{} - - %% Draw ds_write pattern - \ifthenelse{\accessMode = 2}{ - %% First compute the coverage of the first 16 threads - \pgfmathsetmacro{\covK}{min(16, \threadsPerWarpK)*\sizePerThreadK/\vec} - \pgfmathsetmacro{\covM}{ceil(16/\threadsPerWarpK)*\sizePerThreadM} - %% Check conditions for the first 16 threads - \pgfmathsetmacro{\vecInThread}{int(mod(\vecCoordK, \sizePerThreadK/\vec))} - \ifthenelse{\vecInThread=0}{ - \ifthenelse{\vecCoordK<\covK \AND \vecCoordM<\covM}{ - \draw [fill=white] (new vec TL) rectangle ++(\elem, -\elem); - \draw (new vec TL) -- ++(\elem, -\elem); - \draw ($(new vec TL)+(0, -\elem)$) -- ++(\elem, \elem); - }{} - }{} - }{} - - %% Label the phase of each line if swizzling is used - \ifthenelse{\hasSwizzle = 2}{}{ - \pgfmathsetmacro{\lastVecId}{int(64/\vec)-1} - \ifthenelse{\vecLDSKSwizzled = \lastVecId}{ - \draw [ultra thin] ($(new vec TL)+(\vec*\elem, -.5*\elem)$) -- ++(\elem, 0) - node [scale=.6*\scale, right] {\phase}; - }{} - } - } - - %% Draw boundary of 32 banks - %% Assume fp16 data type - \foreach \bank in {0,...,31}{ - \draw [ultra thin, gray] ($(TL)+(\bank*2*\elem, 0)$) -- ++(0, 2*\elem) - node [scale=.6*\scale, right, black] {\bank}; - } - \draw [ultra thin, gray] ($(TL)+(32*2*\elem, 0)$) -- ++(0, 2*\elem); - \node [scale=.6*\scale, left, black] at ($(TL)+(0, 2*\elem)$) {bank id}; - - \node [scale=\scale, above] at ($(TL)+(.5*\LDSK*\elem, 3*\elem)$) {LDS 32 banks}; - \node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM*\elem)$) {LDSM=\LDSM}; - - %% label phase if swizzling is used - \ifthenelse{\hasSwizzle = 2}{}{ - \node [scale=.6*\scale, above right] at($(TL)+(32*2*\elem, 0)$) {phase}; - } -} - -\newcommand{\drawMFMAInstr}[3]{ - %% - %% Draw layout of mfma instructions with tid labeled - %% - %% C TL: pre defined top-left coordinates of the output matrix - %% \elem: pre defined variable - %% - %% #1: mfmaNonKDim - %% #2: kpack - %% #3: mfmaTrans - \pgfmathsetmacro{\mfmaNonKDim}{#1} - \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} - \pgfmathsetmacro{\kpack}{#2} - \pgfmathsetmacro{\mfmaTrans}{#3} - \pgfmathsetmacro{\nonTrans}{1-#3} - - \pgfmathsetmacro{\gap}{\elem*5} - \coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-\groups*\kpack*\elem, 0)$); - \coordinate (mfma op TL) at (mfma opA TL); - \drawMFMAOperand{\mfmaNonKDim}{\kpack}{0}{1} - \coordinate (mfma op TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+\groups*\kpack*\elem)$); - \drawMFMAOperand{\mfmaNonKDim}{\kpack}{1}{1} - - \coordinate (block TL) at (C TL); - \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{1} - - %% Draw labels - \def\vecR{1.5} - \coordinate (vec TL) at ($(mfma opA TL)+(-.25*\kpack*\elem, 3*\elem*\vecR)$); - \pgfmathsetmacro{\maxVec}{\kpack-1} - \foreach \vecId in {0,...,\maxVec}{ - \draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR); - } - \draw [densely dotted] (mfma opA TL) -- ($(vec TL)+(0, -\elem*\vecR)$); - \draw [densely dotted] ($(mfma opA TL)+(\kpack*\elem, 0)$) -- ($(vec TL)+(\kpack*\elem*\vecR, -\elem*\vecR)$); - \node [scale=.8*\scale, above] at ($(vec TL)+(.5*\kpack*\elem*\vecR, 0)$) {vec=\kpack}; - - \coordinate (vec TL) at ($(mfma op TL)+(-3*\elem*\vecR, .25*\kpack*\elem)$); - \foreach \vecId in {0,...,\maxVec}{ - \draw ($(vec TL)+(0, -\vecId*\elem*\vecR)$) rectangle ++(\elem*\vecR, -\elem*\vecR); - } - \draw [densely dotted] (mfma op TL) -- ($(vec TL)+(\elem*\vecR,0)$); - \draw [densely dotted] ($(mfma op TL)+(0, -\kpack*\elem)$) -- ($(vec TL)+(\elem*\vecR, -\kpack*\elem*\vecR)$); - \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*\kpack*\elem*\vecR)$) {vec=\kpack}; - - \node [scale=\scale, below] at ($(block TL)+(.5*\mfmaNonKDim*\elem,-\mfmaNonKDim*\elem)$) {outC}; - \ifthenelse{\mfmaTrans=0}{ - \node [scale=\scale, below] at ($(mfma opA TL)+(\kpack*\elem, -\mfmaNonKDim*\elem)$) {opA}; - \node [scale=\scale, above] at (mfma op TL) {opB}; - \coordinate (vec TL) at ($(block TL)+(-3*\elem-\elem*\vecR, .25*4*\elem)$); - \foreach \vecId in {0,1,2,3}{ - \draw ($(vec TL)+(0, -\vecId*\elem*\vecR)$) rectangle ++(\elem*\vecR, -\elem*\vecR); - } - \draw [densely dotted] (block TL) -- ++(-3*\elem, .25*4*\elem); - \draw [densely dotted] ($(block TL)+(0, -4*\elem)$) -- ++(-3*\elem, -.25*4*\elem); - \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*4*\elem*\vecR)$) {vec=4}; - \node [scale=.8*\scale, above, align=center] at ($(block TL)+(.5*\mfmaNonKDim*\elem, 0)$) {mfmaLayout\\trans=False}; - }{ - \node [scale=\scale, below] at ($(mfma opA TL)+(\kpack*\elem, -\mfmaNonKDim*\elem)$) {opB}; - \node [scale=\scale, above] at (mfma op TL) {opA}; - \coordinate (vec TL) at ($(block TL)+(-.25*4*\elem, 3*\elem+\elem*\vecR)$); - \foreach \vecId in {0,1,2,3}{ - \draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR); - } - \draw [densely dotted] (block TL) -- ++(-.25*4*\elem, 3*\elem); - \draw [densely dotted] ($(block TL)+(4*\elem, 0)$) -- ++(.25*4*\elem, 3*\elem); - \node [scale=.8*\scale, above] at ($(vec TL)+(.5*4*\elem*\vecR, 0)$) {vec=4}; - \node [scale=.8*\scale, above, align=center] at ($(block TL)+(16*\elem, 0)$) {mfmaLayout\\trans=True}; - } -} - -\newcommand{\drawWMMAOperand}[3]{ - %% - %% Draw the layout of one operand of WMMA instruction - %% - %% #1: opIdx. 0 for opA, 1 for opB - %% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing - %% #3: mode. 0 for w32, 1 for w64 - %% - %% wmma op TL: pre defined top-left coordinates of the operand matrix - - \pgfmathsetmacro{\isOpB}{#1} - \pgfmathsetmacro{\isOpA}{1-\isOpB} - \pgfmathsetmacro{\verbose}{#2} - \pgfmathsetmacro{\isWLarge}{#3} - - \foreach \row in {0,...,15}{ - \pgfmathsetmacro{\ratio}{\row*5+15} - \coordinate (vec TL) at ($(wmma op TL)+(\row*\isOpB*\elem, -\row*\elem*\isOpA)$); - \ifthenelse{\isWLarge=1}{ - \pgfmathsetmacro{\tidone}{int(\row+16)} - \pgfmathsetmacro{\tidtwo}{int(\row+32)} - \pgfmathsetmacro{\tidthree}{int(\row+48)} - \draw [line width=0.005mm, fill=brown!\ratio!white] (vec TL) - rectangle ++(16*\elem*\isOpA+\elem*\isOpB, -\elem*\isOpA-16*\elem*\isOpB) - node [scale=0.4*\scale, pos=.5, rotate=90*\isOpB] {t\row, t\tidone, t\tidtwo, t\tidthree}; - }{ - \pgfmathsetmacro{\tidone}{int(\row+16)} - \draw [line width=0.005mm, fill=brown!\ratio!white] (vec TL) - rectangle ++(16*\elem*\isOpA+\elem*\isOpB, -\elem*\isOpA-16*\elem*\isOpB) - node [scale=0.4*\scale, pos=.5, rotate=90*\isOpB] {t\row, t\tidone}; - } - } -} - -\newcommand{\drawWMMAResult}[2]{ - %% - %% Draw layout of WMMA result tensor - %% - %% #1: verbose. 1 means draw tid in each vec; 0 means draw nothing - %% #2: mode. 0 for w32, 1 for w64 - - \pgfmathsetmacro{\verbose}{#1} - \pgfmathsetmacro{\isWLarge}{#2} - - \pgfmathsetmacro{\numElem}{256} - \pgfmathsetmacro{\maxElemId}{\numElem-1} - - \foreach \elemId in {0,...,\maxElemId}{ - %% figure out the rowID - \pgfmathsetmacro{\rowId}{floor(\elemId/16)} - %% figure out the colID - \pgfmathsetmacro{\colId}{mod(\elemId,16)} - %% figure out the tid and color - \ifthenelse{\isWLarge=1}{ - \pgfmathsetmacro{\tid}{int(mod(\elemId,64))} - \pgfmathsetmacro{\laneId}{mod(\elemId,64)} - }{ - \pgfmathsetmacro{\tid}{int(mod(\elemId,32))} - \pgfmathsetmacro{\laneId}{mod(\elemId,32)} - } - %% figure out the color - \pgfmathsetmacro{\colorId}{floor(\laneId/16)} - \pgfmathsetmacro{\vecColor}{\Colors[\colorId]} - %% Coordinate - \coordinate (vec TL) at ($(C TL)+(\colId*\elem, -\rowId*\elem)$); - \draw [line width=0.005mm, fill=\vecColor!60!white] (vec TL) rectangle ++(\elem, -\elem) - node [scale=.4*\scale, pos=.5] {t\tid}; - } - - -} - -\newcommand{\drawWMMAInstr}[2]{ - %% - %% Draw wmma instruction layouts 16x16x16 - %% - %% #1: mode. 0 for w32, 1 for w64 - %% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing - %% - %% C TL: pre defined top-left coordinates of output matrix - %% \elem: pre defined element size - - - \pgfmathsetmacro{\isWLarge}{#1} - \pgfmathsetmacro{\verbose}{#2} - - \pgfmathsetmacro{\gap}{\elem*2} - \coordinate (wmma op TL) at ($(C TL)+(-\gap-16*\elem, 0)$); - \coordinate (wmma opA TL) at (wmma op TL); - \drawWMMAOperand{0}{\verbose}{\isWLarge} - \coordinate (wmma op TL) at ($(C TL)+(0, \gap+16*\elem)$); - \drawWMMAOperand{1}{\verbose}{\isWLarge} - - \drawWMMAResult{1}{\isWLarge} - - %% labels - \pgfmathsetmacro{\gap}{\elem} - \node [above left, scale=\scale] at (wmma opA TL) {A}; - \node [above left, scale=\scale] at (wmma op TL) {B}; - \node [above right, scale=\scale] at ($(C TL)+(16*\elem, 0)$) {C}; - - %% A k dim - \node [scale=.8*\scale] (k dim A) at ($(wmma opA TL)+(8*\elem,\gap)$) {16}; - \draw [->, >=stealth] (k dim A.west) -- ($(wmma opA TL)+(0, \gap)$); - \draw [->, >=stealth] (k dim A.east) -- ($(wmma opA TL)+(16*\elem, \gap)$); - - %% B K dim - \node [scale=.8*\scale, rotate=90] (k dim B) at ($(wmma op TL)+(-\gap, -8*\elem)$) {16}; - \draw [->, >=stealth] (k dim B.east) -- ($(wmma op TL)+(-\gap, 0)$); - \draw [->, >=stealth] (k dim B.west) -- ($(wmma op TL)+(-\gap, -16*\elem)$); - - %% C M dim - \node [scale=.8*\scale] (m dim) at ($(C TL)+(8*\elem,-16*\elem-\gap)$) {16}; - \draw [->, >=stealth] (m dim.west) -- ($(C TL)+(0, -16*\elem-\gap)$); - \draw [->, >=stealth] (m dim.east) -- ($(C TL)+(16*\elem, -16*\elem-\gap)$); - - %% C N dim - \node [scale=.8*\scale, rotate=-90] (n dim) at ($(C TL)+(16*\elem+\gap, -8*\elem)$) {16}; - \draw [->, >=stealth] (n dim.west) -- ($(C TL)+(16*\elem+\gap, 0)$); - \draw [->, >=stealth] (n dim.east) -- ($(C TL)+(16*\elem+\gap, -16*\elem)$); -} diff --git a/python/perf-kernels/tools/plot-layout/wmmaLayout.tex b/python/perf-kernels/tools/plot-layout/wmmaLayout.tex new file mode 100644 index 000000000000..54141b4928cc --- /dev/null +++ b/python/perf-kernels/tools/plot-layout/wmmaLayout.tex @@ -0,0 +1,121 @@ +\newcommand{\drawWMMAOperand}[3]{ + %% + %% Draw the layout of one operand of WMMA instruction + %% + %% #1: opIdx. 0 for opA, 1 for opB + %% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing + %% #3: mode. 0 for w32, 1 for w64 + %% + %% wmma op TL: pre defined top-left coordinates of the operand matrix + + \pgfmathsetmacro{\isOpB}{#1} + \pgfmathsetmacro{\isOpA}{1-\isOpB} + \pgfmathsetmacro{\verbose}{#2} + \pgfmathsetmacro{\isWLarge}{#3} + + \foreach \row in {0,...,15}{ + \pgfmathsetmacro{\ratio}{\row*5+15} + \coordinate (vec TL) at ($(wmma op TL)+(\row*\isOpB*\elem, -\row*\elem*\isOpA)$); + \ifthenelse{\isWLarge=1}{ + \pgfmathsetmacro{\tidone}{int(\row+16)} + \pgfmathsetmacro{\tidtwo}{int(\row+32)} + \pgfmathsetmacro{\tidthree}{int(\row+48)} + \draw [line width=0.005mm, fill=brown!\ratio!white] (vec TL) + rectangle ++(16*\elem*\isOpA+\elem*\isOpB, -\elem*\isOpA-16*\elem*\isOpB) + node [scale=0.4*\scale, pos=.5, rotate=90*\isOpB] {t\row, t\tidone, t\tidtwo, t\tidthree}; + }{ + \pgfmathsetmacro{\tidone}{int(\row+16)} + \draw [line width=0.005mm, fill=brown!\ratio!white] (vec TL) + rectangle ++(16*\elem*\isOpA+\elem*\isOpB, -\elem*\isOpA-16*\elem*\isOpB) + node [scale=0.4*\scale, pos=.5, rotate=90*\isOpB] {t\row, t\tidone}; + } + } +} + +\newcommand{\drawWMMAResult}[2]{ + %% + %% Draw layout of WMMA result tensor + %% + %% #1: verbose. 1 means draw tid in each vec; 0 means draw nothing + %% #2: mode. 0 for w32, 1 for w64 + + \pgfmathsetmacro{\verbose}{#1} + \pgfmathsetmacro{\isWLarge}{#2} + + \pgfmathsetmacro{\numElem}{256} + \pgfmathsetmacro{\maxElemId}{\numElem-1} + + \foreach \elemId in {0,...,\maxElemId}{ + %% figure out the rowID + \pgfmathsetmacro{\rowId}{floor(\elemId/16)} + %% figure out the colID + \pgfmathsetmacro{\colId}{mod(\elemId,16)} + %% figure out the tid and color + \ifthenelse{\isWLarge=1}{ + \pgfmathsetmacro{\tid}{int(mod(\elemId,64))} + \pgfmathsetmacro{\laneId}{mod(\elemId,64)} + }{ + \pgfmathsetmacro{\tid}{int(mod(\elemId,32))} + \pgfmathsetmacro{\laneId}{mod(\elemId,32)} + } + %% figure out the color + \pgfmathsetmacro{\colorId}{floor(\laneId/16)} + \pgfmathsetmacro{\vecColor}{\Colors[\colorId]} + %% Coordinate + \coordinate (vec TL) at ($(C TL)+(\colId*\elem, -\rowId*\elem)$); + \draw [line width=0.005mm, fill=\vecColor!60!white] (vec TL) rectangle ++(\elem, -\elem) + node [scale=.4*\scale, pos=.5] {t\tid}; + } + + +} + +\newcommand{\drawWMMAInstr}[2]{ + %% + %% Draw wmma instruction layouts 16x16x16 + %% + %% #1: mode. 0 for w32, 1 for w64 + %% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing + %% + %% C TL: pre defined top-left coordinates of output matrix + %% \elem: pre defined element size + + + \pgfmathsetmacro{\isWLarge}{#1} + \pgfmathsetmacro{\verbose}{#2} + + \pgfmathsetmacro{\gap}{\elem*2} + \coordinate (wmma op TL) at ($(C TL)+(-\gap-16*\elem, 0)$); + \coordinate (wmma opA TL) at (wmma op TL); + \drawWMMAOperand{0}{\verbose}{\isWLarge} + \coordinate (wmma op TL) at ($(C TL)+(0, \gap+16*\elem)$); + \drawWMMAOperand{1}{\verbose}{\isWLarge} + + \drawWMMAResult{1}{\isWLarge} + + %% labels + \pgfmathsetmacro{\gap}{\elem} + \node [above left, scale=\scale] at (wmma opA TL) {A}; + \node [above left, scale=\scale] at (wmma op TL) {B}; + \node [above right, scale=\scale] at ($(C TL)+(16*\elem, 0)$) {C}; + + %% A k dim + \node [scale=.8*\scale] (k dim A) at ($(wmma opA TL)+(8*\elem,\gap)$) {16}; + \draw [->, >=stealth] (k dim A.west) -- ($(wmma opA TL)+(0, \gap)$); + \draw [->, >=stealth] (k dim A.east) -- ($(wmma opA TL)+(16*\elem, \gap)$); + + %% B K dim + \node [scale=.8*\scale, rotate=90] (k dim B) at ($(wmma op TL)+(-\gap, -8*\elem)$) {16}; + \draw [->, >=stealth] (k dim B.east) -- ($(wmma op TL)+(-\gap, 0)$); + \draw [->, >=stealth] (k dim B.west) -- ($(wmma op TL)+(-\gap, -16*\elem)$); + + %% C M dim + \node [scale=.8*\scale] (m dim) at ($(C TL)+(8*\elem,-16*\elem-\gap)$) {16}; + \draw [->, >=stealth] (m dim.west) -- ($(C TL)+(0, -16*\elem-\gap)$); + \draw [->, >=stealth] (m dim.east) -- ($(C TL)+(16*\elem, -16*\elem-\gap)$); + + %% C N dim + \node [scale=.8*\scale, rotate=-90] (n dim) at ($(C TL)+(16*\elem+\gap, -8*\elem)$) {16}; + \draw [->, >=stealth] (n dim.west) -- ($(C TL)+(16*\elem+\gap, 0)$); + \draw [->, >=stealth] (n dim.east) -- ($(C TL)+(16*\elem+\gap, -16*\elem)$); +} \ No newline at end of file From e66155aa4271b28a26457df0603243e44d058c16 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Mon, 23 Dec 2024 10:51:19 -0600 Subject: [PATCH 05/11] Extend dotLayout plot to support kWidth=32 - When kWidth is large, use a smaller elemSize honrizontally to save space - Improve the labels, such as - change vec to kWidth for operands - change opA/opB to inA/inB and include operand dims - remove group dims in the operands so that they don't overlap with operand block dims - Better alignment: dot op and mfma zoomed-in pics are bottom aligned --- .../tools/plot-layout/dotLayout.tex | 92 ++++++++++--------- .../tools/plot-layout/plot_layout.py | 21 ++++- 2 files changed, 68 insertions(+), 45 deletions(-) diff --git a/python/perf-kernels/tools/plot-layout/dotLayout.tex b/python/perf-kernels/tools/plot-layout/dotLayout.tex index f9ef6df60aff..641c6b3c51d8 100644 --- a/python/perf-kernels/tools/plot-layout/dotLayout.tex +++ b/python/perf-kernels/tools/plot-layout/dotLayout.tex @@ -102,8 +102,10 @@ %% %% Draw one mfma operand %% - %% mfma op TL: pre defined coordinates of the top-left - %% \elem: pre defined variable + %% Pre-defined variables + %% mfma op TL: coordinates of the top-left + %% \elem: vertical element size of operands, element size of output + %% \elemW: honrizontal element size of operands %% %% #1: mfmNonKDim %% #2: kpack @@ -121,7 +123,6 @@ \foreach \col/\tg in {0,...,\maxGID}{ \pgfmathsetmacro{\col}{\Colors[\tg]} \foreach \tid in {0,...,\maxTID} { - % \pgfmathsetmacro{\ratio}{\tid*2.5+15} \ifthenelse{\verbose=0}{ \draw [line width=0.005mm, fill=\col] ($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$) @@ -129,8 +130,8 @@ }{ \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} \draw [line width=0.005mm, fill=\col] - ($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$) - rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA) + ($(mfma op TL)+(\tg*\kpack*\elemW*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elemW*\opIdxA)$) + rectangle ++(\kpack*\elemW*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elemW*\opIdxA) node [pos=.5, scale=.35*\scale, rotate=90*\opIdxA] {t\drawTid}; } } @@ -264,8 +265,10 @@ \coordinate (A TL) at ($(C TL)+(-\gap-\K*\elem, 0)$); \coordinate (B TL) at ($(C TL)+(0, \gap+\K*\elem)$); + %% Draw both A and B operands \drawDotOperands{\M}{\N}{\K}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\kpack} + %% Draw result tensor \drawTensorMFMALayout{\M}{\N}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\mfmaTrans} %% Draw labels @@ -280,19 +283,19 @@ \node [scale=\scale, above left] at (C TL) {C}; %% label nonKDim - \node [scale=.8*\scale, left] at ($(A TL)+(0, -.5*\mfmaNonKDim*\elem)$) {\mfmaNonKDim}; - \node [scale=.8*\scale, above] at ($(B TL)+(.5*\mfmaNonKDim*\elem, 0)$) {\mfmaNonKDim}; - %% label kpack - \node [scale=.8*\scale, above] at ($(A TL)+(0.5*\groups*\kpack*\elem, 0)$) {\kdim}; - \node [scale=.8*\scale, left] at ($(B TL)+(0, -0.5*\groups\kpack*\elem)$) {\kdim}; + \node [scale=.8*\scale, left] at ($(C TL)+(0, -.5*\mfmaNonKDim*\elem)$) {\mfmaNonKDim}; + \node [scale=.8*\scale, above] at ($(C TL)+(.5*\mfmaNonKDim*\elem, 0)$) {\mfmaNonKDim}; } \newcommand{\drawMFMAInstr}[3]{ %% %% Draw layout of mfma instructions with tid labeled %% - %% C TL: pre defined top-left coordinates of the output matrix - %% \elem: pre defined variable + %% Pre-defined variables + %% C TL: top-left coordinates of the output matrix + %% \elem: vertical element size of operands, element size of output + %% \elemW: honrizontal element size of operands + %% \scaleLabel: extra scale applied to labels according to kWidth %% %% #1: mfmaNonKDim %% #2: kpack @@ -302,58 +305,63 @@ \pgfmathsetmacro{\kpack}{#2} \pgfmathsetmacro{\mfmaTrans}{#3} \pgfmathsetmacro{\nonTrans}{1-#3} + \pgfmathsetmacro{\kDim}{int(\kpack*\groups)} \pgfmathsetmacro{\gap}{\elem*5} - \coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-\groups*\kpack*\elem, 0)$); + \coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-\groups*\kpack*\elemW, 0)$); \coordinate (mfma op TL) at (mfma opA TL); \drawMFMAOperand{\mfmaNonKDim}{\kpack}{0}{1} - \coordinate (mfma op TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+\groups*\kpack*\elem)$); + \coordinate (mfma op TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+\groups*\kpack*\elemW)$); \drawMFMAOperand{\mfmaNonKDim}{\kpack}{1}{1} \coordinate (block TL) at (C TL); \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{1} %% Draw labels - \def\vecR{1.5} - \coordinate (vec TL) at ($(mfma opA TL)+(-.25*\kpack*\elem, 3*\elem*\vecR)$); + %% Draw kWidth vector and lable of first operand + \coordinate (vec TL) at ($(mfma opA TL)+(0, 3*\elem)$); \pgfmathsetmacro{\maxVec}{\kpack-1} \foreach \vecId in {0,...,\maxVec}{ - \draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR); + \draw ($(vec TL)+(\vecId*\elem, 0)$) rectangle ++(\elem, -\elem); } - \draw [densely dotted] (mfma opA TL) -- ($(vec TL)+(0, -\elem*\vecR)$); - \draw [densely dotted] ($(mfma opA TL)+(\kpack*\elem, 0)$) -- ($(vec TL)+(\kpack*\elem*\vecR, -\elem*\vecR)$); - \node [scale=.8*\scale, above] at ($(vec TL)+(.5*\kpack*\elem*\vecR, 0)$) {vec=\kpack}; - - \coordinate (vec TL) at ($(mfma op TL)+(-3*\elem*\vecR, .25*\kpack*\elem)$); + \draw [densely dotted] (mfma opA TL) -- ($(vec TL)+(0, -\elem)$); + \draw [densely dotted] ($(mfma opA TL)+(\kpack*\elemW, 0)$) -- ($(vec TL)+(\kpack*\elem, -\elem)$); + \node [scale=.8*\scaleLabel, above] at ($(vec TL)+(.5*\kpack*\elem, 0)$) {kWidth=\kpack}; + %% Draw kWidth vector and lable of second operand + \coordinate (vec TL) at ($(mfma op TL)+(-3*\elem, 0)$); \foreach \vecId in {0,...,\maxVec}{ - \draw ($(vec TL)+(0, -\vecId*\elem*\vecR)$) rectangle ++(\elem*\vecR, -\elem*\vecR); + \draw ($(vec TL)+(0, -\vecId*\elem)$) rectangle ++(\elem, -\elem); } - \draw [densely dotted] (mfma op TL) -- ($(vec TL)+(\elem*\vecR,0)$); - \draw [densely dotted] ($(mfma op TL)+(0, -\kpack*\elem)$) -- ($(vec TL)+(\elem*\vecR, -\kpack*\elem*\vecR)$); - \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*\kpack*\elem*\vecR)$) {vec=\kpack}; + \draw [densely dotted] (mfma op TL) -- ($(vec TL)+(\elem,0)$); + \draw [densely dotted] ($(mfma op TL)+(0, -\kpack*\elemW)$) -- ($(vec TL)+(\elem, -\kpack*\elem)$); + \node [scale=.8*\scaleLabel, above, rotate=90] at ($(vec TL)+(0, -.5*\kpack*\elem)$) {kWidth=\kpack}; - \node [scale=\scale, below] at ($(block TL)+(.5*\mfmaNonKDim*\elem,-\mfmaNonKDim*\elem)$) {outC}; + %% Draw labels according to mfma.trans or not \ifthenelse{\mfmaTrans=0}{ - \node [scale=\scale, below] at ($(mfma opA TL)+(\kpack*\elem, -\mfmaNonKDim*\elem)$) {opA}; - \node [scale=\scale, above] at (mfma op TL) {opB}; - \coordinate (vec TL) at ($(block TL)+(-3*\elem-\elem*\vecR, .25*4*\elem)$); + \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kpack*\elemW*\groups, 0)$) + {inA:$\mfmaNonKDim \times \kDim$}; + \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma op TL)+(0,-\groups*\kpack*\elemW)$) + {inB:$\kDim \times \mfmaNonKDim$}; + \coordinate (vec TL) at ($(block TL)+(-3*\elem-\elem,0)$); \foreach \vecId in {0,1,2,3}{ - \draw ($(vec TL)+(0, -\vecId*\elem*\vecR)$) rectangle ++(\elem*\vecR, -\elem*\vecR); + \draw ($(vec TL)+(0, -\vecId*\elem)$) rectangle ++(\elem, -\elem); } - \draw [densely dotted] (block TL) -- ++(-3*\elem, .25*4*\elem); - \draw [densely dotted] ($(block TL)+(0, -4*\elem)$) -- ++(-3*\elem, -.25*4*\elem); - \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*4*\elem*\vecR)$) {vec=4}; + \draw [densely dotted] (block TL) -- ++(-3*\elem,0); + \draw [densely dotted] ($(block TL)+(0, -4*\elem)$) -- ++(-3*\elem,0); + \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*4*\elem)$) {vec=4$\times$f32}; \node [scale=.8*\scale, above, align=center] at ($(block TL)+(.5*\mfmaNonKDim*\elem, 0)$) {mfmaLayout\\trans=False}; }{ - \node [scale=\scale, below] at ($(mfma opA TL)+(\kpack*\elem, -\mfmaNonKDim*\elem)$) {opB}; - \node [scale=\scale, above] at (mfma op TL) {opA}; - \coordinate (vec TL) at ($(block TL)+(-.25*4*\elem, 3*\elem+\elem*\vecR)$); + \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kpack*\elemW*\groups, 0)$) + {inB:$\kDim \times \mfmaNonKDim^T$}; + \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma op TL)+(0, -\groups*\kpack*\elemW)$) + {inA:$\mfmaNonKDim \times \kDim^T$}; + \coordinate (vec TL) at ($(block TL)+(0, 3*\elem+\elem)$); \foreach \vecId in {0,1,2,3}{ - \draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR); + \draw ($(vec TL)+(\vecId*\elem, 0)$) rectangle ++(\elem, -\elem); } - \draw [densely dotted] (block TL) -- ++(-.25*4*\elem, 3*\elem); - \draw [densely dotted] ($(block TL)+(4*\elem, 0)$) -- ++(.25*4*\elem, 3*\elem); - \node [scale=.8*\scale, above] at ($(vec TL)+(.5*4*\elem*\vecR, 0)$) {vec=4}; - \node [scale=.8*\scale, above, align=center] at ($(block TL)+(16*\elem, 0)$) {mfmaLayout\\trans=True}; + \draw [densely dotted] (block TL) -- ++(0, 3*\elem); + \draw [densely dotted] ($(block TL)+(4*\elem, 0)$) -- ++(0, 3*\elem); + \node [scale=.8*\scale, above] at ($(vec TL)+(.5*4*\elem, 0)$) {vec=4$\times$f32}; + \node [scale=.6*\scale, above, align=center] at ($(block TL)+(8*\elem, 0)$) {mfmaLayout\\trans=True}; } } \ No newline at end of file diff --git a/python/perf-kernels/tools/plot-layout/plot_layout.py b/python/perf-kernels/tools/plot-layout/plot_layout.py index 298eab79c5a2..f86c0df79736 100644 --- a/python/perf-kernels/tools/plot-layout/plot_layout.py +++ b/python/perf-kernels/tools/plot-layout/plot_layout.py @@ -5,10 +5,22 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth): + elemSmall = 0.04 + elemLarge = 0.16 + if kWidth == 16: + ratio = 0.8 + elif kWidth == 32: + ratio = 0.6 + else: + ratio = 1 + elemWidth = elemLarge * ratio + + scaleLabel = 0.7 if kWidth == 4 else 1 + return f'''\\begin{{document}} \\begin{{tikzpicture}} \\def\\scale{{1}} - \\def\\elem{{0.04}} + \\def\\elem{{{elemSmall}}} \\coordinate (C TL) at (0,0); \\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}}{{{kWidth}}} @@ -16,11 +28,14 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth): \\def\\mfmaTrans{{{trans}}} %% Draw zoomed in view of mfma - \\def\\elem{{.16}} + \\def\\scaleLabel{{{scaleLabel}}} + \\pgfmathsetmacro{{\\oldElem}}{{\\elem}} + \\def\\elem{{{elemLarge}}} + \\def\\elemW{{{elemWidth}}} \\pgfmathsetmacro{{\\gap}}{{\\elem*5}} \\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}} \\pgfmathsetmacro{{\\groups}}{{64/{mfmaNonKDim}}} - \\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kWidth}*\\elem, 0)$); + \\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kWidth}*\\elemW, -{M}*\\oldElem+{mfmaNonKDim}*\\elem)$); \\drawMFMAInstr{{{mfmaNonKDim}}}{{{kWidth}}}{{\\mfmaTrans}} \\end{{tikzpicture}} From 9016bece4b249176f0a1e6e1c98eede9953ad7de Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Tue, 24 Dec 2024 09:44:54 -0600 Subject: [PATCH 06/11] [API change] Add support for kGroup kGroup is defined as total elements per thread / kWidth for one mfma instruction. We need kGroup = 2 only for the newly added mfma_f32_16x16x128_f8f6f4 and mfma_f32_32x32x64_f8f6f4 with f8 input type on MI350. --- .../perf-kernels/tools/plot-layout/README.md | 21 ++-- .../tools/plot-layout/dotLayout.tex | 102 ++++++++++-------- .../tools/plot-layout/plot_layout.py | 29 +++-- 3 files changed, 91 insertions(+), 61 deletions(-) diff --git a/python/perf-kernels/tools/plot-layout/README.md b/python/perf-kernels/tools/plot-layout/README.md index 563e42c697d9..f84b6164ff45 100644 --- a/python/perf-kernels/tools/plot-layout/README.md +++ b/python/perf-kernels/tools/plot-layout/README.md @@ -5,21 +5,26 @@ Here is the help info from the script. ```bash >$ python3 plot_layout.py -h -usage: Draw triton layouts [-h] [-shape SHAPE SHAPE SHAPE] [-plot {blocked,dot,wmma,lds}] [-nonKDim {16,32}] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD] [-threadsPerWarp THREADSPERWARP THREADSPERWARP] - [-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-kWidth {4,8,16}] [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] [-wave_size {32,64}] [-o O] [-mfmaTrans] [-keep] +usage: Draw triton layouts [-h] [-tensorShape TENSORSHAPE TENSORSHAPE] [-dotShape DOTSHAPE DOTSHAPE DOTSHAPE] [-plot {blocked,dot,wmma,lds}] [-dim0 DIM0] [-dim1 DIM1] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD] [-threadsPerWarp THREADSPERWARP THREADSPERWARP] + [-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-nonKDim {16,32}] [-kWidth {4,8,16,32}] [-kGroup {1,2}] [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] [-wave_size {32,64}] [-o O] [-mfmaTrans] [-keep] options: -h, --help show this help message and exit - -shape SHAPE SHAPE SHAPE - Tensor shape in the form of M,N,K + -tensorShape TENSORSHAPE TENSORSHAPE + 2D tensor shape in the form of dim0,dim1 + -dotShape DOTSHAPE DOTSHAPE DOTSHAPE + Dot op shape in the form of M,N,K -plot {blocked,dot,wmma,lds} choose plot mode - -nonKDim {16,32} mfma instruction dim + -dim0 DIM0 tensor dim0 name + -dim1 DIM1 tensor dim1 name -sizePerThread SIZEPERTHREAD SIZEPERTHREAD -threadsPerWarp THREADSPERWARP THREADSPERWARP -warpsPerCTA WARPSPERCTA WARPSPERCTA -order ORDER ORDER - -kWidth {4,8,16} number of elements per thread + -nonKDim {16,32} mfma instruction dim + -kWidth {4,8,16,32} number of contiguous elements per thread + -kGroup {1,2} total number of elements / kWidth per mfma instruction -lds_layout {swizzle,padding,none} choose the LDS data layout -lds_access {read,write,none} @@ -69,6 +74,7 @@ python3 plot_layout.py -plot dot -dotShape 128 128 64 -warpsPerCTA 2 4 -nonKDim python3 plot_layout.py -plot dot -dotShape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 -mfmaTrans python3 plot_layout.py -plot dot -dotShape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 8 python3 plot_layout.py -plot dot -dotShape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 16 +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 16 -kGroup 2 ``` This mode draws two graphs: @@ -79,6 +85,9 @@ This mode draws two graphs: Knobs - `-kWidth`: the number of elements that will be loaded into one thread at once +- `-kGroup`: total number of elements / kWidth for on mfma instruction. + This is 1 for all mfma instructions except for mfma_f32_16x16x128_f8f6f4 and mfma_f32_32x32x64_f8f6f4 + with fp8 input types (CBSZ=0 or 1 and/or BLGP=0 or 1) - `-nonKDim`: 16 ot 32, which is used to control the mfma instruction size - `-mfmaTrans`: if set, the transposed mfma layout will be plotted. diff --git a/python/perf-kernels/tools/plot-layout/dotLayout.tex b/python/perf-kernels/tools/plot-layout/dotLayout.tex index 641c6b3c51d8..c79b0f2158fe 100644 --- a/python/perf-kernels/tools/plot-layout/dotLayout.tex +++ b/python/perf-kernels/tools/plot-layout/dotLayout.tex @@ -98,7 +98,7 @@ \coordinate (TL) at (old TL); } -\newcommand{\drawMFMAOperand}[4]{ +\newcommand{\drawMFMAOperand}[5]{ %% %% Draw one mfma operand %% @@ -109,6 +109,7 @@ %% %% #1: mfmNonKDim %% #2: kpack + %% #2: kGroup %% #3: 0 for opA and 1 for opB %% #4: verbose. 1 means draw tid in each vec; 0 means draw nothing @@ -116,29 +117,34 @@ \pgfmathsetmacro{\maxGID}{64/\nonKDim-1} \pgfmathsetmacro{\maxTID}{\nonKDim-1} \pgfmathsetmacro{\kpack}{#2} - \pgfmathsetmacro{\opIdxA}{#3} + \pgfmathsetmacro{\kGroup}{#3} + \pgfmathsetmacro{\maxGroupId}{\kGroup-1} + \pgfmathsetmacro{\opIdxA}{#4} \pgfmathsetmacro{\opIdxB}{1-\opIdxA} - \pgfmathsetmacro{\verbose}{#4} - - \foreach \col/\tg in {0,...,\maxGID}{ - \pgfmathsetmacro{\col}{\Colors[\tg]} - \foreach \tid in {0,...,\maxTID} { - \ifthenelse{\verbose=0}{ - \draw [line width=0.005mm, fill=\col] - ($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$) - rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA); - }{ - \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} - \draw [line width=0.005mm, fill=\col] - ($(mfma op TL)+(\tg*\kpack*\elemW*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elemW*\opIdxA)$) - rectangle ++(\kpack*\elemW*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elemW*\opIdxA) - node [pos=.5, scale=.35*\scale, rotate=90*\opIdxA] {t\drawTid}; + \pgfmathsetmacro{\verbose}{#5} + + \foreach \gp in {0,...,\maxGroupId}{ + \coordinate (group TL) at ($(mfma op TL)+(\gp*\kpack*64*\elemW/\nonKDim*\opIdxB, -\gp*\kpack*64*\elemW/\nonKDim*\opIdxA)$); + \foreach \col/\tg in {0,...,\maxGID}{ + \pgfmathsetmacro{\col}{\Colors[\tg]} + \foreach \tid in {0,...,\maxTID} { + \ifthenelse{\verbose=0}{ + \draw [line width=0.005mm, fill=\col] + ($(group TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$) + rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA); + }{ + \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} + \draw [line width=0.005mm, fill=\col] + ($(group TL)+(\tg*\kpack*\elemW*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elemW*\opIdxA)$) + rectangle ++(\kpack*\elemW*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elemW*\opIdxA) + node [pos=.5, scale=.35*\scale, rotate=90*\opIdxA] {t\drawTid}; + } } } } } -\newcommand{\drawWaveOperand}[4]{ +\newcommand{\drawWaveOperand}[5]{ %% %% Draw the part of the tensor that is one operand of the wave %% @@ -148,29 +154,31 @@ %% #1: K %% #2: mfmNonKDim %% #3: kpack - %% #4: 0 for opA and 1 for opB + %% #4: kGroup + %% #5: 0 for opA and 1 for opB \pgfmathsetmacro{\K}{#1} \pgfmathsetmacro{\nonKDim}{#2} \pgfmathsetmacro{\groups}{64/\nonKDim} \pgfmathsetmacro{\kpack}{#3} - \pgfmathsetmacro{\opIdx}{#4} + \pgfmathsetmacro{\kGroup}{#4} + \pgfmathsetmacro{\opIdx}{#5} \pgfmathsetmacro{\opIdxOther}{1-\opIdx} \coordinate (TL) at (Op TL); - \pgfmathsetmacro{\numKRep}{\K/\kpack/\groups} + \pgfmathsetmacro{\numKRep}{\K/\kpack/\groups/\kGroup} \pgfmathsetmacro{\maxKRepId}{\numKRep-1} \foreach \repId in {0,...,\maxKRepId}{ - \coordinate (mfma op TL) at ($(TL)+(\repId*\groups*\kpack*\elem*\opIdxOther, -\repId*\groups*\kpack*\elem*\opIdx)$); - \drawMFMAOperand{\nonKDim}{\kpack}{\opIdx}{0} + \coordinate (mfma op TL) at ($(TL)+(\repId*\groups*\kpack*\elem*\kGroup*\opIdxOther, -\repId*\groups*\kpack*\kGroup*\elem*\opIdx)$); + \drawMFMAOperand{\nonKDim}{\kpack}{\kGroup}{\opIdx}{0} \draw [thick] (mfma op TL) rectangle - ++(\groups*\kpack*\elem*\opIdxOther+\nonKDim*\opIdx*\elem, -\nonKDim*\opIdxOther*\elem-\groups*\kpack*\elem*\opIdx); + ++(\groups*\kpack*\kGroup*\elem*\opIdxOther+\nonKDim*\opIdx*\elem, -\nonKDim*\opIdxOther*\elem-\groups*\kpack*\kGroup*\elem*\opIdx); } } -\newcommand{\drawDotOperands}[7]{ +\newcommand{\drawDotOperands}[8]{ %% %% Draw operand tensors of dot %% @@ -184,6 +192,7 @@ %% #5: warpsPerCTA[0] %% #6: warpsPerCTA[1] %% #7: kpack + %% #8: kGroup \pgfmathsetmacro{\M}{#1} \pgfmathsetmacro{\N}{#2} @@ -192,6 +201,7 @@ \pgfmathsetmacro{\warpsPerCTAM}{#5} \pgfmathsetmacro{\warpsPerCTAN}{#6} \pgfmathsetmacro{\kpack}{#7} + \pgfmathsetmacro{\kGroup}{#8} %% operand A \pgfmathsetmacro{\CTARepM}{\M/\warpsPerCTAM/\mfmaNonKDim} @@ -205,7 +215,7 @@ } %% Only draw the detailed view of the first wave in CTA \coordinate (Op TL) at (CTA TL); - \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{0} + \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{\kGroup}{0} %% Draw the outline of each CTA rep \draw [ultra thick] (CTA TL) rectangle ++(\K*\elem, -\warpsPerCTAM*\mfmaNonKDim*\elem); @@ -225,7 +235,7 @@ } %% Only draw the detailed view of the first wave in CTA \coordinate (Op TL) at (CTA TL); - \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{1} + \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{\kGroup}{1} %% Draw the outline of each CTA rep \draw [ultra thick] (CTA TL) rectangle ++(\warpsPerCTAN*\mfmaNonKDim*\elem, -\K*\elem); @@ -234,7 +244,7 @@ } -\newcommand{\drawDot}[8]{ +\newcommand{\drawDot}[9]{ %% %% Draw C = dot A, B %% @@ -249,6 +259,7 @@ %% #6: warpsPerCTA[1] %% #7: 1 for mfma.trans, 0 for normal mfma %% #8: kpack + %% #9: kGroup \pgfmathsetmacro{\M}{#1} \pgfmathsetmacro{\N}{#2} @@ -259,6 +270,7 @@ \pgfmathsetmacro{\warpsPerCTAN}{#6} \pgfmathsetmacro{\mfmaTrans}{#7} \pgfmathsetmacro{\kpack}{#8} + \pgfmathsetmacro{\kGroup}{#9} \pgfmathsetmacro{\kdim}{int(\groups*\kpack)} \pgfmathsetmacro{\gap}{\elem*20} @@ -266,7 +278,7 @@ \coordinate (B TL) at ($(C TL)+(0, \gap+\K*\elem)$); %% Draw both A and B operands - \drawDotOperands{\M}{\N}{\K}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\kpack} + \drawDotOperands{\M}{\N}{\K}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\kpack}{\kGroup} %% Draw result tensor \drawTensorMFMALayout{\M}{\N}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\mfmaTrans} @@ -287,7 +299,7 @@ \node [scale=.8*\scale, above] at ($(C TL)+(.5*\mfmaNonKDim*\elem, 0)$) {\mfmaNonKDim}; } -\newcommand{\drawMFMAInstr}[3]{ +\newcommand{\drawMFMAInstr}[4]{ %% %% Draw layout of mfma instructions with tid labeled %% @@ -299,27 +311,29 @@ %% %% #1: mfmaNonKDim %% #2: kpack - %% #3: mfmaTrans + %% #3: kGroup + %% #4: mfmaTrans \pgfmathsetmacro{\mfmaNonKDim}{#1} \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} \pgfmathsetmacro{\kpack}{#2} - \pgfmathsetmacro{\mfmaTrans}{#3} - \pgfmathsetmacro{\nonTrans}{1-#3} - \pgfmathsetmacro{\kDim}{int(\kpack*\groups)} + \pgfmathsetmacro{\kGroup}{#3} + \pgfmathsetmacro{\mfmaTrans}{#4} + \pgfmathsetmacro{\nonTrans}{1-#4} + \pgfmathsetmacro{\kDim}{int(\kpack*\groups*\kGroup)} \pgfmathsetmacro{\gap}{\elem*5} - \coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-\groups*\kpack*\elemW, 0)$); + \coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-\groups*\kpack*\elemW*\kGroup, 0)$); \coordinate (mfma op TL) at (mfma opA TL); - \drawMFMAOperand{\mfmaNonKDim}{\kpack}{0}{1} - \coordinate (mfma op TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+\groups*\kpack*\elemW)$); - \drawMFMAOperand{\mfmaNonKDim}{\kpack}{1}{1} + \drawMFMAOperand{\mfmaNonKDim}{\kpack}{\kGroup}{0}{1} + \coordinate (mfma op TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+\groups*\kpack*\elemW*\kGroup)$); + \drawMFMAOperand{\mfmaNonKDim}{\kpack}{\kGroup}{1}{1} \coordinate (block TL) at (C TL); \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{1} %% Draw labels %% Draw kWidth vector and lable of first operand - \coordinate (vec TL) at ($(mfma opA TL)+(0, 3*\elem)$); + \coordinate (vec TL) at ($(mfma opA TL)+(0, 5*\elem)$); \pgfmathsetmacro{\maxVec}{\kpack-1} \foreach \vecId in {0,...,\maxVec}{ \draw ($(vec TL)+(\vecId*\elem, 0)$) rectangle ++(\elem, -\elem); @@ -328,7 +342,7 @@ \draw [densely dotted] ($(mfma opA TL)+(\kpack*\elemW, 0)$) -- ($(vec TL)+(\kpack*\elem, -\elem)$); \node [scale=.8*\scaleLabel, above] at ($(vec TL)+(.5*\kpack*\elem, 0)$) {kWidth=\kpack}; %% Draw kWidth vector and lable of second operand - \coordinate (vec TL) at ($(mfma op TL)+(-3*\elem, 0)$); + \coordinate (vec TL) at ($(mfma op TL)+(-5*\elem, 0)$); \foreach \vecId in {0,...,\maxVec}{ \draw ($(vec TL)+(0, -\vecId*\elem)$) rectangle ++(\elem, -\elem); } @@ -338,9 +352,9 @@ %% Draw labels according to mfma.trans or not \ifthenelse{\mfmaTrans=0}{ - \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kpack*\elemW*\groups, 0)$) + \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kpack*\elemW*\groups*\kGroup, 0)$) {inA:$\mfmaNonKDim \times \kDim$}; - \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma op TL)+(0,-\groups*\kpack*\elemW)$) + \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma op TL)+(0,-\groups*\kpack*\elemW*\kGroup)$) {inB:$\kDim \times \mfmaNonKDim$}; \coordinate (vec TL) at ($(block TL)+(-3*\elem-\elem,0)$); \foreach \vecId in {0,1,2,3}{ @@ -351,9 +365,9 @@ \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*4*\elem)$) {vec=4$\times$f32}; \node [scale=.8*\scale, above, align=center] at ($(block TL)+(.5*\mfmaNonKDim*\elem, 0)$) {mfmaLayout\\trans=False}; }{ - \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kpack*\elemW*\groups, 0)$) + \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kpack*\elemW*\groups*\kGroup, 0)$) {inB:$\kDim \times \mfmaNonKDim^T$}; - \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma op TL)+(0, -\groups*\kpack*\elemW)$) + \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma op TL)+(0, -\groups*\kpack*\elemW*\kGroup)$) {inA:$\mfmaNonKDim \times \kDim^T$}; \coordinate (vec TL) at ($(block TL)+(0, 3*\elem+\elem)$); \foreach \vecId in {0,1,2,3}{ diff --git a/python/perf-kernels/tools/plot-layout/plot_layout.py b/python/perf-kernels/tools/plot-layout/plot_layout.py index f86c0df79736..1235cd2db490 100644 --- a/python/perf-kernels/tools/plot-layout/plot_layout.py +++ b/python/perf-kernels/tools/plot-layout/plot_layout.py @@ -4,12 +4,13 @@ import subprocess -def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth): +def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup): elemSmall = 0.04 elemLarge = 0.16 - if kWidth == 16: + elemPerThread = kWidth * kGroup + if elemPerThread == 16: ratio = 0.8 - elif kWidth == 32: + elif elemPerThread == 32: ratio = 0.6 else: ratio = 1 @@ -21,8 +22,9 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth): \\begin{{tikzpicture}} \\def\\scale{{1}} \\def\\elem{{{elemSmall}}} + \\def\\elemW{{\\elem}} \\coordinate (C TL) at (0,0); - \\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}}{{{kWidth}}} + \\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}}{{{kWidth}}}{{{kGroup}}} \\coordinate (C TL) at ($(C TL)+({N}*\elem+32*\elem, 0)$); \\def\\mfmaTrans{{{trans}}} @@ -35,8 +37,8 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth): \\pgfmathsetmacro{{\\gap}}{{\\elem*5}} \\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}} \\pgfmathsetmacro{{\\groups}}{{64/{mfmaNonKDim}}} - \\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kWidth}*\\elemW, -{M}*\\oldElem+{mfmaNonKDim}*\\elem)$); - \\drawMFMAInstr{{{mfmaNonKDim}}}{{{kWidth}}}{{\\mfmaTrans}} + \\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kWidth}*{kGroup}*\\elemW, -{M}*\\oldElem+{mfmaNonKDim}*\\elem)$); + \\drawMFMAInstr{{{mfmaNonKDim}}}{{{kWidth}}}{{{kGroup}}}{{\\mfmaTrans}} \\end{{tikzpicture}} \\end{{document}}''' @@ -120,7 +122,6 @@ def parse_args(): parser.add_argument("-dotShape", type=int, nargs=3, default=(32, 128, 64), help='Dot op shape in the form of M,N,K') parser.add_argument("-plot", type=str, default="blocked", choices=['blocked', 'dot', 'wmma', 'lds'], help='choose plot mode') - parser.add_argument("-nonKDim", type=int, default=32, choices=[16, 32], help='mfma instruction dim') parser.add_argument("-dim0", type=str, default="M", help='tensor dim0 name') parser.add_argument("-dim1", type=str, default="K", help='tensor dim1 name') ## blocked layout parameters @@ -128,9 +129,13 @@ def parse_args(): parser.add_argument("-threadsPerWarp", type=int, nargs=2, default=(16, 4)) parser.add_argument("-warpsPerCTA", type=int, nargs=2, default=(1, 4)) parser.add_argument("-order", type=int, nargs=2, default=(1, 0)) - ## LDS access parameters + ## dot layout parameters + parser.add_argument("-nonKDim", type=int, default=16, choices=[16, 32], help='mfma instruction dim') parser.add_argument("-kWidth", type=int, default=4, choices=[4, 8, 16, 32], help='number of contiguous elements per thread') + parser.add_argument("-kGroup", type=int, default=1, choices=[1, 2], + help='total number of elements / kWidth per mfma instruction') + ## LDS access parameters parser.add_argument("-lds_layout", type=str, default="none", choices=['swizzle', 'padding', 'none'], help='choose the LDS data layout') parser.add_argument("-lds_access", type=str, default="none", choices=['read', 'write', 'none'], @@ -162,6 +167,7 @@ def main(): plot_mode = args.plot mfmaNonKDim = args.nonKDim kWidth = args.kWidth + kGroup = args.kGroup trans = 1 if args.mfmaTrans else 0 ofilename = args.o keepSrc = args.keep @@ -190,7 +196,7 @@ def main(): mfma_inst_str = "mfma_32x32" if mfmaNonKDim == 32 else "mfma_16x16" mfma_trans_str = ".trans" if trans else "" print(f"Plotting dot operation with shapes {M=},{N=},{K=}") - print("MFMA: " + mfma_inst_str + mfma_trans_str + f" {kWidth=}", end=" ") + print("MFMA: " + mfma_inst_str + mfma_trans_str + f" {kWidth=}, {kGroup=}", end=" ") print(f"{warpsPerCTA=}", end=" ") CTAShape.append(mfmaNonKDim * warpsPerCTA[0]) CTAShape.append(mfmaNonKDim * warpsPerCTA[1]) @@ -205,7 +211,8 @@ def main(): if plot_mode == 'dot': assert M != 0 and CTAShape[0] <= M and M % CTAShape[0] == 0, "bad tensor dimension M" assert N != 0 and CTAShape[1] <= N and N % CTAShape[1] == 0, "bad tensor dimension N" - assert K != 0 and K % (2 * kWidth) == 0, "bad tensor dimension K" + kDim = kWidth * kGroup * 64 / mfmaNonKDim + assert K != 0 and K % kDim == 0, f"one mfma instruction requires {kDim:.0f} elements along k dim but BLOCK_K = {K}" if plot_mode == 'lds': print(f"Plotting LDS access for tensor M={M},K={K} with vec={kWidth}") @@ -219,7 +226,7 @@ def main(): draw_blockedLayout_str = draw_blocked_layout_cmd(dim0, dim1, dim0Name, dim1Name, sizePerThread, threadsPerWarp, warpsPerCTA, order) - draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth) + draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup) draw_lds_str = draw_lds_access_cmd(M, K, kWidth, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp) From 97e94f6e49ec97b4b3f4a3694904a9d79f092d52 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 25 Dec 2024 11:20:37 -0600 Subject: [PATCH 07/11] [API change] Add support for data types of both operands And print mfma instruction name accordingly. For now, mixed precision mfma between 8-bit and 4- or 6-bit is not supported yet. --- .../perf-kernels/tools/plot-layout/README.md | 47 +++-- .../tools/plot-layout/dotLayout.tex | 22 ++- .../tools/plot-layout/plot_layout.py | 187 ++++++++++++++++-- 3 files changed, 211 insertions(+), 45 deletions(-) diff --git a/python/perf-kernels/tools/plot-layout/README.md b/python/perf-kernels/tools/plot-layout/README.md index f84b6164ff45..30f5ddfa0ab1 100644 --- a/python/perf-kernels/tools/plot-layout/README.md +++ b/python/perf-kernels/tools/plot-layout/README.md @@ -6,7 +6,8 @@ Here is the help info from the script. ```bash >$ python3 plot_layout.py -h usage: Draw triton layouts [-h] [-tensorShape TENSORSHAPE TENSORSHAPE] [-dotShape DOTSHAPE DOTSHAPE DOTSHAPE] [-plot {blocked,dot,wmma,lds}] [-dim0 DIM0] [-dim1 DIM1] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD] [-threadsPerWarp THREADSPERWARP THREADSPERWARP] - [-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-nonKDim {16,32}] [-kWidth {4,8,16,32}] [-kGroup {1,2}] [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] [-wave_size {32,64}] [-o O] [-mfmaTrans] [-keep] + [-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-nonKDim {16,32}] [-kWidth {4,8,16,32}] [-kGroup {1,2}] [-dtype_a {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] [-dtype_b {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] + [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] [-wave_size {32,64}] [-o O] [-mfmaTrans] [-keep] options: -h, --help show this help message and exit @@ -25,6 +26,10 @@ options: -nonKDim {16,32} mfma instruction dim -kWidth {4,8,16,32} number of contiguous elements per thread -kGroup {1,2} total number of elements / kWidth per mfma instruction + -dtype_a {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8} + element type of operand A + -dtype_b {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8} + element type of operand B -lds_layout {swizzle,padding,none} choose the LDS data layout -lds_access {read,write,none} @@ -69,37 +74,41 @@ Notes Examples: ```bash -python3 plot_layout.py -plot dot -dotShape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 4 -python3 plot_layout.py -plot dot -dotShape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 -python3 plot_layout.py -plot dot -dotShape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 -mfmaTrans -python3 plot_layout.py -plot dot -dotShape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 8 -python3 plot_layout.py -plot dot -dotShape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 16 -python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 16 -kGroup 2 +## i8 inputs +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 8 -dtype_a i8 -dtype_b i8 +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -dtype_a i8 -dtype_b i8 +## fp16/bf16 inputs +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 4 -dtype_a fp16 -dtype_b fp16 +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 8 -dtype_a fp16 -dtype_b fp16 +## fp8/bf8 inputs +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 8 -dtype_a fp8 -dtype_b bf8 +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -dtype_a fp8 -dtype_b bf8 +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp8 -dtype_b bf8 +## f4 and fp6/bf6 inputs +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 32 -kGroup 1 -dtype_a f4 -dtype_b bf6 ``` +One can add `-nonKDim [16,32]` and `-mfmaTrans` to all of the above examples. + This mode draws two graphs: -1. The layout of the whole tile for tile A, B, and C +1. The layout of the dot operation, i.e. tile C = tile A x tile B 2. The layout of a single mfma block, operands and results of one or more mfma instructions that share the same accumulating VGPRs. - This view has thread distributions among tensor elements. Knobs -- `-kWidth`: the number of elements that will be loaded into one thread at once -- `-kGroup`: total number of elements / kWidth for on mfma instruction. +- `-kWidth [4,8,16,32]`: the number of elements that will be loaded into one thread at once +- `-kGroup [1,2]`: total number of elements / kWidth for on mfma instruction. This is 1 for all mfma instructions except for mfma_f32_16x16x128_f8f6f4 and mfma_f32_32x32x64_f8f6f4 with fp8 input types (CBSZ=0 or 1 and/or BLGP=0 or 1) -- `-nonKDim`: 16 ot 32, which is used to control the mfma instruction size +- `-nonKDim [16,32]`: mfma instruction size. The default is set to 16. - `-mfmaTrans`: if set, the transposed mfma layout will be plotted. +- `-dtype_a` and `-dtype_b`: element types of operand A and B. The default value is fp16. Notes - The layout shows the mapping from the threads/wave to the elements in the - original tensor. It does not care if the elements are re-arranged in LDS, like - swizzling to avoid bank conflicts. -- The script does not allow settings for data type or k dim of the mfma instruction. - This can be controled by the `-kWidth` flag. - - For example, if we want `mfma_32x32x8xf16`, we can set `-nonKDim 32` and `-kWidth 4`. - - If we want `mfma_32x32x16xf8`, we can set `-nonKDim 32` and `-kWidth 8`. - + original tensor. It does not matter if LDS is used. +- The script does not allow settings for k dim of the mfma instruction. + This can be controled by the `-kWidth` and `-kGroup`. ## Draw LDS access (`-plot lds`) diff --git a/python/perf-kernels/tools/plot-layout/dotLayout.tex b/python/perf-kernels/tools/plot-layout/dotLayout.tex index c79b0f2158fe..17824c6fa5b2 100644 --- a/python/perf-kernels/tools/plot-layout/dotLayout.tex +++ b/python/perf-kernels/tools/plot-layout/dotLayout.tex @@ -299,7 +299,7 @@ \node [scale=.8*\scale, above] at ($(C TL)+(.5*\mfmaNonKDim*\elem, 0)$) {\mfmaNonKDim}; } -\newcommand{\drawMFMAInstr}[4]{ +\newcommand{\drawMFMAInstr}[7]{ %% %% Draw layout of mfma instructions with tid labeled %% @@ -313,6 +313,10 @@ %% #2: kpack %% #3: kGroup %% #4: mfmaTrans + %% #5: dtype_a + %% #6: dtype_b + %% #7: outType + \pgfmathsetmacro{\mfmaNonKDim}{#1} \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} \pgfmathsetmacro{\kpack}{#2} @@ -332,6 +336,10 @@ \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{1} %% Draw labels + %% Set data types + \def\opAType{#5} + \def\opBType{#6} + \def\outType{#7} %% Draw kWidth vector and lable of first operand \coordinate (vec TL) at ($(mfma opA TL)+(0, 5*\elem)$); \pgfmathsetmacro{\maxVec}{\kpack-1} @@ -353,29 +361,29 @@ %% Draw labels according to mfma.trans or not \ifthenelse{\mfmaTrans=0}{ \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kpack*\elemW*\groups*\kGroup, 0)$) - {inA:$\mfmaNonKDim \times \kDim$}; + {inA:$\mfmaNonKDim \times \kDim \times $\opAType}; \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma op TL)+(0,-\groups*\kpack*\elemW*\kGroup)$) - {inB:$\kDim \times \mfmaNonKDim$}; + {inB:$\kDim \times \mfmaNonKDim \times $\opBType}; \coordinate (vec TL) at ($(block TL)+(-3*\elem-\elem,0)$); \foreach \vecId in {0,1,2,3}{ \draw ($(vec TL)+(0, -\vecId*\elem)$) rectangle ++(\elem, -\elem); } \draw [densely dotted] (block TL) -- ++(-3*\elem,0); \draw [densely dotted] ($(block TL)+(0, -4*\elem)$) -- ++(-3*\elem,0); - \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*4*\elem)$) {vec=4$\times$f32}; + \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*4*\elem)$) {vec=4$\times$\outType}; \node [scale=.8*\scale, above, align=center] at ($(block TL)+(.5*\mfmaNonKDim*\elem, 0)$) {mfmaLayout\\trans=False}; }{ \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kpack*\elemW*\groups*\kGroup, 0)$) - {inB:$\kDim \times \mfmaNonKDim^T$}; + {inB:$\kDim \times \mfmaNonKDim^T \times $\opBType}; \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma op TL)+(0, -\groups*\kpack*\elemW*\kGroup)$) - {inA:$\mfmaNonKDim \times \kDim^T$}; + {inA:$\mfmaNonKDim \times \kDim^T \times $\opAType}; \coordinate (vec TL) at ($(block TL)+(0, 3*\elem+\elem)$); \foreach \vecId in {0,1,2,3}{ \draw ($(vec TL)+(\vecId*\elem, 0)$) rectangle ++(\elem, -\elem); } \draw [densely dotted] (block TL) -- ++(0, 3*\elem); \draw [densely dotted] ($(block TL)+(4*\elem, 0)$) -- ++(0, 3*\elem); - \node [scale=.8*\scale, above] at ($(vec TL)+(.5*4*\elem, 0)$) {vec=4$\times$f32}; + \node [scale=.8*\scale, above] at ($(vec TL)+(.5*4*\elem, 0)$) {vec=4$\times$\outType}; \node [scale=.6*\scale, above, align=center] at ($(block TL)+(8*\elem, 0)$) {mfmaLayout\\trans=True}; } } \ No newline at end of file diff --git a/python/perf-kernels/tools/plot-layout/plot_layout.py b/python/perf-kernels/tools/plot-layout/plot_layout.py index 1235cd2db490..dc5e4d1fb632 100644 --- a/python/perf-kernels/tools/plot-layout/plot_layout.py +++ b/python/perf-kernels/tools/plot-layout/plot_layout.py @@ -4,7 +4,8 @@ import subprocess -def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup): +def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup, dtype_a, dtype_b, mfma_inst_str, + kpack): elemSmall = 0.04 elemLarge = 0.16 elemPerThread = kWidth * kGroup @@ -16,7 +17,9 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup ratio = 1 elemWidth = elemLarge * ratio - scaleLabel = 0.7 if kWidth == 4 else 1 + scaleLabel = 0.7 if (kWidth == 4 or (kWidth == 8 and mfmaNonKDim == 32)) else 1 + + outType = 'i32' if dtype_a == 'i8' else 'f32' return f'''\\begin{{document}} \\begin{{tikzpicture}} @@ -38,7 +41,9 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup \\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}} \\pgfmathsetmacro{{\\groups}}{{64/{mfmaNonKDim}}} \\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kWidth}*{kGroup}*\\elemW, -{M}*\\oldElem+{mfmaNonKDim}*\\elem)$); - \\drawMFMAInstr{{{mfmaNonKDim}}}{{{kWidth}}}{{{kGroup}}}{{\\mfmaTrans}} + \\coordinate (mfma instr) at ($(C TL)+(-.5*\\gap-0.6*\\nonTrans*\\gap-0.4*\\mfmaTrans*\\gap, 1.5*\\gap+.5*\\mfmaTrans*\\gap)$); + \\node [scale=\scaleLabel, above left, align=left, draw=black, fill=white] at (mfma instr) {{{mfma_inst_str}}}; + \\drawMFMAInstr{{{mfmaNonKDim}}}{{{kWidth}}}{{{kGroup}}}{{\\mfmaTrans}}{{{dtype_a}}}{{{dtype_b}}}{{{outType}}} \\end{{tikzpicture}} \\end{{document}}''' @@ -106,6 +111,125 @@ def draw_wmma_instr_cmd(waveSize): \\end{{document}}''' +matrixFormatTable = {'fp8': 0, 'bf8': 1, 'fp6': 2, 'bf6': 3, 'f4': 4} + + +def matrixFormat(dtype_a, dtype_b): + ''' + return CBSZ and BLGP according to data types + b000: E4M3(FP8) + b001: E5M2(BF8) + b010: E2M3(FP6) + b011: E3M2(BF6) + b100: E2M1(FP4) + ''' + return matrixFormatTable[dtype_a], matrixFormatTable[dtype_b] + + +def isType4Or6Bit(dtype): + return dtype == 'fp6' or dtype == 'bf6' or dtype == 'f4' + + +def isType8BitFloat(dtype): + return dtype == 'fp8' or dtype == 'bf8' + + +def isType16Bit(dtype): + return dtype == 'bf16' or dtype == 'fp16' + + +def isMixedPrecType(dtype): + return isType8BitFloat(dtype) or isType4Or6Bit(dtype) + + +def isMixedPrecBtwF8AndF4OrF6(dtype_a, dtype_b): + return (isType8BitFloat(dtype_a) and isType4Or6Bit(dtype_b)) or (isType8BitFloat(dtype_b) + and isType4Or6Bit(dtype_a)) + + +def checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtype_a, dtype_b, trans): + ## Check input types + ## Mixed precision is only allowed within f8, f6 and f4 + assert (isMixedPrecType(dtype_a) and isMixedPrecType(dtype_b)) or ( + dtype_a == dtype_b), f"Cannot do mixed precision mfma with {dtype_a} and {dtype_b}" + ''' + Check mfma size according to data types + * refers to newly added instructions on MI350 + Both dtyes are f4 or fp6 or bf6 + *mfma_f32_16x16x128_f8f6f4: kWidth = 32, kGroup = 1 + *mfma_f32_32x32x64_f8f6f4: kWidth = 32, kGroup = 1 + One dtype is fp8 or bf8 + When the other operand is f4, fp6, or bf6 + *mfma_f32_16x16x128_f8f6f4: kWidth = 16, kGroup = 2 + *mfma_f32_32x32x64_f8f6f4: kWidth = 16, kGroup = 2 + When the other operand is fp8 or bf8 + *mfma_f32_16x16x128_f8f6f4: kWidth = 16, kGroup = 2 + mfma_f32_16x16x32_fp8/bf8_fp8/bf8: kWidth = 16, kGroup = 1, kpack=2 + mfma_f32_16x16x32_fp8/bf8_fp8/bf8: kWidth = 8, kGroup = 1 + *mfma_f32_32x32x64_f8f6f4: kWidth = 16, kGroup = 2 + mfma_f32_32x32x16_fp8/bf8_fp8/bf8: kWidth = 16, kGroup = 1, kpack=2 + mfma_f32_32x32x16_fp8/bf8_fp8/bf8: kWidth = 8, kGroup = 1 + Both dtypes are bf16 or bf16 + *mfma_f32_16x16x32_f16/bf16: kWidth = 8, kGroup = 1 + mfma_f32_16x16x16_f16/bf16: kWidth = 4, kGroup = 1 + *mfma_f32_32x32x16_f16/bf16: kWidth = 8, kGroup = 1 + mfma_f32_32x32x8_f16/bf16: kWidth = 4, kGroup = 1 + Both types are i8 + *mfma_i32_16x16x64_i8: kWidth = 16, kGroup = 1 + mfma_i32_16x16x32_i8: kWidth = 8, kGroup = 1 + *mfma_i32_32x32x32_i8: kWidth = 16, kGroup = 1 + mfma_i32_32x32x16_i8: kWidth = 8, kGroup = 1 + + Return mfma instruction name and kpack + ''' + kDim = 64 / mfmaNonKDim * kWidth * kGroup + ## Both dtyes are f4 or fp6 or bf6 + if isType4Or6Bit(dtype_a) and isType4Or6Bit(dtype_b): + assert kWidth == 32 and kGroup == 1, f"Only kWidth=32 and kGroup=1 is supported for {dtype_a} x {dtype_b}" + kpack = 1 + CBSZ = matrixFormatTable[dtype_b] if trans else matrixFormatTable[dtype_a] + BLGP = matrixFormatTable[dtype_a] if trans else matrixFormatTable[dtype_b] + return f"mfma_f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_f8f6f4", kpack, CBSZ, BLGP + + ## Both dtypes are fp8 or bf8 + if isType8BitFloat(dtype_a) and isType8BitFloat(dtype_b): + assert (kWidth == 8 and kGroup == 1) or ( + kWidth == 16), f"Not a valid mfma instruction for {dtype_a} x {dtype_b} with {kWidth=} and {kGroup=}" + kpack = 2 if (kWidth == 16 and kGroup == 1) else 1 + if kGroup == 2: + suffix = "f8f6f4" + CBSZ = matrixFormatTable[dtype_b] if trans else matrixFormatTable[dtype_a] + BLGP = matrixFormatTable[dtype_a] if trans else matrixFormatTable[dtype_b] + else: + suffix = f"{dtype_b}_{dtype_a}" if trans else f"{dtype_a}_{dtype_b}" + CBSZ = -1 + BLGP = -1 + kDim = kDim / 2 if kpack == 2 else kDim + return f"mfma_f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{suffix}", kpack, CBSZ, BLGP + + ## Both types are fp16 or bf16 + if isType16Bit(dtype_a) and isType16Bit(dtype_b): + assert ( + kWidth == 8 or kWidth == 4 + ) and kGroup == 1, f"Not a valid mfma instruction for {dtype_a} x {dtype_b} with {kWidth=} and {kGroup=}" + kpack = 1 + CBSZ = -1 + BLGP = -1 + return f"mfma_f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtype_a}", kpack, CBSZ, BLGP + + ## Both types are i8 + if dtype_a == 'i8' and dtype_b == 'i8': + assert ( + kWidth == 16 or kWidth == 8 + ) and kGroup == 1, f"Not a valid mfma instruction for {dtype_a} x {dtype_b} with {kWidth=} and {kGroup=}" + kpack = 1 + CBSZ = -1 + BLGP = -1 + return f"mfma_i32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtype_a}", kpack, CBSZ, BLGP + + assert False, "Mixed precision between fp8/bf8 and fp6/bf6/f4 not supported in this mode" + + def run_bash_command(commandstring): proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout=subprocess.PIPE) return proc.stdout.splitlines() @@ -135,6 +259,12 @@ def parse_args(): help='number of contiguous elements per thread') parser.add_argument("-kGroup", type=int, default=1, choices=[1, 2], help='total number of elements / kWidth per mfma instruction') + parser.add_argument("-dtype_a", type=str, default='fp16', + choices=['fp16', 'bf16', 'fp8', 'bf8', 'fp6', 'bf6', 'f4', + 'i8'], help='element type of operand A') + parser.add_argument("-dtype_b", type=str, default='fp16', + choices=['fp16', 'bf16', 'fp8', 'bf8', 'fp6', 'bf6', 'f4', + 'i8'], help='element type of operand B') ## LDS access parameters parser.add_argument("-lds_layout", type=str, default="none", choices=['swizzle', 'padding', 'none'], help='choose the LDS data layout') @@ -168,6 +298,8 @@ def main(): mfmaNonKDim = args.nonKDim kWidth = args.kWidth kGroup = args.kGroup + dtype_a = args.dtype_a + dtype_b = args.dtype_b trans = 1 if args.mfmaTrans else 0 ofilename = args.o keepSrc = args.keep @@ -191,28 +323,44 @@ def main(): print(f"{order=}", end=" ") CTAShape.append(sizePerThread[0] * threadsPerWarp[0] * warpsPerCTA[0]) CTAShape.append(sizePerThread[1] * threadsPerWarp[1] * warpsPerCTA[1]) - - if plot_mode == 'dot': - mfma_inst_str = "mfma_32x32" if mfmaNonKDim == 32 else "mfma_16x16" - mfma_trans_str = ".trans" if trans else "" - print(f"Plotting dot operation with shapes {M=},{N=},{K=}") - print("MFMA: " + mfma_inst_str + mfma_trans_str + f" {kWidth=}, {kGroup=}", end=" ") - print(f"{warpsPerCTA=}", end=" ") - CTAShape.append(mfmaNonKDim * warpsPerCTA[0]) - CTAShape.append(mfmaNonKDim * warpsPerCTA[1]) - - if plot_mode == 'blocked' or plot_mode == 'dot': print(f"CTAShape={CTAShape}") - - if plot_mode == 'blocked': assert dim0 != 0 and CTAShape[0] <= dim0 and dim0 % CTAShape[0] == 0, "bad tensor dimension " + dim0Name assert dim1 != 0 and CTAShape[1] <= dim1 and dim1 % CTAShape[1] == 0, "bad tensor dimension " + dim1Name if plot_mode == 'dot': + CTAShape.append(mfmaNonKDim * warpsPerCTA[0]) + CTAShape.append(mfmaNonKDim * warpsPerCTA[1]) + print(f"Plotting dot operation with shapes=M{M}-N{N}-K{K},{kWidth=},{kGroup=},{warpsPerCTA=},{CTAShape=}") assert M != 0 and CTAShape[0] <= M and M % CTAShape[0] == 0, "bad tensor dimension M" assert N != 0 and CTAShape[1] <= N and N % CTAShape[1] == 0, "bad tensor dimension N" - kDim = kWidth * kGroup * 64 / mfmaNonKDim - assert K != 0 and K % kDim == 0, f"one mfma instruction requires {kDim:.0f} elements along k dim but BLOCK_K = {K}" + if isMixedPrecBtwF8AndF4OrF6(dtype_a, dtype_b): + ## In the case of mixed precision between 8-bit and 4 or 6-bit, + ## ignore kWidth and kGroup since inA and inB have different kWidth and kGroup values + kDim = 128 + assert K != 0 and K % kDim == 0, f"one mfma instruction requires {kDim:.0f} elements along k dim but BLOCK_K = {K}" + kpack = 1 + CBSZ = matrixFormatTable[dtype_b] if trans else matrixFormatTable[dtype_a] + BLGP = matrixFormatTable[dtype_a] if trans else matrixFormatTable[dtype_b] + mfma_inst_str = f"mfma_f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_f8f6f4" + else: + kDim = kWidth * kGroup * 64 / mfmaNonKDim + assert K != 0 and K % kDim == 0, f"one mfma instruction requires {kDim:.0f} elements along k dim but BLOCK_K = {K}" + mfma_inst_str, kpack, CBSZ, BLGP = checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtype_a, dtype_b, trans) + flag = '' if CBSZ == -1 else f" with {CBSZ=},{BLGP=}" + print(f"MFMA: {mfma_inst_str} x {kpack}{flag}", end="") + mfma_inst_str = mfma_inst_str.replace("_", "\\_") + mfma_inst_str = mfma_inst_str + flag + if kpack == 2: + mfma_inst_str = mfma_inst_str + " $\\times$ 2" + if ((dtype_a == 'fp16' or dtype_a == 'bf16') and kWidth == 8) or (dtype_a == 'i8' and kWidth == 16): + kDim = 64 / mfmaNonKDim * kWidth / 2 + outType = "i32" if dtype_a == 'i8' else "f32" + old_instr = f"mfma_{outType}_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtype_a}" + print(f" or {old_instr} x 2") + old_instr = old_instr.replace("_", "\\_") + mfma_inst_str = mfma_inst_str + " or\\\\" + old_instr + "$\\times$2" + else: + print("") if plot_mode == 'lds': print(f"Plotting LDS access for tensor M={M},K={K} with vec={kWidth}") @@ -226,7 +374,8 @@ def main(): draw_blockedLayout_str = draw_blocked_layout_cmd(dim0, dim1, dim0Name, dim1Name, sizePerThread, threadsPerWarp, warpsPerCTA, order) - draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup) + draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup, dtype_a, + dtype_b, mfma_inst_str, kpack) draw_lds_str = draw_lds_access_cmd(M, K, kWidth, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp) From d338835ca85aa7f5e970423f4621f9c973a0f158 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 25 Dec 2024 18:22:04 -0600 Subject: [PATCH 08/11] Support mixed mfma with bf8/fp8 and fp6/bf6/f4 --- .../perf-kernels/tools/plot-layout/README.md | 2 + .../tools/plot-layout/dotLayout.tex | 110 +++++++++--------- .../tools/plot-layout/plot_layout.py | 41 +++++-- 3 files changed, 88 insertions(+), 65 deletions(-) diff --git a/python/perf-kernels/tools/plot-layout/README.md b/python/perf-kernels/tools/plot-layout/README.md index 30f5ddfa0ab1..bb81152b6bd2 100644 --- a/python/perf-kernels/tools/plot-layout/README.md +++ b/python/perf-kernels/tools/plot-layout/README.md @@ -86,6 +86,8 @@ python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp8 -dtype_b bf8 ## f4 and fp6/bf6 inputs python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 32 -kGroup 1 -dtype_a f4 -dtype_b bf6 +## fp8/bf8 and fp6/bf6/f4 inputs +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp6 -dtype_b bf8 ``` One can add `-nonKDim [16,32]` and `-mfmaTrans` to all of the above examples. diff --git a/python/perf-kernels/tools/plot-layout/dotLayout.tex b/python/perf-kernels/tools/plot-layout/dotLayout.tex index 17824c6fa5b2..805f2dc852e7 100644 --- a/python/perf-kernels/tools/plot-layout/dotLayout.tex +++ b/python/perf-kernels/tools/plot-layout/dotLayout.tex @@ -108,7 +108,7 @@ %% \elemW: honrizontal element size of operands %% %% #1: mfmNonKDim - %% #2: kpack + %% #2: kWidth %% #2: kGroup %% #3: 0 for opA and 1 for opB %% #4: verbose. 1 means draw tid in each vec; 0 means draw nothing @@ -116,7 +116,7 @@ \pgfmathsetmacro{\nonKDim}{#1} \pgfmathsetmacro{\maxGID}{64/\nonKDim-1} \pgfmathsetmacro{\maxTID}{\nonKDim-1} - \pgfmathsetmacro{\kpack}{#2} + \pgfmathsetmacro{\kWidth}{#2} \pgfmathsetmacro{\kGroup}{#3} \pgfmathsetmacro{\maxGroupId}{\kGroup-1} \pgfmathsetmacro{\opIdxA}{#4} @@ -124,19 +124,19 @@ \pgfmathsetmacro{\verbose}{#5} \foreach \gp in {0,...,\maxGroupId}{ - \coordinate (group TL) at ($(mfma op TL)+(\gp*\kpack*64*\elemW/\nonKDim*\opIdxB, -\gp*\kpack*64*\elemW/\nonKDim*\opIdxA)$); + \coordinate (group TL) at ($(mfma op TL)+(\gp*\kWidth*64*\elemW/\nonKDim*\opIdxB, -\gp*\kWidth*64*\elemW/\nonKDim*\opIdxA)$); \foreach \col/\tg in {0,...,\maxGID}{ \pgfmathsetmacro{\col}{\Colors[\tg]} \foreach \tid in {0,...,\maxTID} { \ifthenelse{\verbose=0}{ \draw [line width=0.005mm, fill=\col] - ($(group TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$) - rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA); + ($(group TL)+(\tg*\kWidth*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kWidth*\elem*\opIdxA)$) + rectangle ++(\kWidth*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kWidth*\elem*\opIdxA); }{ \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} \draw [line width=0.005mm, fill=\col] - ($(group TL)+(\tg*\kpack*\elemW*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elemW*\opIdxA)$) - rectangle ++(\kpack*\elemW*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elemW*\opIdxA) + ($(group TL)+(\tg*\kWidth*\elemW*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kWidth*\elemW*\opIdxA)$) + rectangle ++(\kWidth*\elemW*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kWidth*\elemW*\opIdxA) node [pos=.5, scale=.35*\scale, rotate=90*\opIdxA] {t\drawTid}; } } @@ -153,32 +153,32 @@ %% %% #1: K %% #2: mfmNonKDim - %% #3: kpack + %% #3: kWidth %% #4: kGroup %% #5: 0 for opA and 1 for opB \pgfmathsetmacro{\K}{#1} \pgfmathsetmacro{\nonKDim}{#2} \pgfmathsetmacro{\groups}{64/\nonKDim} - \pgfmathsetmacro{\kpack}{#3} + \pgfmathsetmacro{\kWidth}{#3} \pgfmathsetmacro{\kGroup}{#4} \pgfmathsetmacro{\opIdx}{#5} \pgfmathsetmacro{\opIdxOther}{1-\opIdx} \coordinate (TL) at (Op TL); - \pgfmathsetmacro{\numKRep}{\K/\kpack/\groups/\kGroup} + \pgfmathsetmacro{\numKRep}{\K/\kWidth/\groups/\kGroup} \pgfmathsetmacro{\maxKRepId}{\numKRep-1} \foreach \repId in {0,...,\maxKRepId}{ - \coordinate (mfma op TL) at ($(TL)+(\repId*\groups*\kpack*\elem*\kGroup*\opIdxOther, -\repId*\groups*\kpack*\kGroup*\elem*\opIdx)$); - \drawMFMAOperand{\nonKDim}{\kpack}{\kGroup}{\opIdx}{0} + \coordinate (mfma op TL) at ($(TL)+(\repId*\groups*\kWidth*\elem*\kGroup*\opIdxOther, -\repId*\groups*\kWidth*\kGroup*\elem*\opIdx)$); + \drawMFMAOperand{\nonKDim}{\kWidth}{\kGroup}{\opIdx}{0} \draw [thick] (mfma op TL) rectangle - ++(\groups*\kpack*\kGroup*\elem*\opIdxOther+\nonKDim*\opIdx*\elem, -\nonKDim*\opIdxOther*\elem-\groups*\kpack*\kGroup*\elem*\opIdx); + ++(\groups*\kWidth*\kGroup*\elem*\opIdxOther+\nonKDim*\opIdx*\elem, -\nonKDim*\opIdxOther*\elem-\groups*\kWidth*\kGroup*\elem*\opIdx); } } -\newcommand{\drawDotOperands}[8]{ +\newcommand{\drawDotOperands}[6]{ %% %% Draw operand tensors of dot %% @@ -191,8 +191,6 @@ %% #4: MFMA nonKDim %% #5: warpsPerCTA[0] %% #6: warpsPerCTA[1] - %% #7: kpack - %% #8: kGroup \pgfmathsetmacro{\M}{#1} \pgfmathsetmacro{\N}{#2} @@ -200,8 +198,6 @@ \pgfmathsetmacro{\mfmaNonKDim}{#4} \pgfmathsetmacro{\warpsPerCTAM}{#5} \pgfmathsetmacro{\warpsPerCTAN}{#6} - \pgfmathsetmacro{\kpack}{#7} - \pgfmathsetmacro{\kGroup}{#8} %% operand A \pgfmathsetmacro{\CTARepM}{\M/\warpsPerCTAM/\mfmaNonKDim} @@ -215,7 +211,7 @@ } %% Only draw the detailed view of the first wave in CTA \coordinate (Op TL) at (CTA TL); - \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{\kGroup}{0} + \drawWaveOperand{\K}{\mfmaNonKDim}{\kWidthA}{\kGroupA}{0} %% Draw the outline of each CTA rep \draw [ultra thick] (CTA TL) rectangle ++(\K*\elem, -\warpsPerCTAM*\mfmaNonKDim*\elem); @@ -235,7 +231,7 @@ } %% Only draw the detailed view of the first wave in CTA \coordinate (Op TL) at (CTA TL); - \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{\kGroup}{1} + \drawWaveOperand{\K}{\mfmaNonKDim}{\kWidthB}{\kGroupB}{1} %% Draw the outline of each CTA rep \draw [ultra thick] (CTA TL) rectangle ++(\warpsPerCTAN*\mfmaNonKDim*\elem, -\K*\elem); @@ -244,7 +240,7 @@ } -\newcommand{\drawDot}[9]{ +\newcommand{\drawDot}[7]{ %% %% Draw C = dot A, B %% @@ -258,8 +254,6 @@ %% #5: warpsPerCTA[0] %% #6: warpsPerCTA[1] %% #7: 1 for mfma.trans, 0 for normal mfma - %% #8: kpack - %% #9: kGroup \pgfmathsetmacro{\M}{#1} \pgfmathsetmacro{\N}{#2} @@ -269,16 +263,13 @@ \pgfmathsetmacro{\warpsPerCTAM}{#5} \pgfmathsetmacro{\warpsPerCTAN}{#6} \pgfmathsetmacro{\mfmaTrans}{#7} - \pgfmathsetmacro{\kpack}{#8} - \pgfmathsetmacro{\kGroup}{#9} - \pgfmathsetmacro{\kdim}{int(\groups*\kpack)} \pgfmathsetmacro{\gap}{\elem*20} \coordinate (A TL) at ($(C TL)+(-\gap-\K*\elem, 0)$); \coordinate (B TL) at ($(C TL)+(0, \gap+\K*\elem)$); %% Draw both A and B operands - \drawDotOperands{\M}{\N}{\K}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\kpack}{\kGroup} + \drawDotOperands{\M}{\N}{\K}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN} %% Draw result tensor \drawTensorMFMALayout{\M}{\N}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\mfmaTrans} @@ -299,7 +290,7 @@ \node [scale=.8*\scale, above] at ($(C TL)+(.5*\mfmaNonKDim*\elem, 0)$) {\mfmaNonKDim}; } -\newcommand{\drawMFMAInstr}[7]{ +\newcommand{\drawMFMAInstr}[5]{ %% %% Draw layout of mfma instructions with tid labeled %% @@ -310,59 +301,68 @@ %% \scaleLabel: extra scale applied to labels according to kWidth %% %% #1: mfmaNonKDim - %% #2: kpack - %% #3: kGroup - %% #4: mfmaTrans - %% #5: dtype_a - %% #6: dtype_b - %% #7: outType + %% #2: mfmaTrans + %% #3: dtype_a + %% #4: dtype_b + %% #5: outType \pgfmathsetmacro{\mfmaNonKDim}{#1} \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} - \pgfmathsetmacro{\kpack}{#2} - \pgfmathsetmacro{\kGroup}{#3} - \pgfmathsetmacro{\mfmaTrans}{#4} - \pgfmathsetmacro{\nonTrans}{1-#4} - \pgfmathsetmacro{\kDim}{int(\kpack*\groups*\kGroup)} + \pgfmathsetmacro{\mfmaTrans}{#2} + \pgfmathsetmacro{\nonTrans}{1-#2} + + \ifthenelse{\mfmaTrans=0}{ + \pgfmathsetmacro{\kWidthLeft}{\kWidthA} + \pgfmathsetmacro{\kWidthRight}{\kWidthB} + \pgfmathsetmacro{\kGroupLeft}{\kGroupA} + \pgfmathsetmacro{\kGroupRight}{\kGroupB} + }{ + \pgfmathsetmacro{\kWidthLeft}{\kWidthB} + \pgfmathsetmacro{\kWidthRight}{\kWidthA} + \pgfmathsetmacro{\kGroupLeft}{\kGroupB} + \pgfmathsetmacro{\kGroupRight}{\kGroupA} + } + \pgfmathsetmacro{\kDim}{int(\kWidthLeft*\groups*\kGroupLeft)} \pgfmathsetmacro{\gap}{\elem*5} - \coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-\groups*\kpack*\elemW*\kGroup, 0)$); + \coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-\groups*\kWidthLeft*\elemW*\kGroupLeft, 0)$); \coordinate (mfma op TL) at (mfma opA TL); - \drawMFMAOperand{\mfmaNonKDim}{\kpack}{\kGroup}{0}{1} - \coordinate (mfma op TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+\groups*\kpack*\elemW*\kGroup)$); - \drawMFMAOperand{\mfmaNonKDim}{\kpack}{\kGroup}{1}{1} + \drawMFMAOperand{\mfmaNonKDim}{\kWidthLeft}{\kGroupLeft}{0}{1} + \coordinate (mfma op TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+\groups*\kWidthRight*\elemW*\kGroupRight)$); + \drawMFMAOperand{\mfmaNonKDim}{\kWidthRight}{\kGroupRight}{1}{1} \coordinate (block TL) at (C TL); \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{1} %% Draw labels %% Set data types - \def\opAType{#5} - \def\opBType{#6} - \def\outType{#7} + \def\opAType{#3} + \def\opBType{#4} + \def\outType{#5} %% Draw kWidth vector and lable of first operand \coordinate (vec TL) at ($(mfma opA TL)+(0, 5*\elem)$); - \pgfmathsetmacro{\maxVec}{\kpack-1} + \pgfmathsetmacro{\maxVec}{\kWidthLeft-1} \foreach \vecId in {0,...,\maxVec}{ \draw ($(vec TL)+(\vecId*\elem, 0)$) rectangle ++(\elem, -\elem); } \draw [densely dotted] (mfma opA TL) -- ($(vec TL)+(0, -\elem)$); - \draw [densely dotted] ($(mfma opA TL)+(\kpack*\elemW, 0)$) -- ($(vec TL)+(\kpack*\elem, -\elem)$); - \node [scale=.8*\scaleLabel, above] at ($(vec TL)+(.5*\kpack*\elem, 0)$) {kWidth=\kpack}; + \draw [densely dotted] ($(mfma opA TL)+(\kWidthLeft*\elemW, 0)$) -- ($(vec TL)+(\kWidthLeft*\elem, -\elem)$); + \node [scale=.8*\scaleLabel, above] at ($(vec TL)+(.5*\kWidthLeft*\elem, 0)$) {kWidth=\kWidthLeft}; %% Draw kWidth vector and lable of second operand \coordinate (vec TL) at ($(mfma op TL)+(-5*\elem, 0)$); + \pgfmathsetmacro{\maxVec}{\kWidthRight-1} \foreach \vecId in {0,...,\maxVec}{ \draw ($(vec TL)+(0, -\vecId*\elem)$) rectangle ++(\elem, -\elem); } \draw [densely dotted] (mfma op TL) -- ($(vec TL)+(\elem,0)$); - \draw [densely dotted] ($(mfma op TL)+(0, -\kpack*\elemW)$) -- ($(vec TL)+(\elem, -\kpack*\elem)$); - \node [scale=.8*\scaleLabel, above, rotate=90] at ($(vec TL)+(0, -.5*\kpack*\elem)$) {kWidth=\kpack}; + \draw [densely dotted] ($(mfma op TL)+(0, -\kWidthRight*\elemW)$) -- ($(vec TL)+(\elem, -\kWidthRight*\elem)$); + \node [scale=.8*\scaleLabel, above, rotate=90] at ($(vec TL)+(0, -.5*\kWidthRight*\elem)$) {kWidth=\kWidthRight}; %% Draw labels according to mfma.trans or not \ifthenelse{\mfmaTrans=0}{ - \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kpack*\elemW*\groups*\kGroup, 0)$) + \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kWidthLeft*\elemW*\groups*\kGroupLeft, 0)$) {inA:$\mfmaNonKDim \times \kDim \times $\opAType}; - \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma op TL)+(0,-\groups*\kpack*\elemW*\kGroup)$) + \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma op TL)+(0,-\groups*\kWidthRight*\elemW*\kGroupRight)$) {inB:$\kDim \times \mfmaNonKDim \times $\opBType}; \coordinate (vec TL) at ($(block TL)+(-3*\elem-\elem,0)$); \foreach \vecId in {0,1,2,3}{ @@ -373,9 +373,9 @@ \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*4*\elem)$) {vec=4$\times$\outType}; \node [scale=.8*\scale, above, align=center] at ($(block TL)+(.5*\mfmaNonKDim*\elem, 0)$) {mfmaLayout\\trans=False}; }{ - \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kpack*\elemW*\groups*\kGroup, 0)$) + \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kWidthLeft*\elemW*\groups*\kGroupLeft, 0)$) {inB:$\kDim \times \mfmaNonKDim^T \times $\opBType}; - \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma op TL)+(0, -\groups*\kpack*\elemW*\kGroup)$) + \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma op TL)+(0, -\groups*\kWidthRight*\elemW*\kGroupRight)$) {inA:$\mfmaNonKDim \times \kDim^T \times $\opAType}; \coordinate (vec TL) at ($(block TL)+(0, 3*\elem+\elem)$); \foreach \vecId in {0,1,2,3}{ diff --git a/python/perf-kernels/tools/plot-layout/plot_layout.py b/python/perf-kernels/tools/plot-layout/plot_layout.py index dc5e4d1fb632..5af4ce910a6d 100644 --- a/python/perf-kernels/tools/plot-layout/plot_layout.py +++ b/python/perf-kernels/tools/plot-layout/plot_layout.py @@ -5,10 +5,29 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup, dtype_a, dtype_b, mfma_inst_str, - kpack): + kpack, isMixed864): + scaleLabel = 0.7 if (kWidth == 4 or (kWidth == 8 and mfmaNonKDim == 32)) else 1 + + outType = 'i32' if dtype_a == 'i8' else 'f32' + kWidth_a = kWidth_b = kWidth + kGroup_a = kGroup_b = kGroup + if isMixed864: + if isType8BitFloat(dtype_a): + kWidth_a = 16 + kGroup_a = 2 + kWidth_b = 32 + kGroup_b = 1 + else: + kWidth_a = 32 + kGroup_a = 1 + kWidth_b = 16 + kGroup_b = 2 + kWidth_left = kWidth_b if trans else kWidth_a + kGroup_left = kGroup_b if trans else kGroup_a + elemSmall = 0.04 elemLarge = 0.16 - elemPerThread = kWidth * kGroup + elemPerThread = kWidth_a * kGroup_a if elemPerThread == 16: ratio = 0.8 elif elemPerThread == 32: @@ -17,17 +36,17 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup ratio = 1 elemWidth = elemLarge * ratio - scaleLabel = 0.7 if (kWidth == 4 or (kWidth == 8 and mfmaNonKDim == 32)) else 1 - - outType = 'i32' if dtype_a == 'i8' else 'f32' - return f'''\\begin{{document}} \\begin{{tikzpicture}} \\def\\scale{{1}} \\def\\elem{{{elemSmall}}} \\def\\elemW{{\\elem}} + \\def\\kWidthA{{{kWidth_a}}} + \\def\\kWidthB{{{kWidth_b}}} + \\def\\kGroupA{{{kGroup_a}}} + \\def\\kGroupB{{{kGroup_b}}} \\coordinate (C TL) at (0,0); - \\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}}{{{kWidth}}}{{{kGroup}}} + \\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}} \\coordinate (C TL) at ($(C TL)+({N}*\elem+32*\elem, 0)$); \\def\\mfmaTrans{{{trans}}} @@ -40,10 +59,10 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup \\pgfmathsetmacro{{\\gap}}{{\\elem*5}} \\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}} \\pgfmathsetmacro{{\\groups}}{{64/{mfmaNonKDim}}} - \\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kWidth}*{kGroup}*\\elemW, -{M}*\\oldElem+{mfmaNonKDim}*\\elem)$); + \\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kWidth_left}*{kGroup_left}*\\elemW, -{M}*\\oldElem+{mfmaNonKDim}*\\elem)$); \\coordinate (mfma instr) at ($(C TL)+(-.5*\\gap-0.6*\\nonTrans*\\gap-0.4*\\mfmaTrans*\\gap, 1.5*\\gap+.5*\\mfmaTrans*\\gap)$); \\node [scale=\scaleLabel, above left, align=left, draw=black, fill=white] at (mfma instr) {{{mfma_inst_str}}}; - \\drawMFMAInstr{{{mfmaNonKDim}}}{{{kWidth}}}{{{kGroup}}}{{\\mfmaTrans}}{{{dtype_a}}}{{{dtype_b}}}{{{outType}}} + \\drawMFMAInstr{{{mfmaNonKDim}}}{{\\mfmaTrans}}{{{dtype_a}}}{{{dtype_b}}}{{{outType}}} \\end{{tikzpicture}} \\end{{document}}''' @@ -342,10 +361,12 @@ def main(): CBSZ = matrixFormatTable[dtype_b] if trans else matrixFormatTable[dtype_a] BLGP = matrixFormatTable[dtype_a] if trans else matrixFormatTable[dtype_b] mfma_inst_str = f"mfma_f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_f8f6f4" + isMixed864 = True else: kDim = kWidth * kGroup * 64 / mfmaNonKDim assert K != 0 and K % kDim == 0, f"one mfma instruction requires {kDim:.0f} elements along k dim but BLOCK_K = {K}" mfma_inst_str, kpack, CBSZ, BLGP = checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtype_a, dtype_b, trans) + isMixed864 = False flag = '' if CBSZ == -1 else f" with {CBSZ=},{BLGP=}" print(f"MFMA: {mfma_inst_str} x {kpack}{flag}", end="") mfma_inst_str = mfma_inst_str.replace("_", "\\_") @@ -375,7 +396,7 @@ def main(): warpsPerCTA, order) draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup, dtype_a, - dtype_b, mfma_inst_str, kpack) + dtype_b, mfma_inst_str, kpack, isMixed864) draw_lds_str = draw_lds_access_cmd(M, K, kWidth, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp) From f59dd58f33988118241d8c57cadf9f71cec3456a Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 25 Dec 2024 23:40:18 -0600 Subject: [PATCH 09/11] [API change] Add support for scale --- .../perf-kernels/tools/plot-layout/README.md | 11 +- .../tools/plot-layout/dotLayout.tex | 101 ++++++++++++++---- .../tools/plot-layout/plot_layout.py | 41 ++++--- .../tools/plot-layout/preamble.tex | 20 ++-- 4 files changed, 125 insertions(+), 48 deletions(-) diff --git a/python/perf-kernels/tools/plot-layout/README.md b/python/perf-kernels/tools/plot-layout/README.md index bb81152b6bd2..a41a74a29c28 100644 --- a/python/perf-kernels/tools/plot-layout/README.md +++ b/python/perf-kernels/tools/plot-layout/README.md @@ -6,8 +6,8 @@ Here is the help info from the script. ```bash >$ python3 plot_layout.py -h usage: Draw triton layouts [-h] [-tensorShape TENSORSHAPE TENSORSHAPE] [-dotShape DOTSHAPE DOTSHAPE DOTSHAPE] [-plot {blocked,dot,wmma,lds}] [-dim0 DIM0] [-dim1 DIM1] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD] [-threadsPerWarp THREADSPERWARP THREADSPERWARP] - [-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-nonKDim {16,32}] [-kWidth {4,8,16,32}] [-kGroup {1,2}] [-dtype_a {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] [-dtype_b {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] - [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] [-wave_size {32,64}] [-o O] [-mfmaTrans] [-keep] + [-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-nonKDim {16,32}] [-kWidth {4,8,16,32}] [-kGroup {1,2}] [-dtype_a {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] [-dtype_b {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] [-mfmaTrans] [-scale] + [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] [-wave_size {32,64}] [-o O] [-keep] options: -h, --help show this help message and exit @@ -30,13 +30,14 @@ options: element type of operand A -dtype_b {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8} element type of operand B + -mfmaTrans If set, then use mfma.trans layout + -scale If set, plot the scale tensor for mfma_f8f6f4 instructions -lds_layout {swizzle,padding,none} choose the LDS data layout -lds_access {read,write,none} choose LDS access mode -wave_size {32,64} choose the wmma instruction mode -o O output pdf file name (without surfix) - -mfmaTrans If set, then use mfma.trans layout -keep If set, keep the generated .tex file ``` @@ -88,6 +89,8 @@ python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 32 -kGroup 1 -dtype_a f4 -dtype_b bf6 ## fp8/bf8 and fp6/bf6/f4 inputs python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp6 -dtype_b bf8 +## mixed precision with scaling +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp6 -dtype_b bf8 -scale ``` One can add `-nonKDim [16,32]` and `-mfmaTrans` to all of the above examples. @@ -105,6 +108,8 @@ Knobs - `-nonKDim [16,32]`: mfma instruction size. The default is set to 16. - `-mfmaTrans`: if set, the transposed mfma layout will be plotted. - `-dtype_a` and `-dtype_b`: element types of operand A and B. The default value is fp16. +- `-scale`: plot scale tensors for A and B. This is only supported with f4/f6 and f8 with `kGroup=2`. + If `-scale` is set but not supported, it's ignored. Notes - The layout shows the mapping from the threads/wave to the elements in the diff --git a/python/perf-kernels/tools/plot-layout/dotLayout.tex b/python/perf-kernels/tools/plot-layout/dotLayout.tex index 805f2dc852e7..b8a259b60231 100644 --- a/python/perf-kernels/tools/plot-layout/dotLayout.tex +++ b/python/perf-kernels/tools/plot-layout/dotLayout.tex @@ -20,7 +20,7 @@ \foreach \iVec in {0,...,\maxIVec} { \coordinate (wave TL) at ($(block TL)+(\trans*\iVec*\groups*4*\elem, -\nonTrans*\iVec*\groups*4*\elem)$); \foreach \tg in {0,...,\maxGID}{ - \pgfmathsetmacro{\colID}{\tg+4} + \pgfmathsetmacro{\colID}{\tg} \pgfmathsetmacro{\col}{\Colors[\colID]} \foreach \tid in {0,...,\maxTID} { \pgfmathsetmacro{\ratio}{\tid*2.5*\groups+15} @@ -290,7 +290,35 @@ \node [scale=.8*\scale, above] at ($(C TL)+(.5*\mfmaNonKDim*\elem, 0)$) {\mfmaNonKDim}; } -\newcommand{\drawMFMAInstr}[5]{ +\newcommand{\drawZoomInVec}[3]{ + %% + %% Draw zoomed in view of vector of elements + %% + %% predefined variables + %% vec TL: top-left coordinates of the vector + %% orig TL: top-left coordinates of the original small vector + %% \elem: vertical element size of operands, element size of output + %% \elemW: honrizontal element size of operands + %% \scaleLabel: extra scale applied to labels according to kWidth + %% + %% #1: number of elements + %% #2: 0 for opLeft, 1 for opRight + %% #3: label + + \pgfmathsetmacro{\kWidth}{#1} + \pgfmathsetmacro{\opLeft}{#2} + \pgfmathsetmacro{\opRight}{1-#2} + + \pgfmathsetmacro{\maxVec}{\kWidth-1} + \foreach \vecId in {0,...,\maxVec}{ + \draw ($(vec TL)+(\vecId*\elem*\opRight, -\vecId*\elem*\opLeft)$) rectangle ++(\elem, -\elem); + } + \draw [densely dotted] (orig TL) -- ($(vec TL)+(\elem*\opLeft, -\elem*\opRight)$); + \draw [densely dotted] ($(orig TL)+(\kWidth*\elemW*\opRight, -\kWidth*\elemW*\opLeft)$) -- ($(vec TL)+(\kWidth*\elem*\opRight+\elem*\opLeft, -\elem*\opRight-\kWidth*\elem*\opLeft)$); + \node [scale=.8*\scaleLabel, above, rotate=90*\opLeft] at ($(vec TL)+(.5*\kWidth*\elem*\opRight, -.5*\kWidth*\elem*\opLeft)$) {#3}; +} + +\newcommand{\drawMFMAInstr}[6]{ %% %% Draw layout of mfma instructions with tid labeled %% @@ -305,11 +333,13 @@ %% #3: dtype_a %% #4: dtype_b %% #5: outType + %% #6: scaling: if set, draw scaling tensors \pgfmathsetmacro{\mfmaNonKDim}{#1} \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} \pgfmathsetmacro{\mfmaTrans}{#2} \pgfmathsetmacro{\nonTrans}{1-#2} + \pgfmathsetmacro{\drawScale}{#6} \ifthenelse{\mfmaTrans=0}{ \pgfmathsetmacro{\kWidthLeft}{\kWidthA} @@ -324,13 +354,27 @@ } \pgfmathsetmacro{\kDim}{int(\kWidthLeft*\groups*\kGroupLeft)} + %% Draw operand left \pgfmathsetmacro{\gap}{\elem*5} \coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-\groups*\kWidthLeft*\elemW*\kGroupLeft, 0)$); \coordinate (mfma op TL) at (mfma opA TL); \drawMFMAOperand{\mfmaNonKDim}{\kWidthLeft}{\kGroupLeft}{0}{1} - \coordinate (mfma op TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+\groups*\kWidthRight*\elemW*\kGroupRight)$); + %% Draw operand right + \coordinate (mfma opB TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+\groups*\kWidthRight*\elemW*\kGroupRight)$); + \coordinate (mfma op TL) at (mfma opB TL); \drawMFMAOperand{\mfmaNonKDim}{\kWidthRight}{\kGroupRight}{1}{1} + %% Draw scaling tensors if needed + \ifthenelse{\drawScale=1}{ + \coordinate (left scaling TL) at ($(mfma opA TL)+(-0.3*\gap-\groups*4*\elemW, 0)$); + \coordinate (mfma op TL) at (left scaling TL); + \drawMFMAOperand{\mfmaNonKDim}{4}{1}{0}{1} + + \coordinate (right scaling TL) at ($(mfma opB TL)+(0, 0.3*\gap+\groups*4*\elemW)$); + \coordinate (mfma op TL) at (right scaling TL); + \drawMFMAOperand{\mfmaNonKDim}{4}{1}{1}{1} + }{} + \coordinate (block TL) at (C TL); \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{1} @@ -339,31 +383,38 @@ \def\opAType{#3} \def\opBType{#4} \def\outType{#5} - %% Draw kWidth vector and lable of first operand + + %% Draw kWidth vector and label of first operand \coordinate (vec TL) at ($(mfma opA TL)+(0, 5*\elem)$); - \pgfmathsetmacro{\maxVec}{\kWidthLeft-1} - \foreach \vecId in {0,...,\maxVec}{ - \draw ($(vec TL)+(\vecId*\elem, 0)$) rectangle ++(\elem, -\elem); - } - \draw [densely dotted] (mfma opA TL) -- ($(vec TL)+(0, -\elem)$); - \draw [densely dotted] ($(mfma opA TL)+(\kWidthLeft*\elemW, 0)$) -- ($(vec TL)+(\kWidthLeft*\elem, -\elem)$); - \node [scale=.8*\scaleLabel, above] at ($(vec TL)+(.5*\kWidthLeft*\elem, 0)$) {kWidth=\kWidthLeft}; - %% Draw kWidth vector and lable of second operand - \coordinate (vec TL) at ($(mfma op TL)+(-5*\elem, 0)$); - \pgfmathsetmacro{\maxVec}{\kWidthRight-1} - \foreach \vecId in {0,...,\maxVec}{ - \draw ($(vec TL)+(0, -\vecId*\elem)$) rectangle ++(\elem, -\elem); - } - \draw [densely dotted] (mfma op TL) -- ($(vec TL)+(\elem,0)$); - \draw [densely dotted] ($(mfma op TL)+(0, -\kWidthRight*\elemW)$) -- ($(vec TL)+(\elem, -\kWidthRight*\elem)$); - \node [scale=.8*\scaleLabel, above, rotate=90] at ($(vec TL)+(0, -.5*\kWidthRight*\elem)$) {kWidth=\kWidthRight}; + \coordinate (orig TL) at (mfma opA TL); + \drawZoomInVec{\kWidthLeft}{0}{kWidth=\kWidthLeft} + + %% Draw kWidth vector and label of second operand + \coordinate (vec TL) at ($(mfma opB TL)+(-5*\elem, 0)$); + \coordinate (orig TL) at (mfma opB TL); + \drawZoomInVec{\kWidthRight}{1}{kWidth=\kWidthRight} + + \ifthenelse{\drawScale=1}{ + %% Draw vec and label of scalingLeft + \coordinate (vec TL) at ($(left scaling TL)+(0, 5*\elem)$); + \coordinate (orig TL) at (left scaling TL); + \drawZoomInVec{4}{0}{vec=4$\times$e8m0} + %% Draw vec and label of scalingRight + \coordinate (vec TL) at ($(right scaling TL)+(-5*\elem, 0)$); + \coordinate (orig TL) at (right scaling TL); + \drawZoomInVec{4}{1}{vec=4$\times$e8m0} + }{} %% Draw labels according to mfma.trans or not \ifthenelse{\mfmaTrans=0}{ \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kWidthLeft*\elemW*\groups*\kGroupLeft, 0)$) {inA:$\mfmaNonKDim \times \kDim \times $\opAType}; - \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma op TL)+(0,-\groups*\kWidthRight*\elemW*\kGroupRight)$) + \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma opB TL)+(0,-\groups*\kWidthRight*\elemW*\kGroupRight)$) {inB:$\kDim \times \mfmaNonKDim \times $\opBType}; + \ifthenelse{\drawScale=1}{ + \node [scale=\scaleLabel, above] at ($(left scaling TL)+(0.5*4*\elemW*\groups, 0)$) {scaleA}; + \node [scale=\scaleLabel, above, rotate=90] at ($(right scaling TL)+(0,-0.5*\groups*4*\elemW)$) {scaleB}; + }{} \coordinate (vec TL) at ($(block TL)+(-3*\elem-\elem,0)$); \foreach \vecId in {0,1,2,3}{ \draw ($(vec TL)+(0, -\vecId*\elem)$) rectangle ++(\elem, -\elem); @@ -375,8 +426,12 @@ }{ \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kWidthLeft*\elemW*\groups*\kGroupLeft, 0)$) {inB:$\kDim \times \mfmaNonKDim^T \times $\opBType}; - \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma op TL)+(0, -\groups*\kWidthRight*\elemW*\kGroupRight)$) + \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma opB TL)+(0, -\groups*\kWidthRight*\elemW*\kGroupRight)$) {inA:$\mfmaNonKDim \times \kDim^T \times $\opAType}; + \ifthenelse{\drawScale=1}{ + \node [scale=\scaleLabel, above] at ($(left scaling TL)+(.5*4*\elemW*\groups, 0)$) {scaleB}; + \node [scale=\scaleLabel, above, rotate=90] at ($(right scaling TL)+(0, -.5*\groups*4*\elemW)$) {scaleA}; + }{} \coordinate (vec TL) at ($(block TL)+(0, 3*\elem+\elem)$); \foreach \vecId in {0,1,2,3}{ \draw ($(vec TL)+(\vecId*\elem, 0)$) rectangle ++(\elem, -\elem); @@ -384,6 +439,6 @@ \draw [densely dotted] (block TL) -- ++(0, 3*\elem); \draw [densely dotted] ($(block TL)+(4*\elem, 0)$) -- ++(0, 3*\elem); \node [scale=.8*\scale, above] at ($(vec TL)+(.5*4*\elem, 0)$) {vec=4$\times$\outType}; - \node [scale=.6*\scale, above, align=center] at ($(block TL)+(8*\elem, 0)$) {mfmaLayout\\trans=True}; + \node [scale=.6*\scale, above, align=center] at ($(block TL)+(.5*\mfmaNonKDim*\elem, 0)$) {mfmaLayout\\trans=True}; } } \ No newline at end of file diff --git a/python/perf-kernels/tools/plot-layout/plot_layout.py b/python/perf-kernels/tools/plot-layout/plot_layout.py index 5af4ce910a6d..9833aedf28c1 100644 --- a/python/perf-kernels/tools/plot-layout/plot_layout.py +++ b/python/perf-kernels/tools/plot-layout/plot_layout.py @@ -5,7 +5,7 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup, dtype_a, dtype_b, mfma_inst_str, - kpack, isMixed864): + kpack, isMixed864, plot_scale): scaleLabel = 0.7 if (kWidth == 4 or (kWidth == 8 and mfmaNonKDim == 32)) else 1 outType = 'i32' if dtype_a == 'i8' else 'f32' @@ -36,6 +36,8 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup ratio = 1 elemWidth = elemLarge * ratio + scaling = 1 if plot_scale else 0 + return f'''\\begin{{document}} \\begin{{tikzpicture}} \\def\\scale{{1}} @@ -59,10 +61,10 @@ def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup \\pgfmathsetmacro{{\\gap}}{{\\elem*5}} \\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}} \\pgfmathsetmacro{{\\groups}}{{64/{mfmaNonKDim}}} - \\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kWidth_left}*{kGroup_left}*\\elemW, -{M}*\\oldElem+{mfmaNonKDim}*\\elem)$); + \\coordinate (C TL) at ($(C TL)+({scaling}*0.3*\\gap+{scaling}*\\groups*4*\elemW+.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kWidth_left}*{kGroup_left}*\\elemW, -{M}*\\oldElem+{mfmaNonKDim}*\\elem)$); \\coordinate (mfma instr) at ($(C TL)+(-.5*\\gap-0.6*\\nonTrans*\\gap-0.4*\\mfmaTrans*\\gap, 1.5*\\gap+.5*\\mfmaTrans*\\gap)$); \\node [scale=\scaleLabel, above left, align=left, draw=black, fill=white] at (mfma instr) {{{mfma_inst_str}}}; - \\drawMFMAInstr{{{mfmaNonKDim}}}{{\\mfmaTrans}}{{{dtype_a}}}{{{dtype_b}}}{{{outType}}} + \\drawMFMAInstr{{{mfmaNonKDim}}}{{\\mfmaTrans}}{{{dtype_a}}}{{{dtype_b}}}{{{outType}}}{{{scaling}}} \\end{{tikzpicture}} \\end{{document}}''' @@ -166,7 +168,7 @@ def isMixedPrecBtwF8AndF4OrF6(dtype_a, dtype_b): and isType4Or6Bit(dtype_a)) -def checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtype_a, dtype_b, trans): +def checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtype_a, dtype_b, trans, scale): ## Check input types ## Mixed precision is only allowed within f8, f6 and f4 assert (isMixedPrecType(dtype_a) and isMixedPrecType(dtype_b)) or ( @@ -208,7 +210,8 @@ def checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtype_a, dtype_b, trans): kpack = 1 CBSZ = matrixFormatTable[dtype_b] if trans else matrixFormatTable[dtype_a] BLGP = matrixFormatTable[dtype_a] if trans else matrixFormatTable[dtype_b] - return f"mfma_f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_f8f6f4", kpack, CBSZ, BLGP + scale_str = 'scale_' if scale else '' + return f"mfma_{scale_str}f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_f8f6f4", kpack, CBSZ, BLGP, scale ## Both dtypes are fp8 or bf8 if isType8BitFloat(dtype_a) and isType8BitFloat(dtype_b): @@ -219,12 +222,16 @@ def checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtype_a, dtype_b, trans): suffix = "f8f6f4" CBSZ = matrixFormatTable[dtype_b] if trans else matrixFormatTable[dtype_a] BLGP = matrixFormatTable[dtype_a] if trans else matrixFormatTable[dtype_b] + plot_scale = scale + scale_str = 'scale_' if scale else '' else: suffix = f"{dtype_b}_{dtype_a}" if trans else f"{dtype_a}_{dtype_b}" CBSZ = -1 BLGP = -1 + plot_scale = False + scale_str = '' kDim = kDim / 2 if kpack == 2 else kDim - return f"mfma_f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{suffix}", kpack, CBSZ, BLGP + return f"mfma_{scale_str}f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{suffix}", kpack, CBSZ, BLGP, plot_scale ## Both types are fp16 or bf16 if isType16Bit(dtype_a) and isType16Bit(dtype_b): @@ -234,7 +241,7 @@ def checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtype_a, dtype_b, trans): kpack = 1 CBSZ = -1 BLGP = -1 - return f"mfma_f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtype_a}", kpack, CBSZ, BLGP + return f"mfma_f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtype_a}", kpack, CBSZ, BLGP, False ## Both types are i8 if dtype_a == 'i8' and dtype_b == 'i8': @@ -244,7 +251,7 @@ def checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtype_a, dtype_b, trans): kpack = 1 CBSZ = -1 BLGP = -1 - return f"mfma_i32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtype_a}", kpack, CBSZ, BLGP + return f"mfma_i32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtype_a}", kpack, CBSZ, BLGP, False assert False, "Mixed precision between fp8/bf8 and fp6/bf6/f4 not supported in this mode" @@ -284,6 +291,9 @@ def parse_args(): parser.add_argument("-dtype_b", type=str, default='fp16', choices=['fp16', 'bf16', 'fp8', 'bf8', 'fp6', 'bf6', 'f4', 'i8'], help='element type of operand B') + parser.add_argument("-mfmaTrans", action='store_true', default=False, help='If set, then use mfma.trans layout') + parser.add_argument("-scale", action='store_true', default=False, + help='If set, plot the scale tensor for mfma_f8f6f4 instructions') ## LDS access parameters parser.add_argument("-lds_layout", type=str, default="none", choices=['swizzle', 'padding', 'none'], help='choose the LDS data layout') @@ -291,9 +301,7 @@ def parse_args(): help='choose LDS access mode') ## wmma instruction layout parameter parser.add_argument("-wave_size", type=int, default=32, choices=[32, 64], help='choose the wmma instruction mode') - parser.add_argument("-o", type=str, default="myplot", help='output pdf file name (without surfix)') - parser.add_argument("-mfmaTrans", action='store_true', default=False, help='If set, then use mfma.trans layout') parser.add_argument("-keep", action='store_true', default=False, help='If set, keep the generated .tex file') args = parser.parse_args() @@ -320,6 +328,7 @@ def main(): dtype_a = args.dtype_a dtype_b = args.dtype_b trans = 1 if args.mfmaTrans else 0 + scale = 1 if args.scale else 0 ofilename = args.o keepSrc = args.keep @@ -360,15 +369,19 @@ def main(): kpack = 1 CBSZ = matrixFormatTable[dtype_b] if trans else matrixFormatTable[dtype_a] BLGP = matrixFormatTable[dtype_a] if trans else matrixFormatTable[dtype_b] - mfma_inst_str = f"mfma_f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_f8f6f4" + scale_str = 'scale_' if scale else '' + mfma_inst_str = f"mfma_{scale_str}f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_f8f6f4" isMixed864 = True + plot_scale = scale else: kDim = kWidth * kGroup * 64 / mfmaNonKDim assert K != 0 and K % kDim == 0, f"one mfma instruction requires {kDim:.0f} elements along k dim but BLOCK_K = {K}" - mfma_inst_str, kpack, CBSZ, BLGP = checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtype_a, dtype_b, trans) + mfma_inst_str, kpack, CBSZ, BLGP, plot_scale = checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtype_a, + dtype_b, trans, scale) isMixed864 = False flag = '' if CBSZ == -1 else f" with {CBSZ=},{BLGP=}" - print(f"MFMA: {mfma_inst_str} x {kpack}{flag}", end="") + scale_info = f" (scale is not supported hence ignored)" if (scale and not plot_scale) else '' + print(f"MFMA: {mfma_inst_str} x {kpack}{flag}{scale_info}", end="") mfma_inst_str = mfma_inst_str.replace("_", "\\_") mfma_inst_str = mfma_inst_str + flag if kpack == 2: @@ -396,7 +409,7 @@ def main(): warpsPerCTA, order) draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup, dtype_a, - dtype_b, mfma_inst_str, kpack, isMixed864) + dtype_b, mfma_inst_str, kpack, isMixed864, plot_scale) draw_lds_str = draw_lds_access_cmd(M, K, kWidth, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp) diff --git a/python/perf-kernels/tools/plot-layout/preamble.tex b/python/perf-kernels/tools/plot-layout/preamble.tex index 1c9d7e480446..387f08d3c427 100644 --- a/python/perf-kernels/tools/plot-layout/preamble.tex +++ b/python/perf-kernels/tools/plot-layout/preamble.tex @@ -1,4 +1,4 @@ -\documentclass[tikz, border=1mm, dvipsnames]{standalone} +\documentclass[tikz, border=1mm, dvipsnames, x11names]{standalone} \usepackage{ifthen} \usepackage{tikz} \usetikzlibrary{arrows.meta,arrows} @@ -6,16 +6,20 @@ \usetikzlibrary{calc, quotes} \usetikzlibrary{patterns} \usepackage{xparse} +\definecolor{RoyalPurple}{HTML}{CC79A7} +\definecolor{CrimsonRed}{HTML}{D41159} +\definecolor{Gold}{HTML}{F1C40F} +\definecolor{DeepViolet}{HTML}{7E3F8F} \newcommand{\Colors}{{ - "red", - "YellowGreen", - "blue", - "Maroon", + "SkyBlue", "orange", - "cyan", - "magenta", - "brown", + "ForestGreen", + "RoyalPurple", + "CrimsonRed", "teal", + "Gold", + "DeepViolet", + "cyan", "purple", "gray", "Green", From 210b232242548f185fe97a33f1eb6a2997feaeb4 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 25 Dec 2024 23:38:11 -0600 Subject: [PATCH 10/11] [NFC] Fix format --- python/perf-kernels/tools/plot-layout/dotLayout.tex | 4 ++-- python/perf-kernels/tools/plot-layout/plot_layout.py | 2 +- python/perf-kernels/tools/plot-layout/wmmaLayout.tex | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/perf-kernels/tools/plot-layout/dotLayout.tex b/python/perf-kernels/tools/plot-layout/dotLayout.tex index b8a259b60231..633d4af01023 100644 --- a/python/perf-kernels/tools/plot-layout/dotLayout.tex +++ b/python/perf-kernels/tools/plot-layout/dotLayout.tex @@ -334,7 +334,7 @@ %% #4: dtype_b %% #5: outType %% #6: scaling: if set, draw scaling tensors - + \pgfmathsetmacro{\mfmaNonKDim}{#1} \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} \pgfmathsetmacro{\mfmaTrans}{#2} @@ -441,4 +441,4 @@ \node [scale=.8*\scale, above] at ($(vec TL)+(.5*4*\elem, 0)$) {vec=4$\times$\outType}; \node [scale=.6*\scale, above, align=center] at ($(block TL)+(.5*\mfmaNonKDim*\elem, 0)$) {mfmaLayout\\trans=True}; } -} \ No newline at end of file +} diff --git a/python/perf-kernels/tools/plot-layout/plot_layout.py b/python/perf-kernels/tools/plot-layout/plot_layout.py index 9833aedf28c1..b4e9c2d8316d 100644 --- a/python/perf-kernels/tools/plot-layout/plot_layout.py +++ b/python/perf-kernels/tools/plot-layout/plot_layout.py @@ -380,7 +380,7 @@ def main(): dtype_b, trans, scale) isMixed864 = False flag = '' if CBSZ == -1 else f" with {CBSZ=},{BLGP=}" - scale_info = f" (scale is not supported hence ignored)" if (scale and not plot_scale) else '' + scale_info = " (scale is not supported hence ignored)" if (scale and not plot_scale) else '' print(f"MFMA: {mfma_inst_str} x {kpack}{flag}{scale_info}", end="") mfma_inst_str = mfma_inst_str.replace("_", "\\_") mfma_inst_str = mfma_inst_str + flag diff --git a/python/perf-kernels/tools/plot-layout/wmmaLayout.tex b/python/perf-kernels/tools/plot-layout/wmmaLayout.tex index 54141b4928cc..25d459a1d0dd 100644 --- a/python/perf-kernels/tools/plot-layout/wmmaLayout.tex +++ b/python/perf-kernels/tools/plot-layout/wmmaLayout.tex @@ -118,4 +118,4 @@ \node [scale=.8*\scale, rotate=-90] (n dim) at ($(C TL)+(16*\elem+\gap, -8*\elem)$) {16}; \draw [->, >=stealth] (n dim.west) -- ($(C TL)+(16*\elem+\gap, 0)$); \draw [->, >=stealth] (n dim.east) -- ($(C TL)+(16*\elem+\gap, -16*\elem)$); -} \ No newline at end of file +} From fd5c641afbe777dc9a29ba6094aee7dd1549af99 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Mon, 30 Dec 2024 00:05:30 -0600 Subject: [PATCH 11/11] [API change] Refactor tensor and LDS layout - Support data types - Support both 32 and 64 banks - Still working on LDS accesses --- .../perf-kernels/tools/plot-layout/README.md | 20 +- .../tools/plot-layout/ldsLayout.tex | 220 +++++++++++------- .../tools/plot-layout/plot_layout.py | 67 ++++-- .../tools/plot-layout/preamble.tex | 1 + 4 files changed, 202 insertions(+), 106 deletions(-) diff --git a/python/perf-kernels/tools/plot-layout/README.md b/python/perf-kernels/tools/plot-layout/README.md index a41a74a29c28..bbeb485ba524 100644 --- a/python/perf-kernels/tools/plot-layout/README.md +++ b/python/perf-kernels/tools/plot-layout/README.md @@ -7,7 +7,7 @@ Here is the help info from the script. >$ python3 plot_layout.py -h usage: Draw triton layouts [-h] [-tensorShape TENSORSHAPE TENSORSHAPE] [-dotShape DOTSHAPE DOTSHAPE DOTSHAPE] [-plot {blocked,dot,wmma,lds}] [-dim0 DIM0] [-dim1 DIM1] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD] [-threadsPerWarp THREADSPERWARP THREADSPERWARP] [-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-nonKDim {16,32}] [-kWidth {4,8,16,32}] [-kGroup {1,2}] [-dtype_a {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] [-dtype_b {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] [-mfmaTrans] [-scale] - [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] [-wave_size {32,64}] [-o O] [-keep] + [-banks {32,64}] [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] [-wave_size {32,64}] [-o O] [-keep] options: -h, --help show this help message and exit @@ -32,6 +32,7 @@ options: element type of operand B -mfmaTrans If set, then use mfma.trans layout -scale If set, plot the scale tensor for mfma_f8f6f4 instructions + -banks {32,64} choose the number of banks in LDS -lds_layout {swizzle,padding,none} choose the LDS data layout -lds_access {read,write,none} @@ -121,20 +122,21 @@ Notes Examples: ```bash -python3 plot_layout.py -plot lds -lds_layout none -lds_access none -shape 128 128 64 -kWidth 8 +python3 plot_layout.py -plot lds -lds_layout none -lds_access none -tensorShape 128 128 -kWidth 8 +python3 plot_layout.py -plot lds -lds_layout none -lds_access none -tensorShape 128 128 -kWidth 32 -dtype_a f4 +python3 plot_layout.py -plot lds -lds_layout none -lds_access none -tensorShape 128 128 -kWidth 16 -dtype_a fp8 -banks 64 +python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access none -tensorShape 128 128 -kWidth 16 -dtype_a fp8 -banks 64 ``` Knobs -- `kWidth` here means the vector size when accessing LDS +- `kWidth`: the vector size (in unit of elements) when accessing LDS +- `banks`: the number of banks in LDS. (64 for gfx950, 32 for pre-gfx950) - Three options for `-lds_layout`: - `none`: no swizzling, no padding - - `padding`: padding at every 128B - - `swizzling`: apply the swizzling pattern, which is derived from tensor shape and kWidth. + - `swizzle`: apply the swizzling pattern, which is derived from tensor shape and kWidth. + - `padding`: tbd - Three options for `-lds_access`: - `none`: do not plot access pattern - `read`: plot accessed elements during ds_read - - `write`: plot accessed elements during ds_write. Note that this needs some infomation from + - `write`: plot accessed elements during ds_write. Note that this needs some information from global load. Therefore, we need to provide `-sizePerThread` and `-threadsPerWarp`. - -Notes -- This mode is rarely used. If you have any questions, please contact Lixun Zhang directly. diff --git a/python/perf-kernels/tools/plot-layout/ldsLayout.tex b/python/perf-kernels/tools/plot-layout/ldsLayout.tex index 393c709a29f5..15a66f54a090 100644 --- a/python/perf-kernels/tools/plot-layout/ldsLayout.tex +++ b/python/perf-kernels/tools/plot-layout/ldsLayout.tex @@ -54,55 +54,77 @@ %% Draw tensor layout in global memory without any swizzling %% %% TL: pre defined top-left coordinates of the tensor in global memory - %% \elem: per defined variable + %% \elemH: The height of each element + %% \bsize: The width of each byte %% \Colors: a pre defined array of 16 colors %% %% The following arguments are also expected to be pre defined %% #1: M %% #2: K - %% #3: vec: number of elements in a group - - \pgfmathsetmacro{\numVecK}{\K/\vec} - \pgfmathsetmacro{\maxVecId}{16*\numVecK-1} - \pgfmathsetmacro{\drawM}{20} - - %% Draw the tensor, but only draw 32 rows - \draw (TL) rectangle ++(\K*\elem, -\drawM*\elem); - %% Draw detailed vec view of the tensor - \foreach \vecId in {0,...,\maxVecId}{ - - \pgfmathsetmacro{\vecCoordM}{int(\vecId/\numVecK)} - \pgfmathsetmacro{\vecCoordK}{mod(\vecId,\numVecK)} - \coordinate (vec TL) at ($(TL)+(\vecCoordK*\vec*\elem, -\vecCoordM*\elem)$); - - \pgfmathsetmacro{\colorIdxK}{int(mod(\vecCoordK,16))} - \pgfmathsetmacro{\colorIdxM}{mod(\vecCoordM,16)} - \pgfmathsetmacro{\vecColor}{\Colors[\colorIdxK]} - \pgfmathsetmacro{\ratio}{100-floor(\vecCoordK/16)*40} + %% #3: vecInBytes: number of bytes in a group + + \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} + \pgfmathsetmacro{\maxGpId}{\groups-1} + \pgfmathsetmacro{\maxRowId}{\mfmaNonKDim-1} + \pgfmathsetmacro{\elemsPerVec}{\vec} + + \pgfmathsetmacro{\vecInK}{\K/\elemsPerVec} + \pgfmathsetmacro{\maxKVecId}{\vecInK-1} + + \foreach \gp in {0,...,\maxKVecId}{ + \pgfmathsetmacro{\gpCol}{int(mod(\gp, 16))} + \pgfmathsetmacro{\vecColor}{\Colors[\gpCol]} + \pgfmathsetmacro{\kStart}{int(\gp*\elemsPerVec)} + \pgfmathsetmacro{\kEnd}{int(\kStart+\elemsPerVec-1)} + \foreach \row in {0,...,\maxRowId}{ + \coordinate (vec TL) at ($(TL)+(\gp*\vecInBytes*\bsize, -\row*\elemH)$); + \draw [ultra thin, fill=\vecColor] (vec TL) rectangle ++(\vecInBytes*\bsize, -\elemH) + node [pos=.5, scale=.6*\bankLabelScale*\scale, white] {m\row,k\kStart:\kEnd}; + } + } + %% M and K dim + \def\gap{3} + \pgfmathsetmacro{\drawM}{\mfmaNonKDim*\elemH+\gap*\elemH} + \pgfmathsetmacro{\drawK}{\vecInK*\vecInBytes*\bsize} + \draw [ultra thick] (TL) rectangle ++(\drawK, -\drawM); + %\node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM)$) {BLOCK\_M=\M}; + \node [scale=\scale, above] at ($(TL)+(.5*\drawK, 0)$) {Tile:\M$\times$\K$\times$\dtype}; + \node [scale=\scale, rotate=90] at ($(TL)+(0.5*\K*\bytesPerElem*\bsize, -\drawM+.5*\gap*\elemH)$) {$\ldots$}; + %\node [scale=\scale, below] at ($(TL)+(0.5*\drawK, -\drawM)$) {Tile:\M$\times$\K$\times$\dtype}; +} - \draw [ultra thin, fill=\vecColor!\ratio!white] (vec TL) rectangle ++(\vec*\elem, -\elem) - node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; +\newcommand{\drawLDSDiagram}[1]{ + %% + %% Draw the diagram of LDS without any data + %% + %% Pre-defined variables + %% TL: top-left coordinates of first elements in LDSaccess + %% bsize: size of a byte + %% mfmaNonKDim + %% K + %% bytesPerElem + %% + %% #1: number of banks + + \pgfmathsetmacro{\banks}{#1} + \pgfmathsetmacro{\maxBankId}{\banks-1} + \pgfmathsetmacro{\tensorHeight}{\mfmaNonKDim*\K*\bytesPerElem/4/\banks*\elemH} + \def\gapT{4} + \def\gapB{2} + \pgfmathsetmacro{\LDSHeight}{\tensorHeight+\gapT*\elemH+\gapB*\elemH} + \coordinate (LDS TL) at ($(TL)+(0, \gapT*\elemH)$); + \foreach \bank in {0,...,\maxBankId}{ + \coordinate (bank TL) at ($(LDS TL)+(\bank*4*\bsize, 0)$); + \draw [ultra thick] (bank TL) rectangle ++(4*\bsize, -\LDSHeight) + node [scale=.6*\bankLabelScale*\scale, pos=0, below right, align=center] {bank\\\bank}; + \node [scale=0.8*\bankLabelScale*\scale, rotate=90] at ($(TL)+(2*\bsize+\bank*4*\bsize, -\tensorHeight-0.5*\gapB*\elemH)$) {$\ldots$}; } - %% M and K dim - \node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM*\elem-8*\elem)$) {M=\M}; - \node [scale=.8*\scale, left] at ($(TL)+(0, -.5*16*\elem)$) {16}; - \node [scale=\scale, above] at ($(TL)+(.5*\K*\elem, 0)$) {K=\K}; - %% label for vecSize - \def\vecR{1.5} - \coordinate (vec TL) at ($(TL)+(-.25*\vec*\elem, 3*\elem*\vecR)$); - \pgfmathsetmacro{\maxVec}{\vec-1} - \foreach \vecId in {0,...,\maxVec}{ - \draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR); - } - \draw [densely dotted] (TL) -- ($(vec TL)+(0, -\elem*\vecR)$); - \draw [densely dotted] ($(TL)+(\vec*\elem, 0)$) -- ($(vec TL)+(\vec*\elem*\vecR, -\elem*\vecR)$); - \node [scale=.8*\scale, above] at ($(vec TL)+(.5*\vec*\elem*\vecR, 0)$) {vec=\vec}; + \node [scale=\scale, above] at ($(TL)+(0.5*\banks*4*\bsize, 4*\elemH)$) {LDS \banks\ banks}; } - -\newcommand{\drawLDSLayoutTritonSwizzling}[2]{ +\newcommand{\drawLDSLayoutTritonSwizzling}[3]{ %% %% Draw tensor layout in LDS with swizzling %% @@ -119,6 +141,7 @@ %% 1 means optimal swizzling %% 2 means padding %% #2: access mode, 0 means draw nothing, 1 means ds_read, 2 means ds_write + %% #3: number of banks %% For ds_write access, the following variables are assumed to be pre defined %% \sizePerThreadK %% \sizePerThreadM @@ -127,25 +150,19 @@ \pgfmathsetmacro{\hasSwizzle}{#1} \pgfmathsetmacro{\accessMode}{#2} \pgfmathsetmacro{\numVecK}{\K/\vec} + \pgfmathsetmacro{\banks}{#3} + + \drawLDSDiagram{#3} %% Assuming fp16 data type - \pgfmathsetmacro{\LDSK}{64} + \pgfmathsetmacro{\LDSK}{int(\banks*4/\bytesPerElem)} \pgfmathsetmacro{\numLDSVec}{\LDSK/\vec} \pgfmathsetmacro{\swizzleK}{max(\LDSK, \K)} \pgfmathsetmacro{\LDSM}{int(\M/\LDSK*\K)} - - \ifthenelse{\accessMode = 2}{ - %% \accessMode == 2, draw 8 rows - \pgfmathsetmacro{\maxVecId}{8*\numVecK-1} - \pgfmathsetmacro{\drawM}{8*\K/\LDSK+4} - }{ - %% \accessMode == 0 or 1, draw 16 rows - \pgfmathsetmacro{\maxVecId}{16*\numVecK-1} - \pgfmathsetmacro{\drawM}{16*\K/\LDSK+4} - } + \pgfmathsetmacro{\maxKVecId}{\K/\vec-1} %% Parameters used for swizzling - \pgfmathsetmacro{\numVecSwizzleK}{\swizzleK/\vec} + %% perPhase = ceil(LDSK / K) %% The number of the rows of the tensor that can share the same swizzling pattern \pgfmathsetmacro{\perPhase}{ceil(\LDSK/\K)} @@ -156,12 +173,49 @@ }{ %% When vec is small enough, we want 16/perPhase different swizzling patterns %% When vec is large, we can only have 64 / \vec different swizzling pattern at most - \pgfmathsetmacro{\maxPhase}{min(16/\perPhase,64/\vec)} + \pgfmathsetmacro{\maxPhase}{min(16/\perPhase,\banks*4/\bytesPerElem/\vec)} } - %% Draw the LDS - \draw (TL) rectangle ++(\LDSK*\elem, -\drawM*\elem); + %% Draw the vectors according to the swizzling pattern + \foreach \gp in {0,...,\maxKVecId}{ + \pgfmathsetmacro{\gpCol}{int(mod(\gp, 16))} + \pgfmathsetmacro{\vecColor}{\Colors[\gpCol]} + \pgfmathsetmacro{\kStart}{int(\gp*\elemsPerVec)} + \pgfmathsetmacro{\kEnd}{int(\kStart+\elemsPerVec-1)} + \foreach \row in {0,...,\maxRowId}{ + %% Compute some info of the current vec + \pgfmathsetmacro{\offVec}{\row*\K/\vec+\gp} %% global offset in unit of vec + \pgfmathsetmacro{\LDSRow}{int(\offVec/\numLDSVec)} %% which row of LDS + \pgfmathsetmacro{\LDSVecRaw}{int(mod(\offVec,\numLDSVec))} %% offset in the current LDS row in the unit of vec + \pgfmathsetmacro{\phaseRaw}{int(\row/\perPhase)} + \pgfmathsetmacro{\phase}{int(mod(\phaseRaw, \maxPhase))} + \pgfmathsetmacro{\LDSVec}{\bitwiseXor{\LDSVecRaw}{\phase}} + + \coordinate (vec TL) at ($(TL)+(\LDSVec*\vecInBytes*\bsize, -\LDSRow*\elemH)$); + \draw [ultra thin, fill=\vecColor] (vec TL) rectangle ++(\vecInBytes*\bsize, -\elemH) + node [pos=.5, scale=.6*\bankLabelScale*\scale, white] {m\row,k\kStart:\kEnd}; + + %% draw phase of each LDS row + \pgfmathsetmacro{\lastVecId}{\numLDSVec-1} + \ifthenelse{\LDSVec=\lastVecId}{ + \draw [ultra thin] ($(vec TL)+(\vec*\bytesPerElem*\bsize, -.5*\bsize)$) -- ++(\elemH, 0) + node [scale=0.6*\bankLabelScale*\scale, right] {\phase}; + }{} + } + } + \node [scale=0.6*\bankLabelScale*\scale, above right] at($(TL)+(\banks*4*\bsize, 0)$) {phase}; + %% Start old code + \ifthenelse{\accessMode = 2}{ + %% \accessMode == 2, draw 8 rows + \pgfmathsetmacro{\maxVecId}{8*\numVecK-1} + \pgfmathsetmacro{\drawM}{8*\K/\LDSK+4} + }{ + %% \accessMode == 0 or 1, draw 16 rows + \pgfmathsetmacro{\maxVecId}{16*\numVecK-1} + \pgfmathsetmacro{\drawM}{16*\K/\LDSK+4} + } + \pgfmathsetmacro{\numVecSwizzleK}{\swizzleK/\vec} %% Draw detailed vec view of LDS \foreach \vecId in {0,...,\maxVecId}{ \pgfmathsetmacro{\vecCoordM}{int(\vecId/\numVecK)} @@ -206,17 +260,17 @@ \pgfmathsetmacro{\tailBankId}{0} } - \ifthenelse{\hasSwizzle = 2 \AND \tailBankId > 31}{ - \pgfmathsetmacro{\nextBanks}{\tailBankId-31} - \pgfmathsetmacro{\leftBanks}{\vec/2 - \nextBanks} - \draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\leftBanks*2*\elem, -\elem) - node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; - \draw [ultra thin, fill=\vecColor!\ratio!white] ($(TL)+(0, -\vecPadM*\elem-\elem)$) - rectangle ++(\nextBanks*2*\elem, -\elem) node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; - }{ - \draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\vec*\elem, -\elem) - node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; - } + %\ifthenelse{\hasSwizzle = 2 \AND \tailBankId > 31}{ + % \pgfmathsetmacro{\nextBanks}{\tailBankId-31} + % \pgfmathsetmacro{\leftBanks}{\vec/2 - \nextBanks} + % \draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\leftBanks*2*\elem, -\elem) + % node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; + % \draw [ultra thin, fill=\vecColor!\ratio!white] ($(TL)+(0, -\vecPadM*\elem-\elem)$) + % rectangle ++(\nextBanks*2*\elem, -\elem) node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; + %}{ + % \draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\vec*\elem, -\elem) + % node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; + %} %% ds_read %% Highlight the elements the first 16 threads access in the first cycle @@ -246,29 +300,29 @@ }{} %% Label the phase of each line if swizzling is used - \ifthenelse{\hasSwizzle = 2}{}{ - \pgfmathsetmacro{\lastVecId}{int(64/\vec)-1} - \ifthenelse{\vecLDSKSwizzled = \lastVecId}{ - \draw [ultra thin] ($(new vec TL)+(\vec*\elem, -.5*\elem)$) -- ++(\elem, 0) - node [scale=.6*\scale, right] {\phase}; - }{} - } + %\ifthenelse{\hasSwizzle = 2}{}{ + % \pgfmathsetmacro{\lastVecId}{int(64/\vec)-1} + % \ifthenelse{\vecLDSKSwizzled = \lastVecId}{ + % \draw [ultra thin] ($(new vec TL)+(\vec*\elem, -.5*\elem)$) -- ++(\elem, 0) + % node [scale=.6*\scale, right] {\phase}; + % }{} + %} } %% Draw boundary of 32 banks %% Assume fp16 data type - \foreach \bank in {0,...,31}{ - \draw [ultra thin, gray] ($(TL)+(\bank*2*\elem, 0)$) -- ++(0, 2*\elem) - node [scale=.6*\scale, right, black] {\bank}; - } - \draw [ultra thin, gray] ($(TL)+(32*2*\elem, 0)$) -- ++(0, 2*\elem); - \node [scale=.6*\scale, left, black] at ($(TL)+(0, 2*\elem)$) {bank id}; - - \node [scale=\scale, above] at ($(TL)+(.5*\LDSK*\elem, 3*\elem)$) {LDS 32 banks}; - \node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM*\elem)$) {LDSM=\LDSM}; + %\foreach \bank in {0,...,31}{ + % \draw [ultra thin, gray] ($(TL)+(\bank*2*\elem, 0)$) -- ++(0, 2*\elem) + % node [scale=.6*\scale, right, black] {\bank}; + %} + %\draw [ultra thin, gray] ($(TL)+(32*2*\elem, 0)$) -- ++(0, 2*\elem); + %\node [scale=.6*\scale, left, black] at ($(TL)+(0, 2*\elem)$) {bank id}; +% + %\node [scale=\scale, above] at ($(TL)+(.5*\LDSK*\elem, 3*\elem)$) {LDS \banks banks}; + %\node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM*\elem)$) {LDSM=\LDSM}; %% label phase if swizzling is used - \ifthenelse{\hasSwizzle = 2}{}{ - \node [scale=.6*\scale, above right] at($(TL)+(32*2*\elem, 0)$) {phase}; - } + %\ifthenelse{\hasSwizzle = 2}{}{ + % \node [scale=.6*\scale, above right] at($(TL)+(32*2*\elem, 0)$) {phase}; + %} } diff --git a/python/perf-kernels/tools/plot-layout/plot_layout.py b/python/perf-kernels/tools/plot-layout/plot_layout.py index b4e9c2d8316d..b54afd597636 100644 --- a/python/perf-kernels/tools/plot-layout/plot_layout.py +++ b/python/perf-kernels/tools/plot-layout/plot_layout.py @@ -83,7 +83,30 @@ def draw_blocked_layout_cmd(dim0, dim1, dim0Name, dim1Name, sizePerThread, threa \\end{{document}}''' -def draw_lds_access_cmd(M, K, kWidth, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp): +def typeToBytes(dtype): + if dtype == 'bf16' or dtype == 'fp16': + return 2 + if dtype == 'bf8' or dtype == 'fp8' or dtype == 'i8': + return 1 + if dtype == 'f4': + return 0.5 + if dtype == 'fp6' or dtype == 'bf6': + return 0.75 + + +def maxKDimInBytes(dtype, mfmaNonKDim, kWidth): + groups = 64 / mfmaNonKDim + if (dtype == 'bf8' or dtype == 'fp8') and kWidth == 16: + groups *= 2 + return groups * kWidth * typeToBytes(dtype) + + +def calcPerPhase(banks, dtype, K): + bytesPerBank = 4 + return max(banks * bytesPerBank / (K * typeToBytes(dtype)), 1) + + +def draw_lds_access_cmd(M, K, kWidth, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp, dtype, mfmaNonKDim, banks): if ldsLayout == 'swizzle': hasSwizzle = 1 elif ldsLayout == 'padding': @@ -98,24 +121,41 @@ def draw_lds_access_cmd(M, K, kWidth, ldsLayout, ldsAccess, sizePerThread, threa else: accessMode = 0 + elemTypeInBytes = typeToBytes(dtype) + dimKInBytes = K * elemTypeInBytes + mfmaKDimInBytes = min(dimKInBytes, maxKDimInBytes(dtype, mfmaNonKDim, kWidth)) + vecInBytes = kWidth * elemTypeInBytes + + bsize = 0.12 + bankLabelScale = bsize / 0.15 + return f'''\\begin{{document}} \\begin{{tikzpicture}} \\def\\scale{{1}} \\def\\M{{{M}}} \\def\\K{{{K}}} + \\def\\mfmaK{{{mfmaKDimInBytes}}} \\def\\vec{{{kWidth}}} + \\def\\vecInBytes{{{vecInBytes}}} + \\def\\bytesPerElem{{{elemTypeInBytes}}} \\def\\hasSwizzle{{{hasSwizzle}}} \\def\\accessMode{{{accessMode}}} + \\def\\mfmaNonKDim{{{mfmaNonKDim}}} + \\def\\dtype{{{dtype}}} + \\def\\sizePerThreadK{{{sizePerThread[1]}}} \\def\\sizePerThreadM{{{sizePerThread[0]}}} \\def\\threadsPerWarpK{{{threadsPerWarp[1]}}} + \\def\\elemH{{0.18}} \\def\\elem{{0.18}} + \\def\\bsize{{{bsize}}} + \\def\\bankLabelScale{{{bankLabelScale}}} \\coordinate (TL) at (0,0); \\drawTensorLayoutGlobalMem - \\coordinate (TL) at ($(TL)+(0, -24*\\elem-10*\\elem)$); - \\drawLDSLayoutTritonSwizzling{{\\hasSwizzle}}{{\\accessMode}} + \\coordinate (TL) at ($(TL)+(0, -\drawM-8*\\elemH)$); + \\drawLDSLayoutTritonSwizzling{{\\hasSwizzle}}{{\\accessMode}}{{{banks}}} \\end{{tikzpicture}} \\end{{document}}''' @@ -295,6 +335,7 @@ def parse_args(): parser.add_argument("-scale", action='store_true', default=False, help='If set, plot the scale tensor for mfma_f8f6f4 instructions') ## LDS access parameters + parser.add_argument("-banks", type=int, default=32, choices=[32, 64], help='choose the number of banks in LDS') parser.add_argument("-lds_layout", type=str, default="none", choices=['swizzle', 'padding', 'none'], help='choose the LDS data layout') parser.add_argument("-lds_access", type=str, default="none", choices=['read', 'write', 'none'], @@ -334,6 +375,7 @@ def main(): ldsLayout = args.lds_layout ldsAccess = args.lds_access + banks = args.banks waveSize = args.wave_size @@ -397,7 +439,7 @@ def main(): print("") if plot_mode == 'lds': - print(f"Plotting LDS access for tensor M={M},K={K} with vec={kWidth}") + print(f"Plotting LDS access for tensor M={dim0},K={dim1} with vec={kWidth}") if ldsAccess == 'write': print(f"sizePerThread={sizePerThread}, threadsPerWarp={threadsPerWarp}") @@ -405,27 +447,24 @@ def main(): with open("preamble.tex") as file: preamble = file.read() - draw_blockedLayout_str = draw_blocked_layout_cmd(dim0, dim1, dim0Name, dim1Name, sizePerThread, threadsPerWarp, - warpsPerCTA, order) - - draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup, dtype_a, - dtype_b, mfma_inst_str, kpack, isMixed864, plot_scale) - - draw_lds_str = draw_lds_access_cmd(M, K, kWidth, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp) - - draw_wmma_str = draw_wmma_instr_cmd(waveSize) - f_plot.write(preamble) if plot_mode == 'blocked': + draw_blockedLayout_str = draw_blocked_layout_cmd(dim0, dim1, dim0Name, dim1Name, sizePerThread, + threadsPerWarp, warpsPerCTA, order) f_plot.write("\input{blockedLayout}\n") f_plot.write(draw_blockedLayout_str) elif plot_mode == 'dot': + draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kWidth, kGroup, dtype_a, + dtype_b, mfma_inst_str, kpack, isMixed864, plot_scale) f_plot.write("\input{dotLayout}\n") f_plot.write(draw_dotLayout_str) elif plot_mode == 'lds': + draw_lds_str = draw_lds_access_cmd(dim0, dim1, kWidth, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp, + dtype_a, mfmaNonKDim, banks) f_plot.write("\input{ldsLayout}\n") f_plot.write(draw_lds_str) elif plot_mode == 'wmma': + draw_wmma_str = draw_wmma_instr_cmd(waveSize) f_plot.write("\input{wmmaLayout}\n") f_plot.write(draw_wmma_str) diff --git a/python/perf-kernels/tools/plot-layout/preamble.tex b/python/perf-kernels/tools/plot-layout/preamble.tex index 387f08d3c427..b016b1123391 100644 --- a/python/perf-kernels/tools/plot-layout/preamble.tex +++ b/python/perf-kernels/tools/plot-layout/preamble.tex @@ -6,6 +6,7 @@ \usetikzlibrary{calc, quotes} \usetikzlibrary{patterns} \usepackage{xparse} +\usepackage{libertinus} \definecolor{RoyalPurple}{HTML}{CC79A7} \definecolor{CrimsonRed}{HTML}{D41159} \definecolor{Gold}{HTML}{F1C40F}