Skip to content

Commit

Permalink
[PyCDE] Add fork, join, and merge channel functions
Browse files Browse the repository at this point in the history
- <ChannelSignal>.fork creates two new channels, waits until they are
both available, then accepts an input. Also buffer the output channels
to avoid combinational loops.
- Channel.join waits on two channels then creates a message on the one
output channel containing a struct with field 'a' equal to input channel
A's content and likewise for channel B.
- Channel.merge funnels two channels together into a single output
stream.

This is functionality which really should be handled by the DC dialect
but it's not ready for primetime.
  • Loading branch information
teqdruid committed Dec 19, 2024
1 parent 8f388d1 commit 801ef91
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 4 deletions.
101 changes: 101 additions & 0 deletions frontends/PyCDE/integration_test/esi_advanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# REQUIRES: esi-runtime, esi-cosim, rtl-sim
# RUN: rm -rf %t
# RUN: mkdir %t && cd %t
# RUN: %PYTHON% %s %t 2>&1
# RUN: esi-cosim.py -- %PYTHON% %S/test_software/esi_advanced.py cosim env

import sys

from pycde import generator, Clock, Module, Reset, System
from pycde.bsp import get_bsp
from pycde.common import InputChannel, OutputChannel, Output
from pycde.types import Bits, UInt
from pycde import esi


class Merge(Module):
clk = Clock()
rst = Reset()
a = InputChannel(UInt(8))
b = InputChannel(UInt(8))

x = OutputChannel(UInt(8))

@generator
def build(ports):
chan = ports.a.type.merge(ports.a, ports.b)
ports.x = chan


class Join(Module):
clk = Clock()
rst = Reset()
a = InputChannel(UInt(8))
b = InputChannel(UInt(8))

x = OutputChannel(UInt(9))

@generator
def build(ports):
joined = ports.a.type.join(ports.a, ports.b)
ports.x = joined.transform(lambda x: x.a + x.b)


class Fork(Module):
clk = Clock()
rst = Reset()
a = InputChannel(UInt(8))

x = OutputChannel(UInt(8))
y = OutputChannel(UInt(8))

@generator
def build(ports):
x, y = ports.a.fork(ports.clk, ports.rst)
ports.x = x
ports.y = y


class Top(Module):
clk = Clock()
rst = Reset()

@generator
def build(ports):
clk = ports.clk
rst = ports.rst
merge_a = esi.ChannelService.from_host(esi.AppID("merge_a"),
UInt(8)).buffer(clk, rst, 1)
merge_b = esi.ChannelService.from_host(esi.AppID("merge_b"),
UInt(8)).buffer(clk, rst, 1)
merge = Merge("merge_i8",
clk=ports.clk,
rst=ports.rst,
a=merge_a,
b=merge_b)
esi.ChannelService.to_host(esi.AppID("merge_x"),
merge.x.buffer(clk, rst, 1))

join_a = esi.ChannelService.from_host(esi.AppID("join_a"),
UInt(8)).buffer(clk, rst, 1)
join_b = esi.ChannelService.from_host(esi.AppID("join_b"),
UInt(8)).buffer(clk, rst, 1)
join = Join("join_i8", clk=ports.clk, rst=ports.rst, a=join_a, b=join_b)
esi.ChannelService.to_host(
esi.AppID("join_x"),
join.x.buffer(clk, rst, 1).transform(lambda x: x.as_uint(16)))

fork_a = esi.ChannelService.from_host(esi.AppID("fork_a"),
UInt(8)).buffer(clk, rst, 1)
fork = Fork("fork_i8", clk=ports.clk, rst=ports.rst, a=fork_a)
esi.ChannelService.to_host(esi.AppID("fork_x"), fork.x.buffer(clk, rst, 1))
esi.ChannelService.to_host(esi.AppID("fork_y"), fork.y.buffer(clk, rst, 1))


if __name__ == "__main__":
bsp = get_bsp(sys.argv[2] if len(sys.argv) > 2 else None)
s = System(bsp(Top), name="ESIAdvanced", output_directory=sys.argv[1])
s.generate()
s.run_passes()
s.compile()
s.package()
53 changes: 53 additions & 0 deletions frontends/PyCDE/integration_test/test_software/esi_advanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import esiaccel as esi
import sys

