From e50b6dc5326d5e95647ef510dcc23242c17d52a5 Mon Sep 17 00:00:00 2001 From: John Demme Date: Wed, 20 Sep 2023 21:47:54 +0000 Subject: [PATCH] [PyCDE] Switch over to new !seq.clock type This change broke pycde to all hell, but it's something I've been meaning to do for a long time. --- frontends/PyCDE/src/fsm.py | 4 ++- frontends/PyCDE/src/module.py | 5 ++- frontends/PyCDE/src/types.py | 12 +++----- frontends/PyCDE/test/test_constructs.py | 4 +-- frontends/PyCDE/test/test_esi.py | 32 ++++++++++---------- frontends/PyCDE/test/test_esi_errors.py | 10 +++--- frontends/PyCDE/test/test_fsm.py | 4 +-- frontends/PyCDE/test/test_muxing.py | 2 +- frontends/PyCDE/test/test_syntactic_sugar.py | 6 ++-- lib/Bindings/Python/SeqModule.cpp | 9 ++++-- lib/Bindings/Python/support.py | 6 +++- 11 files changed, 50 insertions(+), 44 deletions(-) diff --git a/frontends/PyCDE/src/fsm.py b/frontends/PyCDE/src/fsm.py index 4659912b204b..fbde0bc4ef84 100644 --- a/frontends/PyCDE/src/fsm.py +++ b/frontends/PyCDE/src/fsm.py @@ -109,8 +109,10 @@ def scan_cls(self): f"Multiple initial states specified ({name}, {initial_state}).") initial_state = name + from .types import ClockType for name, v in self.inputs: - if v.width != 1: + if not (isinstance(v, ClockType) or + (hasattr(v, "width") and v.width == 1)): raise ValueError( f"Input port {name} has width {v.width}. For now, FSMs only " "support i1 inputs.") diff --git a/frontends/PyCDE/src/module.py b/frontends/PyCDE/src/module.py index 11593bdd0126..c169064fff8f 100644 --- a/frontends/PyCDE/src/module.py +++ b/frontends/PyCDE/src/module.py @@ -219,7 +219,6 @@ def go(self): def scan_cls(self): """Scan the class for input/output ports and generators. (Most `ModuleLike` will use these.) Store the results for later use.""" - from .types import Bits input_ports = [] output_ports = [] @@ -245,10 +244,10 @@ def scan_cls(self): if isinstance(attr, Clock): clock_ports.add(len(input_ports)) - input_ports.append((attr_name, Bits(1))) + input_ports.append((attr_name, attr.type)) elif isinstance(attr, Reset): reset_ports.add(len(input_ports)) - input_ports.append((attr_name, Bits(1))) + input_ports.append((attr_name, attr.type)) elif isinstance(attr, Input): input_ports.append((attr_name, attr.type)) elif isinstance(attr, Output): diff --git a/frontends/PyCDE/src/types.py b/frontends/PyCDE/src/types.py index 5dc35688a459..c451c7477077 100644 --- a/frontends/PyCDE/src/types.py +++ b/frontends/PyCDE/src/types.py @@ -7,7 +7,7 @@ from .support import get_user_loc from .circt import ir, support -from .circt.dialects import esi, hw, sv +from .circt.dialects import esi, hw, seq, sv from .circt.dialects.esi import ChannelSignaling import typing @@ -147,6 +147,8 @@ def _FromCirctType(type: typing.Union[ir.Type, Type]) -> Type: return Type.__new__(UInt, type) else: return Type.__new__(Bits, type) + if isinstance(type, seq.ClockType): + return Type.__new__(ClockType, type) if isinstance(type, esi.AnyType): return Type.__new__(Any, type) if isinstance(type, esi.ChannelType): @@ -476,16 +478,12 @@ def _from_obj(self, x: int, alias: typing.Optional[TypeAlias] = None): return hwarith.ConstantOp(circt_type, x) -class ClockType(Bits): +class ClockType(Type): """A special single bit to represent a clock. Can't do any special operations on it, except enter it as a implicit clock block.""" - # TODO: the 'clock' type isn't represented in CIRCT IR. It may be useful to - # have it there if for no other reason than being able to round trip this - # type. - def __new__(cls): - return super(ClockType, cls).__new__(cls, 1) + return super(ClockType, cls).__new__(cls, seq.ClockType.get()) def _get_value_class(self): from .signals import ClockSignal diff --git a/frontends/PyCDE/test/test_constructs.py b/frontends/PyCDE/test/test_constructs.py index e0b60d1b5063..321aa3cc0e88 100644 --- a/frontends/PyCDE/test/test_constructs.py +++ b/frontends/PyCDE/test/test_constructs.py @@ -6,7 +6,7 @@ from pycde.dialects import comb from pycde.testing import unittestmodule -# CHECK-LABEL: hw.module @WireAndRegTest(%In: i8, %InCE: i1, %clk: i1, %rst: i1) -> (Out: i8, OutReg: i8, OutRegRst: i8, OutRegCE: i8) +# CHECK-LABEL: hw.module @WireAndRegTest(%In: i8, %InCE: i1, %clk: !seq.clock, %rst: i1) -> (Out: i8, OutReg: i8, OutRegRst: i8, OutRegCE: i8) # CHECK: [[r0:%.+]] = comb.extract %In from 0 {sv.namehint = "In_0upto7"} : (i8) -> i7 # CHECK: [[r1:%.+]] = comb.extract %In from 7 {sv.namehint = "In_7upto8"} : (i8) -> i1 # CHECK: [[r2:%.+]] = comb.concat [[r1]], [[r0]] {sv.namehint = "w1"} : i1, i7 @@ -64,7 +64,7 @@ def create(ports): # CHECK: sv.read_inout %sum__reg1_0_0 : !hw.inout @unittestmodule(print=True, run_passes=True, print_after_passes=True) class SystolicArrayTest(Module): - clk = Input(types.i1) + clk = Clock() col_data = Input(dim(8, 2)) row_data = Input(dim(8, 3)) out = Output(dim(8, 2, 3)) diff --git a/frontends/PyCDE/test/test_esi.py b/frontends/PyCDE/test/test_esi.py index f71def80870b..f2870d8a97f9 100644 --- a/frontends/PyCDE/test/test_esi.py +++ b/frontends/PyCDE/test/test_esi.py @@ -6,7 +6,7 @@ from pycde import esi from pycde.common import Output from pycde.constructs import Wire -from pycde.types import Bits, Channel, ChannelSignaling, UInt +from pycde.types import Bits, Channel, ChannelSignaling, UInt, ClockType from pycde.testing import unittestmodule from pycde.signals import BitVectorSignal, ChannelSignal, Struct @@ -20,7 +20,7 @@ class HostComms: class Producer(Module): - clk = Input(types.i1) + clk = Clock() int_out = OutputChannel(types.i32) @generator @@ -30,7 +30,7 @@ def construct(ports): class Consumer(Module): - clk = Input(types.i1) + clk = Clock() int_in = InputChannel(types.i32) @generator @@ -38,15 +38,15 @@ def construct(ports): HostComms.to_host(ports.int_in, "loopback_out") -# CHECK-LABEL: hw.module @LoopbackTop(%clk: i1, %rst: i1) -# CHECK: %Producer.int_out = hw.instance "Producer" sym @Producer @Producer(clk: %clk: i1) -> (int_out: !esi.channel) -# CHECK: hw.instance "Consumer" sym @Consumer @Consumer(clk: %clk: i1, int_in: %Producer.int_out: !esi.channel) -> ( -# CHECK: esi.service.instance svc @HostComms impl as "cosim"(%clk, %rst) : (i1, i1) -> () +# CHECK-LABEL: hw.module @LoopbackTop(%clk: !seq.clock, %rst: i1) +# CHECK: %Producer.int_out = hw.instance "Producer" sym @Producer @Producer(clk: %clk: !seq.clock) -> (int_out: !esi.channel) +# CHECK: hw.instance "Consumer" sym @Consumer @Consumer(clk: %clk: !seq.clock, int_in: %Producer.int_out: !esi.channel) -> ( +# CHECK: esi.service.instance svc @HostComms impl as "cosim"(%clk, %rst) : (!seq.clock, i1) -> () # CHECK: hw.output -# CHECK-LABEL: hw.module @Producer(%clk: i1) -> (int_out: !esi.channel) +# CHECK-LABEL: hw.module @Producer(%clk: !seq.clock) -> (int_out: !esi.channel) # CHECK: [[R0:%.+]] = esi.service.req.to_client <@HostComms::@from_host>(["loopback_in"]) : !esi.channel # CHECK: hw.output [[R0]] : !esi.channel -# CHECK-LABEL: hw.module @Consumer(%clk: i1, %int_in: !esi.channel) +# CHECK-LABEL: hw.module @Consumer(%clk: !seq.clock, %int_in: !esi.channel) # CHECK: esi.service.req.to_server %int_in -> <@HostComms::@to_host>(["loopback_out"]) : !esi.channel # CHECK: hw.output # CHECK-LABEL: esi.service.decl @HostComms { @@ -56,7 +56,7 @@ def construct(ports): @unittestmodule(print=True) class LoopbackTop(Module): - clk = Clock(types.i1) + clk = Clock() rst = Input(types.i1) @generator @@ -67,8 +67,8 @@ def construct(ports): esi.Cosim(HostComms, ports.clk, ports.rst) -# CHECK-LABEL: hw.module @LoopbackInOutTop(%clk: i1, %rst: i1) -# CHECK: esi.service.instance svc @HostComms impl as "cosim"(%clk, %rst) : (i1, i1) -> () +# CHECK-LABEL: hw.module @LoopbackInOutTop(%clk: !seq.clock, %rst: i1) +# CHECK: esi.service.instance svc @HostComms impl as "cosim"(%clk, %rst) : (!seq.clock, i1) -> () # CHECK: %0 = esi.service.req.inout %chanOutput -> <@HostComms::@req_resp>(["loopback_inout"]) : !esi.channel -> !esi.channel # CHECK: %rawOutput, %valid = esi.unwrap.vr %0, %ready : i32 # CHECK: %1 = comb.extract %rawOutput from 0 : (i32) -> i16 @@ -76,7 +76,7 @@ def construct(ports): # CHECK: hw.output @unittestmodule(print=True) class LoopbackInOutTop(Module): - clk = Clock(types.i1) + clk = Clock() rst = Input(types.i1) @generator @@ -152,7 +152,7 @@ def unwrap_and_pad(ports, input_channel: ChannelSignal): ports.trunk_out_valid = valid -# CHECK-LABEL: hw.module @MultiplexerTop{{.*}}(%clk: i1, %rst: i1, %trunk_in: i256, %trunk_in_valid: i1, %trunk_out_ready: i1) -> (trunk_in_ready: i1, trunk_out: i256, trunk_out_valid: i1) +# CHECK-LABEL: hw.module @MultiplexerTop{{.*}}(%clk: !seq.clock, %rst: i1, %trunk_in: i256, %trunk_in_valid: i1, %trunk_out_ready: i1) -> (trunk_in_ready: i1, trunk_out: i256, trunk_out_valid: i1) # CHECK: %c0_i224 = hw.constant 0 : i224 # CHECK: [[r0:%.+]] = comb.concat %c0_i224, %Consumer.loopback_out : i224, i32 # CHECK: [[r1:%.+]] = comb.extract %trunk_in from 0 {sv.namehint = "trunk_in_0upto32"} : (i256) -> i32 @@ -167,7 +167,7 @@ def unwrap_and_pad(ports, input_channel: ChannelSignal): @unittestmodule(run_passes=True, print_after_passes=True, emit_outputs=True) class MultiplexerTop(Module): - clk = Clock(types.i1) + clk = Clock() rst = Input(types.i1) trunk_in = Input(types.i256) @@ -218,7 +218,7 @@ class PureTest(esi.PureModule): def construct(ports): PassUpService(None) - clk = esi.PureModule.input_port("clk", types.i1) + clk = esi.PureModule.input_port("clk", ClockType()) p = Producer(clk=clk) Consumer(clk=clk, int_in=p.int_out) p2 = Producer(clk=clk, instance_name="prod2") diff --git a/frontends/PyCDE/test/test_esi_errors.py b/frontends/PyCDE/test/test_esi_errors.py index 18dab35c1037..0d0af2357c03 100644 --- a/frontends/PyCDE/test/test_esi_errors.py +++ b/frontends/PyCDE/test/test_esi_errors.py @@ -14,7 +14,7 @@ class HostComms: class Producer(Module): - clk = Input(types.i1) + clk = Clock() int_out = OutputChannel(types.i32) @generator @@ -24,7 +24,7 @@ def construct(ports): class Consumer(Module): - clk = Input(types.i1) + clk = Clock() int_in = InputChannel(types.i32) @generator @@ -34,7 +34,7 @@ def construct(ports): @unittestmodule(print=True) class LoopbackTop(Module): - clk = Clock(types.i1) + clk = Clock() rst = Input(types.i1) @generator @@ -75,7 +75,7 @@ def generate(ports, channels): @unittestmodule(run_passes=True, print_after_passes=True) class MultiplexerTop(Module): - clk = Clock(types.i1) + clk = Clock() rst = Input(types.i1) @generator @@ -105,7 +105,7 @@ def generate(ports, channels): @unittestmodule(run_passes=True, print_after_passes=True) class BrokenTop(Module): - clk = Clock(types.i1) + clk = Clock() rst = Input(types.i1) @generator diff --git a/frontends/PyCDE/test/test_fsm.py b/frontends/PyCDE/test/test_fsm.py index 4f9f553ad181..90fbc3ce2923 100644 --- a/frontends/PyCDE/test/test_fsm.py +++ b/frontends/PyCDE/test/test_fsm.py @@ -8,7 +8,7 @@ from pycde.testing import unittestmodule # FSM state transitions example -# CHECK-LABEL: hw.module @FSMUser(%a: i1, %b: i1, %c: i1, %clk: i1, %rst: i1) -> (is_a: i1, is_b: i1, is_c: i1) +# CHECK-LABEL: hw.module @FSMUser(%a: i1, %b: i1, %c: i1, %clk: !seq.clock, %rst: i1) -> (is_a: i1, is_b: i1, is_c: i1) # CHECK-NEXT: %0:4 = fsm.hw_instance "F0" @F0(%a, %b, %c), clock %clk, reset %rst : (i1, i1, i1) -> (i1, i1, i1, i1) # CHECK-NEXT: hw.output %0#1, %0#2, %0#3 : i1, i1, i1 # CHECK-NEXT: } @@ -105,7 +105,7 @@ class FSMUser(Module): a = Input(types.i1) b = Input(types.i1) c = Input(types.i1) - clk = Input(types.i1) + clk = Clock() rst = Input(types.i1) is_a = Output(types.i1) is_b = Output(types.i1) diff --git a/frontends/PyCDE/test/test_muxing.py b/frontends/PyCDE/test/test_muxing.py index d1a3375e73c9..67f3f1ae69ef 100644 --- a/frontends/PyCDE/test/test_muxing.py +++ b/frontends/PyCDE/test/test_muxing.py @@ -6,7 +6,7 @@ from pycde.testing import unittestmodule from pycde.types import Bits -# CHECK-LABEL: hw.module @ComplexMux(%Clk: i1, %In: !hw.array<5xarray<4xi3>>, %Sel: i1) -> (Out: !hw.array<4xi3>, OutArr: !hw.array<2xarray<4xi3>>, OutInt: i1, OutSlice: !hw.array<3xarray<4xi3>>) +# CHECK-LABEL: hw.module @ComplexMux(%Clk: !seq.clock, %In: !hw.array<5xarray<4xi3>>, %Sel: i1) -> (Out: !hw.array<4xi3>, OutArr: !hw.array<2xarray<4xi3>>, OutInt: i1, OutSlice: !hw.array<3xarray<4xi3>>) # CHECK: %c3_i3 = hw.constant 3 : i3 # CHECK: %0 = hw.array_get %In[%c3_i3] {sv.namehint = "In__3"} : !hw.array<5xarray<4xi3>> # CHECK: %In__3__reg1 = seq.compreg sym @In__3__reg1 %0, %Clk : !hw.array<4xi3> diff --git a/frontends/PyCDE/test/test_syntactic_sugar.py b/frontends/PyCDE/test/test_syntactic_sugar.py index 4a6022327934..662940c67889 100644 --- a/frontends/PyCDE/test/test_syntactic_sugar.py +++ b/frontends/PyCDE/test/test_syntactic_sugar.py @@ -1,6 +1,6 @@ # RUN: %PYTHON% %s | FileCheck %s -from pycde import (Output, Input, generator, types, dim, Module) +from pycde import (Clock, Output, Input, generator, types, dim, Module) from pycde.testing import unittestmodule # CHECK-LABEL: hw.module @Top() @@ -49,7 +49,7 @@ def build(_): # ----- -# CHECK: hw.module @ComplexPorts(%clk: i1, %data_in: !hw.array<3xi32>, %sel: i2, %struct_data_in: !hw.struct) -> (a: i32, b: i32, c: i32) +# CHECK: hw.module @ComplexPorts(%clk: !seq.clock, %data_in: !hw.array<3xi32>, %sel: i2, %struct_data_in: !hw.struct) -> (a: i32, b: i32, c: i32) # CHECK: %c0_i2 = hw.constant 0 : i2 # CHECK: [[REG0:%.+]] = hw.array_get %data_in[%c0_i2] {sv.namehint = "data_in__0"} : !hw.array<3xi32> # CHECK: [[REGR1:%data_in__0__reg1]] = seq.compreg sym @data_in__0__reg1 [[REG0]], %clk : i32 @@ -62,7 +62,7 @@ def build(_): @unittestmodule() class ComplexPorts(Module): - clk = Input(types.i1) + clk = Clock() data_in = Input(dim(32, 3)) sel = Input(types.i2) struct_data_in = Input(types.struct({"foo": types.i36})) diff --git a/lib/Bindings/Python/SeqModule.cpp b/lib/Bindings/Python/SeqModule.cpp index 21104107a6cc..d9ccf730e5f2 100644 --- a/lib/Bindings/Python/SeqModule.cpp +++ b/lib/Bindings/Python/SeqModule.cpp @@ -27,7 +27,10 @@ void circt::python::populateDialectSeqSubmodule(py::module &m) { m.doc() = "Seq dialect Python native extension"; mlir_type_subclass(m, "ClockType", seqTypeIsAClock) - .def_classmethod("get", [](py::object cls, MlirContext ctx) { - return cls(seqClockTypeGet(ctx)); - }); + .def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(seqClockTypeGet(ctx)); + }, + py::arg("cls"), py::arg("context") = py::none()); } diff --git a/lib/Bindings/Python/support.py b/lib/Bindings/Python/support.py index 2a3d9298144d..cf71aff28ad1 100644 --- a/lib/Bindings/Python/support.py +++ b/lib/Bindings/Python/support.py @@ -86,7 +86,7 @@ def type_to_pytype(t) -> ir.Type: if t.__class__ != ir.Type: return t - from .dialects import esi, hw + from .dialects import esi, hw, seq try: return ir.IntegerType(t) except ValueError: @@ -111,6 +111,10 @@ def type_to_pytype(t) -> ir.Type: return hw.InOutType(t) except ValueError: pass + try: + return seq.ClockType(t) + except ValueError: + pass try: return esi.ChannelType(t) except ValueError: