Skip to content

Commit

Permalink
[PyCDE] Fix ESI service implementations (#6545)
Browse files Browse the repository at this point in the history
Repair the ability to implement services in PyCDE. I broke this some
months ago and never got around to fixing it.
  • Loading branch information
teqdruid authored Jan 3, 2024
1 parent c1a38c7 commit e5bf882
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 106 deletions.
44 changes: 19 additions & 25 deletions frontends/PyCDE/src/esi.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,52 +159,46 @@ def __init__(self, input_chan: ir.Value, client_name: List[str]):
super().__init__(input_chan, _FromCirctType(input_chan.type))


class _OutputChannelSetter:
class _OutputBundleSetter:
"""Return a list of these as a proxy for a 'request to client connection'.
Users should call the 'assign' method with the `ChannelValue` which they
have implemented for this request."""

def __init__(self, req: raw_esi.RequestToClientConnectionOp,
old_chan_to_replace: ChannelSignal):
self.type = Channel(_FromCirctType(req.toClient.type))
self.client_name = req.clientNamePath
self._chan_to_replace = old_chan_to_replace
def __init__(self, req: raw_esi.ServiceImplementConnReqOp,
old_value_to_replace: ir.OpResult):
self.type: Bundle = _FromCirctType(req.toClient.type)
self.client_name = req.relativeAppIDPath
self._bundle_to_replace: Optional[ir.OpResult] = old_value_to_replace

def assign(self, new_value: ChannelSignal):
"""Assign the generated channel to this request."""
if self._chan_to_replace is None:
if self._bundle_to_replace is None:
name_str = ".".join(self.client_name)
raise ValueError(f"{name_str} has already been connected.")
if new_value.type != self.type:
raise TypeError(
f"Channel type mismatch. Expected {self.type}, got {new_value.type}.")
msft.replaceAllUsesWith(self._chan_to_replace, new_value.value)
self._chan_to_replace = None
msft.replaceAllUsesWith(self._bundle_to_replace, new_value.value)
self._bundle_to_replace = None


class _ServiceGeneratorBundles:
"""Provide access to the bundles which the service generator is responsible
for connecting up."""

def __init__(self, mod: Module, req: raw_esi.ServiceImplementReqOp):
def __init__(self, mod: ModuleLikeBuilderBase,
req: raw_esi.ServiceImplementReqOp):
self._req = req
portReqsBlock = req.portReqs.blocks[0]

# Find the input channel requests and store named versions of the values.
self._input_reqs = [
NamedChannelValue(x.toServer, x.clientNamePath)
for x in portReqsBlock
if isinstance(x, raw_esi.RequestToServerConnectionOp)
]

# Find the output channel requests and store the settable proxies.
num_output_ports = len(mod.outputs)
to_client_reqs = [
req for req in portReqsBlock
if isinstance(req, raw_esi.RequestToClientConnectionOp)
if isinstance(req, raw_esi.ServiceImplementConnReqOp)
]
self._output_reqs = [
_OutputChannelSetter(req, self._req.results[num_output_ports + idx])
_OutputBundleSetter(req, self._req.results[num_output_ports + idx])
for idx, req in enumerate(to_client_reqs)
]
assert len(self._output_reqs) == len(req.results) - num_output_ports
Expand All @@ -216,12 +210,12 @@ def reqs(self) -> List[NamedChannelValue]:
return self._input_reqs

@property
def to_client_reqs(self) -> List[_OutputChannelSetter]:
def to_client_reqs(self) -> List[_OutputBundleSetter]:
return self._output_reqs

def check_unconnected_outputs(self):
for req in self._output_reqs:
if req._chan_to_replace is not None:
if req._bundle_to_replace is not None:
name_str = ".".join(req.client_name)
raise ValueError(f"{name_str} has not been connected.")

Expand All @@ -231,7 +225,7 @@ class ServiceImplementationModuleBuilder(ModuleLikeBuilderBase):
no distinction between definition and instance -- ESI service providers are
built where they are instantiated."""

def instantiate(self, impl, inputs: Dict[str, Signal], appid: AppID = None):
def instantiate(self, impl, inputs: Dict[str, Signal], appid: AppID):
# Each instantiation of the ServiceImplementation has its own
# registration.
opts = _service_generator_registry.register(impl)
Expand Down Expand Up @@ -260,14 +254,14 @@ def generate_svc_impl(self,
with self.GeneratorCtxt(self, ports, serviceReq, generator.loc):

# Run the generator.
channels = _ServiceGeneratorChannels(self, serviceReq)
rc = generator.gen_func(ports, channels=channels)
bundles = _ServiceGeneratorBundles(self, serviceReq)
rc = generator.gen_func(ports, bundles=bundles)
if rc is None:
rc = True
elif not isinstance(rc, bool):
raise ValueError("Generators must a return a bool or None")
ports._check_unconnected_outputs()
channels.check_unconnected_outputs()
bundles.check_unconnected_outputs()

# Replace the output values from the service implement request op with
# the generated values. Erase the service implement request op.
Expand Down
31 changes: 23 additions & 8 deletions frontends/PyCDE/src/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations

from collections import OrderedDict
from functools import singledispatchmethod
from typing import Any

from numpy import single

from .support import get_user_loc

from .circt import ir, support
Expand Down Expand Up @@ -619,32 +619,47 @@ def __repr__(self):
class PackSignalResults:
"""Access the FROM channels of a packed bundle in a convenient way."""

def __init__(self, results: typing.List["ChannelSignal"],
bundle_type: "Bundle"):
def __init__(self, results: typing.List[ChannelSignal],
bundle_type: Bundle):
self.results = results
self.bundle_type = bundle_type
from_channels = [

self.from_channels = {
name: result for (name, result) in zip([
c.name
for c in self.bundle_type.channels
if c.direction == ChannelDirection.FROM
], results)
}

from_channels_idx = [
c.name
for c in self.bundle_type.channels
if c.direction == ChannelDirection.FROM
]
self._from_channels_idx = {
name: idx for idx, name in enumerate(from_channels)
name: idx for idx, name in enumerate(from_channels_idx)
}

@singledispatchmethod
def __getitem__(self, name: str) -> "ChannelSignal":
def __getitem__(self, name: str) -> ChannelSignal:
return self.results[self._from_channels_idx[name]]

@__getitem__.register(int)
def __getitem_int(self, idx: int) -> "ChannelSignal":
def __getitem_int(self, idx: int) -> ChannelSignal:
return self.results[idx]

def __getattr__(self, attrname: str):
if attrname in self._from_channels_idx:
return self.results[self._from_channels_idx[attrname]]
return super().__getattribute__(attrname)

def __iter__(self):
return iter(self.from_channels.items())

def __len__(self):
return len(self.from_channels)

def pack(
self, **kwargs: typing.Dict[str, "ChannelSignal"]
) -> ("BundleSignal", typing.Dict[str, "ChannelSignal"]):
Expand Down
135 changes: 62 additions & 73 deletions frontends/PyCDE/test/test_esi_servicegens.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# XFAIL: *
# RUN: rm -rf %t
# RUN: %PYTHON% %s %t 2>&1 | FileCheck %s

from pycde import (Clock, Input, InputChannel, OutputChannel, Module, generator,
types)
from pycde import (Clock, Input, Module, generator)
from pycde import esi
from pycde.common import AppID, Output, RecvBundle, SendBundle
from pycde.common import AppID, Output
from pycde.constructs import Wire
from pycde.types import (Bits, Bundle, BundledChannel, Channel,
ChannelDirection, ChannelSignaling, UInt, ClockType)
ChannelDirection)
from pycde.testing import unittestmodule
from pycde.signals import BitVectorSignal, ChannelSignal
from pycde.signals import BitsSignal, ChannelSignal

from typing import Dict

TestBundle = Bundle([
BundledChannel("resp", ChannelDirection.TO, Bits(16)),
Expand All @@ -27,11 +27,11 @@ class LoopbackInOut(Module):

@generator
def construct(self):
loopback = Wire(types.channel(types.i16))
loopback = Wire(Channel(Bits(16)))
call_bundle, froms = TestBundle.pack(resp=loopback)
from_host = froms['req']
HostComms.req_resp(call_bundle, AppID("loopback_inout", 0))
ready = Wire(types.i1)
ready = Wire(Bits(1))
wide_data, valid = from_host.unwrap(ready)
data = wide_data[0:16]
data_chan, data_ready = loopback.type.wrap(data, valid)
Expand All @@ -41,32 +41,46 @@ def construct(self):

class MultiplexerService(esi.ServiceImplementation):
clk = Clock()
rst = Input(types.i1)
rst = Input(Bits(1))

# Underlying channel is an untyped, 256-bit LI channel.
trunk_in = Input(types.i256)
trunk_in_valid = Input(types.i1)
trunk_in_ready = Output(types.i1)
trunk_out = Output(types.i256)
trunk_out_valid = Output(types.i1)
trunk_out_ready = Input(types.i1)
trunk_in = Input(Bits(256))
trunk_in_valid = Input(Bits(1))
trunk_in_ready = Output(Bits(1))
trunk_out = Output(Bits(256))
trunk_out_valid = Output(Bits(1))
trunk_out_ready = Input(Bits(1))

@generator
def generate(self, bundles):

input_reqs = channels.to_server_reqs
if len(input_reqs) > 1:
def generate(self, bundles: esi._ServiceGeneratorBundles):
assert len(
bundles.to_client_reqs) == 1, "Only one connection request supported"
bundle = bundles.to_client_reqs[0]
to_req_types = {}
for bundled_chan in bundle.type.channels:
if bundled_chan.direction == ChannelDirection.TO:
to_req_types[bundled_chan.name] = bundled_chan.channel

to_channels = MultiplexerService._generate_to(self, to_req_types)
bundle_sig, from_channels = bundle.type.pack(**to_channels)
bundle.assign(bundle_sig)
MultiplexerService._generate_from(self, from_channels)

def _generate_from(self, from_reqs):
if len(from_reqs) > 1:
raise Exception("Multiple to_server requests not supported")
MultiplexerService.unwrap_and_pad(self, input_reqs[0])

output_reqs = channels.to_client_reqs
if len(output_reqs) > 1:
raise Exception("Multiple to_client requests not supported")
output_req = output_reqs[0]
output_chan, ready = MultiplexerService.slice_and_wrap(
self, output_req.type)
output_req.assign(output_chan)
for _, chan in from_reqs:
MultiplexerService.unwrap_and_pad(self, chan)

def _generate_to(
self, to_req_types: Dict[str, Channel]) -> Dict[str, ChannelSignal]:
if len(to_req_types) > 1:
raise Exception("Multiple TO channels not supported")
chan_name = list(to_req_types.keys())[0]
output_type = to_req_types[chan_name]
output_chan, ready = MultiplexerService.slice_and_wrap(self, output_type)
self.trunk_in_ready = ready
return {chan_name: output_chan}

@staticmethod
def slice_and_wrap(ports, channel_type: Channel):
Expand All @@ -80,7 +94,7 @@ def unwrap_and_pad(ports, input_channel: ChannelSignal):
Unwrap the input channel and pad it to 256 bits.
"""
(data, valid) = input_channel.unwrap(ports.trunk_out_ready)
assert isinstance(data, BitVectorSignal)
assert isinstance(data, BitsSignal)
assert len(data) <= 256
ports.trunk_out = data.pad_or_truncate(256)
ports.trunk_out_valid = valid
Expand All @@ -89,14 +103,14 @@ def unwrap_and_pad(ports, input_channel: ChannelSignal):
@unittestmodule(run_passes=True, print_after_passes=True, emit_outputs=True)
class MultiplexerTop(Module):
clk = Clock()
rst = Input(types.i1)
rst = Input(Bits(1))

trunk_in = Input(types.i256)
trunk_in_valid = Input(types.i1)
trunk_in_ready = Output(types.i1)
trunk_out = Output(types.i256)
trunk_out_valid = Output(types.i1)
trunk_out_ready = Input(types.i1)
trunk_in = Input(Bits(256))
trunk_in_valid = Input(Bits(1))
trunk_in_ready = Output(Bits(1))
trunk_out = Output(Bits(256))
trunk_out_valid = Output(Bits(1))
trunk_out_ready = Input(Bits(1))

@generator
def construct(ports):
Expand All @@ -115,40 +129,15 @@ def construct(ports):
LoopbackInOut()


class PassUpService(esi.ServiceImplementation):

@generator
def generate(self, channels):
for req in channels.to_server_reqs:
name = "out_" + "_".join(req.client_name)
esi.PureModule.output_port(name, req)
for req in channels.to_client_reqs:
name = "in_" + "_".join(req.client_name)
req.assign(esi.PureModule.input_port(name, req.type))


# CHECK-LABEL: hw.module @PureTest<FOO: i5, STR: none>(in %in_Producer_loopback_in : i32, in %in_Producer_loopback_in_valid : i1, in %in_prod2_loopback_in : i32, in %in_prod2_loopback_in_valid : i1, in %clk : i1, in %out_Consumer_loopback_out_ready : i1, in %p2_int_ready : i1, out in_Producer_loopback_in_ready : i1, out in_prod2_loopback_in_ready : i1, out out_Consumer_loopback_out : i32, out out_Consumer_loopback_out_valid : i1, out p2_int : i32, out p2_int_valid : i1)
# CHECK-NEXT: %Producer.loopback_in_ready, %Producer.int_out, %Producer.int_out_valid = hw.instance "Producer" sym @Producer @Producer{{.*}}(clk: %clk: i1, loopback_in: %in_Producer_loopback_in: i32, loopback_in_valid: %in_Producer_loopback_in_valid: i1, int_out_ready: %Consumer.int_in_ready: i1) -> (loopback_in_ready: i1, int_out: i32, int_out_valid: i1)
# CHECK-NEXT: %Consumer.int_in_ready, %Consumer.loopback_out, %Consumer.loopback_out_valid = hw.instance "Consumer" sym @Consumer @Consumer{{.*}}(clk: %clk: i1, int_in: %Producer.int_out: i32, int_in_valid: %Producer.int_out_valid: i1, loopback_out_ready: %out_Consumer_loopback_out_ready: i1) -> (int_in_ready: i1, loopback_out: i32, loopback_out_valid: i1)
# CHECK-NEXT: %prod2.loopback_in_ready, %prod2.int_out, %prod2.int_out_valid = hw.instance "prod2" sym @prod2 @Producer{{.*}}(clk: %clk: i1, loopback_in: %in_prod2_loopback_in: i32, loopback_in_valid: %in_prod2_loopback_in_valid: i1, int_out_ready: %p2_int_ready: i1) -> (loopback_in_ready: i1, int_out: i32, int_out_valid: i1)
# CHECK-NEXT: hw.output %Producer.loopback_in_ready, %prod2.loopback_in_ready, %Consumer.loopback_out, %Consumer.loopback_out_valid, %prod2.int_out, %prod2.int_out_valid : i1, i1, i32, i1, i32, i1
@unittestmodule(run_passes=True, print_after_passes=True, emit_outputs=True)
class PureTest(esi.PureModule):

@generator
def construct(ports):
PassUpService(None)

clk = esi.PureModule.input_port("clk", ClockType())
p = Producer(clk=clk)
Consumer(clk=clk, int_in=p.int_out)
p2 = Producer(clk=clk, instance_name="prod2")
esi.PureModule.output_port("p2_int", p2.int_out)
esi.PureModule.param("FOO", Bits(5))
esi.PureModule.param("STR")


ExStruct = types.struct({
'a': Bits(4),
'b': UInt(32),
})
# CHECK-LABEL: hw.module @MultiplexerTop(in %clk : i1, in %rst : i1, in %trunk_in : i256, in %trunk_in_valid : i1, in %trunk_out_ready : i1, out trunk_in_ready : i1, out trunk_out : i256, out trunk_out_valid : i1) attributes {output_file = #hw.output_file<"MultiplexerTop.sv", includeReplicatedOps>} {
# CHECK: %c0_i240 = hw.constant 0 : i240
# CHECK: [[R0:%.+]] = comb.extract %trunk_in from 0 {sv.namehint = "trunk_in_0upto24"} : (i256) -> i24
# CHECK: [[R1:%.+]] = comb.concat %c0_i240, %LoopbackInOut.loopback_inout_0_resp : i240, i16
# CHECK: %LoopbackInOut.loopback_inout_0_req_ready, %LoopbackInOut.loopback_inout_0_resp, %LoopbackInOut.loopback_inout_0_resp_valid = hw.instance "LoopbackInOut" sym @LoopbackInOut @LoopbackInOut(loopback_inout_0_req: [[R0]]: i24, loopback_inout_0_req_valid: %trunk_in_valid: i1, loopback_inout_0_resp_ready: %trunk_out_ready: i1) -> (loopback_inout_0_req_ready: i1, loopback_inout_0_resp: i16, loopback_inout_0_resp_valid: i1)
# CHECK: hw.instance "__manifest" @__ESIManifest() -> ()
# CHECK: hw.output %LoopbackInOut.loopback_inout_0_req_ready, [[R1]], %LoopbackInOut.loopback_inout_0_resp_valid : i1, i256, i1
# CHECK: }
# CHECK-LABEL: hw.module @LoopbackInOut(in %loopback_inout_0_req : i24, in %loopback_inout_0_req_valid : i1, in %loopback_inout_0_resp_ready : i1, out loopback_inout_0_req_ready : i1, out loopback_inout_0_resp : i16, out loopback_inout_0_resp_valid : i1) attributes {output_file = #hw.output_file<"LoopbackInOut.sv", includeReplicatedOps>} {
# CHECK: [[R0:%.+]] = comb.extract %loopback_inout_0_req from 0 : (i24) -> i16
# CHECK: hw.output %loopback_inout_0_resp_ready, [[R0]], %loopback_inout_0_req_valid : i1, i16, i1
# CHECK: }

0 comments on commit e5bf882

Please sign in to comment.