diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index b2657ab4..c9de0844 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -1014,8 +1014,8 @@ def transform_index_backwards( mapping = self.mapping.dynamic_val_mappings[i] subs = {v: k for k, v in zip(iters, mapping.keys())} return { - k: v.apply_expr(subs[k2], mapping[k2]) - for k, (k2, v) in zip(arg.type.symbolic_shape, index.items()) + k: v.apply_expr(subs[k], mapping[k]) + for k, v in zip(arg.type.symbolic_shape, index.values()) } return index @@ -1260,8 +1260,8 @@ def transform_index_backwards( mapping = self.mapping.dynamic_val_mappings[i] subs = {v: k for k, v in zip(iters, mapping.keys())} return { - k: v.apply_expr(subs[k2], mapping[k2]) - for k, (k2, v) in zip(arg.type.symbolic_shape, index.items()) + k: v.apply_expr(subs[k], mapping[k]) + for k, v in zip(arg.type.symbolic_shape, index.values()) } return index diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index fc2590ac..1858b45b 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -482,7 +482,7 @@ def test_read_write_dynamic_mapping_broadcast(): num_iterators=2, inputs={M: i, N: k + j % 16}, outputs={M: i, N: j}, - dynamic_val_mappings={M: i, N: j // 16}, + dynamic_val_mappings={M: i, ONE: j // 16}, ) @tkw.wave(constraints) @@ -537,13 +537,13 @@ def test_read_write_dynamic_mapping_chain(): num_iterators=2, inputs={M: i, SIZE2: k}, outputs={M: i, SIZE2: j}, - dynamic_val_mappings={M: i, SIZE2: j // 2}, + dynamic_val_mappings={M: i, SIZE1: j // 2}, ) mapping2 = tkw.IndexMapping( num_iterators=2, inputs={M: i, N: k + j % 4}, outputs={M: i, N: j}, - dynamic_val_mappings={M: i, N: j // 4}, + dynamic_val_mappings={M: i, SIZE2: j // 4}, ) @tkw.wave(constraints) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 7f5f82aa..fb5dd710 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -475,7 +475,7 @@ def test_offset_read_one(shape, request): num_iterators=2, inputs={M: k, N: j}, outputs={M: i, N: j}, - dynamic_val_mappings={M: i, N: j // ELEMS_PER_THREAD}, + dynamic_val_mappings={M: i, N1: j // ELEMS_PER_THREAD}, ) @tkw.wave(constraints) @@ -640,7 +640,7 @@ def test_offset_write_one(shape, request): num_iterators=2, inputs={M: i, N: j}, outputs={M: i, N: k + j % ELEMS_PER_THREAD}, - dynamic_val_mappings={M: i, N: j // ELEMS_PER_THREAD}, + dynamic_val_mappings={M: i, N1: j // ELEMS_PER_THREAD}, ) @tkw.wave(constraints)