From 296d283af038710b83d241c9c7a3a8e688a39fec Mon Sep 17 00:00:00 2001 From: John Demme Date: Wed, 3 Jul 2024 07:21:22 -0700 Subject: [PATCH] [PyCDE] Restrict slicing index widths to clog2(len) (#7277) Previously, PyCDE was somewhat more loose on indexes into arrays/bitvectors. Now that we have the pad_or_truncate convenience method, lets be a bit more restrictive. --- frontends/PyCDE/src/pycde/signals.py | 10 ++++---- frontends/PyCDE/test/test_muxing.py | 38 ++++------------------------ 2 files changed, 10 insertions(+), 38 deletions(-) diff --git a/frontends/PyCDE/src/pycde/signals.py b/frontends/PyCDE/src/pycde/signals.py index a4db5087f7e0..2d51af7a7304 100644 --- a/frontends/PyCDE/src/pycde/signals.py +++ b/frontends/PyCDE/src/pycde/signals.py @@ -216,12 +216,12 @@ def _validate_idx(size: int, idx: Union[int, BitVectorSignal]): if isinstance(idx, int): if idx >= size: raise ValueError("Subscript out-of-bounds") + elif isinstance(idx, BitVectorSignal): + if idx.type.width != (size - 1).bit_length(): + raise ValueError("Index must be exactly clog2 of the size of the array") else: - idx = support.get_value(idx) - if idx is None or not isinstance(support.type_to_pytype(idx.type), - ir.IntegerType): - raise TypeError("Subscript on array must be either int or int signal" - f" not {type(idx)}.") + raise TypeError("Subscript on array must be either int or int signal" + f" not {type(idx)}.") def get_slice_bounds(size, idxOrSlice: Union[int, slice]): diff --git a/frontends/PyCDE/test/test_muxing.py b/frontends/PyCDE/test/test_muxing.py index 27228d74bdcd..24cc901fb108 100644 --- a/frontends/PyCDE/test/test_muxing.py +++ b/frontends/PyCDE/test/test_muxing.py @@ -24,8 +24,10 @@ # CHECK: %c0_i3_2 = hw.constant 0 : i3 # CHECK: [[R8:%.+]] = hw.array_get %In[%c0_i3_2] {sv.namehint = "In__0"} : !hw.array<5xarray<4xi3>>, i3 # CHECK: [[R9:%.+]] = hw.array_get [[R8]][%c0_i2] {sv.namehint = "In__0__0"} : !hw.array<4xi3> -# CHECK: %c0_i2_3 = hw.constant 0 : i2 -# CHECK: [[R10:%.+]] = comb.concat %c0_i2_3, %Sel {sv.namehint = "Sel_padto_3"} : i2, i1 +# CHECK: %false = hw.constant false +# CHECK: [[RN9:%.+]] = comb.concat %false, %Sel {sv.namehint = "Sel_padto_2"} : i1, i1 +# CHECK: %false_3 = hw.constant false +# CHECK: [[R10:%.+]] = comb.concat %false_3, [[RN9]] {sv.namehint = "Sel_padto_2_padto_3"} : i1, i2 # CHECK: [[R11:%.+]] = comb.shru bin [[R9]], [[R10]] : i3 # CHECK: [[R12:%.+]] = comb.extract [[R11]] from 0 : (i3) -> i1 # CHECK: hw.output [[R3]], [[R6]], [[R12]], [[R7]] : !hw.array<4xi3>, !hw.array<2xarray<4xi3>>, i1, !hw.array<3xarray<4xi3>> @@ -49,41 +51,11 @@ def create(ports): ports.OutArr = Signal.create([ports.In[0], ports.In[1]]) ports.OutSlice = ports.In[0:3] - ports.OutInt = ports.In[0][0][ports.Sel] + ports.OutInt = ports.In[0][0][ports.Sel.pad_or_truncate(2)] # ----- -# CHECK-LABEL: hw.module @Slicing(in %In : !hw.array<5xarray<4xi8>>, in %Sel8 : i8, in %Sel2 : i2, out OutIntSlice : i2, out OutArrSlice8 : !hw.array<2xarray<4xi8>>, out OutArrSlice2 : !hw.array<2xarray<4xi8>>) -# CHECK: [[R0:%.+]] = hw.array_get %In[%c0_i3] {sv.namehint = "In__0"} : !hw.array<5xarray<4xi8>> -# CHECK: [[R1:%.+]] = hw.array_get %0[%c0_i2] {sv.namehint = "In__0__0"} : !hw.array<4xi8> -# CHECK: [[R2:%.+]] = comb.concat %c0_i6, %Sel2 {sv.namehint = "Sel2_padto_8"} : i6, i2 -# CHECK: [[R3:%.+]] = comb.shru bin [[R1]], [[R2]] : i8 -# CHECK: [[R4:%.+]] = comb.extract [[R3]] from 0 : (i8) -> i2 -# CHECK: [[R5:%.+]] = comb.concat %false, %Sel2 {sv.namehint = "Sel2_padto_3"} : i1, i2 -# CHECK: [[R6:%.+]] = hw.array_slice %In[[[R5]]] : (!hw.array<5xarray<4xi8>>) -> !hw.array<2xarray<4xi8>> -# CHECK: [[R7:%.+]] = comb.extract %Sel8 from 0 : (i8) -> i3 -# CHECK: [[R8:%.+]] = hw.array_slice %In[[[R7]]] : (!hw.array<5xarray<4xi8>>) -> !hw.array<2xarray<4xi8>> -# CHECK: hw.output %4, %8, %6 : i2, !hw.array<2xarray<4xi8>>, !hw.array<2xarray<4xi8>> - - -@unittestmodule() -class Slicing(Module): - In = Input(dim(8, 4, 5)) - Sel8 = Input(types.i8) - Sel2 = Input(types.i2) - - OutIntSlice = Output(types.i2) - OutArrSlice8 = Output(dim(8, 4, 2)) - OutArrSlice2 = Output(dim(8, 4, 2)) - - @generator - def create(ports): - i = ports.In[0][0] - ports.OutIntSlice = i.slice(ports.Sel2, 2) - ports.OutArrSlice2 = ports.In.slice(ports.Sel2, 2) - ports.OutArrSlice8 = ports.In.slice(ports.Sel8, 2) - # CHECK-LABEL: hw.module @SimpleMux2(in %op : i1, in %a : i32, in %b : i32, out out : i32) # CHECK-NEXT: [[r0:%.+]] = comb.mux bin %op, %b, %a