diff --git a/frontends/PyCDE/src/pycde/handshake.py b/frontends/PyCDE/src/pycde/handshake.py index 80ca25f6b983..14397890aac6 100644 --- a/frontends/PyCDE/src/pycde/handshake.py +++ b/frontends/PyCDE/src/pycde/handshake.py @@ -3,13 +3,14 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from __future__ import annotations -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Tuple from .module import Module, ModuleLikeBuilderBase, PortError -from .signals import BitsSignal, ChannelSignal, ClockSignal, Signal +from .signals import (BitsSignal, ChannelSignal, ClockSignal, Signal, + _FromCirctValue) from .system import System -from .support import get_user_loc, obj_to_typed_attribute -from .types import Channel +from .support import clog2, get_user_loc +from .types import Bits, Channel from .circt.dialects import handshake as raw_handshake from .circt import ir @@ -82,7 +83,7 @@ def instantiate(self, module_inst, inputs, instance_name: str): # If the input is a channel signal, the types must match. if signal.type.inner_type != port.type: raise ValueError( - f"Wrong type on input signal '{name}'. Got '{signal.type}'," + f"Wrong type on input signal '{name}'. Got '{signal.type.inner_type}'," f" expected '{port.type}'") assert port.idx is not None circt_inputs[port.idx] = signal.value @@ -124,3 +125,24 @@ class Func(Module): BuilderType: type[ModuleLikeBuilderBase] = FuncBuilder _builder: FuncBuilder + + +def demux(cond: BitsSignal, data: Signal) -> Tuple[Signal, Signal]: + """Demux a signal based on a condition.""" + condbr = raw_handshake.ConditionalBranchOp(cond.value, data.value) + return (_FromCirctValue(condbr.trueResult), + _FromCirctValue(condbr.falseResult)) + + +def cmerge(*args: Signal) -> Tuple[Signal, BitsSignal]: + """Merge multiple signals into one and the index of the signal.""" + if len(args) == 0: + raise ValueError("cmerge must have at least one argument") + first = args[0] + for a in args[1:]: + if a.type != first.type: + raise ValueError("All arguments to cmerge must have the same type") + idx_type = Bits(clog2(len(args))) + cm = raw_handshake.ControlMergeOp(a.type._type, idx_type._type, + [a.value for a in args]) + return (_FromCirctValue(cm.result), BitsSignal(cm.index, idx_type)) diff --git a/frontends/PyCDE/src/pycde/system.py b/frontends/PyCDE/src/pycde/system.py index 5ba2be1fbf34..ca5775fac0fc 100644 --- a/frontends/PyCDE/src/pycde/system.py +++ b/frontends/PyCDE/src/pycde/system.py @@ -264,8 +264,8 @@ def get_instance(self, # Then run all the passes to lower dialects which produce `hw.module`s. "builtin.module(lower-handshake-to-dc)", "builtin.module(dc-materialize-forks-sinks)", - "builtin.module(canonicalize)", "builtin.module(lower-dc-to-hw)", + "builtin.module(map-arith-to-comb)", # Run ESI manifest passes. "builtin.module(esi-appid-hier{{top={tops} }}, esi-build-manifest{{top={tops} }})", @@ -275,7 +275,6 @@ def get_instance(self, # Instaniate hlmems, which could produce new esi connections. "builtin.module(hw.module(lower-seq-hlmem))", "builtin.module(lower-esi-to-physical)", - # TODO: support more than just cosim. "builtin.module(lower-esi-bundles, lower-esi-ports)", "builtin.module(lower-esi-to-hw{{platform={platform}}})", "builtin.module(convert-fsm-to-sv)", diff --git a/frontends/PyCDE/test/test_handshake.py b/frontends/PyCDE/test/test_handshake.py index cf3c3ffe8974..d8fc37ac24d4 100644 --- a/frontends/PyCDE/test/test_handshake.py +++ b/frontends/PyCDE/test/test_handshake.py @@ -1,42 +1,50 @@ # RUN: %PYTHON% %s | FileCheck %s from pycde import (Clock, Output, Input, generator, types, Module) -from pycde.handshake import Func +from pycde.handshake import Func, cmerge, demux from pycde.testing import unittestmodule from pycde.types import Bits, Channel -# CHECK: hw.module @Top(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel, out x : !esi.channel) -# CHECK: [[R0:%.+]] = handshake.esi_instance @TestFunc "TestFunc" clk %clk rst %rst(%a) : (!esi.channel) -> !esi.channel -# CHECK: hw.output [[R0]] : !esi.channel -# CHECK: } -# CHECK: handshake.func @TestFunc(%arg0: i8, ...) -> i8 +# CHECK: hw.module @Top(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel, in %b : !esi.channel, out x : !esi.channel) +# CHECK: %0:2 = handshake.esi_instance @TestFunc "TestFunc" clk %clk rst %rst(%a, %b) : (!esi.channel, !esi.channel) -> (!esi.channel, !esi.channel) +# CHECK: hw.output %0#0 : !esi.channel + +# CHECK: handshake.func @TestFunc(%arg0: i8, %arg1: i8, ...) -> (i8, i8) +# CHECK: %result, %index = control_merge %arg0, %arg1 : i8, i1 # CHECK: %c15_i8 = hw.constant 15 : i8 -# CHECK: %0 = comb.and bin %arg0, %c15_i8 : i8 -# CHECK: return %0 : i8 -# CHECK: } +# CHECK: [[R0:%.+]] = comb.and bin %result, %c15_i8 : i8 +# CHECK: %trueResult, %falseResult = cond_br %index, [[R0]] : i8 +# CHECK: return %trueResult, %falseResult : i8, i8 class TestFunc(Func): a = Input(Bits(8)) + b = Input(Bits(8)) x = Output(Bits(8)) + y = Output(Bits(8)) @generator def build(ports): - ports.x = ports.a & Bits(8)(0xF) + c, sel = cmerge(ports.a, ports.b) + z = c & Bits(8)(0xF) + x, y = demux(sel, z) + ports.x = x + ports.y = y BarType = types.struct({"foo": types.i12}, "bar") -@unittestmodule(print=True, run_passes=True) +@unittestmodule(print=True) class Top(Module): clk = Clock() rst = Input(Bits(1)) a = Input(Channel(Bits(8))) + b = Input(Channel(Bits(8))) x = Output(Channel(Bits(8))) @generator def build(ports): - test = TestFunc(clk=ports.clk, rst=ports.rst, a=ports.a) + test = TestFunc(clk=ports.clk, rst=ports.rst, a=ports.a, b=ports.b) ports.x = test.x diff --git a/include/circt/Conversion/Passes.td b/include/circt/Conversion/Passes.td index 0f3dc53a7b64..5f6ab39d4c23 100644 --- a/include/circt/Conversion/Passes.td +++ b/include/circt/Conversion/Passes.td @@ -471,7 +471,12 @@ def HandshakeToDC : Pass<"lower-handshake-to-dc", "mlir::ModuleOp"> { function with graph region behaviour. Thus, for now, we just use `hw.module` as a container operation. }]; - let dependentDialects = ["dc::DCDialect", "mlir::func::FuncDialect", "hw::HWDialect"]; + let dependentDialects = [ + "dc::DCDialect", + "mlir::arith::ArithDialect", + "mlir::func::FuncDialect", + "hw::HWDialect" + ]; let options = [ Option<"clkName", "clk-name", "std::string", "\"clk\"", "Name of the clock signal to use in the generated DC module">, diff --git a/lib/CAPI/Dialect/DC.cpp b/lib/CAPI/Dialect/DC.cpp index 40dd6e04d4a0..a5285e28e4c3 100644 --- a/lib/CAPI/Dialect/DC.cpp +++ b/lib/CAPI/Dialect/DC.cpp @@ -10,9 +10,13 @@ #include "circt/Conversion/Passes.h" #include "circt/Dialect/DC/DCDialect.h" #include "circt/Dialect/DC/DCPasses.h" +#include "circt/Transforms/Passes.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Registration.h" #include "mlir/CAPI/Support.h" -void registerDCPasses() { circt::dc::registerPasses(); } +void registerDCPasses() { + circt::registerMapArithToCombPass(); + circt::dc::registerPasses(); +} MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(DC, dc, circt::dc::DCDialect)