Skip to content

Commit

Permalink
[TKW] Fix chained dynamic vals (#350)
Browse files Browse the repository at this point in the history
* Properly propagate index through chained read.write ops using
`transform_index_backwards`
* Use proper symbols for dynamic vals
* Do not propagate 0 index

---------

Signed-off-by: Ivan Butygin <[email protected]>
  • Loading branch information
Hardcode84 authored Dec 20, 2024
1 parent d4ef311 commit ef8142e
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 23 deletions.
17 changes: 11 additions & 6 deletions iree/turbine/kernel/_support/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,9 @@ class IndexSequence:
size: IndexExpr | int
stride: Optional[IndexExpr | int] = 1

@staticmethod
def _subs(
self, value: int | IndexExpr, map: dict[IndexExpr, IndexExpr]
value: int | IndexExpr, map: dict[IndexExpr, IndexExpr]
) -> int | IndexExpr:
if isinstance(value, (sympy.Basic, IndexSequence)):
return value.subs(map)
Expand All @@ -423,11 +424,15 @@ def subs(self, map: dict[IndexExpr, IndexExpr]):
stride = self._subs(self.stride, map)
return IndexSequence(start, size, stride)

def apply_expr(self, symbol: IndexExpr, expr: IndexExpr):
start = self._subs(expr, {symbol: self.start})
size = self._subs(expr, {symbol: self.size})
stride = self._subs(expr, {symbol: self.stride})
@staticmethod
def from_expr(expr: IndexExpr, subs: dict[IndexExpr, Any]):
start_subs = {k: v.start for k, v in subs.items()}
size_subs = {k: v.size for k, v in subs.items()}
stride_subs = {k: v.stride for k, v in subs.items()}
start = IndexSequence._subs(expr, start_subs)
size = IndexSequence._subs(expr, size_subs)
stride = IndexSequence._subs(expr, stride_subs)
return IndexSequence(start, size, stride)

def __repr__(self):
def __repr__(self) -> str:
return f"{self.start} : {self.size} : {self.stride}"
41 changes: 33 additions & 8 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,17 +1012,32 @@ def transform_index_backwards(
i = self.mapping_dynamic_vals.index(arg)
iters = self.mapping.iters
mapping = self.mapping.dynamic_val_mappings[i]
subs = {v: k for k, v in zip(iters, mapping.keys())}
return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()}

# This logic relies on fact out mapping is identity.
subs = {
k: index[v] for k, v in zip(iters, self.mapping.output_mapping.keys())
}
return {
k: IndexSequence.from_expr(mapping[k], subs)
for k in arg.type.symbolic_shape
}

return index

def get_derived_indices(
self,
) -> list[tuple[dict[IndexSymbol, IndexSequence], fx.Node]]:
def transform_idx(arg):
new_index = self.transform_index_backwards(self.index, arg)
return {k: v for k, v in zip(arg.type.symbolic_shape, new_index.values())}
# Treat zero index as 'not-set' and does't propagate it.
# TODO: `set_thread_independent_index` currently blindly sets zero
# index to all dims which are not participating in constraints, we
# need to refactor `index_sequence_analysis` into proper dataflow
# analysis.
return {
k: v
for k, v in self.transform_index_backwards(self.index, arg).items()
if v.start != 0
}

return [(arg, transform_idx(arg)) for arg in self.mapping_dynamic_vals]

Expand Down Expand Up @@ -1252,17 +1267,27 @@ def transform_index_backwards(
i = self.mapping_dynamic_vals.index(arg)
iters = self.mapping.iters
mapping = self.mapping.dynamic_val_mappings[i]
subs = {v: k for k, v in zip(iters, mapping.keys())}
return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()}

# This logic relies on fact in mapping is identity.
subs = {
k: index[v] for k, v in zip(iters, self.mapping.input_mapping.keys())
}
return {
k: IndexSequence.from_expr(mapping[k], subs)
for k in arg.type.symbolic_shape
}

return index

def get_derived_indices(
self,
) -> list[tuple[dict[IndexSymbol, IndexSequence], fx.Node]]:
def transform_idx(arg):
new_index = self.transform_index_backwards(self.index, arg)
return {k: v for k, v in zip(arg.type.symbolic_shape, new_index.values())}
return {
k: v
for k, v in self.transform_index_backwards(self.index, arg).items()
if v.start != 0
}

return [(arg, transform_idx(arg)) for arg in self.mapping_dynamic_vals]

Expand Down
7 changes: 5 additions & 2 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,9 @@ def combine_derived_index(

new_index = copy(src_index)
for dim, new_idx in dst_index.items():
assert dim in src_index, f"Dim {dim} not in index {src_index}"
if dim not in src_index:
continue

old_idx = src_index[dim]
if old_idx == new_idx:
continue
Expand All @@ -306,7 +308,8 @@ def set_derived_index(trace):
custom = get_custom(current)
custom.index = combine_derived_index(custom.index, index)
for inp in get_inputs(current)[0]:
worklist.append((inp, index))
new_index = custom.transform_index_backwards(custom.index, inp)
worklist.append((inp, new_index))


def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]):
Expand Down
91 changes: 86 additions & 5 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
N = tkl.sym.N
K = tkl.sym.K
B = tkl.sym.B
ONE = tkl.sym.ONE
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
Expand All @@ -29,17 +28,19 @@
ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0


def codegen_test_context(canonicalize: bool = False, dynamic_symbols=[]):
def codegen_test_context(
canonicalize: bool = False, dynamic_symbols=[], additional_symbols={}
):
bindings = {
M: 16,
N: 16,
K: 16,
BLOCK_M: 16,
BLOCK_N: 16,
BLOCK_K: 16,
ONE: 1,
ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value,
}
bindings.update(additional_symbols)

