Skip to content

Commit

Permalink
[DC CAPI][PyCDE] Add cmerge and demux handshake ops
Browse files Browse the repository at this point in the history
- Adds cmerge and demux functions to the handshake pycde module.
- Lowering them requires fixes to the conversion pass and the CAPI code.
  • Loading branch information
teqdruid committed Dec 19, 2024
1 parent 99826b8 commit bcc1e01
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 21 deletions.
32 changes: 27 additions & 5 deletions frontends/PyCDE/src/pycde/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Tuple

from .module import Module, ModuleLikeBuilderBase, PortError
from .signals import BitsSignal, ChannelSignal, ClockSignal, Signal
from .signals import (BitsSignal, ChannelSignal, ClockSignal, Signal,
_FromCirctValue)
from .system import System
from .support import get_user_loc, obj_to_typed_attribute
from .types import Channel
from .support import clog2, get_user_loc
from .types import Bits, Channel

from .circt.dialects import handshake as raw_handshake
from .circt import ir
Expand Down Expand Up @@ -82,7 +83,7 @@ def instantiate(self, module_inst, inputs, instance_name: str):
# If the input is a channel signal, the types must match.
if signal.type.inner_type != port.type:
raise ValueError(
f"Wrong type on input signal '{name}'. Got '{signal.type}',"
f"Wrong type on input signal '{name}'. Got '{signal.type.inner_type}',"
f" expected '{port.type}'")
assert port.idx is not None
circt_inputs[port.idx] = signal.value
Expand Down Expand Up @@ -124,3 +125,24 @@ class Func(Module):

BuilderType: type[ModuleLikeBuilderBase] = FuncBuilder
_builder: FuncBuilder


def demux(cond: BitsSignal, data: Signal) -> Tuple[Signal, Signal]:
"""Demux a signal based on a condition."""
condbr = raw_handshake.ConditionalBranchOp(cond.value, data.value)
return (_FromCirctValue(condbr.trueResult),
_FromCirctValue(condbr.falseResult))