platform = sys.argv[1]
acc = esi.AcceleratorConnection(platform, sys.argv[2])

d = acc.build_accelerator()

merge_a = d.ports[esi.AppID("merge_a")].write_port("data")
merge_a.connect()
merge_b = d.ports[esi.AppID("merge_b")].write_port("data")
merge_b.connect()
merge_x = d.ports[esi.AppID("merge_x")].read_port("data")
merge_x.connect()

for i in range(10, 15):
merge_a.write(i)
merge_b.write(i + 10)
x1 = merge_x.read()
x2 = merge_x.read()
print(f"merge_a: {i}, merge_b: {i + 10}, "
f"merge_x 1: {x1}, merge_x 2: {x2}")
assert x1 == i + 10 or x1 == i
assert x2 == i + 10 or x2 == i
assert x1 != x2

join_a = d.ports[esi.AppID("join_a")].write_port("data")
join_a.connect()
join_b = d.ports[esi.AppID("join_b")].write_port("data")
join_b.connect()
join_x = d.ports[esi.AppID("join_x")].read_port("data")
join_x.connect()

for i in range(15, 27):
join_a.write(i)
join_b.write(i + 10)
x = join_x.read()
print(f"join_a: {i}, join_b: {i + 10}, join_x: {x}")
assert x == (i + i + 10) & 0xFFFF

fork_a = d.ports[esi.AppID("fork_a")].write_port("data")
fork_a.connect()
fork_x = d.ports[esi.AppID("fork_x")].read_port("data")
fork_x.connect()
fork_y = d.ports[esi.AppID("fork_y")].read_port("data")
fork_y.connect()

for i in range(27, 33):
fork_a.write(i)
x = fork_x.read()
y = fork_y.read()
print(f"fork_a: {i}, fork_x: {x}, fork_y: {y}")
assert x == y
18 changes: 18 additions & 0 deletions frontends/PyCDE/src/pycde/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def name(self, new: str):
else:
self._name = new

def get_name(self, default: str = "") -> str:
return self.name if self.name is not None else default

@property
def appid(self) -> Optional[object]: # Optional AppID.
from .module import AppID
Expand Down Expand Up @@ -752,6 +755,21 @@ def transform(self, transform: Callable[[Signal], Signal]) -> ChannelSignal:
ready_wire.assign(ready)
return ret_chan

def fork(self, clk, rst) -> Tuple[ChannelSignal, ChannelSignal]:
"""Fork the channel into two channels, returning the two new channels."""
from .constructs import Wire
from .types import Bits
both_ready = Wire(Bits(1))
both_ready.name = self.get_name() + "_fork_both_ready"
data, valid = self.unwrap(both_ready)
valid_gate = both_ready & valid
a, a_rdy = self.type.wrap(data, valid_gate)
b, b_rdy = self.type.wrap(data, valid_gate)
abuf = a.buffer(clk, rst, 1)
bbuf = b.buffer(clk, rst, 1)
both_ready.assign(a_rdy & b_rdy)
return abuf, bbuf


class BundleSignal(Signal):
"""Signal for types.Bundle."""
Expand Down
62 changes: 58 additions & 4 deletions frontends/PyCDE/src/pycde/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def inner(self):
return self.inner_type

def wrap(self, value,
valueOrEmpty) -> typing.Tuple["ChannelSignal", "BitsSignal"]:
validOrEmpty) -> typing.Tuple["ChannelSignal", "BitsSignal"]:
"""Wrap a data signal and valid signal into a data channel signal and a
ready signal."""

