Skip to content

Commit

Permalink
Redo expansion
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-nod committed Dec 26, 2024
1 parent 20507b7 commit aae1352
Show file tree
Hide file tree
Showing 14 changed files with 1,899 additions and 1,652 deletions.
71 changes: 61 additions & 10 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,18 @@ def update_arg(self, idx_or_name: int | str, value: CustomOp | fx.Node):
else:
raise IndexError("Index out of range")

def copy_core_attributes(self, new_node: fx.Node):
"""
Copy core attributes from the current node to the new node.
"""
core_attributes = ["index", "vector_shapes", "reduction_dim", "iter_idx"]
for attr_name in core_attributes:
if hasattr(self.fx_node, attr_name):
attr = getattr(self.fx_node, attr_name)
if attr_name == "index":
attr = copy.deepcopy(attr)
setattr(new_node, attr_name, attr)

def copy(
self,
new_name: Optional[str] = None,
Expand All @@ -421,14 +433,9 @@ def copy(
new_node = graph.node_copy(self.fx_node, arg_transform=arg_transform)
new_node.tkw_op = self
new_node.tkw_op_name = self.tkw_op_name
if hasattr(self.fx_node, "index"):
new_node.index = copy.deepcopy(self.fx_node.index)
self.copy_core_attributes(new_node)
if new_name:
new_node.name = new_name
if hasattr(self.fx_node, "vector_shapes"):
new_node.vector_shapes = self.fx_node.vector_shapes
if hasattr(self.fx_node, "reduction_dim"):
new_node.reduction_dim = self.fx_node.reduction_dim
return get_custom(new_node)

def replace_all_uses_with(self, new_node: CustomOp | fx.Node):
Expand All @@ -437,6 +444,31 @@ def replace_all_uses_with(self, new_node: CustomOp | fx.Node):
new_node = new_node.fx_node
self.fx_node.replace_all_uses_with(new_node)

def replace_all_uses_with_except(
self, new_node: CustomOp | fx.Node, except_nodes: list[CustomOp]
):
"""Replace all uses of the current node with the new node except for the nodes in except_nodes."""
for user in self.users:
if user in except_nodes:
continue
indices = user.get_node_arg_index(self)
if not isinstance(indices, Sequence):
indices = [indices]
for idx in indices:
if isinstance(user.node_args[idx], Sequence):
sub_idx = user.node_args[idx].index(self)
new_nodes = [
(
user.node_args[idx][x].fx_node
if x != sub_idx
else new_node.fx_node
)
for x in range(len(user.node_args[idx]))
]
user.update_arg(idx, new_nodes)
else:
user.update_arg(idx, new_node.fx_node)

def erase(self):
"""Erase the current node from the graph where it exists."""
assert (
Expand Down Expand Up @@ -470,7 +502,18 @@ def node_args(self) -> dict[int, Any]:
return custom_args

def get_node_arg_index(self, arg: CustomOp) -> Optional[CustomOp | list[CustomOp]]:
return next(key for key, value in self.node_args.items() if value == arg)
keys = []
for key, value in self.node_args.items():
if isinstance(value, Sequence):
if arg in value:
keys.append(key)
elif value == arg:
keys.append(key)
if not keys:
return None
if len(keys) == 1:
return keys[0]
return keys

@property
def users(self) -> list[Any]:
Expand Down Expand Up @@ -785,9 +828,15 @@ class IterArg(Placeholder):
def parent_op(self):
return get_custom(self.graph.parent_op)

def get_iter_idx(self):
src_reduction = self.parent_op()
return src_reduction.iter_args(self.graph).index(self.fx_node)
@property
def iter_idx(self):
if hasattr(self.fx_node, "iter_idx"):
return self.fx_node.iter_idx
return None

@iter_idx.setter
def iter_idx(self, value):
self.fx_node.iter_idx = value


# Ops modeling TKW operations in the kernel language
Expand Down Expand Up @@ -1157,6 +1206,8 @@ def iter_args(self, graph: fx.Graph) -> list[fx.Node]:
custom = get_custom(nested_node)
if isinstance(custom, IterArg):
iter_args.append(nested_node)
# Sort by iter_idx.
iter_args = sorted(iter_args, key=lambda x: get_custom(x).iter_idx)
return iter_args

def captured_vars(self, graph: fx.Graph) -> list[fx.Node]:
Expand Down
Loading

0 comments on commit aae1352

Please sign in to comment.