From aae1352058dad48759a1bd11cef1d3e3348fd1bd Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Sun, 22 Dec 2024 09:01:09 -0800 Subject: [PATCH] Redo expansion --- iree/turbine/kernel/ops/wave_ops.py | 71 +- iree/turbine/kernel/wave/expansion.py | 816 ------------------ .../kernel/wave/expansion/expansion.py | 698 +++++++++++++++ .../kernel/wave/expansion/expansion_utils.py | 253 ++++++ .../wave/scheduling/loop_reconstruction.py | 8 +- iree/turbine/kernel/wave/utils.py | 42 +- iree/turbine/kernel/wave/wave.py | 6 +- lit_tests/kernel/wave/attention.py | 33 +- lit_tests/kernel/wave/barriers.py | 173 ++-- lit_tests/kernel/wave/codegen.py | 8 +- lit_tests/kernel/wave/expansion.py | 718 +++++++-------- .../kernel/wave/index_sequence_analysis.py | 295 +++---- .../kernel/wave/minimize_global_loads.py | 211 ++--- lit_tests/kernel/wave/scheduling.py | 219 ++--- 14 files changed, 1899 insertions(+), 1652 deletions(-) delete mode 100644 iree/turbine/kernel/wave/expansion.py create mode 100644 iree/turbine/kernel/wave/expansion/expansion.py create mode 100644 iree/turbine/kernel/wave/expansion/expansion_utils.py diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index b1fbc7ff..2f5680a0 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -404,6 +404,18 @@ def update_arg(self, idx_or_name: int | str, value: CustomOp | fx.Node): else: raise IndexError("Index out of range") + def copy_core_attributes(self, new_node: fx.Node): + """ + Copy core attributes from the current node to the new node. + """ + core_attributes = ["index", "vector_shapes", "reduction_dim", "iter_idx"] + for attr_name in core_attributes: + if hasattr(self.fx_node, attr_name): + attr = getattr(self.fx_node, attr_name) + if attr_name == "index": + attr = copy.deepcopy(attr) + setattr(new_node, attr_name, attr) + def copy( self, new_name: Optional[str] = None, @@ -421,14 +433,9 @@ def copy( new_node = graph.node_copy(self.fx_node, arg_transform=arg_transform) new_node.tkw_op = self new_node.tkw_op_name = self.tkw_op_name - if hasattr(self.fx_node, "index"): - new_node.index = copy.deepcopy(self.fx_node.index) + self.copy_core_attributes(new_node) if new_name: new_node.name = new_name - if hasattr(self.fx_node, "vector_shapes"): - new_node.vector_shapes = self.fx_node.vector_shapes - if hasattr(self.fx_node, "reduction_dim"): - new_node.reduction_dim = self.fx_node.reduction_dim return get_custom(new_node) def replace_all_uses_with(self, new_node: CustomOp | fx.Node): @@ -437,6 +444,31 @@ def replace_all_uses_with(self, new_node: CustomOp | fx.Node): new_node = new_node.fx_node self.fx_node.replace_all_uses_with(new_node) + def replace_all_uses_with_except( + self, new_node: CustomOp | fx.Node, except_nodes: list[CustomOp] + ): + """Replace all uses of the current node with the new node except for the nodes in except_nodes.""" + for user in self.users: + if user in except_nodes: + continue + indices = user.get_node_arg_index(self) + if not isinstance(indices, Sequence): + indices = [indices] + for idx in indices: + if isinstance(user.node_args[idx], Sequence): + sub_idx = user.node_args[idx].index(self) + new_nodes = [ + ( + user.node_args[idx][x].fx_node + if x != sub_idx + else new_node.fx_node + ) + for x in range(len(user.node_args[idx])) + ] + user.update_arg(idx, new_nodes) + else: + user.update_arg(idx, new_node.fx_node) + def erase(self): """Erase the current node from the graph where it exists.""" assert ( @@ -470,7 +502,18 @@ def node_args(self) -> dict[int, Any]: return custom_args def get_node_arg_index(self, arg: CustomOp) -> Optional[CustomOp | list[CustomOp]]: - return next(key for key, value in self.node_args.items() if value == arg) + keys = [] + for key, value in self.node_args.items(): + if isinstance(value, Sequence): + if arg in value: + keys.append(key) + elif value == arg: + keys.append(key) + if not keys: + return None + if len(keys) == 1: + return keys[0] + return keys @property def users(self) -> list[Any]: @@ -785,9 +828,15 @@ class IterArg(Placeholder): def parent_op(self): return get_custom(self.graph.parent_op) - def get_iter_idx(self): - src_reduction = self.parent_op() - return src_reduction.iter_args(self.graph).index(self.fx_node) + @property + def iter_idx(self): + if hasattr(self.fx_node, "iter_idx"): + return self.fx_node.iter_idx + return None + + @iter_idx.setter + def iter_idx(self, value): + self.fx_node.iter_idx = value # Ops modeling TKW operations in the kernel language @@ -1157,6 +1206,8 @@ def iter_args(self, graph: fx.Graph) -> list[fx.Node]: custom = get_custom(nested_node) if isinstance(custom, IterArg): iter_args.append(nested_node) + # Sort by iter_idx. + iter_args = sorted(iter_args, key=lambda x: get_custom(x).iter_idx) return iter_args def captured_vars(self, graph: fx.Graph) -> list[fx.Node]: diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py deleted file mode 100644 index e3926433..00000000 --- a/iree/turbine/kernel/wave/expansion.py +++ /dev/null @@ -1,816 +0,0 @@ -# Copyright 2024 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import itertools -import torch.fx as fx -from typing import Any, TypeAlias, Sequence, Type, Callable -from functools import partial - -from .symbolic_constraints import SymbolicAlias - -from .constraints import ( - Constraint, - HardwareConstraint, - WorkgroupConstraint, - TilingConstraint, -) -from ..ops.wave_ops import ( - Allocate, - BinaryPyOp, - CustomOp, - GetResult, - Getitem, - IterArg, - MMA, - Output, - Placeholder, - Read, - ReduceOp, - Reduction, - Reshape, - Write, - get_custom, -) -from .._support.indexing import IndexingContext, IndexSymbol -from ...support.logging import get_logger -from .._support.tracing import CapturedTrace -from .utils import ( - get_mma_dimensional_mapping, - specialize_index_sequence, - get_hardware_constraint, - get_workgroup_constraints, -) -from ..lang.global_symbols import * - -logger = get_logger("turbine.wave.expansion") -# This represents a mapping of a node + indexing + res_idx(output index for op with multiple results) -# of node into the dimensions to the corresponding expanded node in these specific dimensions. -# An example for a record in this map is (read_0_0_0, ((M,0),(N,0),(K,1), 0) -> read_0_0_1. -ExpandedNodeMap: TypeAlias = dict[ - tuple[CustomOp, tuple[tuple[IndexSymbol, int], int, ...]], CustomOp -] - - -def already_expanded_iter_arg(node: CustomOp, dims: dict[IndexSymbol, int]) -> bool: - return ( - hasattr(node.fx_node, "expanded_dims") - and isinstance(node, IterArg) - and ( - filter_and_zero_unselected_dims(dims, node.indexing_dims) - == node.fx_node.expanded_dims # type: ignore - ) - ) - - -def expansion_needed( - dims: dict[IndexSymbol, int], selection: Sequence[IndexSymbol] -) -> bool: - """Check if any of the dimensions in the selection are non-zero.""" - return any(dim[1] != 0 and dim[0] in selection for dim in dims.items()) - - -def filter_and_zero_unselected_dims( - dims: dict[IndexSymbol, int], selection: Sequence[IndexSymbol] -) -> dict[IndexSymbol, int]: - """ - Filters dimensions based on selection and sets unselected dimensions' values to zero. - """ - return {dim: val if dim in selection else 0 for dim, val in dims.items()} - - -def get_dim_combinations( - all_dims: dict[IndexSymbol, int], - selection: Sequence[IndexSymbol], -): - """ - Returns all combinations of sizes for the selected dimensions. - Other dimensions are clamped to 0. - """ - adjusted_dimension_sizes = [ - list(range(all_dims[dim])) if dim in selection else [0] for dim in all_dims - ] - return itertools.product(*adjusted_dimension_sizes) - - -def get_indexed_dims( - all_dims: dict[IndexSymbol, int], nodeOrDims: CustomOp | Sequence[IndexSymbol] -) -> tuple[tuple[IndexSymbol, int], ...]: - """ - Generates a tuple of (key, value) pairs from the provided dimensions. - If given a CustomOp instance, it uses its indexing_dims attribute. - """ - if isinstance(nodeOrDims, CustomOp): - nodeOrDims = nodeOrDims.indexing_dims - # Flatten dims for node with multiple values or expanded Reduction. - if all(isinstance(el, Sequence) for el in nodeOrDims): - flattened_dims = list(itertools.chain.from_iterable(nodeOrDims)) - flatten_dims_set = dict.fromkeys(flattened_dims) - nodeOrDims = list(flatten_dims_set) - return tuple((key, all_dims[key]) for key in nodeOrDims if key in all_dims) - - -def get_last(node_list: fx.graph._node_list) -> fx.Node: # type: ignore - """Get the last element of the fx node_list structure""" - return next(iter(reversed(node_list))) # type: ignore - - -def is_expandable(arg: Any) -> bool: - """Check if an argument is expandable.""" - if isinstance(arg, list): - return all(is_expandable(a) for a in arg) - # Placeholder nodes are only expanded if they are a reduction init arg - if isinstance(arg, Placeholder) and not isinstance(arg, IterArg): - return False - return isinstance(arg, CustomOp) - - -def expand_graph( - trace: CapturedTrace, - constraints_or_scaling: Sequence[Constraint] | dict[IndexSymbol, int], -): - """ - Create a graph that represents the expanded version of the wave function. - The expansion is done in the dimensions specified by the constraints. - """ - get_node_dim_scaling = partial(get_dim_scaling, constraints_or_scaling) - - # Start from the back and expand in the corresponding indexing dimensions of a node - # Then proceed to the operands - leaf_nodes: list[Type[CustomOp]] = [Write] - - # Some graphs may not have a write node, so we need to add the leaf nodes present in the - # graph, excluding output nodes. - all_fx_nodes_reversed = list(reversed(trace.get_root_graph().nodes)) - has_write = any( - isinstance(get_custom(fx_node), Write) for fx_node in all_fx_nodes_reversed - ) - if not has_write: - for node in (get_custom(fx_node) for fx_node in all_fx_nodes_reversed): - if isinstance(node, Output): - continue - leaf_nodes.append(node.__class__) - break - - expansion_context: ExpandedNodeMap = {} - for node in (get_custom(fx_node) for fx_node in all_fx_nodes_reversed): - - # Expansion begins at the leaf nodes - if node.__class__ not in leaf_nodes: - continue - - dim_scaling = get_node_dim_scaling(node) - for dim_combination in get_dim_combinations(dim_scaling, node.indexing_dims): - expand_dims = { - dim: val for dim, val in zip(dim_scaling.keys(), dim_combination) - } - logger.debug(f"Starting expansion at leaf:{node} in dims:{expand_dims}") - _expand_node( - node, - trace, - expand_dims, - dim_scaling, - expansion_context, - get_node_dim_scaling, - ) - - -def _expand_node( - node: CustomOp | list[CustomOp], - trace: CapturedTrace, - dim_query: dict[IndexSymbol, int], - dim_scaling: dict[IndexSymbol, int], - context: ExpandedNodeMap, - get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], - res_idx: int = 0, -) -> CustomOp: - """Expand a single node or list of nodes in specific dimensions and recursively proceed to its inputs.""" - if isinstance(node, list): - expanded_nodes = [] - for elem in node: - expanded_nodes.append( - _expand_node( - elem, - trace, - dim_query, - get_node_dim_scaling(elem), - context, - get_node_dim_scaling, - res_idx, - ).fx_node - ) - return expanded_nodes - # If we expanded a node in the same dimensions before, we can reuse it - if (node, get_indexed_dims(dim_query, node), res_idx) in context: - logger.debug(f"Already expanded node: {node} in {dim_query}") - return context[(node, get_indexed_dims(dim_query, node), res_idx)] - elif isinstance(node, MMA): - # Handle expansion of MMA nodes whose reduction dim is not the same as the reduction - # dim of the parent reduction op or when there is no parent reduction op. - has_parent_op = hasattr(node.graph, "parent_op") - reduction_axes_different = False - if has_parent_op: - reduction: Reduction = get_custom(node.graph.parent_op) - reduction_axes_different = reduction.axis != node.reduction_dim - parallel_dim_query = node.reduction_dim not in dim_query - if (not has_parent_op or reduction_axes_different) and parallel_dim_query: - return _expand_mma_reduction( - node, - trace, - dim_query, - dim_scaling, - context, - get_node_dim_scaling, - res_idx, - ) - elif isinstance(node, Reduction): - return _expand_reduction( - node, trace, dim_query, dim_scaling, context, get_node_dim_scaling, res_idx - ) - elif isinstance(node, Getitem): - res_idx = node.res_idx - elif isinstance(node, GetResult) and not isinstance(node, Getitem): - # The presence of a GetResult node indicates that the reduction has already - # been expanded. Simply return the corresponding node. - reduction = get_custom(node.value) - return context[(reduction, get_indexed_dims(dim_query, reduction), res_idx)] - elif isinstance(node, Allocate): - # Allocate nodes are not expanded. - return node - - # Filter out the dimensions that are not indexed by the node - restricted_dims = filter_and_zero_unselected_dims(dim_query, node.indexing_dims) - logger.debug(f"Expanding node: {node} in {restricted_dims}") - - # For iter args, we want to insert - if not hasattr(_expand_node, "last_expanded_iter_arg"): - _expand_node.last_expanded_iter_arg = None - - # Clone the node for the new expansion. The original node is reused for the - # case of all dimensions being zero. - if expansion_needed(restricted_dims, node.indexing_dims): - new_node = node.copy( - anchor=( - _expand_node.last_expanded_iter_arg - if isinstance(node, IterArg) - else None - ) - ) - else: - new_node = node - logger.debug(f"did not clone node: {node} in {restricted_dims}") - - if isinstance(node, IterArg): - _expand_node.last_expanded_iter_arg = new_node.fx_node - - new_node.expanded_dims = restricted_dims - new_node.fx_node.name = get_expanded_name(node, restricted_dims) - - # For reshapes, we need more explicit control over how the arguments are expanded. - if isinstance(new_node, Reshape): - _expand_reshape( - new_node, - trace, - dim_query, - dim_scaling, - context, - get_node_dim_scaling, - res_idx, - ) - context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node - return new_node - - # Proceed with expansion of the arguments - for i, arg in node.node_args.items(): - arg_list = arg - unpack = lambda x: x - if isinstance(arg, list): - if not all(is_expandable(a) for a in arg): - continue - else: - arg_list = [arg] - unpack = lambda x: x[0] - if not is_expandable(arg): - continue - - new_args = [] - for subarg in arg_list: - new_subarg = _expand_node( - subarg, - trace, - restricted_dims, - get_node_dim_scaling(subarg), - context, - get_node_dim_scaling, - res_idx, - ) - new_args.append(new_subarg.fx_node) - new_node.update_arg(i, unpack(new_args)) - - context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node - return new_node - - -def _expand_reduction( - reduction: Reduction, - trace: CapturedTrace, - dim_query: dict[IndexSymbol, int], - dim_scaling: dict[IndexSymbol, int], - context: ExpandedNodeMap, - get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], - res_idx: int = 0, -) -> CustomOp: - """Expand a reduction in a specific dimension and recursively proceed to its inputs.""" - # Determine the dimensions to expand the reduction from the indexing of its users - users = reduction.users - expand_dims: list[IndexSymbol] = [] - for user in users: - dim_scaling.update(get_node_dim_scaling(user)) - for indexing_dim in user.indexing_dims: - if indexing_dim not in expand_dims: - expand_dims.append(indexing_dim) - logger.debug(f"expanding reduction in dims: {expand_dims}") - - # Get the output node of the reduction - reduction_subgraph = trace.get_subgraph(reduction.subgraph_name) - output = get_custom(get_last(reduction_subgraph.nodes)) - if not isinstance(output, Output): - raise ValueError( - "fx.Graph malformed: The last node of a subgraph must be an output node" - ) - - new_output_args = [] - new_init_args = [] - for dim_vals in get_dim_combinations(dim_scaling, expand_dims): - return_vals = output.return_vals[0] - dims = {dim: val for dim, val in zip(dim_scaling.keys(), dim_vals)} - if not isinstance(return_vals, Sequence): - return_vals = [return_vals] - # Proceed with expansion inside the reduction - for arg_idx, arg in enumerate(return_vals): - arg = get_custom(arg) - # Add GetResult nodes for the corresponding dimensions - reduction.graph.inserting_after(reduction.fx_node) - new_node = GetResult(reduction.fx_node, len(new_output_args)) - # Usually we would rely on infer_types inside add_to_graph to figure out - # the type of the new node. However, in this case, the logic to determine - # the type requires the reduction node to have its init_args set, which has - # not happened yet (it happens later). So instead, since we have access to - # arg, we just set the type directly. - new_node.add_to_graph(reduction.graph, arg.type) - new_node.fx_node.name = get_expanded_name(new_node, dims) - context[ - (reduction, get_indexed_dims(dims, expand_dims), arg_idx) - ] = new_node - - expanded_output = _expand_node( - arg, - trace, - dims, - get_node_dim_scaling(arg), - context, - get_node_dim_scaling, - res_idx, - ) - # If condition below is needed to skip over induction variable - # who doesn't have all dims of ReductionOp. For example, - # a reduction Op that has induction variables of types - # (max, mma) -> [M], [M, N] - # will have indexing dims of ([M, N]). - # However, the 1st induction variable won't expand in N-dim - # M:0, N:0 expand(max) -> max_0_0_0 - # M:0, N:1 expand(max) -> max_0_0_0 - # but will get added to the `new_output_args` without the if condition. - - # TODO: Handle expansion of induction variables with "non-complete" dims - # by checking on the indexing_dims on each induction variable. - if expanded_output in new_output_args: - continue - new_output_args.append(expanded_output) - - # Proceed with expansion outside the reduction - for init_arg in reduction.init_args: - custom_init_arg = get_custom(init_arg) - expanded_init_arg = _expand_node( - custom_init_arg, - trace, - dims, - get_node_dim_scaling(custom_init_arg), - context, - get_node_dim_scaling, - res_idx, - ) - # TODO: Handle expansion of induction variables with "non-complete" dims - # by checking on the indexing_dims on each induction variable. - if expanded_init_arg in new_init_args: - continue - new_init_args.append(expanded_init_arg) - - # Update init_args and return values - reduction.update_arg( - "init_args", [new_init_arg.fx_node for new_init_arg in new_init_args] - ) - output.update_arg("return_vals", [node.fx_node for node in new_output_args]) - _handle_reduction_dim( - reduction, - output, - trace, - dim_scaling, - context, - get_node_dim_scaling, - res_idx, - ) - # Even though we expanded the reduction in multiple dimensions, we only return - # the node corresponding to the original query - return context[(reduction, get_indexed_dims(dim_query, expand_dims), res_idx)] - - -def _expand_mma_reduction( - mma: MMA, - trace: CapturedTrace, - dim_query: dict[IndexSymbol, int], - dim_scaling: dict[IndexSymbol, int], - context: ExpandedNodeMap, - get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], - res_idx: int, -) -> CustomOp: - """ - This function expands an MMA node along its reduction dimension. It is called - P times where P is the product of all of its parallel dimensions. For each - invocation, we expand the reduction dimension. - - We first compute the dim scaling along the reduction dimension and then append - it to the dim query so that the expanded node and its arguments can use the - expanded dim query with the appropriate value of the reduction dimension. - - Unlike the reduction expansion, where we can do a separate expansion for each iter_arg, - here we only have a single MMA node to start with. So we keep track of it and re-use - it for all the expansions. We also keep track of the accumulator value to be used as - the accumulator for the first expansion along the reduction dimension. - """ - - logger.debug(f"Expanding MMA reduction: {mma} in dims: {dim_query}") - expand_dims = set(mma.indexing_dims) - set([mma.reduction_dim]) - - idxc = IndexingContext.current() - for dim in mma.indexing_dims: - if dim not in dim_scaling and mma.vector_shapes[dim] > 0: - tile_size = idxc.get_static_value(dim) - dim_scaling[dim] = max(tile_size // mma.vector_shapes[dim], 1) - - # Store the original mma node and accumulator value for expansion. - # When we begin expansion, we have a single mma node with the correct accumulator. - # This node corresponds to the dim query with all 0s and for this we reuse the - # original mma node. For all other queries, we create a new node. - # So say we have parallel dimensions {M, K2} and reduction dimension {K1}. - # For M = 0, K2 = 0, K1 = 0, we use the original mma node. - # For M = 0, K2 = 0, K1 = 1, we create a new node. - # Now, when it is time to expand along new parallel dimensions, we use the original node - # For M = 0, K2 = 1, K1 = 0, we use the original mma node so that the last cloned node's - # accumulator value is not modified. - - dim_query_dims = tuple(dim_query.keys()) - if not hasattr(_expand_mma_reduction, "acc"): - _expand_mma_reduction.acc = {} - if not hasattr(_expand_mma_reduction, "mma"): - _expand_mma_reduction.mma = {} - if ( - dim_query_dims not in _expand_mma_reduction.mma - or _expand_mma_reduction.mma[dim_query_dims].graph != mma.graph - ): - _expand_mma_reduction.mma[dim_query_dims] = mma - _expand_mma_reduction.acc[dim_query_dims] = mma.acc - - context_key = ( - _expand_mma_reduction.mma[dim_query_dims], - get_indexed_dims(dim_query, expand_dims), - res_idx, - ) - - user = _expand_mma_reduction.mma[dim_query_dims] - for scale_idx in range(dim_scaling[mma.reduction_dim]): - if isinstance(user, Output): - continue - - dims = dim_query - dims[mma.reduction_dim] = scale_idx - # Temporarily replace the loop carried arg here to avoid - # duplicated expansion. Otherwise we have the following situation: - # Suppose we have: - # mma_0_0_0(..., acc_0_0_0) - # mma_0_0_1(..., mma_0_0_0) - # Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg - # mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node. - # To avoid this we temporarily replace the use of it with a dummy - # placeholder which will not trigger further expansion. - index = user.get_node_arg_index(get_custom(user.acc)) - dummy = Placeholder("dummy").add_to_graph(user.graph) - dummy.type = None - - saved_arg = user.node_args[index] - user.update_arg(index, dummy) - new_node = _expand_node( - user, - trace, - dims, - get_node_dim_scaling(user), - context, - get_node_dim_scaling, - ) - - # Update the new node accumulator with the user, except the first one. - if scale_idx > 0: - new_node.update_arg(index, user) - else: - new_node.update_arg(index, _expand_mma_reduction.acc[dim_query_dims]) - user.update_arg(index, saved_arg) - user.graph.erase_node(dummy) - user = new_node - - context[context_key] = new_node - return new_node - - -def _expand_reshape( - reshape: Reshape, - trace: CapturedTrace, - dim_query: dict[IndexSymbol, int], - dim_scaling: dict[IndexSymbol, int], - context: ExpandedNodeMap, - get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], - res_idx: int, -) -> CustomOp: - """ - When expanding a reshape, we have to expand the arguments of the reshape and then concatenate them together - for the expanded node. Say we have a node with indexing dims = [M, N] with vector shapes m=8, n=2 and - the reshape wants to map it to m=4, n=4. So we start by expanding the node - node: {m = 0, n = 0} - arg: {m = 0, n = 0} - arg: {m = 0, n = 1} - node: {m = 1, n = 0} - arg: {m = 0, n = 0} - arg: {m = 0, n = 1} - node: {m = 2, n = 0} - arg: {m = 1, n = 0} - arg: {m = 1, n = 1} - node: {m = 3, n = 0} - arg: {m = 1, n = 0} - arg: {m = 1, n = 1} - ... - In general, - For the (m = i, n = j) expansion of the reshape node, we expand the arguments of the reshape node - using the following recipe: - - if m_src < m_dst, => we have a one to many mapping from source to destination - so we expand the arguments along m = i // (m_dst / m_src) and we expand the argument only once. - - if m_src > m_dst, => we have a many to one mapping from source to destination - so we expand the arguments along m = i * (m_src / m_dst), ... and we expand the argument m_dst / m_src times. - - In situations where the argument has been expanded along the same dimension, we reuse the expanded node - by making use of the context. - """ - - dim_combinations = {} - for dim, value in dim_query.items(): - if dim not in reshape.target_vector_shape: - continue - if reshape.vector_shapes[dim] < reshape.target_vector_shape[dim]: - scale_factor = ( - reshape.target_vector_shape[dim] // reshape.vector_shapes[dim] - ) - dim_combinations[dim] = [value // scale_factor] - else: - scale_factor = ( - reshape.vector_shapes[dim] // reshape.target_vector_shape[dim] - ) - begin = value * scale_factor - dim_combinations[dim] = list(range(begin, begin + scale_factor)) - reshape_dim_combinations = list(itertools.product(*dim_combinations.values())) - - new_args = [] - for i, arg_dim_query in enumerate(reshape_dim_combinations): - arg_dim_query = { - dim: val for dim, val in zip(dim_combinations.keys(), arg_dim_query) - } - if isinstance(reshape.args, Sequence): - custom_arg = get_custom(reshape.args[i]) - else: - custom_arg = get_custom(reshape.args) - new_node = _expand_node( - custom_arg, - trace, - arg_dim_query, - get_node_dim_scaling(custom_arg.fx_node), - context, - get_node_dim_scaling, - res_idx, - ) - new_args.append(new_node.fx_node) - - reshape.update_arg("args", new_args) - - -def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str: - """Returns the name of a node with the dimensions appended.""" - - separated = node.fx_node.name.split("_") - node_name = separated[0] - if isinstance(node, Read) or isinstance(node, Write): - if get_custom(node.memory).type.address_space == SHARED_ADDRESS_SPACE: - node_name = node_name + "_shared" - # Special case for get_result op - if node_name == "get": - node_name = node_name + separated[1] - for val in dims.values(): - node_name += f"_{val}" - return node_name - - -def _contains(elem, container): - if container is None: - return False - - return elem in container - - -def get_dim_scaling( - constraints: Sequence[Constraint], node: fx.Node -) -> dict[IndexSymbol, int]: - """Get the number of expansions for the dimensions based on the constraints for a specific node.""" - dim_scaling: dict[IndexSymbol, int] = {} - if node.vector_shapes is None: - return dim_scaling - - hardware_constraints: list[HardwareConstraint] = [ - constraint - for constraint in constraints - if isinstance(constraint, HardwareConstraint) - ] - if len(hardware_constraints) != 1: - raise ValueError("Exactly one hardware constraint must be provided") - - aliased_dims: list[IndexSymbol] = [ - constraint.source - for constraint in constraints - if isinstance(constraint, SymbolicAlias) - ] - - idxc = IndexingContext.current() - for constraint in constraints: - if isinstance(constraint, WorkgroupConstraint) or isinstance( - constraint, TilingConstraint - ): - if constraint.dim in aliased_dims: - continue - hw_cons = hardware_constraints[0] - tile_size = idxc.get_static_value(constraint.tile_size) - if constraint.dim not in node.vector_shapes: - continue - vector_size = node.vector_shapes[constraint.dim] - - # No dim scaling for dims with 0 vector size. - if vector_size == 0: - continue - - wave_count = 1 - if isinstance(constraint, WorkgroupConstraint): - wave_count = hw_cons.waves_per_block[constraint.workgroup_dim] - if tile_size is None or wave_count is None or vector_size is None: - raise ValueError( - "Tile size, wave count and vector size must be statically known" - ) - if ( - tile_size % wave_count != 0 - or (tile_size / wave_count) % vector_size != 0 - ): - raise ValueError( - "Tile size must be divisible by wave count and vector size" - ) - dim_scaling[constraint.dim] = tile_size // wave_count // vector_size - - return dim_scaling - - -def _expand_mma_tiled_reduction( - mma: MMA, - trace: CapturedTrace, - dim_query: dict[IndexSymbol, int], - dim_scaling: dict[IndexSymbol, int], - context: ExpandedNodeMap, - get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], - res_idx: int, -) -> CustomOp: - latest_reduced_op = mma - # The initial nodes are expanded in the first dimension, so we start from 1 - for scale_idx in range(1, dim_scaling[mma.reduction_dim]): - dim_query[mma.reduction_dim] = scale_idx - # Temporarily replace the loop carried arg here to avoid - # duplicated expansion. Otherwise we have the following situation: - # Suppose we have: - # mma_0_0_0(..., acc_0_0_0) - # mma_0_0_1(..., mma_0_0_0) - # Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg - # mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node. - # To avoid this we temporarily replace the use of it with a dummy - # placeholder which will not trigger further expansion. - dummy = Placeholder("dummy").add_to_graph(latest_reduced_op.graph) - dummy.type = None - - saved_acc = latest_reduced_op.acc - latest_reduced_op.update_arg("acc", dummy) - new_node = _expand_node( - latest_reduced_op, - trace, - dim_query, - dim_scaling, - context, - get_node_dim_scaling, - res_idx, - ) - - # Node is always cloned; Hence, will never be equal to latest reduced op - assert new_node != latest_reduced_op - # Update MMA_{t} to accumulate on MMA_{t-1}, and then save - # current MMA_{t} to outputs for use in next loop. - latest_reduced_op.update_arg("acc", saved_acc) - new_node.update_arg("acc", latest_reduced_op) - latest_reduced_op.graph.erase_node(dummy) - latest_reduced_op = new_node - return latest_reduced_op - - -def _handle_reduction_dim( - reduction: Reduction, - output: Output, - trace: CapturedTrace, - dim_scaling: dict[IndexSymbol, int], - context: ExpandedNodeMap, - get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], - res_idx: int, -): - # Rediscover iter args - # TODO: Register iter args with the reduction initially so accessing them is easier - reduction_subgraph = trace.get_subgraph(reduction.subgraph_name) - - # TODO: Handle case where MMAs/ReduceOps do not have Output as direct consumer. - def get_output_index(custom: CustomOp): - output_users = [ - get_custom(user) - for user in custom.fx_node.users - if isinstance(get_custom(user), Output) - ] - if len(output_users) != 1: - raise NotImplementedError( - "NYI: Currently only handle direct and 1:1 MMA -> Output case." - ) - return output_users[0].return_vals[0].index(custom.fx_node) - - # Collect MMA and ReduceOp who's reduction axis matches parent ReductionOp. - reduction_root_ops = [] - for node in (get_custom(fx_node) for fx_node in reduction_subgraph.nodes): - if isinstance(node, (MMA, ReduceOp)) and reduction.axis == node.reduction_dim: - reduction_root_ops.append(node) - - new_outputs = list(reduction.outputs(trace.get_subgraph(reduction.subgraph_name))) - # Users of the loop carried nodes will be duplicated - for root_op in reduction_root_ops: - dim_scaling = get_node_dim_scaling(root_op) - dims = dict(root_op.fx_node.expanded_dims) - latest_reduced_op = root_op - op_output_index = get_output_index(root_op) - if isinstance(root_op, MMA): - latest_reduced_op = _expand_mma_tiled_reduction( - root_op, - trace, - dims, - dim_scaling, - context, - get_node_dim_scaling, - res_idx, - ) - elif isinstance(root_op, ReduceOp): - original_src = latest_reduced_op.arg - # The initial nodes are expanded in the first dimension, so we start from 1 - for scale_idx in range(1, dim_scaling[reduction.axis]): - dims[root_op.reduction_dim] = scale_idx - current_src = latest_reduced_op.arg - if not isinstance(current_src, Sequence): - current_src = [current_src] - expanded_src = _expand_node( - get_custom(original_src), - trace, - dims, - dim_scaling, - context, - get_node_dim_scaling, - res_idx, - ) - current_src.append(expanded_src.fx_node) - latest_reduced_op.update_arg("arg", current_src) - new_outputs[op_output_index] = latest_reduced_op.fx_node - init_dims = root_op.fx_node.expanded_dims - context[ - (root_op, get_indexed_dims(init_dims, root_op), res_idx) - ] = latest_reduced_op - output.update_arg("return_vals", new_outputs) diff --git a/iree/turbine/kernel/wave/expansion/expansion.py b/iree/turbine/kernel/wave/expansion/expansion.py new file mode 100644 index 00000000..7dec0170 --- /dev/null +++ b/iree/turbine/kernel/wave/expansion/expansion.py @@ -0,0 +1,698 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ..._support.tracing import CapturedTrace +from typing import Sequence, Type, Any +from ..constraints import ( + Constraint, +) +from ...ops.wave_ops import ( + Allocate, + CustomOp, + get_custom, + Output, + Write, + Reduction, + ReduceOp, + IterArg, + Reshape, + GetResult, + MMA, +) +from ..._support.indexing import IndexingContext, IndexSymbol +import itertools +from torch import fx +from dataclasses import dataclass +from .expansion_utils import ( + get_dim_scaling, + flatten_list, + get_indexed_dims, + is_expandable, + get_expanded_name, + compute_strides, + ExpansionMetadata, + get_reshape_dim_queries, + remove_original_nodes, + remove_unused_registers, +) +from ..utils import ( + get_users, + get_inputs, +) +from copy import deepcopy +import math + + +@dataclass(frozen=True) +class ExpansionInfo: + """ + Key used to store and look up nodes during expansion. + """ + + node: CustomOp + indexed_dims: tuple[tuple[IndexSymbol, int], ...] + + +class ReductionInfo: + """ + Contains fixup information for a reduction node. + """ + + def __init__(self, reduction: Reduction): + self.reduction = reduction + self.outputs: dict[int, ExpansionInfo] = {} + self.init_args: dict[int, ExpansionInfo] = {} + self.get_results: dict[int, ExpansionInfo] = {} + + +class ExpansionContext: + """ + Context used to store information during expansion. + """ + + def __init__(self): + self.expansion_context: dict[ExpansionInfo, CustomOp] = {} + # Additional operator specific information. + self.reduction_context: dict[Reduction, ReductionInfo] = {} + self.mma_connections: list[tuple[MMA, MMA]] = [] + self.mma_nodes: list[tuple[MMA]] = [] + + def __getitem__(self, key: ExpansionInfo): + return self.expansion_context[key] + + def __contains__(self, key: ExpansionInfo): + return key in self.expansion_context + + def __setitem__(self, key: ExpansionInfo, value: CustomOp): + self.expansion_context[key] = value + + +def get_leaf_node_types( + trace: CapturedTrace, all_nodes_reversed: list[fx.Node] +) -> list[Type[CustomOp]]: + """ + Get the types of leaf nodes of the trace. These are of type write. + If there are no write nodes, then the type of the last node in the trace is + used. + """ + leaf_nodes: list[Type[CustomOp]] = [Write] + if not leaf_nodes: + for node in (get_custom(node) for node in all_nodes_reversed): + if isinstance(node, Output): + continue + leaf_nodes.append(node.__class__) + break + + return leaf_nodes + + +def get_dim_combinations( + node: CustomOp, + constraints: Sequence[Constraint], +): + """ + Returns all combinations of sizes for the selected dimensions. + Other dimensions are clamped to 0. A dictionary is return where + the keys are the dimensions and the values are the combination. + """ + dim_scaling = get_dim_scaling(constraints, node) + adjusted_dimension_sizes = [ + list(range(dim_scaling[dim])) if dim in node.indexing_dims else [0] + for dim in dim_scaling + ] + dim_combinations = itertools.product(*adjusted_dimension_sizes) + return [ + {dim: val for dim, val in zip(dim_scaling.keys(), dim_combination)} + for dim_combination in dim_combinations + ] + + +def filter_expandable_args(args: list[Any]) -> list[Any]: + """ + Filter out the arguments that can be expanded. These are the arguments + that are of type CustomOp. + """ + filtered_args = [] + for arg in args: + arg_list = arg + if isinstance(arg, Sequence): + if not all(is_expandable(arg) for arg in arg): + continue + else: + if not is_expandable(arg): + continue + arg_list = [arg] + filtered_args.append(arg_list) + return flatten_list(filtered_args) + + +def filter_and_zero_unselected_dims( + dims: dict[IndexSymbol, int], selection: Sequence[IndexSymbol] +) -> dict[IndexSymbol, int]: + """ + Filters dimensions based on selection and sets unselected dimensions' values to zero. + """ + return {dim: val if dim in selection else 0 for dim, val in dims.items()} + + +def compute_result_index( + dim_query: dict[IndexSymbol, int], + dim_scaling: dict[IndexSymbol, int], + node: fx.Node, + outputs: list[fx.Node], +): + """ + Compute the result index for a reduction node based on the dim + query and dim scaling. + + Say we have a reduction with output: + (max, sum, mma) + + and say that the indexing dims are: + max -> (M) + sum -> (M) + mma -> (M, N) + + and if the dim_scaling is: + M -> 2 + N -> 2 + + then we know that there will be a total of 8 results and + we can arrange them as follows: + (max_M:0, max_M:1, sum_M:0, sum_M:1, mma_M:0_N:0, mma_M:0_N:1, mma_M:1_N:0, mma_M:1_N:1) + + This means that each of these results has a predefined index that we can compute + based on the dim_scaling and dim_query. + + The formula for inputs in general is: + + # Global offset. + global_index = 0 + for i in range(index(input)): + global_index += product(dim_scaling[d] for d in node[i].indexing_dims) + + # Local offset. + local_index += dim_query * dim_strides + + """ + input_index = outputs.index(node) + get_shape = lambda x: get_custom(x).type.symbolic_shape + result_index = sum( + math.prod(dim_scaling[d] for d in get_shape(outputs[i]) if d in dim_scaling) + for i in range(input_index) + ) + node_shape = get_shape(node) + restricted_dim_scaling = {k: v for k, v in dim_scaling.items() if k in node_shape} + restricted_dim_query = {k: v for k, v in dim_query.items() if k in node_shape} + result_index += sum( + x * y + for x, y in zip( + compute_strides(restricted_dim_scaling), restricted_dim_query.values() + ) + ) + return result_index + + +def to_tuple(d: dict[IndexSymbol, int]) -> tuple[int, ...]: + return tuple([(k, v) for k, v in d.items()]) + + +def to_dict(t: tuple[int, ...]) -> dict[IndexSymbol, int]: + return {k: v for k, v in t} + + +def handle_reduction_entry( + reduction: Reduction, + inputs: list[CustomOp], + new_node: CustomOp, + node: CustomOp, + dim_query: dict[IndexSymbol, int], + dim_scaling: dict[IndexSymbol, int], + expansion_context: ExpansionContext, +): + # TODO: GetItems may not always be emitted if there is only one output. + reduction_context = expansion_context.reduction_context + if isinstance(new_node, GetResult): + assert len(inputs) == 1, f"Expected one input, got {inputs}" + outputs = reduction.outputs(inputs[0].graph) + if not isinstance(outputs, Sequence): + outputs = [outputs] + if reduction not in reduction_context: + reduction_context[reduction] = ReductionInfo(reduction) + result_index = compute_result_index(dim_query, dim_scaling, inputs[0], outputs) + custom = get_custom(inputs[0]) + key = ExpansionInfo(custom, get_indexed_dims(dim_query, custom)) + reduction_context[reduction].outputs[result_index] = key + reduction_context[reduction].get_results[result_index] = new_node + + +def handle_reduction_exit( + reduction: Reduction, + inputs: list[CustomOp], + new_node: CustomOp, + node: CustomOp, + dim_query: dict[IndexSymbol, int], + dim_scaling: dict[IndexSymbol, int], + expansion_context: ExpansionContext, +): + # If we are an iter arg, then we are exiting a reduction. + reduction_context = expansion_context.reduction_context + if isinstance(new_node, IterArg): + assert len(inputs) == 1, f"Expected one input, got {inputs}" + reduction = new_node.parent_op() + result_index = compute_result_index( + dim_query, dim_scaling, inputs[0], reduction.init_args + ) + assert reduction in reduction_context, f"Reduction not found: {reduction}" + new_node.iter_idx = result_index + custom = get_custom(inputs[0]) + key = ExpansionInfo(custom, get_indexed_dims(dim_query, custom)) + reduction_context[reduction].init_args[result_index] = key + + +def concatenate_outputs( + user: CustomOp, + new_user: CustomOp, + node: CustomOp, + new_node: CustomOp, + i: int, + metadata: ExpansionMetadata, +): + reshape_check = isinstance(new_user, Reshape) + reduce_check = isinstance(new_user, ReduceOp) and i == 0 + if reshape_check or reduce_check: + if metadata.query_index == 0: + new_node = [new_node.fx_node] + else: + assert ( + metadata.query_index > 0 + ), f"Expected query index > 0, got {metadata.query_index}" + new_node = [x.fx_node for x in new_user.node_args[i]] + [new_node.fx_node] + return new_node + return replace_node(user, new_user, node, new_node, i) + + +def replace_node( + user: CustomOp, new_user: CustomOp, node: CustomOp, new_node: CustomOp, i: int +): + # If we are updating a single value in a sequence, then we need to + # insert the new node at the correct location. + if isinstance(user.node_args[i], Sequence) and not isinstance(new_node, Sequence): + new_node = [ + x.fx_node if x != node else new_node.fx_node for x in new_user.node_args[i] + ] + return new_node + + +def update_users( + node: CustomOp, + new_node: CustomOp, + metadata: ExpansionMetadata, + expansion_context: ExpansionContext, +): + users, _ = get_users(node.fx_node, None) + for user in users: + user = get_custom(user) + dim_query = metadata.dim_query + # For reshapes and reduces, multiple users can share the same source. + if isinstance(user, (Reshape, ReduceOp)): + if not metadata.source_dim_query: + continue + dim_query = metadata.source_dim_query + key = ExpansionInfo(user, get_indexed_dims(dim_query, user)) + if key in expansion_context: + new_user = expansion_context[key] + if not new_user.node_args: + continue + indices = user.get_node_arg_index(node) + if indices is None: + continue + if not isinstance(indices, Sequence): + indices = [indices] + for i in indices: + # Check if an update is required. + if isinstance(new_user.node_args[i], Sequence): + # If node is already in the list, then we don't need to update. + if any(x == new_node for x in new_user.node_args[i]): + continue + else: + if new_user.node_args[i] == new_node: + continue + new_arg = concatenate_outputs( + user, new_user, node, new_node, i, metadata + ) + new_user.update_arg(i, new_arg) + + +def add_to_outputs(node: CustomOp, new_node: CustomOp): + """ + Add the new node to the outputs of the node at the correct index. + """ + output = [x for x in node.users if isinstance(get_custom(x), Output)] + if not output: + return + output = get_custom(output[0]) + users, _ = get_users(new_node.fx_node, None) + get_result = [x for x in users if isinstance(get_custom(x), GetResult)] + assert len(get_result) == 1, f"Expected one GetResult, got {get_result}" + result_index = get_result[0].result_index + if len(output.return_vals[0]) < result_index: + new_return_vals = output.return_vals[0] + [None] * ( + result_index - len(output.return_vals[0]) + 1 + ) + new_return_vals[result_index] = new_node + output.return_vals = [new_return_vals] + + +def get_node( + dim_query: dict[IndexSymbol, int], + node: CustomOp, + expansion_context: ExpansionContext, +): + key = ExpansionInfo(node, get_indexed_dims(dim_query, node)) + assert key in expansion_context, f"Key not found: {key}" + return expansion_context[key] + + +def get_mma_reduction_count(arg: MMA, dim_scaling: dict[IndexSymbol, int]) -> int: + if arg.reduction_dim in dim_scaling: + reduction_count = dim_scaling[arg.reduction_dim] + else: + idxc = IndexingContext.current() + tile_size = idxc.get_static_value(arg.reduction_dim) + assert tile_size, f"Dimension not known : {arg.reduction_dim}" + reduction_count = max(tile_size // arg.vector_shapes[arg.reduction_dim], 1) + return reduction_count + + +def add_get_results(trace: CapturedTrace): + reductions = trace.walk(lambda x: isinstance(get_custom(x), Reduction)) + for reduction in reductions: + reduction = get_custom(reduction) + if len(reduction.init_args) == 1: + reduction.graph.inserting_after(reduction.fx_node) + get_result = get_custom( + GetResult(reduction.fx_node, 0).add_to_graph(reduction.graph) + ) + get_result.vector_shapes = reduction.init_args[0].vector_shapes + reduction.replace_all_uses_with_except(get_result, [get_result]) + + +def populate_inputs( + node: CustomOp, + inputs: list[fx.Node], + metadata: ExpansionMetadata, + dim_scaling: dict[IndexSymbol, int], + nodes_to_expand: list[tuple[CustomOp, dict[IndexSymbol, int]]], + expansion_context: ExpansionContext, +): + expandable_args = filter_expandable_args([get_custom(x) for x in inputs]) + new_nodes_to_expand = [] + + if isinstance(node, (Reshape, ReduceOp)): + match node: + case Reshape(): + dim_queries = get_reshape_dim_queries( + node, metadata, dim_scaling, new_nodes_to_expand + ) + case ReduceOp(): + reduction_count = dim_scaling[node.reduction_dim] + dim_queries = [] + for i in range(reduction_count): + dim_query = deepcopy(metadata.dim_query) + dim_query[node.reduction_dim] = i + dim_queries.append(dim_query) + + count = 0 + for i, arg in enumerate(expandable_args): + # For the init arg of the reduce op, if it exists, we expand only once. + if isinstance(node, ReduceOp) and node.init and arg == node.init: + nodes_to_expand.append((arg, metadata)) + continue + for j, query in enumerate(dim_queries): + new_metadata = deepcopy(metadata) + new_metadata.dim_query = query + new_metadata.source_dim_query = metadata.dim_query + new_metadata.num_queries = len(dim_queries) + new_metadata.query_index = count + new_nodes_to_expand.append((arg, new_metadata)) + count += 1 + nodes_to_expand.extend(new_nodes_to_expand) + return nodes_to_expand + + for arg in expandable_args: + match arg: + case MMA(): + reduction_count = get_mma_reduction_count(arg, dim_scaling) + for i in range(reduction_count): + mma_metadata = deepcopy(metadata) + mma_metadata.dim_query[arg.reduction_dim] = i + if i == reduction_count - 1: + mma_metadata.last_mma_node = True + new_nodes_to_expand.append((arg, mma_metadata)) + continue + case Allocate(): + alloc_metadata = deepcopy(metadata) + alloc_metadata.do_not_expand = True + new_nodes_to_expand.append((arg, alloc_metadata)) + continue + + new_nodes_to_expand.append((arg, metadata)) + + nodes_to_expand.extend(new_nodes_to_expand) + return nodes_to_expand + + +def store_fixup_data( + node: CustomOp, + new_node: CustomOp, + expanded_dims: dict[IndexSymbol, int], + dim_scaling: dict[IndexSymbol, int], + metadata: ExpansionMetadata, + expansion_context: ExpansionContext, +): + """ + Keep track of which MMA nodes need to be connected and replaced + for the fixup phase. + """ + match node: + case MMA(): + if expanded_dims[node.reduction_dim] == 0: + return + + def get_dim_query(new_v: int): + dims = { + k: v if k != node.reduction_dim else new_v + for k, v in expanded_dims.items() + } + return dims + + # Update accumulator. + last_dim_query = get_dim_query(expanded_dims[node.reduction_dim] - 1) + last_node = get_node(last_dim_query, node, expansion_context) + expansion_context.mma_connections.append((new_node, last_node)) + + # Keep track of fixup nodes. + if metadata.last_mma_node: + first_node = get_node(get_dim_query(0), node, expansion_context) + second_node = get_node(get_dim_query(1), node, expansion_context) + expansion_context.mma_nodes.append((first_node, new_node, second_node)) + + +def expand_node( + node: CustomOp, + dim_scaling: dict[IndexSymbol, int], + nodes_to_expand: list[tuple[CustomOp, dict[IndexSymbol, int], int]], + metadata: ExpansionMetadata, + expansion_context: ExpansionContext, +): + """ + When we expand a node, we clone it and add its arguments to the + list of nodes to be expanded. + """ + # TODO: Handle reduce_ops + # TODO: Handle reshape + + # Filter out the dimensions that are not selected in the query. + expanded_dims = filter_and_zero_unselected_dims( + metadata.dim_query, node.indexing_dims + ) + + # Check if the node has already been expanded, if so return early. + key = ExpansionInfo(node, get_indexed_dims(expanded_dims, node)) + if key in expansion_context: + update_users(node, expansion_context[key], metadata, expansion_context) + return nodes_to_expand + + if metadata.do_not_expand: + return nodes_to_expand + + # Make a copy of the node, adjust its name and set any metadata. + new_node = node.copy(anchor=(node.fx_node.prev)) + new_node.fx_node.name = get_expanded_name(node, metadata.dim_query) + new_node.expanded_dims = expanded_dims + + # Add new node to expansion context. + expansion_context[key] = new_node + + # Store information needed for the fixup phase. + store_fixup_data( + node, + new_node, + expanded_dims, + dim_scaling, + metadata, + expansion_context, + ) + + # Check for any expanded users and update their arguments. + update_users(node, new_node, metadata, expansion_context) + + # Check if node should not be expanded. + if metadata.do_not_expand: + return nodes_to_expand + + # Add expandable inputs to the list of nodes to expand. + inputs, reduction = get_inputs(node.fx_node, None) + + handle_reduction_entry( + reduction, + inputs, + new_node, + node, + metadata.dim_query, + dim_scaling, + expansion_context, + ) + handle_reduction_exit( + reduction, + inputs, + new_node, + node, + metadata.dim_query, + dim_scaling, + expansion_context, + ) + + nodes_to_expand = populate_inputs( + node, inputs, metadata, dim_scaling, nodes_to_expand, expansion_context + ) + return nodes_to_expand + + +def dfs( + node: CustomOp, + dim_query: dict[IndexSymbol, int], + constraints: Sequence[Constraint], + expansion_context: ExpansionContext, +): + """ + Perform a depth-first search on the graph starting at the given node + for the given dimension combination. + """ + + visited = set() + nodes_to_expand = [(node, ExpansionMetadata(dim_query))] + while nodes_to_expand: + node, metadata = nodes_to_expand.pop(0) + if (node, metadata) in visited: + continue + visited.add((node, metadata)) + dim_scaling = get_dim_scaling(constraints, node) + nodes_to_expand = expand_node( + node, dim_scaling, nodes_to_expand, metadata, expansion_context + ) + + +def get_last(node_list: fx.graph._node_list) -> fx.Node: # type: ignore + """Get the last element of the fx node_list structure""" + return next(iter(reversed(node_list))) # type: ignore + + +def fixup_mma_nodes(trace: CapturedTrace, expansion_context: ExpansionContext): + # Chain MMA connections + for current, last in expansion_context.mma_connections: + current.update_arg(2, last) + # Use the last MMA node instead of the first one. + for first, second, exclude in expansion_context.mma_nodes: + first.replace_all_uses_with_except(second, [exclude]) + + +def fixup_reduction_nodes(trace: CapturedTrace, expansion_context: ExpansionContext): + reduction_context = expansion_context.reduction_context + for reduction in trace.walk(lambda x: isinstance(get_custom(x), Reduction)): + reduction = get_custom(reduction) + reduction_subgraph = trace.get_subgraph(reduction.subgraph_name) + output = get_custom(get_last(reduction_subgraph.nodes)) + return_vals = output.return_vals[0] + if isinstance(return_vals, Sequence): + return_vals = [get_custom(x) for x in output.return_vals[0]] + else: + return_vals = [get_custom(return_vals)] + reduction_info = reduction_context[reduction] + sorted_keys = dict(sorted(reduction_info.outputs.items(), key=lambda x: x[0])) + new_outputs = [] + for key in sorted_keys.values(): + new_outputs.append(expansion_context[key].fx_node) + output.update_arg("return_vals", new_outputs) + + sorted_keys = dict(sorted(reduction_info.init_args.items(), key=lambda x: x[0])) + new_init_args = [] + for key in sorted_keys.values(): + new_init_args.append(expansion_context[key].fx_node) + reduction.update_arg("init_args", new_init_args) + + for result_index, get_item in reduction_info.get_results.items(): + get_item.graph.inserting_before(get_item.fx_node) + get_result = GetResult(get_item.value, result_index).add_to_graph( + get_item.graph, get_item.type + ) + get_result.name = get_item.fx_node.name + get_item.replace_all_uses_with(get_custom(get_result)) + get_item.graph.erase_node(get_item.fx_node) + + remove_original_nodes(return_vals) + + +def expand_graph( + trace: CapturedTrace, + constraints: Sequence[Constraint], +): + """ + Create a graph that represents the expanded version of the wave function. + The constraints are used to determine how the graph should be expanded. + The expansion does a DFS starting at the leaf nodes and expanding them + to the root of the graph. + """ + + add_get_results(trace) + + all_nodes_reversed = list(reversed(trace.get_root_graph().nodes)) + leaf_node_types = get_leaf_node_types(trace, all_nodes_reversed) + leaf_nodes = [] + expansion_context = ExpansionContext() + for node in all_nodes_reversed: + custom = get_custom(node) + type_check = custom.__class__ in leaf_node_types + unused_reduction_result = isinstance(custom, GetResult) and not custom.users + if not (type_check or unused_reduction_result): + continue + leaf_nodes.append(custom) + for dim_combination in get_dim_combinations(custom, constraints): + dfs( + custom, + dim_combination, + constraints, + expansion_context, + ) + + # Fixup all reduction nodes. + fixup_reduction_nodes(trace, expansion_context) + # Fixup all mma nodes. + fixup_mma_nodes(trace, expansion_context) + # Remove original nodes in root graph. + remove_original_nodes(leaf_nodes) + remove_unused_registers(trace) diff --git a/iree/turbine/kernel/wave/expansion/expansion_utils.py b/iree/turbine/kernel/wave/expansion/expansion_utils.py new file mode 100644 index 00000000..04697d5e --- /dev/null +++ b/iree/turbine/kernel/wave/expansion/expansion_utils.py @@ -0,0 +1,253 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ..._support.tracing import CapturedTrace +from typing import Sequence, Any +from ..constraints import ( + Constraint, + HardwareConstraint, + WorkgroupConstraint, + TilingConstraint, +) +from torch import fx +from ..._support.indexing import IndexingContext, IndexSymbol +from ...ops.wave_ops import ( + get_custom, + CustomOp, + Placeholder, + IterArg, + Read, + Write, + Reshape, + NewRegister, +) +from ...lang.global_symbols import SHARED_ADDRESS_SPACE +import itertools +from ..utils import ( + get_inputs, +) + + +class ExpansionMetadata: + def __init__(self, dim_query: dict[IndexSymbol, int] = None): + self.do_not_expand: bool = False + self.dim_query = dim_query + self.last_mma_node = False + self.source_dim_query = None + self.num_queries = None + self.query_index = None + + def __str__(self): + return str(self.__dict__) + + +def get_dim_scaling( + constraints: Sequence[Constraint], node: fx.Node +) -> dict[IndexSymbol, int]: + """Get the number of expansions for the dimensions based on the constraints for a specific node.""" + dim_scaling: dict[IndexSymbol, int] = {} + if node.vector_shapes is None: + return dim_scaling + + hardware_constraints: list[HardwareConstraint] = [ + constraint + for constraint in constraints + if isinstance(constraint, HardwareConstraint) + ] + if len(hardware_constraints) != 1: + raise ValueError("Exactly one hardware constraint must be provided") + + idxc = IndexingContext.current() + for constraint in constraints: + if isinstance(constraint, WorkgroupConstraint) or isinstance( + constraint, TilingConstraint + ): + hw_cons = hardware_constraints[0] + tile_size = idxc.get_static_value(constraint.tile_size) + if constraint.dim not in node.vector_shapes: + continue + vector_size = node.vector_shapes[constraint.dim] + + # No dim scaling for dims with 0 vector size. + if vector_size == 0: + continue + + wave_count = 1 + if isinstance(constraint, WorkgroupConstraint): + wave_count = hw_cons.waves_per_block[constraint.workgroup_dim] + if tile_size is None or wave_count is None or vector_size is None: + raise ValueError( + "Tile size, wave count and vector size must be statically known" + ) + if ( + tile_size % wave_count != 0 + or (tile_size / wave_count) % vector_size != 0 + ): + raise ValueError( + "Tile size must be divisible by wave count and vector size" + ) + dim_scaling[constraint.dim] = tile_size // wave_count // vector_size + + return dim_scaling + + +def flatten_list(nested_list): + flat_list = [] + for item in nested_list: + if isinstance(item, list): + flat_list.extend(flatten_list(item)) + else: + flat_list.append(item) + return flat_list + + +def get_indexed_dims( + all_dims: dict[IndexSymbol, int], nodeOrDims: CustomOp | Sequence[IndexSymbol] +) -> tuple[tuple[IndexSymbol, int], ...]: + """ + Generates a tuple of (key, value) pairs from the provided dimensions. + If given a CustomOp instance, it uses its indexing_dims attribute. + """ + if isinstance(nodeOrDims, CustomOp): + nodeOrDims = nodeOrDims.indexing_dims + # Flatten dims for node with multiple values or expanded Reduction. + if all(isinstance(el, Sequence) for el in nodeOrDims): + flattened_dims = list(itertools.chain.from_iterable(nodeOrDims)) + flatten_dims_set = dict.fromkeys(flattened_dims) + nodeOrDims = list(flatten_dims_set) + return tuple((key, all_dims[key]) for key in nodeOrDims if key in all_dims) + + +def is_expandable(arg: Any) -> bool: + """Check if an argument is expandable.""" + if isinstance(arg, Sequence): + return all(is_expandable(a) for a in arg) + # Placeholder nodes are only expanded if they are a reduction init arg + if isinstance(arg, Placeholder) and not isinstance(arg, IterArg): + return False + return isinstance(arg, CustomOp) + + +def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str: + """Returns the name of a node with the dimensions appended.""" + + separated = node.fx_node.name.split("_") + node_name = separated[0] + if isinstance(node, Read) or isinstance(node, Write): + if get_custom(node.memory).type.address_space == SHARED_ADDRESS_SPACE: + node_name = node_name + "_shared" + # Special case for get_result op + if node_name == "get": + node_name = node_name + separated[1] + max_chars = 4 + for key, val in dims.items(): + key_str = str(key) + if len(key_str) > max_chars: + key_str = key_str[0:4] + "*" + node_name += f"_{key_str}:{val}" + return node_name + + +def compute_strides(dim_scaling: dict[IndexSymbol, int]) -> list[int]: + """ + Compute the strides for each dimension based on the dim scaling. + """ + strides = [1] * len(dim_scaling) + stride = 1 + for i, dim in enumerate(reversed(dim_scaling.keys())): + strides[i] = stride + stride *= dim_scaling[dim] + return strides[::-1] + + +def filter_non_cloned_nodes(nodes: list[CustomOp]) -> list[CustomOp]: + """ + Filter out nodes that have been cloned. + """ + global expansion_context + return [node for node in nodes if node not in expansion_context.values()] + + +def get_reshape_dim_queries( + reshape: Reshape, + metadata: ExpansionMetadata, + dim_scaling: dict[IndexSymbol, int], + nodes_to_expand: list[tuple[CustomOp, dict[IndexSymbol, int]]], +): + """ + When expanding a reshape, we have to expand the arguments of the reshape and then concatenate them together + for the expanded node. Say we have a node with indexing dims = [M, N] with vector shapes m=8, n=2 and + the reshape wants to map it to m=4, n=4. So we start by expanding the node + node: {m = 0, n = 0} + arg: {m = 0, n = 0} + arg: {m = 0, n = 1} + node: {m = 1, n = 0} + arg: {m = 0, n = 0} + arg: {m = 0, n = 1} + node: {m = 2, n = 0} + arg: {m = 1, n = 0} + arg: {m = 1, n = 1} + node: {m = 3, n = 0} + arg: {m = 1, n = 0} + arg: {m = 1, n = 1} + ... + In general, + For the (m = i, n = j) expansion of the reshape node, we expand the arguments of the reshape node + using the following recipe: + - if m_src < m_dst, => we have a one to many mapping from source to destination + so we expand the arguments along m = i // (m_dst / m_src) and we expand the argument only once. + - if m_src > m_dst, => we have a many to one mapping from source to destination + so we expand the arguments along m = i * (m_src / m_dst), ... and we expand the argument m_dst / m_src times. + + In situations where the argument has been expanded along the same dimension, we reuse the expanded node + by making use of the context. + """ + + dim_combinations = {} + for dim, value in metadata.dim_query.items(): + if dim not in reshape.target_vector_shape: + continue + if reshape.vector_shapes[dim] < reshape.target_vector_shape[dim]: + scale_factor = ( + reshape.target_vector_shape[dim] // reshape.vector_shapes[dim] + ) + dim_combinations[dim] = [value // scale_factor] + else: + scale_factor = ( + reshape.vector_shapes[dim] // reshape.target_vector_shape[dim] + ) + begin = value * scale_factor + dim_combinations[dim] = list(range(begin, begin + scale_factor)) + reshape_dim_combinations = list(itertools.product(*dim_combinations.values())) + return [ + {dim: val for dim, val in zip(dim_combinations.keys(), combination)} + for combination in reshape_dim_combinations + ] + + +def remove_original_nodes(leaf_nodes: list[CustomOp]): + """ + Remove the original nodes from the graph. + """ + stack = leaf_nodes + while stack: + node = stack.pop(0) + inputs, _ = get_inputs(node.fx_node, None) + for input in inputs: + stack.append(get_custom(input)) + if not node.users: + if node.fx_node._erased: + continue + node.graph.erase_node(node.fx_node) + + +def remove_unused_registers(trace: CapturedTrace): + """ + Remove registers that are not used in the graph. + """ + for node in trace.walk(lambda x: isinstance(get_custom(x), NewRegister)): + if not node.users: + node.graph.erase_node(node) diff --git a/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py b/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py index 0f5e9877..806cc606 100644 --- a/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py +++ b/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py @@ -340,6 +340,7 @@ def push_rotating_registers( """ new_rotating_registers: dict[fx.Node, deque[fx.Node]] = {} count = 0 + iter_arg_count = len(arg_context.iter_args) for node, registers in rotating_registers.items(): new_registers: deque[fx.Node] = deque() custom = get_custom(node) @@ -354,6 +355,8 @@ def push_rotating_registers( iter_arg = IterArg(f"rotating_reg_{count}").add_to_graph(graph) iter_arg.type = get_custom(node).type iter_arg.index = get_custom(node).index + iter_arg.iter_idx = iter_arg_count + iter_arg_count += 1 new_registers.append(iter_arg) mapped_value = iter_arg else: @@ -421,6 +424,7 @@ def construct_kernel( iter_arg = IterArg(node.name).add_to_graph(pipelined_reduction_graph) iter_arg.type = get_custom(node).type iter_arg.index = get_custom(node).index + iter_arg.iter_idx = get_custom(node).iter_idx arg_context.map_arg_all(node, iter_arg) # Push the rotating registers into the argument context. @@ -520,7 +524,7 @@ def construct_epilogue( if i in existing_indices: continue with pipelined_reduction.graph.inserting_before( - existing_get_results[0].fx_node.next + existing_get_results[-1].fx_node.next ): result = GetResult(pipelined_reduction.fx_node, i).add_to_graph( pipelined_reduction.graph, type=iter_args[i].type @@ -534,7 +538,7 @@ def construct_epilogue( arg_context.map_arg_all(iter_arg, get_result.fx_node) with pipelined_reduction.graph.inserting_before( - existing_get_results[0].fx_node.next + existing_get_results[-1].fx_node.next ): # Add get result nodes for the rotating registers and update the # argument map with them. diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 5fb4ff30..397a5548 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -828,14 +828,21 @@ def get_inputs( local_reduction = reduction if reduction is None: local_reduction = custom.parent_op() - iter_arg_idx = custom.get_iter_idx() + iter_arg_idx = custom.iter_idx inputs.append(local_reduction.init_args[iter_arg_idx]) elif isinstance(custom, GetResult): reduction = get_custom(custom.value) assert isinstance(reduction, Reduction), "GetResult must be used by a Reduction" # Map get result to output reduction_subgraph = reduction.graph.subgraphs[reduction.subgraph_name] - inputs.append(reduction.outputs(reduction_subgraph)[custom.res_idx]) + if len(reduction.init_args) == 1: + outputs = reduction.outputs(reduction_subgraph) + if isinstance(outputs, Sequence): + inputs += outputs + else: + inputs.append(outputs) + else: + inputs.append(reduction.outputs(reduction_subgraph)[custom.res_idx]) elif isinstance(custom, Reduction): reduction_subgraph = custom.get_root_graph().subgraphs[custom.subgraph_name] inputs.append(custom.outputs(reduction_subgraph)) @@ -1358,3 +1365,34 @@ def get_largest_index_and_size(indices: dict[IndexExpr, IndexSequence]): key=lambda x: (-x[2], -x[0]), ) return sorted_values[0][1:] + + +def print_graph(graph: fx.Graph): + """ + Pretty-print the graph containing this node. + """ + graph_str = str(graph) + graph_str = graph_str.replace( + "iree.turbine.kernel.lang.kernel_buffer.KernelBufferMeta.new_subtype..SubType", + "", + ) + graph_str = graph_str.replace("target=iree.turbine.kernel.ops.wave_ops.", "") + graph_str = graph_str.replace("call_function", "") + print(graph_str) + + +def initialize_iter_args(trace: CapturedTrace) -> None: + """ + Initializes the IterArgs in each reduction with an index + based on their location in the graph. + + """ + reductions = trace.walk(lambda node: isinstance(get_custom(node), Reduction)) + for reduction in reductions: + reduction_graph = trace.get_subgraph(get_custom(reduction).subgraph_name) + count = 0 + for node in reduction_graph.nodes: + custom = get_custom(node) + if isinstance(custom, IterArg): + custom.iter_idx = count + count += 1 diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index 51b1d051..a8481c73 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -22,7 +22,7 @@ get_grid_shape, ) from .codegen import WaveEmitter -from .expansion import expand_graph +from .expansion.expansion import expand_graph from .promotion import promote_placeholders from .hoisting import hoist_loop_invariant_ops from .utils import ( @@ -35,6 +35,7 @@ subs_idxc, delinearize_index, _write_file, + initialize_iter_args, ) from .minimize_global_loads import minimize_global_loads from .decompose_reduce_ops import decompose_reduce_ops @@ -42,7 +43,7 @@ from ..lang import Grid, IndexMapping from ..lang.global_symbols import * from ..ops import wave_ops -from ..ops.wave_ops import Reduction, CustomOp, get_custom +from ..ops.wave_ops import Reduction, CustomOp, get_custom, IterArg from .index_sequence_analysis import ( partition_ops_with_gpr_offsets, partition_strided_operators, @@ -315,6 +316,7 @@ def _trace_and_get_kernel_signature( # Trace the function. graph = self._trace() + initialize_iter_args(graph) self.create_induction_vars(graph) self.initialize_wave_constraints(graph) self.initialize_reductions(graph) diff --git a/lit_tests/kernel/wave/attention.py b/lit_tests/kernel/wave/attention.py index a51cbe46..ab4ffd22 100644 --- a/lit_tests/kernel/wave/attention.py +++ b/lit_tests/kernel/wave/attention.py @@ -338,16 +338,22 @@ def repeat( print(dynamic_attention_pipelined(q, k, v, output).module_op) # CHECK-LABEL: func.func @dynamic_attention_pipelined - # CHECK-COUNT-4: {{.*}} = vector.maskedload {{.*}} + # CHECK-COUNT-2: {{.*}} = vector.maskedload {{.*}} # CHECK: {{.*}} = scf.for + # CHECK-COUNT-4: {{.*}} = vector.load {{.*}} # CHECK-COUNT-2: {{.*}} = vector.maskedload {{.*}} - # CHECK-COUNT-4: {{.*}} = amdgpu.mfma - # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-4: {{.*}} = amdgpu.mfma # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-10: {{.*}} = amdgpu.mfma - # CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}} # CHECK-COUNT-2: {{.*}} = amdgpu.mfma + # CHECK-COUNT-4: {{.*}} = vector.load {{.*}} + # CHECK-COUNT-2: {{.*}} = amdgpu.mfma + # CHECK-COUNT-4: {{.*}} = vector.load {{.*}} + # CHECK-COUNT-6: {{.*}} = amdgpu.mfma + # CHECK-COUNT-4: {{.*}} = vector.load {{.*}} + # CHECK-COUNT-11: {{.*}} = amdgpu.mfma + # CHECK-COUNT-5: {{.*}} = gpu.shuffle xor {{.*}} + # CHECK-COUNT-1: {{.*}} = amdgpu.mfma + # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} + # CHECK-COUNT-1: {{.*}} = amdgpu.mfma # CHECK-COUNT-16: vector.maskedstore {{.*}} @@ -467,17 +473,14 @@ def repeat( # CHECK-LABEL: func.func @base_attention_pipelined # CHECK: {{.*}} = scf.for - # CHECK-COUNT-4: {{.*}} = amdgpu.mfma - # CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-2: {{.*}} = amdgpu.mfma # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-4: {{.*}} = amdgpu.mfma + # CHECK-COUNT-1: {{.*}} = amdgpu.mfma # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-10: {{.*}} = amdgpu.mfma - # CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-2: {{.*}} = amdgpu.mfma - # CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-2: {{.*}} = amdgpu.mfma + # CHECK-COUNT-19: {{.*}} = amdgpu.mfma + # CHECK-COUNT-5: {{.*}} = gpu.shuffle xor {{.*}} + # CHECK-COUNT-1: {{.*}} = amdgpu.mfma + # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} + # CHECK-COUNT-1: {{.*}} = amdgpu.mfma @run_test diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index ac8cd6f5..d105452d 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -11,13 +11,13 @@ from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops -from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.expansion.expansion import expand_graph from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext from iree.turbine.kernel.ops.wave_ops import * -from iree.turbine.kernel.wave.utils import run_test, print_trace +from iree.turbine.kernel.wave.utils import run_test, print_trace, initialize_iter_args def get_read_nodes(graph: fx.Graph) -> list[CustomOp]: @@ -97,39 +97,39 @@ def test_read_write_equal_sizes(): # CHECK-NEXT: %c # CHECK-NEXT: %allocate # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N + 4), f16, $SHARED_ADDRESS_SPACE) - # CHECK-NEXT: %read_0_0 + # CHECK-NEXT: %read_M:0_N:0 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %read_1_1 + # CHECK-NEXT: %read_M:0_N:1 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %read_1_0 + # CHECK-NEXT: %read_M:1_N:0 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %read_0_1 + # CHECK-NEXT: %read_M:1_N:1 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %write_shared_0_0 - # CHECK-SAME: (%read_0_0, %allocate, 4, None, ()) - # CHECK-NEXT: %write_shared_1_1 - # CHECK-SAME: (%read_1_1, %allocate, 4, None, ()) - # CHECK-NEXT: %write_shared_1_0 - # CHECK-SAME: (%read_1_0, %allocate, 4, None, ()) - # CHECK-NEXT: %write_shared_0_1 - # CHECK-SAME: (%read_0_1, %allocate, 4, None, ()) + # CHECK-NEXT: %write_shared_M:0_N:0 + # CHECK-SAME: (%read_M:0_N:0, %allocate, 4, None, ()) + # CHECK-NEXT: %write_shared_M:0_N:1 + # CHECK-SAME: (%read_M:0_N:1, %allocate, 4, None, ()) + # CHECK-NEXT: %write_shared_M:1_N:0 + # CHECK-SAME: (%read_M:1_N:0, %allocate, 4, None, ()) + # CHECK-NEXT: %write_shared_M:1_N:1 + # CHECK-SAME: (%read_M:1_N:1, %allocate, 4, None, ()) # CHECK-NEXT: %shared_memory_barrier - # CHECK-NEXT: %read_shared_0_0 - # CHECK-SAME: (%allocate, 4, None, (), [%write_shared_0_0] - # CHECK-NEXT: %read_shared_1_1 - # CHECK-SAME: (%allocate, 4, None, (), [%write_shared_1_1] - # CHECK-NEXT: %read_shared_1_0 - # CHECK-SAME: (%allocate, 4, None, (), [%write_shared_1_0] - # CHECK-NEXT: %read_shared_0_1 - # CHECK-SAME: (%allocate, 4, None, (), [%write_shared_0_1] - # CHECK-NEXT: %write_0_0 - # CHECK-SAME: (%read_shared_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_1_1 - # CHECK-SAME: (%read_shared_1_1, %c, 4, None, ()) - # CHECK-NEXT: %write_1_0 - # CHECK-SAME: (%read_shared_1_0, %c, 4, None, ()) - # CHECK-NEXT: %write_0_1 - # CHECK-SAME: (%read_shared_0_1, %c, 4, None, ()) + # CHECK-NEXT: %read_shared_M:0_N:0 + # CHECK-SAME: (%allocate, 4, None, (), [%write_shared_M:0_N:0] + # CHECK-NEXT: %read_shared_M:0_N:1 + # CHECK-SAME: (%allocate, 4, None, (), [%write_shared_M:0_N:1] + # CHECK-NEXT: %read_shared_M:1_N:0 + # CHECK-SAME: (%allocate, 4, None, (), [%write_shared_M:1_N:0] + # CHECK-NEXT: %read_shared_M:1_N:1 + # CHECK-SAME: (%allocate, 4, None, (), [%write_shared_M:1_N:1] + # CHECK-NEXT: %write_M:0_N:0 + # CHECK-SAME: (%read_shared_M:0_N:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:0_N:1 + # CHECK-SAME: (%read_shared_M:0_N:1, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:0 + # CHECK-SAME: (%read_shared_M:1_N:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:1 + # CHECK-SAME: (%read_shared_M:1_N:1, %c, 4, None, ()) # CHECK-NEXT: return None # CHECK: ----- @@ -171,6 +171,7 @@ def test_gemm(): trace: CapturedTrace = gemm() graph: fx.Graph = trace.get_subgraph("region_0") IndexingContext.current().finalize() + initialize_iter_args(trace) infer_types(trace) read_nodes = get_read_nodes(graph) for read_node in read_nodes: @@ -186,76 +187,76 @@ def test_gemm(): # CHECK: %a # CHECK-NEXT: %b # CHECK-NEXT: %c - # CHECK-NEXT: %register_0_0_0 - # CHECK-NEXT: %register_1_1_0 - # CHECK-NEXT: %register_1_0_0 - # CHECK-NEXT: %register_0_1_0 + # CHECK-NEXT: %register_M:0_N:0_K:0 + # CHECK-NEXT: %register_M:0_N:1_K:0 + # CHECK-NEXT: %register_M:1_N:0_K:0 + # CHECK-NEXT: %register_M:1_N:1_K:0 # CHECK-NEXT: %allocate # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction - # CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0] - # CHECK-NEXT: %getresult_1_1_0 - # CHECK-SAME: (%reduction, 3) - # CHECK-NEXT: %getresult_1_0_0 - # CHECK-SAME: (%reduction, 2) - # CHECK-NEXT: %getresult_0_1_0 - # CHECK-SAME: (%reduction, 1) - # CHECK-NEXT: %getresult_0_0_0 + # CHECK-SAME: (K, [%register_M:0_N:0_K:0, %register_M:0_N:1_K:0, %register_M:1_N:0_K:0, %register_M:1_N:1_K:0] + # CHECK-NEXT: %getresult_M:0_N:0_K:0 # CHECK-SAME: (%reduction, 0) - # CHECK-NEXT: %write_0_0_0 - # CHECK-SAME: (%getresult_0_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_1_1_0 - # CHECK-SAME: (%getresult_1_1_0, %c, 4, None, ()) - # CHECK-NEXT: %write_1_0_0 - # CHECK-SAME: (%getresult_1_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_0_1_0 - # CHECK-SAME: (%getresult_0_1_0, %c, 4, None, ()) + # CHECK-NEXT: %getresult_M:0_N:1_K:0 + # CHECK-SAME: (%reduction, 1) + # CHECK-NEXT: %getresult_M:1_N:0_K:0 + # CHECK-SAME: (%reduction, 2) + # CHECK-NEXT: %getresult_M:1_N:1_K:0 + # CHECK-SAME: (%reduction, 3) + # CHECK-NEXT: %write_M:0_N:0_K:0 + # CHECK-SAME: (%getresult_M:0_N:0_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:0_N:1_K:0 + # CHECK-SAME: (%getresult_M:0_N:1_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:0 + # CHECK-SAME: (%getresult_M:1_N:0_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:1_K:0 + # CHECK-SAME: (%getresult_M:1_N:1_K:0, %c, 4, None, ()) # CHECK-NEXT: return None # Reduction subgraph: - # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_0_1_0 - # CHECK-NEXT: %acc_1_0_0 - # CHECK-NEXT: %acc_1_1_0 + # CHECK: %acc_M:0_N:0_K:0 + # CHECK-NEXT: %acc_M:0_N:1_K:0 + # CHECK-NEXT: %acc_M:1_N:0_K:0 + # CHECK-NEXT: %acc_M:1_N:1_K:0 # CHECK-NEXT: %a - # CHECK-NEXT: %read_0_0_0 - # CHECK-NEXT: %read_0_0_1 - # CHECK-NEXT: %read_1_0_0 - # CHECK-NEXT: %read_1_0_1 - # CHECK-NEXT: %write_shared_0_0_0 - # CHECK-NEXT: %write_shared_0_0_1 - # CHECK-NEXT: %write_shared_1_0_0 - # CHECK-NEXT: %write_shared_1_0_1 + # CHECK-NEXT: %read_M:0_N:0_K:0 + # CHECK-NEXT: %read_M:0_N:0_K:1 + # CHECK-NEXT: %read_M:1_N:0_K:0 + # CHECK-NEXT: %read_M:1_N:0_K:1 + # CHECK-NEXT: %write_shared_M:0_N:0_K:0 + # CHECK-NEXT: %write_shared_M:0_N:0_K:1 + # CHECK-NEXT: %write_shared_M:1_N:0_K:0 + # CHECK-NEXT: %write_shared_M:1_N:0_K:1 # CHECK-NEXT: %shared_memory_barrier - # CHECK-NEXT: %read_shared_0_0_0 - # CHECK-NEXT: %read_shared_0_0_1 - # CHECK-NEXT: %read_shared_1_0_0 - # CHECK-NEXT: %read_shared_1_0_1 + # CHECK-NEXT: %read_shared_M:0_N:0_K:0 + # CHECK-NEXT: %read_shared_M:0_N:0_K:1 + # CHECK-NEXT: %read_shared_M:1_N:0_K:0 + # CHECK-NEXT: %read_shared_M:1_N:0_K:1 # CHECK-NEXT: %b - # CHECK-NEXT: %read_0_0_0 - # CHECK-NEXT: %read_0_0_1 - # CHECK-NEXT: %read_0_1_0 - # CHECK-NEXT: %read_0_1_1 + # CHECK-NEXT: %read_M:0_N:0_K:0 + # CHECK-NEXT: %read_M:0_N:0_K:1 + # CHECK-NEXT: %read_M:0_N:1_K:0 + # CHECK-NEXT: %read_M:0_N:1_K:1 # CHECK-NEXT: %shared_memory_barrier_1 - # CHECK-NEXT: %write_shared_0_0_0 - # CHECK-NEXT: %write_shared_0_0_1 - # CHECK-NEXT: %write_shared_0_1_0 - # CHECK-NEXT: %write_shared_0_1_1 + # CHECK-NEXT: %write_shared_M:0_N:0_K:0 + # CHECK-NEXT: %write_shared_M:0_N:0_K:1 + # CHECK-NEXT: %write_shared_M:0_N:1_K:0 + # CHECK-NEXT: %write_shared_M:0_N:1_K:1 # CHECK-NEXT: %shared_memory_barrier_2 - # CHECK-NEXT: %read_shared_0_0_0 - # CHECK-NEXT: %read_shared_0_0_1 - # CHECK-NEXT: %read_shared_0_1_0 - # CHECK-NEXT: %read_shared_0_1_1 - # CHECK-NEXT: %mma_0_0_0 - # CHECK-NEXT: %mma_0_0_1 - # CHECK-NEXT: %mma_1_1_0 - # CHECK-NEXT: %mma_1_1_1 - # CHECK-NEXT: %mma_1_0_0 - # CHECK-NEXT: %mma_1_0_1 - # CHECK-NEXT: %mma_0_1_0 - # CHECK-NEXT: %mma_0_1_1 + # CHECK-NEXT: %read_shared_M:0_N:0_K:0 + # CHECK-NEXT: %read_shared_M:0_N:0_K:1 + # CHECK-NEXT: %read_shared_M:0_N:1_K:0 + # CHECK-NEXT: %read_shared_M:0_N:1_K:1 + # CHECK-NEXT: %mma_M:0_N:0_K:0 + # CHECK-NEXT: %mma_M:0_N:0_K:1 + # CHECK-NEXT: %mma_M:0_N:1_K:0 + # CHECK-NEXT: %mma_M:0_N:1_K:1 + # CHECK-NEXT: %mma_M:1_N:0_K:0 + # CHECK-NEXT: %mma_M:1_N:0_K:1 + # CHECK-NEXT: %mma_M:1_N:1_K:0 + # CHECK-NEXT: %mma_M:1_N:1_K:1 if __name__ == "__main__": diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 1858b45b..2d8e57e3 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -1155,7 +1155,7 @@ def repeat( # Tile Reduction Loop # CHECK: %[[TILED:.+]]:4 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]] - # CHECK-SAME: iter_args(%[[ACC0:.+]] = %[[INIT_MAX]], %[[ACC1:.+]] = %[[INIT_SUM]], %[[ACC2:.+]] = %[[INIT_MAX]], %[[ACC3:.+]] = %[[INIT_SUM]]) + # CHECK-SAME: iter_args(%[[ACC0:.+]] = %[[INIT_MAX]], %[[ACC1:.+]] = %[[INIT_MAX]], %[[ACC2:.+]] = %[[INIT_SUM]], %[[ACC3:.+]] = %[[INIT_SUM]]) # CHECK-SAME: -> (vector<1xf16>, vector<1xf16>, vector<1xf16>, vector<1xf16>) { # 1st Expanded Local Max Reduction # CHECK: arith.maximumf {{.*}} : vector<1xf16> @@ -1169,14 +1169,14 @@ def repeat( # 2nd Expanded Global Max Reduction # CHECK-COUNT-6: gpu.shuffle xor # 2nd Expanded Accumulator Max Reduction - # CHECK: %[[ACC_MAX_1:.+]] = arith.maximumf %[[ACC2]], %{{.*}} + # CHECK: %[[ACC_MAX_1:.+]] = arith.maximumf %[[ACC1]], %{{.*}} # 1st Expanded Local Sum Reduction # CHECK: arith.addf {{.*}} : vector<1xf16> # 1st Expanded Global Sum Reduction # CHECK-COUNT-6: gpu.shuffle xor # 1st Expanded Accumulator Sum Reduction - # CHECK: %[[ACC_SUM_0:.+]] = arith.addf %[[ACC1]], %{{.*}} + # CHECK: %[[ACC_SUM_0:.+]] = arith.addf %[[ACC2]], %{{.*}} # 2nd Expanded Local Sum Reduction # CHECK: arith.addf {{.*}} : vector<1xf16> @@ -1185,7 +1185,7 @@ def repeat( # 2nd Expanded Accumulator Sum Reduction # CHECK: %[[ACC_SUM_1:.+]] = arith.addf %[[ACC3]], %{{.*}} - # CHECK: scf.yield %[[ACC_MAX_0]], %[[ACC_SUM_0]], %[[ACC_MAX_1]], %[[ACC_SUM_1]] + # CHECK: scf.yield %[[ACC_MAX_0]], %[[ACC_MAX_1]], %[[ACC_SUM_0]], %[[ACC_SUM_1]] # This test is used to ensure: diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 7d022eb1..e0aa960f 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -4,7 +4,7 @@ import iree.turbine.kernel as tk import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw -from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.expansion.expansion import expand_graph from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.wave.index_sequence_analysis import ( set_node_indices, @@ -12,7 +12,7 @@ ) from iree.turbine.kernel._support.indexing import IndexingContext from iree.turbine.kernel.lang.global_symbols import * -from iree.turbine.kernel.wave.utils import run_test, print_trace +from iree.turbine.kernel.wave.utils import run_test, print_trace, initialize_iter_args from iree.turbine.kernel.wave.constraints import MMAType import sympy @@ -76,22 +76,22 @@ def test_read_write_equal_sizes(): print_trace(graph) # CHECK: %a # CHECK-NEXT: %c - # CHECK-NEXT: %read_0_0 + # CHECK-NEXT: %read_M:0_N:0 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %read_1_1 + # CHECK-NEXT: %read_M:0_N:1 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %read_1_0 + # CHECK-NEXT: %read_M:1_N:0 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %read_0_1 + # CHECK-NEXT: %read_M:1_N:1 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %write_0_0 - # CHECK-SAME: (%read_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_1_1 - # CHECK-SAME: (%read_1_1, %c, 4, None, ()) - # CHECK-NEXT: %write_1_0 - # CHECK-SAME: (%read_1_0, %c, 4, None, ()) - # CHECK-NEXT: %write_0_1 - # CHECK-SAME: (%read_0_1, %c, 4, None, ()) + # CHECK-NEXT: %write_M:0_N:0 + # CHECK-SAME: (%read_M:0_N:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:0_N:1 + # CHECK-SAME: (%read_M:0_N:1, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:0 + # CHECK-SAME: (%read_M:1_N:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:1 + # CHECK-SAME: (%read_M:1_N:1, %c, 4, None, ()) # CHECK-NEXT: return # Custom format: @@ -100,19 +100,19 @@ def test_read_write_equal_sizes(): # CHECK-NEXT: read(memory=a # CHECK-SAME: index={M: $T0 + $WG0*BLOCK_M + BLOCK_M*floor($T0/64) : 1 : 16, N: $T1*BLOCK_N + 4*$T1 + $WG1*BLOCK_N : 4 : 1} # CHECK-NEXT: read(memory=a - # CHECK-SAME: index={M: $T0 + $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + 16 : 1 : 16, N: $T1*BLOCK_N + 4*$T1 + $WG1*BLOCK_N + 16 : 4 : 1} + # CHECK-SAME: index={M: $T0 + $WG0*BLOCK_M + BLOCK_M*floor($T0/64) : 1 : 16, N: $T1*BLOCK_N + 4*$T1 + $WG1*BLOCK_N + 16 : 4 : 1} # CHECK-NEXT: read(memory=a # CHECK-SAME: index={M: $T0 + $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + 16 : 1 : 16, N: $T1*BLOCK_N + 4*$T1 + $WG1*BLOCK_N : 4 : 1} # CHECK-NEXT: read(memory=a - # CHECK-SAME: index={M: $T0 + $WG0*BLOCK_M + BLOCK_M*floor($T0/64) : 1 : 16, N: $T1*BLOCK_N + 4*$T1 + $WG1*BLOCK_N + 16 : 4 : 1} - # CHECK-NEXT: write(register_=read_0_0 - # CHECK-SAME: index={M: $T0 + $WG0*BLOCK_M + BLOCK_M*floor($T0/64) : 1 : 16, N: $T1*BLOCK_N + 4*$T1 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: write(register_=read_1_1 # CHECK-SAME: index={M: $T0 + $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + 16 : 1 : 16, N: $T1*BLOCK_N + 4*$T1 + $WG1*BLOCK_N + 16 : 4 : 1} - # CHECK-NEXT: write(register_=read_1_0 - # CHECK-SAME: index={M: $T0 + $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + 16 : 1 : 16, N: $T1*BLOCK_N + 4*$T1 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: write(register_=read_0_1 + # CHECK-NEXT: write(register_=read_M:0_N:0 + # CHECK-SAME: index={M: $T0 + $WG0*BLOCK_M + BLOCK_M*floor($T0/64) : 1 : 16, N: $T1*BLOCK_N + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: write(register_=read_M:0_N:1 # CHECK-SAME: index={M: $T0 + $WG0*BLOCK_M + BLOCK_M*floor($T0/64) : 1 : 16, N: $T1*BLOCK_N + 4*$T1 + $WG1*BLOCK_N + 16 : 4 : 1} + # CHECK-NEXT: write(register_=read_M:1_N:0 + # CHECK-SAME: index={M: $T0 + $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + 16 : 1 : 16, N: $T1*BLOCK_N + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: write(register_=read_M:1_N:1 + # CHECK-SAME: index={M: $T0 + $WG0*BLOCK_M + BLOCK_M*floor($T0/64) + 16 : 1 : 16, N: $T1*BLOCK_N + 4*$T1 + $WG1*BLOCK_N + 16 : 4 : 1} # CHECK-NEXT: output # CHECK: ----- @@ -158,18 +158,18 @@ def test_read_write(): print_trace(graph) # CHECK: %a # CHECK-NEXT: %c - # CHECK-NEXT: %read_0_0_0 + # CHECK-NEXT: %read_M:0_N:0_K:0 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %read_1_0_0 + # CHECK-NEXT: %read_M:1_N:0_K:0 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %write_0_0_0 - # CHECK-SAME: (%read_0_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_1_0_1 - # CHECK-SAME: (%read_1_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_1_0_0 - # CHECK-SAME: (%read_1_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_0_0_1 - # CHECK-SAME: (%read_0_0_0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:0_N:0_K:0 + # CHECK-SAME: (%read_M:0_N:0_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:0_N:0_K:1 + # CHECK-SAME: (%read_M:0_N:0_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:0_K:0 + # CHECK-SAME: (%read_M:1_N:0_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:0_K:1 + # CHECK-SAME: (%read_M:1_N:0_K:0, %c, 4, None, ()) # CHECK-NEXT: return None # Custom format: @@ -179,14 +179,14 @@ def test_read_write(): # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $T0 + $WG0*BLOCK_M : 1 : 16, N: $T1*BLOCK_N + 4*$T1 + $WG1*BLOCK_N : 4 : 1} # CHECK-NEXT: read(memory=a # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, N: $T1*BLOCK_N + 4*$T1 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: write(register_=read_0_0_0 + # CHECK-NEXT: write(register_=read_M:0_N:0_K:0 # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $T0 + $WG0*BLOCK_M : 1 : 16, K: $T2*BLOCK_K + 4*$T2 + $WG2*BLOCK_K : 4 : 1} - # CHECK-NEXT: write(register_=read_1_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, K: $T2*BLOCK_K + 4*$T2 + $WG2*BLOCK_K + 16 : 4 : 1} - # CHECK-NEXT: write(register_=read_1_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, K: $T2*BLOCK_K + 4*$T2 + $WG2*BLOCK_K : 4 : 1} - # CHECK-NEXT: write(register_=read_0_0_0 + # CHECK-NEXT: write(register_=read_M:0_N:0_K:0 # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $T0 + $WG0*BLOCK_M : 1 : 16, K: $T2*BLOCK_K + 4*$T2 + $WG2*BLOCK_K + 16 : 4 : 1} + # CHECK-NEXT: write(register_=read_M:1_N:0_K:0 + # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, K: $T2*BLOCK_K + 4*$T2 + $WG2*BLOCK_K : 4 : 1} + # CHECK-NEXT: write(register_=read_M:1_N:0_K:0 + # CHECK-SAME: index={M: $T0*BLOCK_M/64 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, K: $T2*BLOCK_K + 4*$T2 + $WG2*BLOCK_K + 16 : 4 : 1} # CHECK-NEXT: output # CHECK: ----- @@ -229,6 +229,7 @@ def test_gemm(): ): graph = gemm() IndexingContext.current().finalize() + initialize_iter_args(graph) infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) @@ -238,24 +239,24 @@ def test_gemm(): # CHECK: %a # CHECK-NEXT: %b # CHECK-NEXT: %c - # CHECK-NEXT: %register_0_0_0 - # CHECK-NEXT: %register_1_1_0 - # CHECK-NEXT: %register_1_0_0 - # CHECK-NEXT: %register_0_1_0 + # CHECK-NEXT: %register_M:0_N:0_K:0 + # CHECK-NEXT: %register_M:0_N:1_K:0 + # CHECK-NEXT: %register_M:1_N:0_K:0 + # CHECK-NEXT: %register_M:1_N:1_K:0 # CHECK-NEXT: %reduction - # CHECK-SAME: %register_0_0_0, %register_0_1_0, %register_1_0_0, %register_1_1_0 - # CHECK-NEXT: %getresult_1_1_0 - # CHECK-NEXT: %getresult_1_0_0 - # CHECK-NEXT: %getresult_0_1_0 - # CHECK-NEXT: %getresult_0_0_0 - # CHECK-NEXT: %write_0_0_0 - # CHECK-SAME: (%getresult_0_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_1_1_0 - # CHECK-SAME: (%getresult_1_1_0, %c, 4, None, ()) - # CHECK-NEXT: %write_1_0_0 - # CHECK-SAME: (%getresult_1_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_0_1_0 - # CHECK-SAME: (%getresult_0_1_0, %c, 4, None, ()) + # CHECK-SAME: %register_M:0_N:0_K:0, %register_M:0_N:1_K:0, %register_M:1_N:0_K:0, %register_M:1_N:1_K:0 + # CHECK-NEXT: %getresult_M:0_N:0_K:0 + # CHECK-NEXT: %getresult_M:0_N:1_K:0 + # CHECK-NEXT: %getresult_M:1_N:0_K:0 + # CHECK-NEXT: %getresult_M:1_N:1_K:0 + # CHECK-NEXT: %write_M:0_N:0_K:0 + # CHECK-SAME: (%getresult_M:0_N:0_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:0_N:1_K:0 + # CHECK-SAME: (%getresult_M:0_N:1_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:0_K:0 + # CHECK-SAME: (%getresult_M:1_N:0_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:1_K:0 + # CHECK-SAME: (%getresult_M:1_N:1_K:0, %c, 4, None, ()) # CHECK-NEXT: return None # Custom format: @@ -263,74 +264,73 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1}) # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1}) - # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b]) - # CHECK-NEXT: get_result(value=reduction, res_idx=3) - # CHECK-NEXT: get_result(value=reduction, res_idx=2) - # CHECK-NEXT: get_result(value=reduction, res_idx=1) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1}) + # CHECK-NEXT: reduction(axis=K, init_args=[register_M:0_N:0_K:0, register_M:0_N:1_K:0, register_M:1_N:0_K:0, register_M:1_N:1_K:0], subgraph_name=region_0, implicit_captures=[a, b]) # CHECK-NEXT: get_result(value=reduction, res_idx=0) - # CHECK-NEXT: write(register_=getresult_0_0_0 + # CHECK-NEXT: get_result(value=reduction, res_idx=1) + # CHECK-NEXT: get_result(value=reduction, res_idx=2) + # CHECK-NEXT: get_result(value=reduction, res_idx=3) + # CHECK-NEXT: write(register_=getresult_M:0_N:0_K:0 # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1} - # CHECK-NEXT: write(register_=getresult_1_1_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1} - # CHECK-NEXT: write(register_=getresult_1_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1} - # CHECK-NEXT: write(register_=getresult_0_1_0 + # CHECK-NEXT: write(register_=getresult_M:0_N:1_K:0 # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1} + # CHECK-NEXT: write(register_=getresult_M:1_N:0_K:0 + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1} + # CHECK-NEXT: write(register_=getresult_M:1_N:1_K:0 + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1} # CHECK-NEXT: output # Reduction subgraph: - - # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_0_1_0 - # CHECK-NEXT: %acc_1_0_0 - # CHECK-NEXT: %acc_1_1_0 + # CHECK: %acc_M:0_N:0_K:0 + # CHECK-NEXT: %acc_M:0_N:1_K:0 + # CHECK-NEXT: %acc_M:1_N:0_K:0 + # CHECK-NEXT: %acc_M:1_N:1_K:0 # CHECK-NEXT: %a - # CHECK-NEXT: %read_0_0_0 + # CHECK-NEXT: %read_M:0_N:0_K:0 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %read_0_0_1 + # CHECK-NEXT: %read_M:0_N:0_K:1 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %read_1_0_0 + # CHECK-NEXT: %read_M:1_N:0_K:0 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %read_1_0_1 + # CHECK-NEXT: %read_M:1_N:0_K:1 # CHECK-SAME: (%a, 4, None, (), None) # CHECK-NEXT: %b - # CHECK-NEXT: %read_0_0_0 + # CHECK-NEXT: %read_M:0_N:0_K:0 # CHECK-SAME: (%b, 4, None, (), None) - # CHECK-NEXT: %read_0_0_1 + # CHECK-NEXT: %read_M:0_N:0_K:1 # CHECK-SAME: (%b, 4, None, (), None) - # CHECK-NEXT: %read_0_1_0 + # CHECK-NEXT: %read_M:0_N:1_K:0 # CHECK-SAME: (%b, 4, None, (), None) - # CHECK-NEXT: %read_0_1_1 + # CHECK-NEXT: %read_M:0_N:1_K:1 # CHECK-SAME: (%b, 4, None, (), None) - # CHECK-NEXT: %mma_0_0_0 - # CHECK-SAME: (%read_0_0_0, %read_0_0_0, %acc_0_0_0, None) - # CHECK-NEXT: %mma_0_0_1 - # CHECK-SAME: (%read_0_0_1, %read_0_0_1, %mma_0_0_0, None) - # CHECK-NEXT: %mma_1_1_0 - # CHECK-SAME: (%read_1_0_0, %read_0_1_0, %acc_1_1_0, None) - # CHECK-NEXT: %mma_1_1_1 - # CHECK-SAME: (%read_1_0_1, %read_0_1_1, %mma_1_1_0, None) - # CHECK-NEXT: %mma_1_0_0 - # CHECK-SAME: (%read_1_0_0, %read_0_0_0, %acc_1_0_0, None) - # CHECK-NEXT: %mma_1_0_1 - # CHECK-SAME: (%read_1_0_1, %read_0_0_1, %mma_1_0_0, None) - # CHECK-NEXT: %mma_0_1_0 - # CHECK-SAME: (%read_0_0_0, %read_0_1_0, %acc_0_1_0, None) - # CHECK-NEXT: %mma_0_1_1 - # CHECK-SAME: (%read_0_0_1, %read_0_1_1, %mma_0_1_0, None) - # CHECK-NEXT: return [mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1] + # CHECK-NEXT: %mma_M:0_N:0_K:0 + # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None) + # CHECK-NEXT: %mma_M:0_N:0_K:1 + # CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None) + # CHECK-NEXT: %mma_M:0_N:1_K:0 + # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:0_N:1_K:0, None) + # CHECK-NEXT: %mma_M:0_N:1_K:1 + # CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:0_N:1_K:0, None) + # CHECK-NEXT: %mma_M:1_N:0_K:0 + # CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:1_N:0_K:0, None) + # CHECK-NEXT: %mma_M:1_N:0_K:1 + # CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:1_N:0_K:0, None) + # CHECK-NEXT: %mma_M:1_N:1_K:0 + # CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:1_N:1_K:0, None) + # CHECK-NEXT: %mma_M:1_N:1_K:1 + # CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:1_N:1_K:0, None) + # CHECK-NEXT: return [mma_M:0_N:0_K:1, mma_M:0_N:1_K:1, mma_M:1_N:0_K:1, mma_M:1_N:1_K:1] # Custom format: - # CHECK-NEXT: placeholder(_name=acc_0_0_0 - # CHECK-NEXT: placeholder(_name=acc_0_1_0 - # CHECK-NEXT: placeholder(_name=acc_1_0_0 - # CHECK-NEXT: placeholder(_name=acc_1_1_0 + # CHECK-NEXT: placeholder(_name=acc_M:0_N:0_K:0 + # CHECK-NEXT: placeholder(_name=acc_M:0_N:1_K:0 + # CHECK-NEXT: placeholder(_name=acc_M:1_N:0_K:0 + # CHECK-NEXT: placeholder(_name=acc_M:1_N:1_K:0 # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) @@ -341,31 +341,31 @@ def test_gemm(): # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), index={N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), index={N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), index={N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-NEXT: mma(lhs=read_0_0_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: rhs=read_0_0_0 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: acc=acc_0_0_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) - # CHECK-NEXT: mma(lhs=read_0_0_1 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: rhs=read_0_0_1 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: acc=mma_0_0_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) - # CHECK-NEXT: mma(lhs=read_1_0_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: rhs=read_0_1_0 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: acc=acc_1_1_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) - # CHECK-NEXT: mma(lhs=read_1_0_1 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: rhs=read_0_1_1 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: acc=mma_1_1_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) - # CHECK-NEXT: mma(lhs=read_1_0_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: rhs=read_0_0_0 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: acc=acc_1_0_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) - # CHECK-NEXT: mma(lhs=read_1_0_1 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: rhs=read_0_0_1 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: acc=mma_1_0_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) - # CHECK-NEXT: mma(lhs=read_0_0_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: rhs=read_0_1_0 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: acc=acc_0_1_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) - # CHECK-NEXT: mma(lhs=read_0_0_1 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: rhs=read_0_1_1 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: acc=mma_0_1_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) - # CHECK-NEXT: output(return_vals=([mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1],)) + # CHECK-NEXT: mma(lhs=read_M:0_N:0_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:0_K:0 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: acc=acc_M:0_N:0_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) + # CHECK-NEXT: mma(lhs=read_M:0_N:0_K:1 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:0_K:1 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: acc=mma_M:0_N:0_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) + # CHECK-NEXT: mma(lhs=read_M:0_N:0_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:1_K:0 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: acc=acc_M:0_N:1_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) + # CHECK-NEXT: mma(lhs=read_M:0_N:0_K:1 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:1_K:1 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: acc=mma_M:0_N:1_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) + # CHECK-NEXT: mma(lhs=read_M:1_N:0_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:0_K:0 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: acc=acc_M:1_N:0_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) + # CHECK-NEXT: mma(lhs=read_M:1_N:0_K:1 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:0_K:1 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: acc=mma_M:1_N:0_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) + # CHECK-NEXT: mma(lhs=read_M:1_N:0_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:1_K:0 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: acc=acc_M:1_N:1_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) + # CHECK-NEXT: mma(lhs=read_M:1_N:0_K:1 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:1_K:1 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: acc=mma_M:1_N:1_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) + # CHECK-NEXT: output(return_vals=([mma_M:0_N:0_K:1, mma_M:0_N:1_K:1, mma_M:1_N:0_K:1, mma_M:1_N:1_K:1],)) # CHECK-NEXT: ----- @@ -416,6 +416,7 @@ def test_batched_gemm(): ): graph = batched_gemm() IndexingContext.current().finalize() + initialize_iter_args(graph) infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) @@ -425,24 +426,24 @@ def test_batched_gemm(): # CHECK: %a # CHECK-NEXT: %b # CHECK-NEXT: %c - # CHECK-NEXT: %register_0_0_0 - # CHECK-NEXT: %register_1_1_0 - # CHECK-NEXT: %register_1_0_0 - # CHECK-NEXT: %register_0_1_0 + # CHECK-NEXT: %register_M:0_N:0_K:0 + # CHECK-NEXT: %register_M:0_N:1_K:0 + # CHECK-NEXT: %register_M:1_N:0_K:0 + # CHECK-NEXT: %register_M:1_N:1_K:0 # CHECK-NEXT: %reduction - # CHECK-SAME: %register_0_0_0, %register_0_1_0, %register_1_0_0, %register_1_1_0 - # CHECK-NEXT: %getresult_1_1_0 - # CHECK-NEXT: %getresult_1_0_0 - # CHECK-NEXT: %getresult_0_1_0 - # CHECK-NEXT: %getresult_0_0_0 - # CHECK-NEXT: %write_0_0_0 - # CHECK-SAME: (%getresult_0_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_1_1_0 - # CHECK-SAME: (%getresult_1_1_0, %c, 4, None, ()) - # CHECK-NEXT: %write_1_0_0 - # CHECK-SAME: (%getresult_1_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_0_1_0 - # CHECK-SAME: (%getresult_0_1_0, %c, 4, None, ()) + # CHECK-SAME: %register_M:0_N:0_K:0, %register_M:0_N:1_K:0, %register_M:1_N:0_K:0, %register_M:1_N:1_K:0 + # CHECK-NEXT: %getresult_M:0_N:0_K:0 + # CHECK-NEXT: %getresult_M:0_N:1_K:0 + # CHECK-NEXT: %getresult_M:1_N:0_K:0 + # CHECK-NEXT: %getresult_M:1_N:1_K:0 + # CHECK-NEXT: %write_M:0_N:0_K:0 + # CHECK-SAME: (%getresult_M:0_N:0_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:0_N:1_K:0 + # CHECK-SAME: (%getresult_M:0_N:1_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:0_K:0 + # CHECK-SAME: (%getresult_M:1_N:0_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:1_K:0 + # CHECK-SAME: (%getresult_M:1_N:1_K:0, %c, 4, None, ()) # CHECK-NEXT: return None # Custom format: @@ -450,74 +451,74 @@ def test_batched_gemm(): # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1}) - # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1}) - # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1}) # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1}) - # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b]) - # CHECK-NEXT: get_result(value=reduction, res_idx=3) - # CHECK-NEXT: get_result(value=reduction, res_idx=2) - # CHECK-NEXT: get_result(value=reduction, res_idx=1) + # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1}) + # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1}) + # CHECK-NEXT: reduction(axis=K, init_args=[register_M:0_N:0_K:0, register_M:0_N:1_K:0, register_M:1_N:0_K:0, register_M:1_N:1_K:0], subgraph_name=region_0, implicit_captures=[a, b]) # CHECK-NEXT: get_result(value=reduction, res_idx=0) - # CHECK-NEXT: write(register_=getresult_0_0_0 + # CHECK-NEXT: get_result(value=reduction, res_idx=1) + # CHECK-NEXT: get_result(value=reduction, res_idx=2) + # CHECK-NEXT: get_result(value=reduction, res_idx=3) + # CHECK-NEXT: write(register_=getresult_M:0_N:0_K:0 # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1} - # CHECK-NEXT: write(register_=getresult_1_1_0 - # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1} - # CHECK-NEXT: write(register_=getresult_1_0_0 - # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1} - # CHECK-NEXT: write(register_=getresult_0_1_0 + # CHECK-NEXT: write(register_=getresult_M:0_N:1_K:0 # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1} + # CHECK-NEXT: write(register_=getresult_M:1_N:0_K:0 + # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1} + # CHECK-NEXT: write(register_=getresult_M:1_N:1_K:0 + # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1} # CHECK-NEXT: output # Reduction subgraph: - # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_0_1_0 - # CHECK-NEXT: %acc_1_0_0 - # CHECK-NEXT: %acc_1_1_0 + # CHECK: %acc_M:0_N:0_K:0 + # CHECK-NEXT: %acc_M:0_N:1_K:0 + # CHECK-NEXT: %acc_M:1_N:0_K:0 + # CHECK-NEXT: %acc_M:1_N:1_K:0 # CHECK-NEXT: %a - # CHECK-NEXT: %read_0_0_0 + # CHECK-NEXT: %read_M:0_N:0_K:0 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %read_0_0_1 + # CHECK-NEXT: %read_M:0_N:0_K:1 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %read_1_0_0 + # CHECK-NEXT: %read_M:1_N:0_K:0 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %read_1_0_1 + # CHECK-NEXT: %read_M:1_N:0_K:1 # CHECK-SAME: (%a, 4, None, (), None) # CHECK-NEXT: %b - # CHECK-NEXT: %read_0_0_0 + # CHECK-NEXT: %read_M:0_N:0_K:0 # CHECK-SAME: (%b, 4, None, (), None) - # CHECK-NEXT: %read_0_0_1 + # CHECK-NEXT: %read_M:0_N:0_K:1 # CHECK-SAME: (%b, 4, None, (), None) - # CHECK-NEXT: %read_0_1_0 + # CHECK-NEXT: %read_M:0_N:1_K:0 # CHECK-SAME: (%b, 4, None, (), None) - # CHECK-NEXT: %read_0_1_1 + # CHECK-NEXT: %read_M:0_N:1_K:1 # CHECK-SAME: (%b, 4, None, (), None) - # CHECK-NEXT: %mma_0_0_0 - # CHECK-SAME: (%read_0_0_0, %read_0_0_0, %acc_0_0_0, None) - # CHECK-NEXT: %mma_0_0_1 - # CHECK-SAME: (%read_0_0_1, %read_0_0_1, %mma_0_0_0, None) - # CHECK-NEXT: %mma_1_1_0 - # CHECK-SAME: (%read_1_0_0, %read_0_1_0, %acc_1_1_0, None) - # CHECK-NEXT: %mma_1_1_1 - # CHECK-SAME: (%read_1_0_1, %read_0_1_1, %mma_1_1_0, None) - # CHECK-NEXT: %mma_1_0_0 - # CHECK-SAME: (%read_1_0_0, %read_0_0_0, %acc_1_0_0, None) - # CHECK-NEXT: %mma_1_0_1 - # CHECK-SAME: (%read_1_0_1, %read_0_0_1, %mma_1_0_0, None) - # CHECK-NEXT: %mma_0_1_0 - # CHECK-SAME: (%read_0_0_0, %read_0_1_0, %acc_0_1_0, None) - # CHECK-NEXT: %mma_0_1_1 - # CHECK-SAME: (%read_0_0_1, %read_0_1_1, %mma_0_1_0, None) - # CHECK-NEXT: return [mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1] + # CHECK-NEXT: %mma_M:0_N:0_K:0 + # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None) + # CHECK-NEXT: %mma_M:0_N:0_K:1 + # CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None) + # CHECK-NEXT: %mma_M:0_N:1_K:0 + # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:0_N:1_K:0, None) + # CHECK-NEXT: %mma_M:0_N:1_K:1 + # CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:0_N:1_K:0, None) + # CHECK-NEXT: %mma_M:1_N:0_K:0 + # CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:1_N:0_K:0, None) + # CHECK-NEXT: %mma_M:1_N:0_K:1 + # CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:1_N:0_K:0, None) + # CHECK-NEXT: %mma_M:1_N:1_K:0 + # CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:0_N:1_K:0, %acc_M:1_N:1_K:0, None) + # CHECK-NEXT: %mma_M:1_N:1_K:1 + # CHECK-SAME: (%read_M:1_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:1_N:1_K:0, None) + # CHECK-NEXT: return [mma_M:0_N:0_K:1, mma_M:0_N:1_K:1, mma_M:1_N:0_K:1, mma_M:1_N:1_K:1] # Custom format: - # CHECK-NEXT: placeholder(_name=acc_0_0_0 - # CHECK-NEXT: placeholder(_name=acc_0_1_0 - # CHECK-NEXT: placeholder(_name=acc_1_0_0 - # CHECK-NEXT: placeholder(_name=acc_1_1_0 + # CHECK-NEXT: placeholder(_name=acc_M:0_N:0_K:0 + # CHECK-NEXT: placeholder(_name=acc_M:0_N:1_K:0 + # CHECK-NEXT: placeholder(_name=acc_M:1_N:0_K:0 + # CHECK-NEXT: placeholder(_name=acc_M:1_N:1_K:0 # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) @@ -528,31 +529,31 @@ def test_batched_gemm(): # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), index={B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), index={B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), index={B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-NEXT: mma(lhs=read_0_0_0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: rhs=read_0_0_0 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: acc=acc_0_0_0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) - # CHECK-NEXT: mma(lhs=read_0_0_1 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: rhs=read_0_0_1 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: acc=mma_0_0_0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) - # CHECK-NEXT: mma(lhs=read_1_0_0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: rhs=read_0_1_0 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: acc=acc_1_1_0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) - # CHECK-NEXT: mma(lhs=read_1_0_1 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: rhs=read_0_1_1 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: acc=mma_1_1_0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) - # CHECK-NEXT: mma(lhs=read_1_0_0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: rhs=read_0_0_0 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: acc=acc_1_0_0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) - # CHECK-NEXT: mma(lhs=read_1_0_1 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: rhs=read_0_0_1 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: acc=mma_1_0_0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) - # CHECK-NEXT: mma(lhs=read_0_0_0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: rhs=read_0_1_0 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: acc=acc_0_1_0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) - # CHECK-NEXT: mma(lhs=read_0_0_1 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: rhs=read_0_1_1 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: acc=mma_0_1_0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) - # CHECK-NEXT: output(return_vals=([mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1],)) + # CHECK-NEXT: mma(lhs=read_M:0_N:0_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:0_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: acc=acc_M:0_N:0_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) + # CHECK-NEXT: mma(lhs=read_M:0_N:0_K:1 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:0_K:1 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: acc=mma_M:0_N:0_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) + # CHECK-NEXT: mma(lhs=read_M:0_N:0_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:1_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: acc=acc_M:0_N:1_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) + # CHECK-NEXT: mma(lhs=read_M:0_N:0_K:1 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:1_K:1 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: acc=mma_M:0_N:1_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) + # CHECK-NEXT: mma(lhs=read_M:1_N:0_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:0_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: acc=acc_M:1_N:0_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) + # CHECK-NEXT: mma(lhs=read_M:1_N:0_K:1 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:0_K:1 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: acc=mma_M:1_N:0_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) + # CHECK-NEXT: mma(lhs=read_M:1_N:0_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:1_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: acc=acc_M:1_N:1_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) + # CHECK-NEXT: mma(lhs=read_M:1_N:0_K:1 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:1_K:1 (index = {B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: acc=mma_M:1_N:1_K:0 (index = {B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16 : 1 : 1})) + # CHECK-NEXT: output(return_vals=([mma_M:0_N:0_K:1, mma_M:0_N:1_K:1, mma_M:1_N:0_K:1, mma_M:1_N:1_K:1],)) # CHECK-NEXT: ----- @@ -597,35 +598,36 @@ def test_gemm_non_direct_acc(): ): graph = gemm_non_direct_acc() IndexingContext.current().finalize() + initialize_iter_args(graph) infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) print_trace(graph) - # CHECK: %add_0_0_0 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_0_0, %acc_0_0_0), kwargs = {}) - # CHECK: %add_1_1_0 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_1_0, %acc_1_1_0), kwargs = {}) - # CHECK: %add_1_0_0 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_0_0, %acc_1_0_0), kwargs = {}) - # CHECK: %add_0_1_0 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_1_0, %acc_0_1_0), kwargs = {}) - # CHECK: %mma_0_0_0 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_0, %read_0_0_0, %add_0_0_0, None), kwargs = {}) - # CHECK: %mma_0_0_1 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_1, %read_0_0_1, %mma_0_0_0, None), kwargs = {}) - # CHECK: %mma_1_1_0 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_0, %read_0_1_0, %add_1_1_0, None), kwargs = {}) - # CHECK: %mma_1_1_1 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_1, %read_0_1_1, %mma_1_1_0, None), kwargs = {}) - # CHECK: %mma_1_0_0 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_0, %read_0_0_0, %add_1_0_0, None), kwargs = {}) - # CHECK: %mma_1_0_1 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_1, %read_0_0_1, %mma_1_0_0, None), kwargs = {}) - # CHECK: %mma_0_1_0 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_0, %read_0_1_0, %add_0_1_0, None), kwargs = {}) - # CHECK: %mma_0_1_1 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_1, %read_0_1_1, %mma_0_1_0, None), kwargs = {}) + # CHECK: %add_M:0_N:0_K:0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_M:0_N:0_K:0, %acc_M:0_N:0_K:0), kwargs = {}) + # CHECK: %add_M:0_N:1_K:0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_M:0_N:1_K:0, %acc_M:0_N:1_K:0), kwargs = {}) + # CHECK: %add_M:1_N:0_K:0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_M:1_N:0_K:0, %acc_M:1_N:0_K:0), kwargs = {}) + # CHECK: %add_M:1_N:1_K:0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_M:1_N:1_K:0, %acc_M:1_N:1_K:0), kwargs = {}) + # CHECK: %mma_M:0_N:0_K:0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %add_M:0_N:0_K:0, None), kwargs = {}) + # CHECK: %mma_M:0_N:0_K:1 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None), kwargs = {}) + # CHECK: %mma_M:0_N:1_K:0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:0, %read_M:0_N:1_K:0, %add_M:0_N:1_K:0, None), kwargs = {}) + # CHECK: %mma_M:0_N:1_K:1 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:0_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:0_N:1_K:0, None), kwargs = {}) + # CHECK: %mma_M:1_N:0_K:0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:0, %read_M:0_N:0_K:0, %add_M:1_N:0_K:0, None), kwargs = {}) + # CHECK: %mma_M:1_N:0_K:1 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:1_N:0_K:0, None), kwargs = {}) + # CHECK: %mma_M:1_N:1_K:0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:0, %read_M:0_N:1_K:0, %add_M:1_N:1_K:0, None), kwargs = {}) + # CHECK: %mma_M:1_N:1_K:1 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_M:1_N:0_K:1, %read_M:0_N:1_K:1, %mma_M:1_N:1_K:0, None), kwargs = {}) @tkw.wave_trace_only() @@ -664,14 +666,15 @@ def test_tiled_max(): ): graph = tiled_max() IndexingContext.current().finalize() + initialize_iter_args(graph) infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) print_trace(graph) - # CHECK: max(arg=[read_0_0, read_0_1, read_0_2, read_0_3, read_0_4, read_0_5, read_0_6, read_0_7], init=acc_0_0 - # CHECK: max(arg=[read_1_0, read_1_1, read_1_2, read_1_3, read_1_4, read_1_5, read_1_6, read_1_7], init=acc_1_0 - # CHECK: output(return_vals=([max_0_0, max_1_0],)) + # CHECK: max(arg=[read_M:0_K:0, read_M:0_K:1, read_M:0_K:2, read_M:0_K:3, read_M:0_K:4, read_M:0_K:5, read_M:0_K:6, read_M:0_K:7], init=acc_M:0_K:0 + # CHECK: max(arg=[read_M:1_K:0, read_M:1_K:1, read_M:1_K:2, read_M:1_K:3, read_M:1_K:4, read_M:1_K:5, read_M:1_K:6, read_M:1_K:7], init=acc_M:1_K:0 + # CHECK: output(return_vals=([max_M:0_K:0, max_M:1_K:0],)) # CHECK-NEXT: ----- @@ -696,6 +699,7 @@ def test_gemm_reduction_expansion_only(): ): graph = gemm() IndexingContext.current().finalize() + initialize_iter_args(graph) infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) @@ -705,11 +709,11 @@ def test_gemm_reduction_expansion_only(): # CHECK: %a # CHECK-NEXT: %b # CHECK-NEXT: %c - # CHECK-NEXT: %register_0_0_0 + # CHECK-NEXT: %register_M:0_N:0_K:0 # CHECK-NEXT: %reduction - # CHECK-NEXT: %getresult_0_0_0 - # CHECK-NEXT: %write_0_0_0 - # CHECK-SAME: (%getresult_0_0_0, %c, 4, None, ()) + # CHECK-NEXT: %getresult_M:0_N:0_K:0 + # CHECK-NEXT: %write_M:0_N:0_K:0 + # CHECK-SAME: (%getresult_M:0_N:0_K:0, %c, 4, None, ()) # CHECK-NEXT: return None # Custom format: @@ -717,51 +721,51 @@ def test_gemm_reduction_expansion_only(): # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1}) - # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0] + # CHECK-NEXT: reduction(axis=K, init_args=[register_M:0_N:0_K:0] # CHECK-NEXT: get_result(value=reduction, res_idx=0) - # CHECK-NEXT: write(register_=getresult_0_0_0 + # CHECK-NEXT: write(register_=getresult_M:0_N:0_K:0 # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1}) # CHECK-NEXT: output(return_vals=(None,)) # Reduction subgraph: - # CHECK: %acc_0_0_0 + # CHECK: %acc_M:0_N:0_K:0 # CHECK-NEXT: %a - # CHECK-NEXT: %read_0_0_0 + # CHECK-NEXT: %read_M:0_N:0_K:0 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %read_0_0_1 + # CHECK-NEXT: %read_M:0_N:0_K:1 # CHECK-SAME: (%a, 4, None, (), None) # CHECK-NEXT: %b - # CHECK-NEXT: %read_0_0_0 + # CHECK-NEXT: %read_M:0_N:0_K:0 # CHECK-SAME: (%b, 4, None, (), None) - # CHECK-NEXT: %read_0_0_1 + # CHECK-NEXT: %read_M:0_N:0_K:1 # CHECK-SAME: (%b, 4, None, (), None) - # CHECK-NEXT: %mma_0_0_0 - # CHECK-SAME: (%read_0_0_0, %read_0_0_0, %acc_0_0_0, None) - # CHECK-NEXT: %mma_0_0_1 - # CHECK-SAME: (%read_0_0_1, %read_0_0_1, %mma_0_0_0, None) + # CHECK-NEXT: %mma_M:0_N:0_K:0 + # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0, %acc_M:0_N:0_K:0, None) + # CHECK-NEXT: %mma_M:0_N:0_K:1 + # CHECK-SAME: (%read_M:0_N:0_K:1, %read_M:0_N:0_K:1, %mma_M:0_N:0_K:0, None) - # CHECK-NEXT: return [mma_0_0_1] + # CHECK-NEXT: return [mma_M:0_N:0_K:1] # Custom format: - # CHECK-NEXT: placeholder(_name=acc_0_0_0 + # CHECK-NEXT: placeholder(_name=acc_M:0_N:0_K:0 # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), index={N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=b, elements_per_thread=4, mapping_dynamic_vals=(), index={N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-NEXT: mma(lhs=read_0_0_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: rhs=read_0_0_0 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-SAME: acc=acc_0_0_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) - # CHECK-NEXT: mma(lhs=read_0_0_1 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: rhs=read_0_0_1 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-SAME: acc=mma_0_0_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) - # CHECK-NEXT: output(return_vals=([mma_0_0_1],)) + # CHECK-NEXT: mma(lhs=read_M:0_N:0_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:0_K:0 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: acc=acc_M:0_N:0_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) + # CHECK-NEXT: mma(lhs=read_M:0_N:0_K:1 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: rhs=read_M:0_N:0_K:1 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: acc=mma_M:0_N:0_K:0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1})) + # CHECK-NEXT: output(return_vals=([mma_M:0_N:0_K:1],)) # CHECK-NEXT: ----- @@ -841,6 +845,7 @@ def test_attention(): ): graph = attention() IndexingContext.current().finalize() + initialize_iter_args(graph) infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) @@ -848,19 +853,25 @@ def test_attention(): print_trace(graph) # Root graph: - # CHECK: write(register_=truediv_0_0_0, + # CHECK: write(register_=truediv_M:0_N:0_K2:0, # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 4*floor((Mod($T0, 64))/16) : 4 : 16, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1}) - # CHECK: write(register_=truediv_1_1_0, - # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1}) - # CHECK: write(register_=truediv_1_0_0, - # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 4*floor((Mod($T0, 64))/16) : 4 : 16, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1}) - # CHECK: write(register_=truediv_0_1_0, + # CHECK: write(register_=truediv_M:0_N:1_K2:0, # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1}) + # CHECK: write(register_=truediv_M:1_N:0_K2:0, + # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 4*floor((Mod($T0, 64))/16) : 4 : 16, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1}) + # CHECK: write(register_=truediv_M:1_N:1_K2:0, + # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1}) # Reduction graph: # CHECK: read(memory=q, # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK: read(memory=q, + # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK: read(memory=q, + # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) + # CHECK: read(memory=q, + # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) + # CHECK: read(memory=q, # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK: read(memory=q, # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) @@ -868,16 +879,16 @@ def test_attention(): # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) # CHECK: read(memory=q, # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16 : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) - # CHECK: read(memory=q, - # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK: read(memory=q, - # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) - # CHECK: read(memory=q, - # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) # CHECK: read(memory=k, # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, K2: ARGK*BLOCK_K2 + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK: read(memory=k, + # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, K2: ARGK*BLOCK_K2 + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK: read(memory=k, + # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, K2: ARGK*BLOCK_K2 + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) + # CHECK: read(memory=k, + # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, K2: ARGK*BLOCK_K2 + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) + # CHECK: read(memory=k, # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, K2: ARGK*BLOCK_K2 + Mod($T0, 16) + 16 : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK: read(memory=k, # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, K2: ARGK*BLOCK_K2 + Mod($T0, 16) + 16 : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) @@ -885,12 +896,6 @@ def test_attention(): # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, K2: ARGK*BLOCK_K2 + Mod($T0, 16) + 16 : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) # CHECK: read(memory=k, # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, K2: ARGK*BLOCK_K2 + Mod($T0, 16) + 16 : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) - # CHECK: read(memory=k, - # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, K2: ARGK*BLOCK_K2 + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK: read(memory=k, - # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, K2: ARGK*BLOCK_K2 + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) - # CHECK: read(memory=k, - # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, K2: ARGK*BLOCK_K2 + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) # CHECK: read(memory=v, # CHECK-SAME: index={B: $WG2*BLOCK_B : 1 : 1, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) : 1 : 1, K2: ARGK*BLOCK_K2 + 4*floor((Mod($T0, 64))/16) : 4 : 1}) @@ -943,30 +948,30 @@ def py_arithmetic_different_dims(): print_trace(graph) # CHECK: %a # CHECK-NEXT: %c - # CHECK-NEXT: %read_0_0_0 + # CHECK-NEXT: %read_M:0_N:0_K:0 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %read_1_0_0 + # CHECK-NEXT: %read_M:1_N:0_K:0 # CHECK-SAME: (%a, 4, None, (), None) - # CHECK-NEXT: %add_0_0_0 - # CHECK-SAME: (%read_0_0_0, %read_0_0_0) - # CHECK-NEXT: %add_1_0_0 - # CHECK-SAME: (%read_1_0_0, %read_1_0_0) - # CHECK-NEXT: %sub_0_0_0 - # CHECK-SAME: (%add_0_0_0, %read_0_0_0) - # CHECK-NEXT: %sub_1_0_0 - # CHECK-SAME: (%add_1_0_0, %read_1_0_0) - # CHECK-NEXT: %neg_0_0_0 - # CHECK-SAME: (%sub_0_0_0,) - # CHECK-NEXT: %neg_1_0_0 - # CHECK-SAME: (%sub_1_0_0,) - # CHECK-NEXT: %write_0_0_0 - # CHECK-SAME: (%neg_0_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_1_0_1 - # CHECK-SAME: (%neg_1_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_1_0_0 - # CHECK-SAME: (%neg_1_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_0_0_1 - # CHECK-SAME: (%neg_0_0_0, %c, 4, None, ()) + # CHECK-NEXT: %add_M:0_N:0_K:0 + # CHECK-SAME: (%read_M:0_N:0_K:0, %read_M:0_N:0_K:0) + # CHECK-NEXT: %add_M:1_N:0_K:0 + # CHECK-SAME: (%read_M:1_N:0_K:0, %read_M:1_N:0_K:0) + # CHECK-NEXT: %sub_M:0_N:0_K:0 + # CHECK-SAME: (%add_M:0_N:0_K:0, %read_M:0_N:0_K:0) + # CHECK-NEXT: %sub_M:1_N:0_K:0 + # CHECK-SAME: (%add_M:1_N:0_K:0, %read_M:1_N:0_K:0) + # CHECK-NEXT: %neg_M:0_N:0_K:0 + # CHECK-SAME: (%sub_M:0_N:0_K:0,) + # CHECK-NEXT: %neg_M:1_N:0_K:0 + # CHECK-SAME: (%sub_M:1_N:0_K:0,) + # CHECK-NEXT: %write_M:0_N:0_K:0 + # CHECK-SAME: (%neg_M:0_N:0_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:0_N:0_K:1 + # CHECK-SAME: (%neg_M:0_N:0_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:0_K:0 + # CHECK-SAME: (%neg_M:1_N:0_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:0_K:1 + # CHECK-SAME: (%neg_M:1_N:0_K:0, %c, 4, None, ()) # CHECK-NEXT: return None # Custom format: @@ -974,16 +979,16 @@ def py_arithmetic_different_dims(): # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M : 1 : 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} # CHECK-NEXT: read(memory=a, elements_per_thread=4, mapping_dynamic_vals=(), index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: add(lhs=read_0_0_0, rhs=read_0_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M : 1 : 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: add(lhs=read_1_0_0, rhs=read_1_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: sub(lhs=add_0_0_0, rhs=read_0_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M : 1 : 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: sub(lhs=add_1_0_0, rhs=read_1_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: neg(arg=sub_0_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M : 1 : 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: neg(arg=sub_1_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: write(register_=neg_0_0_0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M : 1 : 16, K: 4*$T2 + $WG2*BLOCK_K : 4 : 1} - # CHECK-NEXT: write(register_=neg_1_0_0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, K: 4*$T2 + $WG2*BLOCK_K + 16 : 4 : 1} - # CHECK-NEXT: write(register_=neg_1_0_0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, K: 4*$T2 + $WG2*BLOCK_K : 4 : 1} - # CHECK-NEXT: write(register_=neg_0_0_0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M : 1 : 16, K: 4*$T2 + $WG2*BLOCK_K + 16 : 4 : 1} + # CHECK-NEXT: add(lhs=read_M:0_N:0_K:0, rhs=read_M:0_N:0_K:0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M : 1 : 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: add(lhs=read_M:1_N:0_K:0, rhs=read_M:1_N:0_K:0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: sub(lhs=add_M:0_N:0_K:0, rhs=read_M:0_N:0_K:0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M : 1 : 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: sub(lhs=add_M:1_N:0_K:0, rhs=read_M:1_N:0_K:0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: neg(arg=sub_M:0_N:0_K:0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M : 1 : 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: neg(arg=sub_M:1_N:0_K:0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: write(register_=neg_M:0_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M : 1 : 16, K: 4*$T2 + $WG2*BLOCK_K : 4 : 1} + # CHECK-NEXT: write(register_=neg_M:0_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M : 1 : 16, K: 4*$T2 + $WG2*BLOCK_K + 16 : 4 : 1} + # CHECK-NEXT: write(register_=neg_M:1_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, K: 4*$T2 + $WG2*BLOCK_K : 4 : 1} + # CHECK-NEXT: write(register_=neg_M:1_N:0_K:0, memory=c, elements_per_thread=4, mapping_dynamic_vals=(), index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16 : 1 : 16, K: 4*$T2 + $WG2*BLOCK_K + 16 : 4 : 1} # CHECK: ----- @@ -1042,70 +1047,71 @@ def test_chained_gemm_32x32x8(): ): graph = chained_gemm_32x32x8() IndexingContext.current().finalize() + initialize_iter_args(graph) infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) print_trace(graph) - # CHECK: %acc_0_0_0 + # CHECK: %acc_M:0_N:0_K2:0 # CHECK: %register # CHECK: %q - # CHECK: %read_0_0_0 + # CHECK: %read_M:0_K2:0_K1:0 # CHECK-SAME: (args = (%q, 4, None, (), None) - # CHECK: %read_0_0_1 + # CHECK: %read_M:0_K2:0_K1:1 # CHECK-SAME: (args = (%q, 4, None, (), None) - # CHECK: %read_0_0_2 + # CHECK: %read_M:0_K2:0_K1:2 # CHECK-SAME: (args = (%q, 4, None, (), None) - # CHECK: %read_0_0_3 + # CHECK: %read_M:0_K2:0_K1:3 # CHECK-SAME: (args = (%q, 4, None, (), None) # CHECK: %k - # CHECK: %read_shared_0_0_0 + # CHECK: %read_shared_M:0_K2:0_K1:0 # CHECK-SAME: (args = (%k, 4, None, (), None) - # CHECK: %read_shared_0_0_1 + # CHECK: %read_shared_M:0_K2:0_K1:1 # CHECK-SAME: (args = (%k, 4, None, (), None) - # CHECK: %read_shared_0_0_2 + # CHECK: %read_shared_M:0_K2:0_K1:2 # CHECK-SAME: (args = (%k, 4, None, (), None) - # CHECK: %read_shared_0_0_3 + # CHECK: %read_shared_M:0_K2:0_K1:3 # CHECK-SAME: (args = (%k, 4, None, (), None) - # CHECK: %mma_0_0_0 - # CHECK-SAME: (args = (%read_shared_0_0_0, %read_0_0_0, %register, None) - # CHECK: %mma_0_0_1 - # CHECK-SAME: (args = (%read_shared_0_0_1, %read_0_0_1, %mma_0_0_0, None) - # CHECK: %mma_0_0_2 - # CHECK-SAME: (args = (%read_shared_0_0_2, %read_0_0_2, %mma_0_0_1, None) - # CHECK: %mma_0_0_3 - # CHECK-SAME: (args = (%read_shared_0_0_3, %read_0_0_3, %mma_0_0_2, None) - # CHECK: %permute_0_0 - # CHECK-SAME: (args = (%mma_0_0_3, [B, M, K2]) - # CHECK: %cast_0_0 - # CHECK-SAME: (args = (%permute_0_0, f16) + # CHECK: %mma_M:0_K2:0_K1:0 + # CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:0, %read_M:0_K2:0_K1:0, %register_M:0_K2:0_K1:0, None) + # CHECK: %mma_M:0_K2:0_K1:1 + # CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:1, %read_M:0_K2:0_K1:1, %mma_M:0_K2:0_K1:0, None) + # CHECK: %mma_M:0_K2:0_K1:2 + # CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:2, %read_M:0_K2:0_K1:2, %mma_M:0_K2:0_K1:1, None) + # CHECK: %mma_M:0_K2:0_K1:3 + # CHECK-SAME: (args = (%read_shared_M:0_K2:0_K1:3, %read_M:0_K2:0_K1:3, %mma_M:0_K2:0_K1:2, None) + # CHECK: %permute_M:0_K2:0 + # CHECK-SAME: (args = (%mma_M:0_K2:0_K1:3, [B, M, K2]) + # CHECK: %cast_M:0_K2:0 + # CHECK-SAME: (args = (%permute_M:0_K2:0, f16) # CHECK: %v - # CHECK: %read_shared_0_0_0 + # CHECK: %read_shared_M:0_N:0_K2:0 # CHECK-SAME: (args = (%v, 4, None, (), None) - # CHECK: %read_shared_0_0_1 + # CHECK: %read_shared_M:0_N:0_K2:1 # CHECK-SAME: (args = (%v, 4, None, (), None) - # CHECK: %read_shared_0_0_2 + # CHECK: %read_shared_M:0_N:0_K2:2 # CHECK-SAME: (args = (%v, 4, None, (), None) - # CHECK: %read_shared_0_0_3 + # CHECK: %read_shared_M:0_N:0_K2:3 # CHECK-SAME: (args = (%v, 4, None, (), None) - # CHECK: %reshape_0_0_0 - # CHECK-SAME: (args = ([%cast_0_0], {K2: 32, M: 32, K1: 8, B: 0}) - # CHECK: %reshape_0_0_1 - # CHECK-SAME: (args = ([%cast_0_0], {K2: 32, M: 32, K1: 8, B: 0}) - # CHECK: %reshape_0_0_2 - # CHECK-SAME: (args = ([%cast_0_0], {K2: 32, M: 32, K1: 8, B: 0}) - # CHECK: %reshape_0_0_3 - # CHECK-SAME: (args = ([%cast_0_0], {K2: 32, M: 32, K1: 8, B: 0}) - # CHECK: %mma_0_0_0 - # CHECK-SAME: (args = (%reshape_0_0_0, %read_shared_0_0_0, %acc_0_0_0, None) - # CHECK: %mma_0_0_1 - # CHECK-SAME: (args = (%reshape_0_0_1, %read_shared_0_0_1, %mma_0_0_0, None) - # CHECK: %mma_0_0_2 - # CHECK-SAME: (args = (%reshape_0_0_2, %read_shared_0_0_2, %mma_0_0_1, None) - # CHECK: %mma_0_0_3 - # CHECK-SAME: (args = (%reshape_0_0_3, %read_shared_0_0_3, %mma_0_0_2, None) - # CHECK: return [mma_0_0_3] + # CHECK: %reshape_M:0_N:0_K2:0 + # CHECK-SAME: (args = ([%cast_M:0_K2:0], {K2: 32, M: 32, K1: 8, B: 0}) + # CHECK: %reshape_M:0_N:0_K2:1 + # CHECK-SAME: (args = ([%cast_M:0_K2:0], {K2: 32, M: 32, K1: 8, B: 0}) + # CHECK: %reshape_M:0_N:0_K2:2 + # CHECK-SAME: (args = ([%cast_M:0_K2:0], {K2: 32, M: 32, K1: 8, B: 0}) + # CHECK: %reshape_M:0_N:0_K2:3 + # CHECK-SAME: (args = ([%cast_M:0_K2:0], {K2: 32, M: 32, K1: 8, B: 0}) + # CHECK: %mma_M:0_N:0_K2:0 + # CHECK-SAME: (args = (%reshape_M:0_N:0_K2:0, %read_shared_M:0_N:0_K2:0, %acc_M:0_N:0_K2:0, None) + # CHECK: %mma_M:0_N:0_K2:1 + # CHECK-SAME: (args = (%reshape_M:0_N:0_K2:1, %read_shared_M:0_N:0_K2:1, %mma_M:0_N:0_K2:0, None) + # CHECK: %mma_M:0_N:0_K2:2 + # CHECK-SAME: (args = (%reshape_M:0_N:0_K2:2, %read_shared_M:0_N:0_K2:2, %mma_M:0_N:0_K2:1, None) + # CHECK: %mma_M:0_N:0_K2:3 + # CHECK-SAME: (args = (%reshape_M:0_N:0_K2:3, %read_shared_M:0_N:0_K2:3, %mma_M:0_N:0_K2:2, None) + # CHECK: return [mma_M:0_N:0_K2:3] if __name__ == "__main__": diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index 814c0089..a6de6328 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -6,13 +6,13 @@ import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.promotion import promote_placeholders from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops -from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.expansion.expansion import expand_graph from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext from iree.turbine.kernel.ops.wave_ops import * -from iree.turbine.kernel.wave.utils import run_test, print_trace +from iree.turbine.kernel.wave.utils import run_test, print_trace, initialize_iter_args from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads from iree.turbine.kernel.wave.shared_memory_indexing import ( apply_shared_memory_indexing_corrections, @@ -84,6 +84,7 @@ def test_gemm(): ): trace: CapturedTrace = gemm() IndexingContext.current().finalize() + initialize_iter_args(trace) infer_types(trace) promote_placeholders(trace, constraints) set_node_indices(trace, constraints) @@ -98,87 +99,87 @@ def test_gemm(): # CHECK: %a # CHECK-NEXT: %b # CHECK-NEXT: %c - # CHECK-NEXT: %register_0_0_0 - # CHECK-NEXT: %register_1_1_0 - # CHECK-NEXT: %register_1_0_0 - # CHECK-NEXT: %register_0_1_0 + # CHECK-NEXT: %register_M:0_N:0_K:0 + # CHECK-NEXT: %register_M:0_N:1_K:0 + # CHECK-NEXT: %register_M:1_N:0_K:0 + # CHECK-NEXT: %register_M:1_N:1_K:0 # CHECK-NEXT: %allocate # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction - # CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0] - # CHECK-NEXT: %getresult_1_1_0 - # CHECK-SAME: (%reduction, 3) - # CHECK-NEXT: %getresult_1_0_0 - # CHECK-SAME: (%reduction, 2) - # CHECK-NEXT: %getresult_0_1_0 - # CHECK-SAME: (%reduction, 1) - # CHECK-NEXT: %getresult_0_0_0 + # CHECK-SAME (K, [%register_M:0_N:0_K:0, %register_M:0_N:1_K:0, %register_M:1_N:0_K:0, %register_M:1_N:1_K:0] + # CHECK-NEXT: %getresult_M:0_N:0_K:0 # CHECK-SAME: (%reduction, 0) + # CHECK-NEXT: %getresult_M:0_N:1_K:0 + # CHECK-SAME: (%reduction, 1) + # CHECK-NEXT: %getresult_M:1_N:0_K:0 + # CHECK-SAME: (%reduction, 2) + # CHECK-NEXT: %getresult_M:1_N:1_K:0 + # CHECK-SAME: (%reduction, 3) # CHECK-NEXT: extract_slice - # CHECK-SAME: (%getresult_0_0_0, [0], [1], [1]) - # CHECK-NEXT: %write_1 + # CHECK-SAME: (%getresult_M:0_N:0_K:0, [0], [1], [1]) + # CHECK-NEXT: %write_5 # CHECK-SAME: (%extract_slice, %c, 1, None, ()) # CHECK-NEXT: extract_slice_1 - # CHECK-SAME: (%getresult_0_0_0, [1], [1], [1]) - # CHECK-NEXT: %write_2 + # CHECK-SAME: (%getresult_M:0_N:0_K:0, [1], [1], [1]) + # CHECK-NEXT: %write_6 # CHECK-SAME: (%extract_slice_1, %c, 1, None, ()) # CHECK-NEXT: extract_slice_2 - # CHECK-SAME: (%getresult_0_0_0, [2], [1], [1]) - # CHECK-NEXT: %write_3 + # CHECK-SAME: (%getresult_M:0_N:0_K:0, [2], [1], [1]) + # CHECK-NEXT: %write_7 # CHECK-SAME: (%extract_slice_2, %c, 1, None, ()) # CHECK-NEXT: extract_slice_3 - # CHECK-SAME: (%getresult_0_0_0, [3], [1], [1]) - # CHECK-NEXT: %write_4 + # CHECK-SAME: (%getresult_M:0_N:0_K:0, [3], [1], [1]) + # CHECK-NEXT: %write_8 # CHECK-SAME: (%extract_slice_3, %c, 1, None, ()) # CHECK-NEXT: extract_slice_4 - # CHECK-SAME: (%getresult_1_1_0, [0], [1], [1]) - # CHECK-NEXT: %write_5 + # CHECK-SAME: (%getresult_M:0_N:1_K:0, [0], [1], [1]) + # CHECK-NEXT: %write_9 # CHECK-SAME: (%extract_slice_4, %c, 1, None, ()) # CHECK-NEXT: extract_slice_5 - # CHECK-SAME: (%getresult_1_1_0, [1], [1], [1]) - # CHECK-NEXT: %write_6 + # CHECK-SAME: (%getresult_M:0_N:1_K:0, [1], [1], [1]) + # CHECK-NEXT: %write_10 # CHECK-SAME: (%extract_slice_5, %c, 1, None, ()) # CHECK-NEXT: extract_slice_6 - # CHECK-SAME: (%getresult_1_1_0, [2], [1], [1]) - # CHECK-NEXT: %write_7 + # CHECK-SAME: (%getresult_M:0_N:1_K:0, [2], [1], [1]) + # CHECK-NEXT: %write_11 # CHECK-SAME: (%extract_slice_6, %c, 1, None, ()) # CHECK-NEXT: extract_slice_7 - # CHECK-SAME: (%getresult_1_1_0, [3], [1], [1]) - # CHECK-NEXT: %write_8 + # CHECK-SAME: (%getresult_M:0_N:1_K:0, [3], [1], [1]) + # CHECK-NEXT: %write_12 # CHECK-SAME: (%extract_slice_7, %c, 1, None, ()) # CHECK-NEXT: extract_slice_8 - # CHECK-SAME: (%getresult_1_0_0, [0], [1], [1]) - # CHECK-NEXT: %write_9 + # CHECK-SAME: (%getresult_M:1_N:0_K:0, [0], [1], [1]) + # CHECK-NEXT: %write_13 # CHECK-SAME: (%extract_slice_8, %c, 1, None, ()) # CHECK-NEXT: extract_slice_9 - # CHECK-SAME: (%getresult_1_0_0, [1], [1], [1]) - # CHECK-NEXT: %write_10 + # CHECK-SAME: (%getresult_M:1_N:0_K:0, [1], [1], [1]) + # CHECK-NEXT: %write_14 # CHECK-SAME: (%extract_slice_9, %c, 1, None, ()) # CHECK-NEXT: extract_slice_10 - # CHECK-SAME: (%getresult_1_0_0, [2], [1], [1]) - # CHECK-NEXT: %write_11 + # CHECK-SAME: (%getresult_M:1_N:0_K:0, [2], [1], [1]) + # CHECK-NEXT: %write_15 # CHECK-SAME: (%extract_slice_10, %c, 1, None, ()) # CHECK-NEXT: extract_slice_11 - # CHECK-SAME: (%getresult_1_0_0, [3], [1], [1]) - # CHECK-NEXT: %write_12 + # CHECK-SAME: (%getresult_M:1_N:0_K:0, [3], [1], [1]) + # CHECK-NEXT: %write_16 # CHECK-SAME: (%extract_slice_11, %c, 1, None, ()) # CHECK-NEXT: extract_slice_12 - # CHECK-SAME: (%getresult_0_1_0, [0], [1], [1]) - # CHECK-NEXT: %write_13 + # CHECK-SAME: (%getresult_M:1_N:1_K:0, [0], [1], [1]) + # CHECK-NEXT: %write_17 # CHECK-SAME: (%extract_slice_12, %c, 1, None, ()) # CHECK-NEXT: extract_slice_13 - # CHECK-SAME: (%getresult_0_1_0, [1], [1], [1]) - # CHECK-NEXT: %write_14 + # CHECK-SAME: (%getresult_M:1_N:1_K:0, [1], [1], [1]) + # CHECK-NEXT: %write_18 # CHECK-SAME: (%extract_slice_13, %c, 1, None, ()) # CHECK-NEXT: extract_slice_14 - # CHECK-SAME: (%getresult_0_1_0, [2], [1], [1]) - # CHECK-NEXT: %write_15 + # CHECK-SAME: (%getresult_M:1_N:1_K:0, [2], [1], [1]) + # CHECK-NEXT: %write_19 # CHECK-SAME: (%extract_slice_14, %c, 1, None, ()) # CHECK-NEXT: extract_slice_15 - # CHECK-SAME: (%getresult_0_1_0, [3], [1], [1]) - # CHECK-NEXT: %write_16 + # CHECK-SAME: (%getresult_M:1_N:1_K:0, [3], [1], [1]) + # CHECK-NEXT: %write_20 # CHECK-SAME: (%extract_slice_15, %c, 1, None, ()) # CHECK-NEXT: return None @@ -189,162 +190,162 @@ def test_gemm(): # CHECK-NEXT: register # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) : 1 : 1}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1}) # CHECK-NEXT: register( # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) : 1 : 1}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1}) # CHECK-NEXT: allocate( # CHECK-NEXT: allocate( # CHECK-NEXT: reduction( - # CHECK-NEXT: get_result(value=reduction, res_idx=3) - # CHECK-NEXT: get_result(value=reduction, res_idx=2) - # CHECK-NEXT: get_result(value=reduction, res_idx=1) # CHECK-NEXT: get_result(value=reduction, res_idx=0) - # CHECK-NEXT: extract_slice(register_=getresult_0_0_0, offset=[0], size=[1], stride=[1]) + # CHECK-NEXT: get_result(value=reduction, res_idx=1) + # CHECK-NEXT: get_result(value=reduction, res_idx=2) + # CHECK-NEXT: get_result(value=reduction, res_idx=3) + # CHECK-NEXT: extract_slice(register_=getresult_M:0_N:0_K:0, offset=[0], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice, memory=c, elements_per_thread=1, # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 32 : 1 : 1}) - # CHECK-NEXT: extract_slice(register_=getresult_0_0_0, offset=[1], size=[1], stride=[1]) + # CHECK-NEXT: extract_slice(register_=getresult_M:0_N:0_K:0, offset=[1], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_1, memory=c, elements_per_thread=1, # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 1 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 32 : 1 : 1}) - # CHECK-NEXT: extract_slice(register_=getresult_0_0_0, offset=[2], size=[1], stride=[1]) + # CHECK-NEXT: extract_slice(register_=getresult_M:0_N:0_K:0, offset=[2], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_2, memory=c, elements_per_thread=1, # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 2 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 32 : 1 : 1}) - # CHECK-NEXT: extract_slice(register_=getresult_0_0_0, offset=[3], size=[1], stride=[1]) + # CHECK-NEXT: extract_slice(register_=getresult_M:0_N:0_K:0, offset=[3], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_3, memory=c, elements_per_thread=1, # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 3 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 32 : 1 : 1}) - # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[0], size=[1], stride=[1]) + # CHECK-NEXT: extract_slice(register_=getresult_M:0_N:1_K:0, offset=[0], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_4, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 16 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) - # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[1], size=[1], stride=[1]) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) + # CHECK-NEXT: extract_slice(register_=getresult_M:0_N:1_K:0, offset=[1], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_5, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 17 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) - # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[2], size=[1], stride=[1]) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 1 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) + # CHECK-NEXT: extract_slice(register_=getresult_M:0_N:1_K:0, offset=[2], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_6, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 18 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) - # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[3], size=[1], stride=[1]) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 2 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) + # CHECK-NEXT: extract_slice(register_=getresult_M:0_N:1_K:0, offset=[3], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_7, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 19 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) - # CHECK-NEXT: extract_slice(register_=getresult_1_0_0, offset=[0], size=[1], stride=[1]) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 3 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) + # CHECK-NEXT: extract_slice(register_=getresult_M:1_N:0_K:0, offset=[0], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_8, memory=c, elements_per_thread=1, # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 16 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 32 : 1 : 1}) - # CHECK-NEXT: extract_slice(register_=getresult_1_0_0, offset=[1], size=[1], stride=[1]) + # CHECK-NEXT: extract_slice(register_=getresult_M:1_N:0_K:0, offset=[1], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_9, memory=c, elements_per_thread=1, # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 17 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 32 : 1 : 1}) - # CHECK-NEXT: extract_slice(register_=getresult_1_0_0, offset=[2], size=[1], stride=[1]) + # CHECK-NEXT: extract_slice(register_=getresult_M:1_N:0_K:0, offset=[2], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_10, memory=c, elements_per_thread=1, # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 18 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 32 : 1 : 1}) - # CHECK-NEXT: extract_slice(register_=getresult_1_0_0, offset=[3], size=[1], stride=[1]) + # CHECK-NEXT: extract_slice(register_=getresult_M:1_N:0_K:0, offset=[3], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_11, memory=c, elements_per_thread=1, # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 19 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 32 : 1 : 1}) - # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[0], size=[1], stride=[1]) + # CHECK-NEXT: extract_slice(register_=getresult_M:1_N:1_K:0, offset=[0], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_12, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) - # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[1], size=[1], stride=[1]) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 16 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) + # CHECK-NEXT: extract_slice(register_=getresult_M:1_N:1_K:0, offset=[1], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_13, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 1 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) - # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[2], size=[1], stride=[1]) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 17 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) + # CHECK-NEXT: extract_slice(register_=getresult_M:1_N:1_K:0, offset=[2], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_14, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 2 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) - # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[3], size=[1], stride=[1]) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 18 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) + # CHECK-NEXT: extract_slice(register_=getresult_M:1_N:1_K:0, offset=[3], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_15, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 3 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 19 : 1 : 1, N: 64*$WG1 + Mod($T0, 16) + 48 : 1 : 1}) # Reduction subgraph: - # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_0_1_0 - # CHECK-NEXT: %acc_1_0_0 - # CHECK-NEXT: %acc_1_1_0 + # CHECK: %acc_M:0_N:0_K:0 + # CHECK-NEXT: %acc_M:0_N:1_K:0 + # CHECK-NEXT: %acc_M:1_N:0_K:0 + # CHECK-NEXT: %acc_M:1_N:1_K:0 # CHECK-NEXT: %a - # CHECK-NEXT: %read_4 + # CHECK-NEXT: %read_36 # CHECK-SAME: (%a, 8, None, (), None) - # CHECK-NEXT: %write_2 - # CHECK-SAME: (%read_4, %allocate, 8, None, ()) - # CHECK-NEXT: %read_5 + # CHECK-NEXT: %write_18 + # CHECK-SAME: (%read_36, %allocate, 8, None, ()) + # CHECK-NEXT: %read_37 # CHECK-SAME: (%a, 8, None, (), None) - # CHECK-NEXT: %write_3 - # CHECK-SAME: (%read_5, %allocate, 8, None, ()) - # CHECK-NEXT: %read_shared_0_0_0 - # CHECK-NEXT: %read_shared_0_0_1 - # CHECK-NEXT: %read_shared_0_0_2 - # CHECK-NEXT: %read_shared_0_0_3 - # CHECK-NEXT: %read_shared_1_0_0 - # CHECK-NEXT: %read_shared_1_0_1 - # CHECK-NEXT: %read_shared_1_0_2 - # CHECK-NEXT: %read_shared_1_0_3 + # CHECK-NEXT: %write_19 + # CHECK-SAME: (%read_37, %allocate, 8, None, ()) + # CHECK-NEXT: %read_shared_M:0_N:0_K:0 + # CHECK-NEXT: %read_shared_M:0_N:0_K:1 + # CHECK-NEXT: %read_shared_M:0_N:0_K:2 + # CHECK-NEXT: %read_shared_M:0_N:0_K:3 + # CHECK-NEXT: %read_shared_M:1_N:0_K:0 + # CHECK-NEXT: %read_shared_M:1_N:0_K:1 + # CHECK-NEXT: %read_shared_M:1_N:0_K:2 + # CHECK-NEXT: %read_shared_M:1_N:0_K:3 # CHECK-NEXT: %b - # CHECK-NEXT: %read_6 + # CHECK-NEXT: %read_38 # CHECK-SAME: (%b, 8, None, (), None) - # CHECK-NEXT: %write_4 - # CHECK-SAME: (%read_6, %allocate_1, 8, None, ()) - # CHECK-NEXT: %read_7 + # CHECK-NEXT: %write_20 + # CHECK-SAME: (%read_38, %allocate_1, 8, None, ()) + # CHECK-NEXT: %read_39 # CHECK-SAME: (%b, 8, None, (), None) - # CHECK-NEXT: %write_5 - # CHECK-SAME: (%read_7, %allocate_1, 8, None, ()) - # CHECK-NEXT: %read_shared_0_0_0 - # CHECK-NEXT: %read_shared_0_0_1 - # CHECK-NEXT: %read_shared_0_0_2 - # CHECK-NEXT: %read_shared_0_0_3 - # CHECK-NEXT: %read_shared_0_1_0 - # CHECK-NEXT: %read_shared_0_1_1 - # CHECK-NEXT: %read_shared_0_1_2 - # CHECK-NEXT: %read_shared_0_1_3 - # CHECK-NEXT: %mma_0_0_0 - # CHECK-NEXT: %mma_0_0_1 - # CHECK-NEXT: %mma_0_0_2 - # CHECK-NEXT: %mma_0_0_3 - # CHECK-NEXT: %mma_1_1_0 - # CHECK-NEXT: %mma_1_1_1 - # CHECK-NEXT: %mma_1_1_2 - # CHECK-NEXT: %mma_1_1_3 - # CHECK-NEXT: %mma_1_0_0 - # CHECK-NEXT: %mma_1_0_1 - # CHECK-NEXT: %mma_1_0_2 - # CHECK-NEXT: %mma_1_0_3 - # CHECK-NEXT: %mma_0_1_0 - # CHECK-NEXT: %mma_0_1_1 - # CHECK-NEXT: %mma_0_1_2 - # CHECK-NEXT: %mma_0_1_3 + # CHECK-NEXT: %write_21 + # CHECK-SAME: (%read_39, %allocate_1, 8, None, ()) + # CHECK-NEXT: %read_shared_M:0_N:0_K:0 + # CHECK-NEXT: %read_shared_M:0_N:0_K:1 + # CHECK-NEXT: %read_shared_M:0_N:0_K:2 + # CHECK-NEXT: %read_shared_M:0_N:0_K:3 + # CHECK-NEXT: %read_shared_M:0_N:1_K:0 + # CHECK-NEXT: %read_shared_M:0_N:1_K:1 + # CHECK-NEXT: %read_shared_M:0_N:1_K:2 + # CHECK-NEXT: %read_shared_M:0_N:1_K:3 + # CHECK-NEXT: %mma_M:0_N:0_K:0 + # CHECK-NEXT: %mma_M:0_N:0_K:1 + # CHECK-NEXT: %mma_M:0_N:0_K:2 + # CHECK-NEXT: %mma_M:0_N:0_K:3 + # CHECK-NEXT: %mma_M:0_N:1_K:0 + # CHECK-NEXT: %mma_M:0_N:1_K:1 + # CHECK-NEXT: %mma_M:0_N:1_K:2 + # CHECK-NEXT: %mma_M:0_N:1_K:3 + # CHECK-NEXT: %mma_M:1_N:0_K:0 + # CHECK-NEXT: %mma_M:1_N:0_K:1 + # CHECK-NEXT: %mma_M:1_N:0_K:2 + # CHECK-NEXT: %mma_M:1_N:0_K:3 + # CHECK-NEXT: %mma_M:1_N:1_K:0 + # CHECK-NEXT: %mma_M:1_N:1_K:1 + # CHECK-NEXT: %mma_M:1_N:1_K:2 + # CHECK-NEXT: %mma_M:1_N:1_K:3 # Reduction subgraph (custom format): - # CHECK: placeholder(_name=acc_0_0_0 - # CHECK-NEXT: placeholder(_name=acc_0_1_0 - # CHECK-NEXT: placeholder(_name=acc_1_0_0 - # CHECK-NEXT: placeholder(_name=acc_1_1_0 + # CHECK: placeholder(_name=acc_M:0_N:0_K:0 + # CHECK-NEXT: placeholder(_name=acc_M:0_N:1_K:0 + # CHECK-NEXT: placeholder(_name=acc_M:1_N:0_K:0 + # CHECK-NEXT: placeholder(_name=acc_M:1_N:1_K:0 # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: read(memory=a, elements_per_thread=8, # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod(16*$T1 + 32*$T2 + floor($T0/8), 64) : 1 : 1, K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) - # CHECK-NEXT: write(register_=read_4, memory=allocate, elements_per_thread=8, + # CHECK-NEXT: write(register_=read_36, memory=allocate, elements_per_thread=8, # CHECK-SAME: index={M: Mod(16*$T1 + 32*$T2 + floor($T0/8), 64) : 1 : 1, K: 8*(Mod($T0, 8)) : 8 : 1}) # CHECK-NEXT: read(memory=a, elements_per_thread=8, # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod(16*$T1 + 32*$T2 + floor($T0/8) + 32, 64) : 1 : 1, K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) - # CHECK-NEXT: write(register_=read_5, memory=allocate, elements_per_thread=8, + # CHECK-NEXT: write(register_=read_37, memory=allocate, elements_per_thread=8, # CHECK-SAME: index={M: Mod(16*$T1 + 32*$T2 + floor($T0/8) + 32, 64) : 1 : 1, K: 8*(Mod($T0, 8)) : 8 : 1}) - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) # CHECK-NEXT: placeholder(_name=b, _type=Memory[N, K].of(f16)) # CHECK-NEXT: read(memory=b, elements_per_thread=8, # CHECK-SAME: index={N: $WG1*BLOCK_N + BLOCK_N/2 + Mod(16*$T1 + 32*$T2 + floor($T0/8), 64) : 1 : 1, K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) - # CHECK-NEXT: write(register_=read_6, memory=allocate_1, elements_per_thread=8, + # CHECK-NEXT: write(register_=read_38, memory=allocate_1, elements_per_thread=8, # CHECK-SAME: index={N: BLOCK_N/2 + Mod(16*$T1 + 32*$T2 + floor($T0/8), 64) : 1 : 1, K: 8*(Mod($T0, 8)) : 8 : 1}) # CHECK-NEXT: read(memory=b, elements_per_thread=8, # CHECK-SAME: index={N: $WG1*BLOCK_N + BLOCK_N/2 + Mod(16*$T1 + 32*$T2 + floor($T0/8) + 32, 64) : 1 : 1, K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) - # CHECK-NEXT: write(register_=read_7, memory=allocate_1, elements_per_thread=8, + # CHECK-NEXT: write(register_=read_39, memory=allocate_1, elements_per_thread=8, # CHECK-SMAE: index={N: BLOCK_N/2 + Mod(16*$T1 + 32*$T2 + floor($T0/8) + 32, 64), K: 8*(Mod($T0, 8)) : 8 : 1}) - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) if __name__ == "__main__": diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index 5a5402fd..695c5854 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -7,13 +7,13 @@ from iree.turbine.kernel.wave.promotion import promote_placeholders from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers -from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.expansion.expansion import expand_graph from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext from iree.turbine.kernel.ops.wave_ops import * -from iree.turbine.kernel.wave.utils import run_test, print_trace +from iree.turbine.kernel.wave.utils import run_test, print_trace, initialize_iter_args from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads from iree.turbine.kernel.wave.visualization import visualize_graph from iree.turbine.kernel.wave.shared_memory_indexing import ( @@ -86,6 +86,7 @@ def test_gemm(): trace: CapturedTrace = gemm() visualize = False IndexingContext.current().finalize() + initialize_iter_args(trace) infer_types(trace) promote_placeholders(trace, constraints) set_node_indices(trace, constraints) @@ -104,32 +105,32 @@ def test_gemm(): # CHECK: %a # CHECK-NEXT: %b # CHECK-NEXT: %c - # CHECK-NEXT: %register_0_0_0 - # CHECK-NEXT: %register_1_1_0 - # CHECK-NEXT: %register_1_0_0 - # CHECK-NEXT: %register_0_1_0 + # CHECK-NEXT: %register_M:0_N:0_K:0 + # CHECK-NEXT: %register_M:0_N:1_K:0 + # CHECK-NEXT: %register_M:1_N:0_K:0 + # CHECK-NEXT: %register_M:1_N:1_K:0 # CHECK-NEXT: %allocate # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction - # CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0] - # CHECK-NEXT: %getresult_1_1_0 - # CHECK-SAME: (%reduction, 3) - # CHECK-NEXT: %getresult_1_0_0 - # CHECK-SAME: (%reduction, 2) - # CHECK-NEXT: %getresult_0_1_0 - # CHECK-SAME: (%reduction, 1) - # CHECK-NEXT: %getresult_0_0_0 + # CHECK-SAME (K, [%register_M:0_N:0_K:0, %register_M:1_N:1_K:0, %register_M:1_N:0_K:0, %register_M:0_N:1_K:0] + # CHECK-NEXT: %getresult_M:0_N:0_K:0 # CHECK-SAME: (%reduction, 0) - # CHECK-NEXT: %write_0_0_0 - # CHECK-SAME: (%getresult_0_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_1_1_0 - # CHECK-SAME: (%getresult_1_1_0, %c, 4, None, ()) - # CHECK-NEXT: %write_1_0_0 - # CHECK-SAME: (%getresult_1_0_0, %c, 4, None, ()) - # CHECK-NEXT: %write_0_1_0 - # CHECK-SAME: (%getresult_0_1_0, %c, 4, None, ()) + # CHECK-NEXT: %getresult_M:0_N:1_K:0 + # CHECK-SAME: (%reduction, 1) + # CHECK-NEXT: %getresult_M:1_N:0_K:0 + # CHECK-SAME: (%reduction, 2) + # CHECK-NEXT: %getresult_M:1_N:1_K:0 + # CHECK-SAME: (%reduction, 3) + # CHECK-NEXT: %write_M:0_N:0_K:0 + # CHECK-SAME: (%getresult_M:0_N:0_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:0_N:1_K:0 + # CHECK-SAME: (%getresult_M:0_N:1_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:0_K:0 + # CHECK-SAME: (%getresult_M:1_N:0_K:0, %c, 4, None, ()) + # CHECK-NEXT: %write_M:1_N:1_K:0 + # CHECK-SAME: (%getresult_M:1_N:1_K:0, %c, 4, None, ()) # CHECK-NEXT: return None # Root graph (custom format): @@ -139,128 +140,128 @@ def test_gemm(): # CHECK-NEXT: register # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) : 1 : 1}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1}) # CHECK-NEXT: register( # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) : 1 : 1}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1}) # CHECK-NEXT: allocate( # CHECK-NEXT: allocate( # CHECK-NEXT: reduction( - # CHECK-NEXT: get_result(value=reduction, res_idx=3) - # CHECK-NEXT: get_result(value=reduction, res_idx=2) - # CHECK-NEXT: get_result(value=reduction, res_idx=1) # CHECK-NEXT: get_result(value=reduction, res_idx=0) - # CHECK-NEXT: write(register_=getresult_0_0_0, memory=c + # CHECK-NEXT: get_result(value=reduction, res_idx=1) + # CHECK-NEXT: get_result(value=reduction, res_idx=2) + # CHECK-NEXT: get_result(value=reduction, res_idx=3) + # CHECK-NEXT: write(register_=getresult_M:0_N:0_K:0, memory=c # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) : 1 : 1}) - # CHECK-NEXT: write(register_=getresult_1_1_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1}) - # CHECK-NEXT: write(register_=getresult_1_0_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) : 1 : 1}) - # CHECK-NEXT: write(register_=getresult_0_1_0, memory=c + # CHECK-NEXT: write(register_=getresult_M:0_N:1_K:0, memory=c # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1}) + # CHECK-NEXT: write(register_=getresult_M:1_N:0_K:0, memory=c + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) : 1 : 1}) + # CHECK-NEXT: write(register_=getresult_M:1_N:1_K:0, memory=c + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1}) # Reduction subgraph: - # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_0_1_0 - # CHECK-NEXT: %acc_1_0_0 - # CHECK-NEXT: %acc_1_1_0 + # CHECK: %acc_M:0_N:0_K:0 + # CHECK-NEXT: %acc_M:0_N:1_K:0 + # CHECK-NEXT: %acc_M:1_N:0_K:0 + # CHECK-NEXT: %acc_M:1_N:1_K:0 # CHECK-NEXT: %a - # CHECK-NEXT: %read_4 + # CHECK-NEXT: %read_36 # CHECK-SAME: (%a, 8, None, (), None) - # CHECK-NEXT: %write_2 - # CHECK-SAME: (%read_4, %allocate, 8, None, ()) - # CHECK-NEXT: %read_5 + # CHECK-NEXT: %write_18 + # CHECK-SAME: (%read_36, %allocate, 8, None, ()) + # CHECK-NEXT: %read_37 # CHECK-SAME: (%a, 8, None, (), None) - # CHECK-NEXT: %write_3 - # CHECK-SAME: (%read_5, %allocate, 8, None, ()) + # CHECK-NEXT: %write_19 + # CHECK-SAME: (%read_37, %allocate, 8, None, ()) # CHECK-NEXT: %shared_memory_barrier - # CHECK-NEXT: %read_shared_0_0_0 - # CHECK-NEXT: %read_shared_0_0_1 - # CHECK-NEXT: %read_shared_0_0_2 - # CHECK-NEXT: %read_shared_0_0_3 - # CHECK-NEXT: %read_shared_1_0_0 - # CHECK-NEXT: %read_shared_1_0_1 - # CHECK-NEXT: %read_shared_1_0_2 - # CHECK-NEXT: %read_shared_1_0_3 + # CHECK-NEXT: %read_shared_M:0_N:0_K:0 + # CHECK-NEXT: %read_shared_M:0_N:0_K:1 + # CHECK-NEXT: %read_shared_M:0_N:0_K:2 + # CHECK-NEXT: %read_shared_M:0_N:0_K:3 + # CHECK-NEXT: %read_shared_M:1_N:0_K:0 + # CHECK-NEXT: %read_shared_M:1_N:0_K:1 + # CHECK-NEXT: %read_shared_M:1_N:0_K:2 + # CHECK-NEXT: %read_shared_M:1_N:0_K:3 # CHECK-NEXT: %b - # CHECK-NEXT: %read_6 + # CHECK-NEXT: %read_38 # CHECK-SAME: (%b, 8, None, (), None) # CHECK-NEXT: %shared_memory_barrier_1 - # CHECK-NEXT: %write_4 - # CHECK-SAME: (%read_6, %allocate_1, 8, None, ()) - # CHECK-NEXT: %read_7 + # CHECK-NEXT: %write_20 + # CHECK-SAME: (%read_38, %allocate_1, 8, None, ()) + # CHECK-NEXT: %read_39 # CHECK-SAME: (%b, 8, None, (), None) - # CHECK-NEXT: %write_5 - # CHECK-SAME: (%read_7, %allocate_1, 8, None, ()) + # CHECK-NEXT: %write_21 + # CHECK-SAME: (%read_39, %allocate_1, 8, None, ()) # CHECK-NEXT: %shared_memory_barrier_2 - # CHECK-NEXT: %read_shared_0_0_0 - # CHECK-NEXT: %read_shared_0_0_1 - # CHECK-NEXT: %read_shared_0_0_2 - # CHECK-NEXT: %read_shared_0_0_3 - # CHECK-NEXT: %read_shared_0_1_0 - # CHECK-NEXT: %read_shared_0_1_1 - # CHECK-NEXT: %read_shared_0_1_2 - # CHECK-NEXT: %read_shared_0_1_3 - # CHECK-NEXT: %mma_0_0_0 - # CHECK-NEXT: %mma_0_0_1 - # CHECK-NEXT: %mma_0_0_2 - # CHECK-NEXT: %mma_0_0_3 - # CHECK-NEXT: %mma_1_1_0 - # CHECK-NEXT: %mma_1_1_1 - # CHECK-NEXT: %mma_1_1_2 - # CHECK-NEXT: %mma_1_1_3 - # CHECK-NEXT: %mma_1_0_0 - # CHECK-NEXT: %mma_1_0_1 - # CHECK-NEXT: %mma_1_0_2 - # CHECK-NEXT: %mma_1_0_3 - # CHECK-NEXT: %mma_0_1_0 - # CHECK-NEXT: %mma_0_1_1 - # CHECK-NEXT: %mma_0_1_2 - # CHECK-NEXT: %mma_0_1_3 + # CHECK-NEXT: %read_shared_M:0_N:0_K:0 + # CHECK-NEXT: %read_shared_M:0_N:0_K:1 + # CHECK-NEXT: %read_shared_M:0_N:0_K:2 + # CHECK-NEXT: %read_shared_M:0_N:0_K:3 + # CHECK-NEXT: %read_shared_M:0_N:1_K:0 + # CHECK-NEXT: %read_shared_M:0_N:1_K:1 + # CHECK-NEXT: %read_shared_M:0_N:1_K:2 + # CHECK-NEXT: %read_shared_M:0_N:1_K:3 + # CHECK-NEXT: %mma_M:0_N:0_K:0 + # CHECK-NEXT: %mma_M:0_N:0_K:1 + # CHECK-NEXT: %mma_M:0_N:0_K:2 + # CHECK-NEXT: %mma_M:0_N:0_K:3 + # CHECK-NEXT: %mma_M:0_N:1_K:0 + # CHECK-NEXT: %mma_M:0_N:1_K:1 + # CHECK-NEXT: %mma_M:0_N:1_K:2 + # CHECK-NEXT: %mma_M:0_N:1_K:3 + # CHECK-NEXT: %mma_M:1_N:0_K:0 + # CHECK-NEXT: %mma_M:1_N:0_K:1 + # CHECK-NEXT: %mma_M:1_N:0_K:2 + # CHECK-NEXT: %mma_M:1_N:0_K:3 + # CHECK-NEXT: %mma_M:1_N:1_K:0 + # CHECK-NEXT: %mma_M:1_N:1_K:1 + # CHECK-NEXT: %mma_M:1_N:1_K:2 + # CHECK-NEXT: %mma_M:1_N:1_K:3 # Reduction subgraph (custom format): - # CHECK: placeholder(_name=acc_0_0_0 - # CHECK-NEXT: placeholder(_name=acc_0_1_0 - # CHECK-NEXT: placeholder(_name=acc_1_0_0 - # CHECK-NEXT: placeholder(_name=acc_1_1_0 + # CHECK: placeholder(_name=acc_M:0_N:0_K:0 + # CHECK-NEXT: placeholder(_name=acc_M:0_N:1_K:0 + # CHECK-NEXT: placeholder(_name=acc_M:1_N:0_K:0 + # CHECK-NEXT: placeholder(_name=acc_M:1_N:1_K:0 # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: read(memory=a, elements_per_thread=8, # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod(16*$T1 + 32*$T2 + floor($T0/8), 64) : 1 : 1, K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) - # CHECK-NEXT: write(register_=read_4, memory=allocate, elements_per_thread=8, + # CHECK-NEXT: write(register_=read_36, memory=allocate, elements_per_thread=8, # CHECK-SAME: index={M: Mod(16*$T1 + 32*$T2 + floor($T0/8), 64) : 1 : 1, K: 8*(Mod($T0, 8)) : 8 : 1}) # CHECK-NEXT: read(memory=a, elements_per_thread=8, # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod(16*$T1 + 32*$T2 + floor($T0/8) + 32, 64) : 1 : 1, K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) - # CHECK-NEXT: write(register_=read_5, memory=allocate, elements_per_thread=8, + # CHECK-NEXT: write(register_=read_37, memory=allocate, elements_per_thread=8, # CHECK-SAME: index={M: Mod(16*$T1 + 32*$T2 + floor($T0/8) + 32, 64) : 1 : 1, K: 8*(Mod($T0, 8)) : 8 : 1}) # CHECK-NEXT: shared_memory_barrier() - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_2, write_3], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_18, write_19], index={M: Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) # CHECK-NEXT: placeholder(_name=b, _type=Memory[N, K].of(f16)) # CHECK-NEXT: read(memory=b, elements_per_thread=8, # CHECK-SAME: index={N: $WG1*BLOCK_N + BLOCK_N/2 + Mod(16*$T1 + 32*$T2 + floor($T0/8), 64) : 1 : 1, K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) # CHECK-NEXT: shared_memory_barrier() - # CHECK-NEXT: write(register_=read_6, memory=allocate_1, elements_per_thread=8, + # CHECK-NEXT: write(register_=read_38, memory=allocate_1, elements_per_thread=8, # CHECK-SAME: index={N: BLOCK_N/2 + Mod(16*$T1 + 32*$T2 + floor($T0/8), 64) : 1 : 1, K: 8*(Mod($T0, 8)) : 8 : 1}) # CHECK-NEXT: read(memory=b, elements_per_thread=8, # CHECK-SAME: index={N: $WG1*BLOCK_N + BLOCK_N/2 + Mod(16*$T1 + 32*$T2 + floor($T0/8) + 32, 64) : 1 : 1, K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) - # CHECK-NEXT: write(register_=read_7, memory=allocate_1, elements_per_thread=8, + # CHECK-NEXT: write(register_=read_39, memory=allocate_1, elements_per_thread=8, # CHECK-SMAE: index={N: BLOCK_N/2 + Mod(16*$T1 + 32*$T2 + floor($T0/8) + 32, 64) : 1 : 1, K: 8*(Mod($T0, 8)) : 8 : 1}) # CHECK-NEXT: barrier() - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) - # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_4, write_5], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) + # CHECK-NEXT: read(memory=allocate_1, elements_per_thread=4, mapping_dynamic_vals=(), _write_dependency=[write_20, write_21], index={N: BLOCK_N/2 + Mod($T0, 16) + 16 : 1 : 1, K: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) if __name__ == "__main__": diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py index 3f6004fd..7b4405eb 100644 --- a/lit_tests/kernel/wave/scheduling.py +++ b/lit_tests/kernel/wave/scheduling.py @@ -6,13 +6,17 @@ import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.promotion import promote_placeholders from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops -from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.expansion.expansion import expand_graph from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext from iree.turbine.kernel.ops.wave_ops import * -from iree.turbine.kernel.wave.utils import run_test, print_subgraph +from iree.turbine.kernel.wave.utils import ( + run_test, + print_subgraph, + initialize_iter_args, +) from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads from iree.turbine.kernel.wave.shared_memory_indexing import ( apply_shared_memory_indexing_corrections, @@ -96,6 +100,7 @@ def test_gemm_pipelined(): ): trace: CapturedTrace = gemm_pipelined() IndexingContext.current().finalize() + initialize_iter_args(trace) infer_types(trace) promote_placeholders(trace, constraints) set_node_indices(trace, constraints) @@ -107,10 +112,10 @@ def test_gemm_pipelined(): schedule_graph(trace, constraints, True) print_subgraph(trace, "pipelined_reduction", False) - # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_0_1_0 - # CHECK-NEXT: %acc_1_0_0 - # CHECK-NEXT: %acc_1_1_0 + # CHECK: %acc_m_0_n_0_k_0 + # CHECK-NEXT: %acc_m_0_n_1_k_0 + # CHECK-NEXT: %acc_m_1_n_0_k_0 + # CHECK-NEXT: %acc_m_1_n_1_k_0 # CHECK-NEXT: %rotating_reg_0 # CHECK-NEXT: %rotating_reg_1 # CHECK-NEXT: %rotating_reg_2 @@ -118,128 +123,128 @@ def test_gemm_pipelined(): # CHECK-NEXT: %rotating_reg_4 # CHECK-NEXT: %rotating_reg_5 # CHECK-NEXT: %rotating_reg_6 - # CHECK-NEXT: %mma_1_1_1 + # CHECK-NEXT: %mma_M_1_N_1_K_1 # CHECK-SAME: (%rotating_reg_1, %rotating_reg_4, %rotating_reg_6, None) - # CHECK-NEXT: %read_shared_0_0_0 - # CHECK-NEXT: %read_shared_0_0_1 - # CHECK-NEXT: %read_4 - # CHECK-NEXT: %read_5 + # CHECK-NEXT: %read_shared_M_0_N_0_K_0 + # CHECK-NEXT: %read_shared_M_0_N_0_K_1 + # CHECK-NEXT: %read_20 + # CHECK-NEXT: %read_21 # CHECK-NEXT: %scheduling_group_barrier # CHECK-SAME: ({Operation.MMA: 1, Operation.READ_SHARED: 2, Operation.READ_GLOBAL: 2}, 0) - # CHECK-NEXT: %read_shared_1_0_0 - # CHECK-NEXT: %read_shared_1_0_1 - # CHECK-NEXT: %mma_0_0_0 - # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_0_1, %acc_0_0_0, None) - # CHECK-NEXT: %mma_0_1_0 - # CHECK-SAME: (%read_shared_0_0_0, %rotating_reg_3, %acc_0_1_0, None) + # CHECK-NEXT: %read_shared_M_1_N_0_K_0 + # CHECK-NEXT: %read_shared_M_1_N_0_K_1 + # CHECK-NEXT: %mma_M_0_N_0_K_0 + # CHECK-SAME: (%read_shared_M_0_N_0_K_0, %read_shared_M_0_N_0_K_1, %acc_m_0_n_0_k_0, None) + # CHECK-NEXT: %mma_M_0_N_1_K_0 + # CHECK-SAME: (%read_shared_M_0_N_0_K_0, %rotating_reg_3, %acc_m_0_n_1_k_0, None) # CHECK-NEXT: %scheduling_group_barrier # CHECK-SAME: ({Operation.READ_SHARED: 2, Operation.MMA: 2}, 0) - # CHECK-NEXT: %mma_0_0_1 - # CHECK-SAME: (%rotating_reg_0, %rotating_reg_2, %mma_0_0_0, None) - # CHECK-NEXT: %mma_1_0_0 - # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_0_1, %acc_1_0_0, None) - # CHECK-NEXT: %write_2 - # CHECK-NEXT: %write_3 + # CHECK-NEXT: %mma_M_0_N_0_K_1 + # CHECK-SAME: (%rotating_reg_0, %rotating_reg_2, %mma_M_0_N_0_K_0, None) + # CHECK-NEXT: %mma_M_1_N_0_K_0 + # CHECK-SAME: (%read_shared_M_1_N_0_K_0, %read_shared_M_0_N_0_K_1, %acc_m_1_n_0_k_0, None) + # CHECK-NEXT: %write_10 + # CHECK-NEXT: %write_11 # CHECK-NEXT: %scheduling_group_barrier # CHECK-SAME: ({Operation.MMA: 2, Operation.WRITE_SHARED: 2}, 0) - # CHECK-NEXT: %mma_1_0_1 - # CHECK-SAME: (%read_shared_1_0_1, %rotating_reg_2, %mma_1_0_0, None) - # CHECK-NEXT: %mma_0_1_1 - # CHECK-SAME: (%rotating_reg_0, %rotating_reg_5, %mma_0_1_0, None) - # CHECK-NEXT: %read_shared_0_1_0 - # CHECK-NEXT: %read_shared_0_1_1 + # CHECK-NEXT: %mma_M_0_N_1_K_1 + # CHECK-SAME: (%rotating_reg_0, %rotating_reg_5, %mma_M_0_N_1_K_0, None) + # CHECK-NEXT: %mma_M_1_N_0_K_1 + # CHECK-SAME: (%read_shared_M_1_N_0_K_1, %rotating_reg_2, %mma_M_1_N_0_K_0, None) + # CHECK-NEXT: %read_shared_M_0_N_1_K_0 + # CHECK-NEXT: %read_shared_M_0_N_1_K_1 # CHECK-NEXT: %scheduling_group_barrier # CHECK-SAME: ({Operation.MMA: 2, Operation.READ_SHARED: 2}, 0) - # CHECK-NEXT: %mma_1_1_0 - # CHECK-SAME: (%read_shared_1_0_0, %rotating_reg_3, %mma_1_1_1, None) - # CHECK-NEXT: %read_shared_0_0_2 - # CHECK-NEXT: %read_shared_0_0_3 + # CHECK-NEXT: %mma_M_1_N_1_K_0 + # CHECK-SAME: (%read_shared_M_1_N_0_K_0, %rotating_reg_3, %mma_M_1_N_1_K_1, None) + # CHECK-NEXT: %read_shared_M_0_N_0_K_2 + # CHECK-NEXT: %read_shared_M_0_N_0_K_3 # CHECK-NEXT: %scheduling_group_barrier # CHECK-SAME: ({Operation.MMA: 1, Operation.READ_SHARED: 2}, 0) - # CHECK-NEXT: [mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1, read_shared_0_0_2, read_shared_1_0_1, read_shared_0_0_3, read_shared_0_1_0, rotating_reg_5, read_shared_0_1_1, mma_1_1_0] + # CHECK-NEXT: [mma_M_0_N_0_K_1, mma_M_0_N_1_K_1, mma_M_1_N_0_K_1, mma_M_1_N_1_K_1, read_shared_M_0_N_0_K_2, read_shared_M_1_N_0_K_1, read_shared_M_0_N_0_K_3, read_shared_M_0_N_1_K_0, rotating_reg_5, read_shared_M_0_N_1_K_1, mma_M_1_N_1_K_0] print_subgraph(trace, "region_1", False) # CHECK: %a # CHECK-NEXT: %b # CHECK-NEXT: %c - # CHECK-NEXT: %register_0_0_0 - # CHECK-NEXT: %register_1_1_0 - # CHECK-NEXT: %register_1_0_0 - # CHECK-NEXT: %register_0_1_0 + # CHECK-NEXT: %register_M:0_N:0_K:0 + # CHECK-NEXT: %register_M:0_N:1_K:0 + # CHECK-NEXT: %register_M:1_N:0_K:0 + # CHECK-NEXT: %register_M:1_N:1_K:0 # CHECK-NEXT: %allocate # CHECK-NEXT: %allocate_1 - # CHECK-NEXT: %read_4 - # CHECK-NEXT: %read_5 - # CHECK-NEXT: %write_2 - # CHECK-NEXT: %write_3 - # CHECK-NEXT: %read_shared_0_1_0 - # CHECK-NEXT: %read_shared_0_1_1 - # CHECK-NEXT: %read_shared_0_0_1 - # CHECK-NEXT: %read_shared_0_0_2 - # CHECK-NEXT: %read_shared_0_0_0 - # CHECK-NEXT: %read_shared_0_0_3 - # CHECK-NEXT: %read_6 - # CHECK-NEXT: %read_7 - # CHECK-NEXT: %read_shared_1_0_0 - # CHECK-NEXT: %read_shared_1_0_1 - # CHECK-NEXT: %mma_0_0_0 - # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_0_3, %register_0_0_0, None) - # CHECK-NEXT: %mma_0_1_0 - # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_1_0, %register_0_1_0, None) - # CHECK-NEXT: %mma_0_0_1 - # CHECK-SAME: (%read_shared_0_0_1, %read_shared_0_0_2, %mma_0_0_0, None) - # CHECK-NEXT: %mma_1_0_0 - # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_0_3, %register_1_0_0, None) - # CHECK-NEXT: %write_4 - # CHECK-NEXT: %write_5 - # CHECK-NEXT: %mma_1_0_1 - # CHECK-SAME: (%read_shared_1_0_1, %read_shared_0_0_2, %mma_1_0_0, None) - # CHECK-NEXT: %mma_0_1_1 - # CHECK-SAME: (%read_shared_0_0_1, %read_shared_0_1_1, %mma_0_1_0, None) - # CHECK-NEXT: %read_shared_0_1_2 - # CHECK-NEXT: %read_shared_0_1_3 - # CHECK-NEXT: %mma_1_1_0 - # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_1_0, %register_1_1_0, None) - # CHECK-NEXT: %read_shared_0_0_4 - # CHECK-NEXT: %read_shared_0_0_5 + # CHECK-NEXT: %read_20 + # CHECK-NEXT: %read_21 + # CHECK-NEXT: %write_10 + # CHECK-NEXT: %write_11 + # CHECK-NEXT: %read_shared_M_0_N_1_K_0 + # CHECK-NEXT: %read_shared_M_0_N_1_K_1 + # CHECK-NEXT: %read_shared_M_0_N_0_K_1 + # CHECK-NEXT: %read_shared_M_0_N_0_K_2 + # CHECK-NEXT: %read_shared_M_0_N_0_K_0 + # CHECK-NEXT: %read_shared_M_0_N_0_K_3 + # CHECK-NEXT: %read_22 + # CHECK-NEXT: %read_23 + # CHECK-NEXT: %read_shared_M_1_N_0_K_0 + # CHECK-NEXT: %read_shared_M_1_N_0_K_1 + # CHECK-NEXT: %mma_M_0_N_0_K_0 + # CHECK-SAME: (%read_shared_M_0_N_0_K_0, %read_shared_M_0_N_0_K_3, %register_M:0_N:0_K:0, None) + # CHECK-NEXT: %mma_M_0_N_1_K_0 + # CHECK-SAME: (%read_shared_M_0_N_0_K_0, %read_shared_M_0_N_1_K_0, %register_M:0_N:1_K:0, None) + # CHECK-NEXT: %mma_M_0_N_0_K_1 + # CHECK-SAME: (%read_shared_M_0_N_0_K_1, %read_shared_M_0_N_0_K_2, %mma_M_0_N_0_K_0, None) + # CHECK-NEXT: %mma_M_1_N_0_K_0 + # CHECK-SAME: (%read_shared_M_1_N_0_K_0, %read_shared_M_0_N_0_K_3, %register_M:1_N:0_K:0, None) + # CHECK-NEXT: %write_12 + # CHECK-NEXT: %write_13 + # CHECK-NEXT: %mma_M_0_N_1_K_1 + # CHECK-SAME: (%read_shared_M_0_N_0_K_1, %read_shared_M_0_N_1_K_1, %mma_M_0_N_1_K_0, None) + # CHECK-NEXT: %mma_M_1_N_0_K_1 + # CHECK-SAME: (%read_shared_M_1_N_0_K_1, %read_shared_M_0_N_0_K_2, %mma_M_1_N_0_K_0, None) + # CHECK-NEXT: %read_shared_M_0_N_1_K_2 + # CHECK-NEXT: %read_shared_M_0_N_1_K_3 + # CHECK-NEXT: %mma_M_1_N_1_K_0 + # CHECK-SAME: (%read_shared_M_1_N_0_K_0, %read_shared_M_0_N_1_K_0, %register_M:1_N:1_K:0, None) + # CHECK-NEXT: %read_shared_M_0_N_0_K_4 + # CHECK-NEXT: %read_shared_M_0_N_0_K_5 # CHECK-NEXT: %reduction_1 - # CHECK-NEXT: %getresult_1_1_0 - # CHECK-NEXT: %getresult_1_0_0 - # CHECK-NEXT: %getresult_0_1_0 - # CHECK-NEXT: %getresult_0_0_0 - # CHECK-NEXT: %get_result_4 - # CHECK-NEXT: %get_result_5 - # CHECK-NEXT: %get_result_6 - # CHECK-NEXT: %get_result_7 - # CHECK-NEXT: %get_result_8 + # CHECK-NEXT: %getresult_M:0_N:0_K:0 + # CHECK-NEXT: %getresult_M:0_N:1_K:0 + # CHECK-NEXT: %getresult_M:1_N:0_K:0 + # CHECK-NEXT: %getresult_M:1_N:1_K:0 # CHECK-NEXT: %get_result_9 # CHECK-NEXT: %get_result_10 - # CHECK-NEXT: %mma_1_1_1 - # CHECK-SAME: (%get_result_5, %get_result_8, %get_result_10, None) - # CHECK-NEXT: %read_shared_0_0_6 - # CHECK-NEXT: %read_shared_0_0_7 - # CHECK-NEXT: %read_shared_1_0_2 - # CHECK-NEXT: %read_shared_1_0_3 - # CHECK-NEXT: %mma_0_0_2 - # CHECK-SAME: (%read_shared_0_0_6, %read_shared_0_0_7, %getresult_0_0_0, None) - # CHECK-NEXT: %mma_0_1_2 - # CHECK-SAME: (%read_shared_0_0_6, %get_result_7, %getresult_0_1_0, None) - # CHECK-NEXT: %mma_0_0_3 - # CHECK-SAME: (%get_result_4, %get_result_6, %mma_0_0_2, None) - # CHECK-NEXT: %mma_1_0_2 - # CHECK-SAME: (%read_shared_1_0_2, %read_shared_0_0_7, %getresult_1_0_0, None) - # CHECK-NEXT: %mma_1_0_3 - # CHECK-SAME: (%read_shared_1_0_3, %get_result_6, %mma_1_0_2, None) - # CHECK-NEXT: %mma_0_1_3 - # CHECK-SAME: (%get_result_4, %get_result_9, %mma_0_1_2, None) - # CHECK-NEXT: %mma_1_1_2 - # CHECK-SAME: (%read_shared_1_0_2, %get_result_7, %mma_1_1_1, None) - # CHECK-NEXT: %mma_1_1_3 - # CHECK-SAME: (%read_shared_1_0_3, %get_result_9, %mma_1_1_2, None) - # CHECK-NEXT: %write_0_0_0 - # CHECK-NEXT: %write_1_1_0 - # CHECK-NEXT: %write_1_0_0 - # CHECK-NEXT: %write_0_1_0 + # CHECK-NEXT: %get_result_11 + # CHECK-NEXT: %get_result_12 + # CHECK-NEXT: %get_result_13 + # CHECK-NEXT: %get_result_14 + # CHECK-NEXT: %get_result_15 + # CHECK-NEXT: %mma_M_1_N_1_K_1 + # CHECK-SAME: (%get_result_10, %get_result_13, %get_result_15, None) + # CHECK-NEXT: %read_shared_M_0_N_0_K_6 + # CHECK-NEXT: %read_shared_M_0_N_0_K_7 + # CHECK-NEXT: %read_shared_M_1_N_0_K_2 + # CHECK-NEXT: %read_shared_M_1_N_0_K_3 + # CHECK-NEXT: %mma_M_0_N_0_K_2 + # CHECK-SAME: (%read_shared_M_0_N_0_K_6, %read_shared_M_0_N_0_K_7, %getresult_M:0_N:0_K:0, None) + # CHECK-NEXT: %mma_M_0_N_1_K_2 + # CHECK-SAME: (%read_shared_M_0_N_0_K_6, %get_result_12, %getresult_M:0_N:1_K:0, None) + # CHECK-NEXT: %mma_M_0_N_0_K_3 + # CHECK-SAME: (%get_result_9, %get_result_11, %mma_M_0_N_0_K_2, None) + # CHECK-NEXT: %mma_M_1_N_0_K_2 + # CHECK-SAME: (%read_shared_M_1_N_0_K_2, %read_shared_M_0_N_0_K_7, %getresult_M:1_N:0_K:0, None) + # CHECK-NEXT: %mma_M_0_N_1_K_3 + # CHECK-SAME: (%get_result_9, %get_result_14, %mma_M_0_N_1_K_2, None) + # CHECK-NEXT: %mma_M_1_N_0_K_3 + # CHECK-SAME: (%read_shared_M_1_N_0_K_3, %get_result_11, %mma_M_1_N_0_K_2, None) + # CHECK-NEXT: %mma_M_1_N_1_K_2 + # CHECK-SAME: (%read_shared_M_1_N_0_K_2, %get_result_12, %mma_M_1_N_1_K_1, None) + # CHECK-NEXT: %mma_M_1_N_1_K_3 + # CHECK-SAME: (%read_shared_M_1_N_0_K_3, %get_result_14, %mma_M_1_N_1_K_2, None) + # CHECK-NEXT: %write_M:0_N:0_K:0 + # CHECK-NEXT: %write_M:0_N:1_K:0 + # CHECK-NEXT: %write_M:1_N:0_K:0 + # CHECK-NEXT: %write_M:1_N:1_K:0 # CHECK-NEXT: return None