Skip to content

Commit

Permalink
fix chained dyn vals
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Butygin <[email protected]>
  • Loading branch information
Hardcode84 committed Dec 19, 2024
1 parent 08b5752 commit ce646d1
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 12 deletions.
24 changes: 18 additions & 6 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,16 +1013,22 @@ def transform_index_backwards(
iters = self.mapping.iters
mapping = self.mapping.dynamic_val_mappings[i]
subs = {v: k for k, v in zip(iters, mapping.keys())}
return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()}
return {
k: v.apply_expr(subs[k2], mapping[k2])
for k, (k2, v) in zip(arg.type.symbolic_shape, index.items())
}

return index

def get_derived_indices(
self,
) -> list[tuple[dict[IndexSymbol, IndexSequence], fx.Node]]:
def transform_idx(arg):
new_index = self.transform_index_backwards(self.index, arg)
return {k: v for k, v in zip(arg.type.symbolic_shape, new_index.values())}
return {
k: v
for k, v in self.transform_index_backwards(self.index, arg).items()
if v.start != 0
}

return [(arg, transform_idx(arg)) for arg in self.mapping_dynamic_vals]

Expand Down Expand Up @@ -1253,16 +1259,22 @@ def transform_index_backwards(
iters = self.mapping.iters
mapping = self.mapping.dynamic_val_mappings[i]
subs = {v: k for k, v in zip(iters, mapping.keys())}
return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()}
return {
k: v.apply_expr(subs[k2], mapping[k2])
for k, (k2, v) in zip(arg.type.symbolic_shape, index.items())
}

return index

def get_derived_indices(
self,
) -> list[tuple[dict[IndexSymbol, IndexSequence], fx.Node]]:
def transform_idx(arg):
new_index = self.transform_index_backwards(self.index, arg)
return {k: v for k, v in zip(arg.type.symbolic_shape, new_index.values())}
return {
k: v
for k, v in self.transform_index_backwards(self.index, arg).items()
if v.start != 0
}

return [(arg, transform_idx(arg)) for arg in self.mapping_dynamic_vals]

Expand Down
7 changes: 5 additions & 2 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,9 @@ def combine_derived_index(

new_index = copy(src_index)
for dim, new_idx in dst_index.items():
assert dim in src_index, f"Dim {dim} not in index {src_index}"
if dim not in src_index:
continue

old_idx = src_index[dim]
if old_idx == new_idx:
continue
Expand All @@ -306,7 +308,8 @@ def set_derived_index(trace):
custom = get_custom(current)
custom.index = combine_derived_index(custom.index, index)
for inp in get_inputs(current)[0]:
worklist.append((inp, index))
new_index = custom.transform_index_backwards(custom.index, inp)
worklist.append((inp, new_index))


def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]):
Expand Down
89 changes: 85 additions & 4 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
N = tkl.sym.N
K = tkl.sym.K
B = tkl.sym.B
ONE = tkl.sym.ONE
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
Expand All @@ -29,17 +28,19 @@
ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0


def codegen_test_context(canonicalize: bool = False, dynamic_symbols=[]):
def codegen_test_context(
canonicalize: bool = False, dynamic_symbols=[], additional_symbols={}
):
bindings = {
M: 16,
N: 16,
K: 16,
BLOCK_M: 16,
BLOCK_N: 16,
BLOCK_K: 16,
ONE: 1,
ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value,
}
bindings.update(additional_symbols)