Expand All @@ -608,21 +608,75 @@ def wrap(self, value,
# one.

from .dialects import esi
from .signals import Signal
signaling = self.signaling
if signaling == ChannelSignaling.ValidReady:
value = self.inner_type(value)
valid = types.i1(valueOrEmpty)
if not isinstance(value, Signal):
value = self.inner_type(value)
elif value.type != self.inner_type:
raise TypeError(
f"Expected signal of type {self.inner_type}, got {value.type}")
valid = Bits(1)(validOrEmpty)
wrap_op = esi.WrapValidReadyOp(self._type, types.i1, value.value,
valid.value)
return wrap_op[0], wrap_op[1]
elif signaling == ChannelSignaling.FIFO:
value = self.inner_type(value)
empty = types.i1(valueOrEmpty)
empty = Bits(1)(validOrEmpty)
wrap_op = esi.WrapFIFOOp(self._type, types.i1, value.value, empty.value)
return wrap_op[0], wrap_op[1]
else:
raise TypeError("Unknown signaling standard")

def _join(self, a: "ChannelSignal", b: "ChannelSignal") -> "ChannelSignal":
"""Join two channels into a single channel. The resulting type is a struct
with two fields, 'a' and 'b' wherein 'a' is the data from channel a and 'b'
is the data from channel b."""

from .constructs import Wire
both_ready = Wire(Bits(1))
a_data, a_valid = a.unwrap(both_ready)
b_data, b_valid = b.unwrap(both_ready)
both_valid = a_valid & b_valid
result_data = self.inner_type({"a": a_data, "b": b_data})
result_chan, result_ready = self.wrap(result_data, both_valid)
both_ready.assign(result_ready & both_valid)
return result_chan

@staticmethod
def join(a: "ChannelSignal", b: "ChannelSignal") -> "ChannelSignal":
"""Join two channels into a single channel. The resulting type is a struct
with two fields, 'a' and 'b' wherein 'a' is the data from channel a and 'b'
is the data from channel b."""

from .types import Channel, StructType
return Channel(
StructType([("a", a.type.inner_type),
("b", b.type.inner_type)]))._join(a, b)

def merge(self, a: "ChannelSignal", b: "ChannelSignal") -> "ChannelSignal":
"""Merge two channels into a single channel, selecting a message from either
one. May implement an sort of fairness policy. Both channels must be of the
same type. Returns both the merged channel."""

from .constructs import Mux, Wire
a_ready = Wire(Bits(1))
b_ready = Wire(Bits(1))
a_data, a_valid = a.unwrap(a_ready)
b_data, b_valid = b.unwrap(b_ready)

sel_a = a_valid
sel_b = ~sel_a
out_ready = Wire(Bits(1))
a_ready.assign(sel_a & out_ready)
b_ready.assign(sel_b & out_ready)

valid = (sel_a & a_valid) | (sel_b & b_valid)
data = Mux(sel_a, b_data, a_data)
chan, ready = self.wrap(data, valid)
out_ready.assign(ready)
return chan


@dataclass
class BundledChannel:
Expand Down
88 changes: 88 additions & 0 deletions frontends/PyCDE/test/test_esi_advanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# RUN: %PYTHON% %s | FileCheck %s

from pycde import generator, Clock, Module, Reset
from pycde.common import InputChannel, OutputChannel
from pycde.testing import unittestmodule
from pycde.types import Bits, UInt

# CHECK-LABEL: hw.module @Merge(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<i8>, in %b : !esi.channel<i8>, out x : !esi.channel<i8>)
# CHECK-NEXT: %rawOutput, %valid = esi.unwrap.vr %a, [[R1:%.+]] : i8
# CHECK-NEXT: %rawOutput_0, %valid_1 = esi.unwrap.vr %b, [[R2:%.+]] : i8
# CHECK-NEXT: %true = hw.constant true
# CHECK-NEXT: [[R0:%.+]] = comb.xor bin %valid, %true : i1
# CHECK-NEXT: [[R1]] = comb.and bin %valid, %ready : i1
# CHECK-NEXT: [[R2]] = comb.and bin [[R0]], %ready : i1
# CHECK-NEXT: [[R3:%.+]] = comb.and bin %valid, %valid : i1
# CHECK-NEXT: [[R4:%.+]] = comb.and bin [[R0]], %valid_1 : i1
# CHECK-NEXT: [[R5:%.+]] = comb.or bin [[R3]], [[R4]] : i1
# CHECK-NEXT: [[R6:%.+]] = comb.mux bin %valid, %rawOutput, %rawOutput_0
# CHECK-NEXT: %chanOutput, %ready = esi.wrap.vr [[R6]], [[R5]] : i8
# CHECK-NEXT: hw.output %chanOutput : !esi.channel<i8>


@unittestmodule()
class Merge(Module):
clk = Clock()
rst = Reset()
a = InputChannel(Bits(8))
b = InputChannel(Bits(8))

x = OutputChannel(Bits(8))

@generator
def build(ports):
chan = ports.a.type.merge(ports.a, ports.b)
ports.x = chan


# CHECK-LABEL: hw.module @Join(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<ui8>, in %b : !esi.channel<ui8>, out x : !esi.channel<ui9>)
# CHECK-NEXT: %rawOutput, %valid = esi.unwrap.vr %a, [[R2:%.+]] : ui8
# CHECK-NEXT: %rawOutput_0, %valid_1 = esi.unwrap.vr %b, [[R2]] : ui8
# CHECK-NEXT: [[R0:%.+]] = comb.and bin %valid, %valid_1 : i1
# CHECK-NEXT: [[R1:%.+]] = hw.struct_create (%rawOutput, %rawOutput_0) : !hw.struct<a: ui8, b: ui8>
# CHECK-NEXT: %chanOutput, %ready = esi.wrap.vr [[R1]], [[R0]] : !hw.struct<a: ui8, b: ui8>
# CHECK-NEXT: [[R2]] = comb.and bin %ready, [[R0]] : i1
# CHECK-NEXT: %rawOutput_2, %valid_3 = esi.unwrap.vr %chanOutput, %ready_7 : !hw.struct<a: ui8, b: ui8>
# CHECK-NEXT: %a_4 = hw.struct_extract %rawOutput_2["a"] : !hw.struct<a: ui8, b: ui8>
# CHECK-NEXT: %b_5 = hw.struct_extract %rawOutput_2["b"] : !hw.struct<a: ui8, b: ui8>
# CHECK-NEXT: [[R3:%.+]] = hwarith.add %a_4, %b_5 : (ui8, ui8) -> ui9
# CHECK-NEXT: %chanOutput_6, %ready_7 = esi.wrap.vr [[R3]], %valid_3 : ui9
# CHECK-NEXT: hw.output %chanOutput_6 : !esi.channel<ui9>
@unittestmodule(run_passes=True, emit_outputs=True)
class Join(Module):
clk = Clock()
rst = Reset()
a = InputChannel(UInt(8))
b = InputChannel(UInt(8))

x = OutputChannel(UInt(9))

@generator
def build(ports):
joined = ports.a.type.join(ports.a, ports.b)
ports.x = joined.transform(lambda x: x.a + x.b)


# CHECK-LABEL: hw.module @Fork(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<ui8>, out x : !esi.channel<ui8>, out y : !esi.channel<ui8>)
# CHECK-NEXT: %rawOutput, %valid = esi.unwrap.vr %a, [[R3:%.+]] : ui8
# CHECK-NEXT: [[R0:%.+]] = comb.and bin [[R3]], %valid : i1
# CHECK-NEXT: %chanOutput, %ready = esi.wrap.vr %rawOutput, [[R0]] : ui8
# CHECK-NEXT: %chanOutput_0, %ready_1 = esi.wrap.vr %rawOutput, [[R0]] : ui8
# CHECK-NEXT: [[R1:%.+]] = esi.buffer %clk, %rst, %chanOutput {stages = 1 : i64} : ui8
# CHECK-NEXT: [[R2:%.+]] = esi.buffer %clk, %rst, %chanOutput_0 {stages = 1 : i64} : ui8
# CHECK-NEXT: [[R3]] = comb.and bin %ready, %ready_1 : i1
# CHECK-NEXT: hw.output [[R1]], [[R2]] : !esi.channel<ui8>, !esi.channel<ui8>
@unittestmodule(run_passes=True, emit_outputs=True)
class Fork(Module):
clk = Clock()
rst = Reset()
a = InputChannel(UInt(8))

x = OutputChannel(UInt(8))
y = OutputChannel(UInt(8))

@generator
def build(ports):
x, y = ports.a.fork(ports.clk, ports.rst)
ports.x = x
ports.y = y

0 comments on commit 801ef91

Please sign in to comment.