diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index ffa73618..b2657ab4 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -1013,7 +1013,10 @@ def transform_index_backwards( iters = self.mapping.iters mapping = self.mapping.dynamic_val_mappings[i] subs = {v: k for k, v in zip(iters, mapping.keys())} - return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()} + return { + k: v.apply_expr(subs[k2], mapping[k2]) + for k, (k2, v) in zip(arg.type.symbolic_shape, index.items()) + } return index @@ -1021,8 +1024,11 @@ def get_derived_indices( self, ) -> list[tuple[dict[IndexSymbol, IndexSequence], fx.Node]]: def transform_idx(arg): - new_index = self.transform_index_backwards(self.index, arg) - return {k: v for k, v in zip(arg.type.symbolic_shape, new_index.values())} + return { + k: v + for k, v in self.transform_index_backwards(self.index, arg).items() + if v.start != 0 + } return [(arg, transform_idx(arg)) for arg in self.mapping_dynamic_vals] @@ -1253,7 +1259,10 @@ def transform_index_backwards( iters = self.mapping.iters mapping = self.mapping.dynamic_val_mappings[i] subs = {v: k for k, v in zip(iters, mapping.keys())} - return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()} + return { + k: v.apply_expr(subs[k2], mapping[k2]) + for k, (k2, v) in zip(arg.type.symbolic_shape, index.items()) + } return index @@ -1261,8 +1270,11 @@ def get_derived_indices( self, ) -> list[tuple[dict[IndexSymbol, IndexSequence], fx.Node]]: def transform_idx(arg): - new_index = self.transform_index_backwards(self.index, arg) - return {k: v for k, v in zip(arg.type.symbolic_shape, new_index.values())} + return { + k: v + for k, v in self.transform_index_backwards(self.index, arg).items() + if v.start != 0 + } return [(arg, transform_idx(arg)) for arg in self.mapping_dynamic_vals] diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 6bf5f639..dbbd5330 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -281,7 +281,9 @@ def combine_derived_index( new_index = copy(src_index) for dim, new_idx in dst_index.items(): - assert dim in src_index, f"Dim {dim} not in index {src_index}" + if dim not in src_index: + continue + old_idx = src_index[dim] if old_idx == new_idx: continue @@ -306,7 +308,8 @@ def set_derived_index(trace): custom = get_custom(current) custom.index = combine_derived_index(custom.index, index) for inp in get_inputs(current)[0]: - worklist.append((inp, index)) + new_index = custom.transform_index_backwards(custom.index, inp) + worklist.append((inp, new_index)) def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]): diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 8bbf75fc..fc2590ac 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -18,7 +18,6 @@ N = tkl.sym.N K = tkl.sym.K B = tkl.sym.B -ONE = tkl.sym.ONE BLOCK_M = tkl.sym.BLOCK_M BLOCK_N = tkl.sym.BLOCK_N BLOCK_K = tkl.sym.BLOCK_K @@ -29,7 +28,9 @@ ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 -def codegen_test_context(canonicalize: bool = False, dynamic_symbols=[]): +def codegen_test_context( + canonicalize: bool = False, dynamic_symbols=[], additional_symbols={} +): bindings = { M: 16, N: 16, @@ -37,9 +38,9 @@ def codegen_test_context(canonicalize: bool = False, dynamic_symbols=[]): BLOCK_M: 16, BLOCK_N: 16, BLOCK_K: 16, - ONE: 1, ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, } + bindings.update(additional_symbols) # Remove dynamic symbols from the bindings. for sym in dynamic_symbols: @@ -461,6 +462,7 @@ def test( @run_test def test_read_write_dynamic_mapping_broadcast(): + ONE = tkl.sym.ONE constraints: list[tkw.Constraint] = [ tkw.HardwareConstraint( threads_per_wave=64, @@ -498,7 +500,7 @@ def test( ) tkw.write(res, b, elements_per_thread=16) - with codegen_test_context(canonicalize=True): + with codegen_test_context(canonicalize=True, additional_symbols={ONE: 1}): a = torch.randn(16, 16, dtype=torch.float16) off = torch.randint(16, (16, 1), dtype=torch.int32) b = torch.zeros(16, 16, dtype=torch.float16) @@ -512,6 +514,85 @@ def test( # CHECK: vector.store %[[RES]], %{{.*}}[%[[M]], %{{.*}}] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xf16> +@run_test +def test_read_write_dynamic_mapping_chain(): + SIZE1 = tkl.sym.SIZE1 + SIZE2 = tkl.sym.SIZE2 + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 16, N: 4, SIZE1: 1, SIZE2: 1}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.dynamic_val(0) + mapping1 = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, SIZE2: k}, + outputs={M: i, SIZE2: j}, + dynamic_val_mappings={M: i, SIZE2: j // 2}, + ) + mapping2 = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, N: k + j % 4}, + outputs={M: i, N: j}, + dynamic_val_mappings={M: i, N: j // 4}, + ) + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + off1: tkl.Memory[M, SIZE1, ADDRESS_SPACE, tkl.i32], + off2: tkl.Memory[M, SIZE2, ADDRESS_SPACE, tkl.i32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + offset1 = tkw.read(off1, elements_per_thread=1) + offset2 = tkw.read( + off2, + mapping=mapping1, + mapping_dynamic_vals=(offset1,), + elements_per_thread=1, + ) + res = tkw.read( + a, + mapping=mapping2, + mapping_dynamic_vals=(offset2,), + elements_per_thread=4, + ) + tkw.write(res, b, elements_per_thread=4) + + with codegen_test_context( + canonicalize=True, additional_symbols={BLOCK_N: 4, SIZE1: 2, SIZE2: 4} + ): + a = torch.randn(16, 16, dtype=torch.float16) + off1 = torch.randint(2, (16, 2), dtype=torch.int32) + off2 = torch.randint(16, (16, 4), dtype=torch.int32) + b = torch.zeros(16, 16, dtype=torch.float16) + print(test(a, off1, off2, b).module_op) + + # CHECK-LABEL: func.func @test + # CHECK: %[[C8:.*]] = arith.constant 8 : index + # CHECK: %[[thread_id_y:.*]] = gpu.thread_id y + # CHECK: %[[D7:.*]] = arith.addi %{{.*}}, %[[thread_id_y]] overflow : index + # CHECK: %[[D8:.*]] = vector.load %{{.*}}[%[[D5:.*]], %[[D7]]] : memref<16x2xi32, strided<[2, 1], offset: ?>>, vector<1xi32> + # CHECK: %[[D10:.*]] = arith.index_cast %[[D8]] : vector<1xi32> to vector<1xindex> + # CHECK: %[[D11:.*]] = vector.extract %[[D10]][0] : index from vector<1xindex> + # CHECK: %[[D12:.*]] = vector.load %{{.*}}[%[[D5]], %[[D11]]] : memref<16x4xi32, strided<[4, 1], offset: ?>>, vector<1xi32> + # CHECK: %[[D14:.*]] = arith.index_cast %[[D12]] : vector<1xi32> to vector<1xindex> + # CHECK: %[[D15:.*]] = vector.extract %[[D14]][0] : index from vector<1xindex> + # CHECK: %[[D16:.*]] = vector.load %{{.*}}[%[[D5]], %[[D15]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16> + # CHECK: %[[D19:.*]] = arith.muli %[[thread_id_y]], %[[C8]] overflow : index + # CHECK: %[[D20:.*]] = arith.addi %[[D19]], %{{.*}} overflow : index + # CHECK: vector.store %[[D16]], %{{.*}}[%[[D5]], %[[D20]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16> + + @run_test def test_dynamic_copy(): constraints: list[tkw.Constraint] = [