def cmerge(*args: Signal) -> Tuple[Signal, BitsSignal]:
"""Merge multiple signals into one and the index of the signal."""
if len(args) == 0:
raise ValueError("cmerge must have at least one argument")
first = args[0]
for a in args[1:]:
if a.type != first.type:
raise ValueError("All arguments to cmerge must have the same type")
idx_type = Bits(clog2(len(args)))
cm = raw_handshake.ControlMergeOp(a.type._type, idx_type._type,
[a.value for a in args])
return (_FromCirctValue(cm.result), BitsSignal(cm.index, idx_type))
3 changes: 1 addition & 2 deletions frontends/PyCDE/src/pycde/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def get_instance(self,
# Then run all the passes to lower dialects which produce `hw.module`s.
"builtin.module(lower-handshake-to-dc)",
"builtin.module(dc-materialize-forks-sinks)",
"builtin.module(canonicalize)",
"builtin.module(lower-dc-to-hw)",
"builtin.module(map-arith-to-comb)",

# Run ESI manifest passes.
"builtin.module(esi-appid-hier{{top={tops} }}, esi-build-manifest{{top={tops} }})",
Expand All @@ -275,7 +275,6 @@ def get_instance(self,
# Instaniate hlmems, which could produce new esi connections.
"builtin.module(hw.module(lower-seq-hlmem))",
"builtin.module(lower-esi-to-physical)",
# TODO: support more than just cosim.
"builtin.module(lower-esi-bundles, lower-esi-ports)",
"builtin.module(lower-esi-to-hw{{platform={platform}}})",
"builtin.module(convert-fsm-to-sv)",
Expand Down
32 changes: 20 additions & 12 deletions frontends/PyCDE/test/test_handshake.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,50 @@
# RUN: %PYTHON% %s | FileCheck %s

from pycde import (Clock, Output, Input, generator, types, Module)
from pycde.handshake import Func
from pycde.handshake import Func, cmerge, demux
from pycde.testing import unittestmodule
from pycde.types import Bits, Channel

# CHECK: hw.module @Top(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<i8>, out x : !esi.channel<i8>)
# CHECK: [[R0:%.+]] = handshake.esi_instance @TestFunc "TestFunc" clk %clk rst %rst(%a) : (!esi.channel<i8>) -> !esi.channel<i8>
# CHECK: hw.output [[R0]] : !esi.channel<i8>
# CHECK: }
# CHECK: handshake.func @TestFunc(%arg0: i8, ...) -> i8
# CHECK: hw.module @Top(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<i8>, in %b : !esi.channel<i8>, out x : !esi.channel<i8>)
# CHECK: %0:2 = handshake.esi_instance @TestFunc "TestFunc" clk %clk rst %rst(%a, %b) : (!esi.channel<i8>, !esi.channel<i8>) -> (!esi.channel<i8>, !esi.channel<i8>)
# CHECK: hw.output %0#0 : !esi.channel<i8>

# CHECK: handshake.func @TestFunc(%arg0: i8, %arg1: i8, ...) -> (i8, i8)
# CHECK: %result, %index = control_merge %arg0, %arg1 : i8, i1
# CHECK: %c15_i8 = hw.constant 15 : i8
# CHECK: %0 = comb.and bin %arg0, %c15_i8 : i8
# CHECK: return %0 : i8
# CHECK: }
# CHECK: [[R0:%.+]] = comb.and bin %result, %c15_i8 : i8
# CHECK: %trueResult, %falseResult = cond_br %index, [[R0]] : i8
# CHECK: return %trueResult, %falseResult : i8, i8


class TestFunc(Func):
a = Input(Bits(8))
b = Input(Bits(8))
x = Output(Bits(8))
y = Output(Bits(8))

@generator
def build(ports):
ports.x = ports.a & Bits(8)(0xF)
c, sel = cmerge(ports.a, ports.b)
z = c & Bits(8)(0xF)
x, y = demux(sel, z)
ports.x = x
ports.y = y


BarType = types.struct({"foo": types.i12}, "bar")


@unittestmodule(print=True, run_passes=True)
@unittestmodule(print=True)
class Top(Module):
clk = Clock()
rst = Input(Bits(1))

a = Input(Channel(Bits(8)))
b = Input(Channel(Bits(8)))
x = Output(Channel(Bits(8)))

@generator
def build(ports):
test = TestFunc(clk=ports.clk, rst=ports.rst, a=ports.a)
test = TestFunc(clk=ports.clk, rst=ports.rst, a=ports.a, b=ports.b)
ports.x = test.x
7 changes: 6 additions & 1 deletion include/circt/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,12 @@ def HandshakeToDC : Pass<"lower-handshake-to-dc", "mlir::ModuleOp"> {
function with graph region behaviour. Thus, for now, we just use `hw.module`
as a container operation.
}];
let dependentDialects = ["dc::DCDialect", "mlir::func::FuncDialect", "hw::HWDialect"];
let dependentDialects = [
"dc::DCDialect",
"mlir::arith::ArithDialect",
"mlir::func::FuncDialect",
"hw::HWDialect"
];
let options = [
Option<"clkName", "clk-name", "std::string", "\"clk\"",
"Name of the clock signal to use in the generated DC module">,
Expand Down
6 changes: 5 additions & 1 deletion lib/CAPI/Dialect/DC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
#include "circt/Conversion/Passes.h"
#include "circt/Dialect/DC/DCDialect.h"
#include "circt/Dialect/DC/DCPasses.h"
#include "circt/Transforms/Passes.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Support.h"

void registerDCPasses() { circt::dc::registerPasses(); }
void registerDCPasses() {
circt::registerMapArithToCombPass();
circt::dc::registerPasses();
}
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(DC, dc, circt::dc::DCDialect)

0 comments on commit bcc1e01

Please sign in to comment.