# Remove dynamic symbols from the bindings.
for sym in dynamic_symbols:
Expand Down Expand Up @@ -461,6 +462,7 @@ def test(

@run_test
def test_read_write_dynamic_mapping_broadcast():
ONE = tkl.sym.ONE
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64,
Expand Down Expand Up @@ -498,7 +500,7 @@ def test(
)
tkw.write(res, b, elements_per_thread=16)

with codegen_test_context(canonicalize=True):
with codegen_test_context(canonicalize=True, additional_symbols={ONE: 1}):
a = torch.randn(16, 16, dtype=torch.float16)
off = torch.randint(16, (16, 1), dtype=torch.int32)
b = torch.zeros(16, 16, dtype=torch.float16)
Expand All @@ -512,6 +514,85 @@ def test(
# CHECK: vector.store %[[RES]], %{{.*}}[%[[M]], %{{.*}}] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xf16>


@run_test
def test_read_write_dynamic_mapping_chain():
SIZE1 = tkl.sym.SIZE1
SIZE2 = tkl.sym.SIZE2
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
vector_shapes={M: 16, N: 4, SIZE1: 1, SIZE2: 1},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
k = tkw.IndexMapping.dynamic_val(0)
mapping1 = tkw.IndexMapping(
num_iterators=2,
inputs={M: i, SIZE2: k},
outputs={M: i, SIZE2: j},
dynamic_val_mappings={M: i, SIZE2: j // 2},
)
mapping2 = tkw.IndexMapping(
num_iterators=2,
inputs={M: i, N: k + j % 4},
outputs={M: i, N: j},
dynamic_val_mappings={M: i, N: j // 4},
)

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
off1: tkl.Memory[M, SIZE1, ADDRESS_SPACE, tkl.i32],
off2: tkl.Memory[M, SIZE2, ADDRESS_SPACE, tkl.i32],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
offset1 = tkw.read(off1, elements_per_thread=1)
offset2 = tkw.read(
off2,
mapping=mapping1,
mapping_dynamic_vals=(offset1,),
elements_per_thread=1,
)
res = tkw.read(
a,
mapping=mapping2,
mapping_dynamic_vals=(offset2,),
elements_per_thread=4,
)
tkw.write(res, b, elements_per_thread=4)

with codegen_test_context(
canonicalize=True, additional_symbols={BLOCK_N: 4, SIZE1: 2, SIZE2: 4}
):
a = torch.randn(16, 16, dtype=torch.float16)
off1 = torch.randint(2, (16, 2), dtype=torch.int32)
off2 = torch.randint(16, (16, 4), dtype=torch.int32)
b = torch.zeros(16, 16, dtype=torch.float16)
print(test(a, off1, off2, b).module_op)

# CHECK-LABEL: func.func @test
# CHECK: %[[C8:.*]] = arith.constant 8 : index
# CHECK: %[[thread_id_y:.*]] = gpu.thread_id y
# CHECK: %[[D7:.*]] = arith.addi %{{.*}}, %[[thread_id_y]] overflow<nsw, nuw> : index
# CHECK: %[[D8:.*]] = vector.load %{{.*}}[%[[D5:.*]], %[[D7]]] : memref<16x2xi32, strided<[2, 1], offset: ?>>, vector<1xi32>
# CHECK: %[[D10:.*]] = arith.index_cast %[[D8]] : vector<1xi32> to vector<1xindex>
# CHECK: %[[D11:.*]] = vector.extract %[[D10]][0] : index from vector<1xindex>
# CHECK: %[[D12:.*]] = vector.load %{{.*}}[%[[D5]], %[[D11]]] : memref<16x4xi32, strided<[4, 1], offset: ?>>, vector<1xi32>
# CHECK: %[[D14:.*]] = arith.index_cast %[[D12]] : vector<1xi32> to vector<1xindex>
# CHECK: %[[D15:.*]] = vector.extract %[[D14]][0] : index from vector<1xindex>
# CHECK: %[[D16:.*]] = vector.load %{{.*}}[%[[D5]], %[[D15]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16>
# CHECK: %[[D19:.*]] = arith.muli %[[thread_id_y]], %[[C8]] overflow<nsw, nuw> : index
# CHECK: %[[D20:.*]] = arith.addi %[[D19]], %{{.*}} overflow<nsw, nuw> : index
# CHECK: vector.store %[[D16]], %{{.*}}[%[[D5]], %[[D20]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16>


@run_test
def test_dynamic_copy():
constraints: list[tkw.Constraint] = [
Expand Down

0 comments on commit ce646d1

Please sign in to comment.