Skip to content

Commit

Permalink
fix symbols
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Butygin <[email protected]>
  • Loading branch information
Hardcode84 committed Dec 19, 2024
1 parent ce646d1 commit e2eb995
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
8 changes: 4 additions & 4 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,8 +1014,8 @@ def transform_index_backwards(
mapping = self.mapping.dynamic_val_mappings[i]
subs = {v: k for k, v in zip(iters, mapping.keys())}
return {
k: v.apply_expr(subs[k2], mapping[k2])
for k, (k2, v) in zip(arg.type.symbolic_shape, index.items())
k: v.apply_expr(subs[k], mapping[k])
for k, v in zip(arg.type.symbolic_shape, index.values())
}

return index
Expand Down Expand Up @@ -1260,8 +1260,8 @@ def transform_index_backwards(
mapping = self.mapping.dynamic_val_mappings[i]
subs = {v: k for k, v in zip(iters, mapping.keys())}
return {
k: v.apply_expr(subs[k2], mapping[k2])
for k, (k2, v) in zip(arg.type.symbolic_shape, index.items())
k: v.apply_expr(subs[k], mapping[k])
for k, v in zip(arg.type.symbolic_shape, index.values())
}

return index
Expand Down
6 changes: 3 additions & 3 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def test_read_write_dynamic_mapping_broadcast():
num_iterators=2,
inputs={M: i, N: k + j % 16},
outputs={M: i, N: j},
dynamic_val_mappings={M: i, N: j // 16},
dynamic_val_mappings={M: i, ONE: j // 16},
)

@tkw.wave(constraints)
Expand Down Expand Up @@ -537,13 +537,13 @@ def test_read_write_dynamic_mapping_chain():
num_iterators=2,
inputs={M: i, SIZE2: k},
outputs={M: i, SIZE2: j},
dynamic_val_mappings={M: i, SIZE2: j // 2},
dynamic_val_mappings={M: i, SIZE1: j // 2},
)
mapping2 = tkw.IndexMapping(
num_iterators=2,
inputs={M: i, N: k + j % 4},
outputs={M: i, N: j},
dynamic_val_mappings={M: i, N: j // 4},
dynamic_val_mappings={M: i, SIZE2: j // 4},
)

@tkw.wave(constraints)
Expand Down
4 changes: 2 additions & 2 deletions tests/kernel/wave/wave_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def test_offset_read_one(shape, request):
num_iterators=2,
inputs={M: k, N: j},
outputs={M: i, N: j},
dynamic_val_mappings={M: i, N: j // ELEMS_PER_THREAD},
dynamic_val_mappings={M: i, N1: j // ELEMS_PER_THREAD},
)

@tkw.wave(constraints)
Expand Down Expand Up @@ -640,7 +640,7 @@ def test_offset_write_one(shape, request):
num_iterators=2,
inputs={M: i, N: j},
outputs={M: i, N: k + j % ELEMS_PER_THREAD},
dynamic_val_mappings={M: i, N: j // ELEMS_PER_THREAD},
dynamic_val_mappings={M: i, N1: j // ELEMS_PER_THREAD},
)

@tkw.wave(constraints)
Expand Down

0 comments on commit e2eb995

Please sign in to comment.