From ef8142eed8f045e28aa04a9e0bae4707e5090309 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 20 Dec 2024 19:44:33 +0100 Subject: [PATCH] [TKW] Fix chained dynamic vals (#350) * Properly propagate index through chained read.write ops using `transform_index_backwards` * Use proper symbols for dynamic vals * Do not propagate 0 index --------- Signed-off-by: Ivan Butygin --- iree/turbine/kernel/_support/indexing.py | 17 ++-- iree/turbine/kernel/ops/wave_ops.py | 41 +++++++-- .../kernel/wave/index_sequence_analysis.py | 7 +- lit_tests/kernel/wave/codegen.py | 91 ++++++++++++++++++- tests/kernel/wave/wave_e2e_test.py | 4 +- 5 files changed, 137 insertions(+), 23 deletions(-) diff --git a/iree/turbine/kernel/_support/indexing.py b/iree/turbine/kernel/_support/indexing.py index 6be5f043..4613b141 100644 --- a/iree/turbine/kernel/_support/indexing.py +++ b/iree/turbine/kernel/_support/indexing.py @@ -410,8 +410,9 @@ class IndexSequence: size: IndexExpr | int stride: Optional[IndexExpr | int] = 1 + @staticmethod def _subs( - self, value: int | IndexExpr, map: dict[IndexExpr, IndexExpr] + value: int | IndexExpr, map: dict[IndexExpr, IndexExpr] ) -> int | IndexExpr: if isinstance(value, (sympy.Basic, IndexSequence)): return value.subs(map) @@ -423,11 +424,15 @@ def subs(self, map: dict[IndexExpr, IndexExpr]): stride = self._subs(self.stride, map) return IndexSequence(start, size, stride) - def apply_expr(self, symbol: IndexExpr, expr: IndexExpr): - start = self._subs(expr, {symbol: self.start}) - size = self._subs(expr, {symbol: self.size}) - stride = self._subs(expr, {symbol: self.stride}) + @staticmethod + def from_expr(expr: IndexExpr, subs: dict[IndexExpr, Any]): + start_subs = {k: v.start for k, v in subs.items()} + size_subs = {k: v.size for k, v in subs.items()} + stride_subs = {k: v.stride for k, v in subs.items()} + start = IndexSequence._subs(expr, start_subs) + size = IndexSequence._subs(expr, size_subs) + stride = IndexSequence._subs(expr, stride_subs) return IndexSequence(start, size, stride) - def __repr__(self): + def __repr__(self) -> str: return f"{self.start} : {self.size} : {self.stride}" diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index ffa73618..ca80f7d4 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -1012,8 +1012,15 @@ def transform_index_backwards( i = self.mapping_dynamic_vals.index(arg) iters = self.mapping.iters mapping = self.mapping.dynamic_val_mappings[i] - subs = {v: k for k, v in zip(iters, mapping.keys())} - return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()} + + # This logic relies on fact out mapping is identity. + subs = { + k: index[v] for k, v in zip(iters, self.mapping.output_mapping.keys()) + } + return { + k: IndexSequence.from_expr(mapping[k], subs) + for k in arg.type.symbolic_shape + } return index @@ -1021,8 +1028,16 @@ def get_derived_indices( self, ) -> list[tuple[dict[IndexSymbol, IndexSequence], fx.Node]]: def transform_idx(arg): - new_index = self.transform_index_backwards(self.index, arg) - return {k: v for k, v in zip(arg.type.symbolic_shape, new_index.values())} + # Treat zero index as 'not-set' and does't propagate it. + # TODO: `set_thread_independent_index` currently blindly sets zero + # index to all dims which are not participating in constraints, we + # need to refactor `index_sequence_analysis` into proper dataflow + # analysis. + return { + k: v + for k, v in self.transform_index_backwards(self.index, arg).items() + if v.start != 0 + } return [(arg, transform_idx(arg)) for arg in self.mapping_dynamic_vals] @@ -1252,8 +1267,15 @@ def transform_index_backwards( i = self.mapping_dynamic_vals.index(arg) iters = self.mapping.iters mapping = self.mapping.dynamic_val_mappings[i] - subs = {v: k for k, v in zip(iters, mapping.keys())} - return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()} + + # This logic relies on fact in mapping is identity. + subs = { + k: index[v] for k, v in zip(iters, self.mapping.input_mapping.keys()) + } + return { + k: IndexSequence.from_expr(mapping[k], subs) + for k in arg.type.symbolic_shape + } return index @@ -1261,8 +1283,11 @@ def get_derived_indices( self, ) -> list[tuple[dict[IndexSymbol, IndexSequence], fx.Node]]: def transform_idx(arg): - new_index = self.transform_index_backwards(self.index, arg) - return {k: v for k, v in zip(arg.type.symbolic_shape, new_index.values())} + return { + k: v + for k, v in self.transform_index_backwards(self.index, arg).items() + if v.start != 0 + } return [(arg, transform_idx(arg)) for arg in self.mapping_dynamic_vals] diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 6bf5f639..dbbd5330 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -281,7 +281,9 @@ def combine_derived_index( new_index = copy(src_index) for dim, new_idx in dst_index.items(): - assert dim in src_index, f"Dim {dim} not in index {src_index}" + if dim not in src_index: + continue + old_idx = src_index[dim] if old_idx == new_idx: continue @@ -306,7 +308,8 @@ def set_derived_index(trace): custom = get_custom(current) custom.index = combine_derived_index(custom.index, index) for inp in get_inputs(current)[0]: - worklist.append((inp, index)) + new_index = custom.transform_index_backwards(custom.index, inp) + worklist.append((inp, new_index)) def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]): diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 8bbf75fc..1858b45b 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -18,7 +18,6 @@ N = tkl.sym.N K = tkl.sym.K B = tkl.sym.B -ONE = tkl.sym.ONE BLOCK_M = tkl.sym.BLOCK_M BLOCK_N = tkl.sym.BLOCK_N BLOCK_K = tkl.sym.BLOCK_K @@ -29,7 +28,9 @@ ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 -def codegen_test_context(canonicalize: bool = False, dynamic_symbols=[]): +def codegen_test_context( + canonicalize: bool = False, dynamic_symbols=[], additional_symbols={} +): bindings = { M: 16, N: 16, @@ -37,9 +38,9 @@ def codegen_test_context(canonicalize: bool = False, dynamic_symbols=[]): BLOCK_M: 16, BLOCK_N: 16, BLOCK_K: 16, - ONE: 1, ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, } + bindings.update(additional_symbols) # Remove dynamic symbols from the bindings. for sym in dynamic_symbols: @@ -461,6 +462,7 @@ def test( @run_test def test_read_write_dynamic_mapping_broadcast(): + ONE = tkl.sym.ONE constraints: list[tkw.Constraint] = [ tkw.HardwareConstraint( threads_per_wave=64, @@ -480,7 +482,7 @@ def test_read_write_dynamic_mapping_broadcast(): num_iterators=2, inputs={M: i, N: k + j % 16}, outputs={M: i, N: j}, - dynamic_val_mappings={M: i, N: j // 16}, + dynamic_val_mappings={M: i, ONE: j // 16}, ) @tkw.wave(constraints) @@ -498,7 +500,7 @@ def test( ) tkw.write(res, b, elements_per_thread=16) - with codegen_test_context(canonicalize=True): + with codegen_test_context(canonicalize=True, additional_symbols={ONE: 1}): a = torch.randn(16, 16, dtype=torch.float16) off = torch.randint(16, (16, 1), dtype=torch.int32) b = torch.zeros(16, 16, dtype=torch.float16) @@ -512,6 +514,85 @@ def test( # CHECK: vector.store %[[RES]], %{{.*}}[%[[M]], %{{.*}}] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xf16> +@run_test +def test_read_write_dynamic_mapping_chain(): + SIZE1 = tkl.sym.SIZE1 + SIZE2 = tkl.sym.SIZE2 + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 16, N: 4, SIZE1: 1, SIZE2: 1}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.dynamic_val(0) + mapping1 = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, SIZE2: k}, + outputs={M: i, SIZE2: j}, + dynamic_val_mappings={M: i, SIZE1: j // 2}, + ) + mapping2 = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, N: k + j % 4}, + outputs={M: i, N: j}, + dynamic_val_mappings={M: i, SIZE2: j // 4}, + ) + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + off1: tkl.Memory[M, SIZE1, ADDRESS_SPACE, tkl.i32], + off2: tkl.Memory[M, SIZE2, ADDRESS_SPACE, tkl.i32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + offset1 = tkw.read(off1, elements_per_thread=1) + offset2 = tkw.read( + off2, + mapping=mapping1, + mapping_dynamic_vals=(offset1,), + elements_per_thread=1, + ) + res = tkw.read( + a, + mapping=mapping2, + mapping_dynamic_vals=(offset2,), + elements_per_thread=4, + ) + tkw.write(res, b, elements_per_thread=4) + + with codegen_test_context( + canonicalize=True, additional_symbols={BLOCK_N: 4, SIZE1: 2, SIZE2: 4} + ): + a = torch.randn(16, 16, dtype=torch.float16) + off1 = torch.randint(2, (16, 2), dtype=torch.int32) + off2 = torch.randint(16, (16, 4), dtype=torch.int32) + b = torch.zeros(16, 16, dtype=torch.float16) + print(test(a, off1, off2, b).module_op) + + # CHECK-LABEL: func.func @test + # CHECK: %[[C8:.*]] = arith.constant 8 : index + # CHECK: %[[thread_id_y:.*]] = gpu.thread_id y + # CHECK: %[[D7:.*]] = arith.addi %{{.*}}, %[[thread_id_y]] overflow : index + # CHECK: %[[D8:.*]] = vector.load %{{.*}}[%[[D5:.*]], %[[D7]]] : memref<16x2xi32, strided<[2, 1], offset: ?>>, vector<1xi32> + # CHECK: %[[D10:.*]] = arith.index_cast %[[D8]] : vector<1xi32> to vector<1xindex> + # CHECK: %[[D11:.*]] = vector.extract %[[D10]][0] : index from vector<1xindex> + # CHECK: %[[D12:.*]] = vector.load %{{.*}}[%[[D5]], %[[D11]]] : memref<16x4xi32, strided<[4, 1], offset: ?>>, vector<1xi32> + # CHECK: %[[D14:.*]] = arith.index_cast %[[D12]] : vector<1xi32> to vector<1xindex> + # CHECK: %[[D15:.*]] = vector.extract %[[D14]][0] : index from vector<1xindex> + # CHECK: %[[D16:.*]] = vector.load %{{.*}}[%[[D5]], %[[D15]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16> + # CHECK: %[[D19:.*]] = arith.muli %[[thread_id_y]], %[[C8]] overflow : index + # CHECK: %[[D20:.*]] = arith.addi %[[D19]], %{{.*}} overflow : index + # CHECK: vector.store %[[D16]], %{{.*}}[%[[D5]], %[[D20]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16> + + @run_test def test_dynamic_copy(): constraints: list[tkw.Constraint] = [ diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 7f5f82aa..fb5dd710 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -475,7 +475,7 @@ def test_offset_read_one(shape, request): num_iterators=2, inputs={M: k, N: j}, outputs={M: i, N: j}, - dynamic_val_mappings={M: i, N: j // ELEMS_PER_THREAD}, + dynamic_val_mappings={M: i, N1: j // ELEMS_PER_THREAD}, ) @tkw.wave(constraints) @@ -640,7 +640,7 @@ def test_offset_write_one(shape, request): num_iterators=2, inputs={M: i, N: j}, outputs={M: i, N: k + j % ELEMS_PER_THREAD}, - dynamic_val_mappings={M: i, N: j // ELEMS_PER_THREAD}, + dynamic_val_mappings={M: i, N1: j // ELEMS_PER_THREAD}, ) @tkw.wave(constraints)