# Remove dynamic symbols from the bindings.
for sym in dynamic_symbols:
Expand Down Expand Up @@ -461,6 +462,7 @@ def test(

@run_test
def test_read_write_dynamic_mapping_broadcast():
ONE = tkl.sym.ONE
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64,
Expand All @@ -480,7 +482,7 @@ def test_read_write_dynamic_mapping_broadcast():
num_iterators=2,
inputs={M: i, N: k + j % 16},
outputs={M: i, N: j},
dynamic_val_mappings={M: i, N: j // 16},
dynamic_val_mappings={M: i, ONE: j // 16},
)

@tkw.wave(constraints)
Expand All @@ -498,7 +500,7 @@ def test(
)
tkw.write(res, b, elements_per_thread=16)

with codegen_test_context(canonicalize=True):
with codegen_test_context(canonicalize=True, additional_symbols={ONE: 1}):
a = torch.randn(16, 16, dtype=torch.float16)
off = torch.randint(16, (16, 1), dtype=torch.int32)
b = torch.zeros(16, 16, dtype=torch.float16)
Expand All @@ -512,6 +514,85 @@ def test(
# CHECK: vector.store %[[RES]], %{{.*}}[%[[M]], %{{.*}}] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xf16>


@run_test
def test_read_write_dynamic_mapping_chain():
SIZE1 = tkl.sym.SIZE1
SIZE2 = tkl.sym.SIZE2
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
vector_shapes={M: 16, N: 4, SIZE1: 1, SIZE2: 1},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
k = tkw.IndexMapping.dynamic_val(0)
mapping1 = tkw.IndexMapping(
num_iterators=2,
inputs={M: i, SIZE2: k},
outputs={M: i, SIZE2: j},
dynamic_val_mappings={M: i, SIZE1: j // 2},
)
mapping2 = tkw.IndexMapping(
num_iterators=2,
inputs={M: i, N: k + j % 4},
outputs={M: i, N: j},
dynamic_val_mappings={M: i, SIZE2: j // 4},
)

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
off1: tkl.Memory[M, SIZE1, ADDRESS_SPACE, tkl.i32],
off2: tkl.Memory[M, SIZE2, ADDRESS_SPACE, tkl.i32],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
offset1 = tkw.read(off1, elements_per_thread=1)
offset2 = tkw.read(
off2,
mapping=mapping1,
mapping_dynamic_vals=(offset1,),
elements_per_thread=1,
)
res = tkw.read(
a,
mapping=mapping2,
mapping_dynamic_vals=(offset2,),
elements_per_thread=4,
)
tkw.write(res, b, elements_per_thread=4)

with codegen_test_context(
canonicalize=True, additional_symbols={BLOCK_N: 4, SIZE1: 2, SIZE2: 4}
):
a = torch.randn(16, 16, dtype=torch.float16)
off1 = torch.randint(2, (16, 2), dtype=torch.int32)
off2 = torch.randint(16, (16, 4), dtype=torch.int32)
b = torch.zeros(16, 16, dtype=torch.float16)
print(test(a, off1, off2, b).module_op)

# CHECK-LABEL: func.func @test
# CHECK: %[[C8:.*]] = arith.constant 8 : index
# CHECK: %[[thread_id_y:.*]] = gpu.thread_id y
# CHECK: %[[D7:.*]] = arith.addi %{{.*}}, %[[thread_id_y]] overflow<nsw, nuw> : index
# CHECK: %[[D8:.*]] = vector.load %{{.*}}[%[[D5:.*]], %[[D7]]] : memref<16x2xi32, strided<[2, 1], offset: ?>>, vector<1xi32>
# CHECK: %[[D10:.*]] = arith.index_cast %[[D8]] : vector<1xi32> to vector<1xindex>
# CHECK: %[[D11:.*]] = vector.extract %[[D10]][0] : index from vector<1xindex>
# CHECK: %[[D12:.*]] = vector.load %{{.*}}[%[[D5]], %[[D11]]] : memref<16x4xi32, strided<[4, 1], offset: ?>>, vector<1xi32>
# CHECK: %[[D14:.*]] = arith.index_cast %[[D12]] : vector<1xi32> to vector<1xindex>
# CHECK: %[[D15:.*]] = vector.extract %[[D14]][0] : index from vector<1xindex>
# CHECK: %[[D16:.*]] = vector.load %{{.*}}[%[[D5]], %[[D15]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16>
# CHECK: %[[D19:.*]] = arith.muli %[[thread_id_y]], %[[C8]] overflow<nsw, nuw> : index
# CHECK: %[[D20:.*]] = arith.addi %[[D19]], %{{.*}} overflow<nsw, nuw> : index
# CHECK: vector.store %[[D16]], %{{.*}}[%[[D5]], %[[D20]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16>


@run_test
def test_dynamic_copy():
constraints: list[tkw.Constraint] = [
Expand Down
4 changes: 2 additions & 2 deletions tests/kernel/wave/wave_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def test_offset_read_one(shape, request):
num_iterators=2,
inputs={M: k, N: j},
outputs={M: i, N: j},
dynamic_val_mappings={M: i, N: j // ELEMS_PER_THREAD},
dynamic_val_mappings={M: i, N1: j // ELEMS_PER_THREAD},
)

@tkw.wave(constraints)
Expand Down Expand Up @@ -640,7 +640,7 @@ def test_offset_write_one(shape, request):
num_iterators=2,
inputs={M: i, N: j},
outputs={M: i, N: k + j % ELEMS_PER_THREAD},
dynamic_val_mappings={M: i, N: j // ELEMS_PER_THREAD},
dynamic_val_mappings={M: i, N1: j // ELEMS_PER_THREAD},
)

@tkw.wave(constraints)
Expand Down

0 comments on commit ef8142e

Please sign in to comment.