[DC CAPI][PyCDE] Add cmerge and demux handshake ops

- Adds cmerge and demux functions to the handshake pycde module.
- Lowering them requires fixes to the conversion pass and the CAPI code.
This commit is contained in:
John Demme 2024-12-19 01:21:14 +00:00
parent 99826b8499
commit bcc1e01cdb
5 changed files with 59 additions and 21 deletions

View File

@ -3,13 +3,14 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from __future__ import annotations
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Tuple
from .module import Module, ModuleLikeBuilderBase, PortError
from .signals import BitsSignal, ChannelSignal, ClockSignal, Signal
from .signals import (BitsSignal, ChannelSignal, ClockSignal, Signal,
_FromCirctValue)
from .system import System
from .support import get_user_loc, obj_to_typed_attribute
from .types import Channel
from .support import clog2, get_user_loc
from .types import Bits, Channel
from .circt.dialects import handshake as raw_handshake
from .circt import ir
@ -82,7 +83,7 @@ class FuncBuilder(ModuleLikeBuilderBase):
# If the input is a channel signal, the types must match.
if signal.type.inner_type != port.type:
raise ValueError(
f"Wrong type on input signal '{name}'. Got '{signal.type}',"
f"Wrong type on input signal '{name}'. Got '{signal.type.inner_type}',"
f" expected '{port.type}'")
assert port.idx is not None
circt_inputs[port.idx] = signal.value
@ -124,3 +125,24 @@ class Func(Module):
BuilderType: type[ModuleLikeBuilderBase] = FuncBuilder
_builder: FuncBuilder
def demux(cond: BitsSignal, data: Signal) -> Tuple[Signal, Signal]:
"""Demux a signal based on a condition."""
condbr = raw_handshake.ConditionalBranchOp(cond.value, data.value)
return (_FromCirctValue(condbr.trueResult),
_FromCirctValue(condbr.falseResult))
def cmerge(*args: Signal) -> Tuple[Signal, BitsSignal]:
"""Merge multiple signals into one and the index of the signal."""
if len(args) == 0:
raise ValueError("cmerge must have at least one argument")
first = args[0]
for a in args[1:]:
if a.type != first.type:
raise ValueError("All arguments to cmerge must have the same type")
idx_type = Bits(clog2(len(args)))
cm = raw_handshake.ControlMergeOp(a.type._type, idx_type._type,
[a.value for a in args])
return (_FromCirctValue(cm.result), BitsSignal(cm.index, idx_type))

View File

@ -264,8 +264,8 @@ class System:
# Then run all the passes to lower dialects which produce `hw.module`s.
"builtin.module(lower-handshake-to-dc)",
"builtin.module(dc-materialize-forks-sinks)",
"builtin.module(canonicalize)",
"builtin.module(lower-dc-to-hw)",
"builtin.module(map-arith-to-comb)",
# Run ESI manifest passes.
"builtin.module(esi-appid-hier{{top={tops} }}, esi-build-manifest{{top={tops} }})",
@ -275,7 +275,6 @@ class System:
# Instaniate hlmems, which could produce new esi connections.
"builtin.module(hw.module(lower-seq-hlmem))",
"builtin.module(lower-esi-to-physical)",
# TODO: support more than just cosim.
"builtin.module(lower-esi-bundles, lower-esi-ports)",
"builtin.module(lower-esi-to-hw{{platform={platform}}})",
"builtin.module(convert-fsm-to-sv)",

View File

@ -1,42 +1,50 @@
# RUN: %PYTHON% %s | FileCheck %s
from pycde import (Clock, Output, Input, generator, types, Module)
from pycde.handshake import Func
from pycde.handshake import Func, cmerge, demux
from pycde.testing import unittestmodule
from pycde.types import Bits, Channel
# CHECK: hw.module @Top(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<i8>, out x : !esi.channel<i8>)
# CHECK: [[R0:%.+]] = handshake.esi_instance @TestFunc "TestFunc" clk %clk rst %rst(%a) : (!esi.channel<i8>) -> !esi.channel<i8>
# CHECK: hw.output [[R0]] : !esi.channel<i8>
# CHECK: }
# CHECK: handshake.func @TestFunc(%arg0: i8, ...) -> i8
# CHECK: hw.module @Top(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<i8>, in %b : !esi.channel<i8>, out x : !esi.channel<i8>)
# CHECK: %0:2 = handshake.esi_instance @TestFunc "TestFunc" clk %clk rst %rst(%a, %b) : (!esi.channel<i8>, !esi.channel<i8>) -> (!esi.channel<i8>, !esi.channel<i8>)
# CHECK: hw.output %0#0 : !esi.channel<i8>
# CHECK: handshake.func @TestFunc(%arg0: i8, %arg1: i8, ...) -> (i8, i8)
# CHECK: %result, %index = control_merge %arg0, %arg1 : i8, i1
# CHECK: %c15_i8 = hw.constant 15 : i8
# CHECK: %0 = comb.and bin %arg0, %c15_i8 : i8
# CHECK: return %0 : i8
# CHECK: }
# CHECK: [[R0:%.+]] = comb.and bin %result, %c15_i8 : i8
# CHECK: %trueResult, %falseResult = cond_br %index, [[R0]] : i8
# CHECK: return %trueResult, %falseResult : i8, i8
class TestFunc(Func):
a = Input(Bits(8))
b = Input(Bits(8))
x = Output(Bits(8))
y = Output(Bits(8))
@generator
def build(ports):
ports.x = ports.a & Bits(8)(0xF)
c, sel = cmerge(ports.a, ports.b)
z = c & Bits(8)(0xF)
x, y = demux(sel, z)
ports.x = x
ports.y = y
BarType = types.struct({"foo": types.i12}, "bar")
@unittestmodule(print=True, run_passes=True)
@unittestmodule(print=True)
class Top(Module):
clk = Clock()
rst = Input(Bits(1))
a = Input(Channel(Bits(8)))
b = Input(Channel(Bits(8)))
x = Output(Channel(Bits(8)))
@generator
def build(ports):
test = TestFunc(clk=ports.clk, rst=ports.rst, a=ports.a)
test = TestFunc(clk=ports.clk, rst=ports.rst, a=ports.a, b=ports.b)
ports.x = test.x

View File

@ -471,7 +471,12 @@ def HandshakeToDC : Pass<"lower-handshake-to-dc", "mlir::ModuleOp"> {
function with graph region behaviour. Thus, for now, we just use `hw.module`
as a container operation.
}];
let dependentDialects = ["dc::DCDialect", "mlir::func::FuncDialect", "hw::HWDialect"];
let dependentDialects = [
"dc::DCDialect",
"mlir::arith::ArithDialect",
"mlir::func::FuncDialect",
"hw::HWDialect"
];
let options = [
Option<"clkName", "clk-name", "std::string", "\"clk\"",
"Name of the clock signal to use in the generated DC module">,

View File

@ -10,9 +10,13 @@
#include "circt/Conversion/Passes.h"
#include "circt/Dialect/DC/DCDialect.h"
#include "circt/Dialect/DC/DCPasses.h"
#include "circt/Transforms/Passes.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Support.h"
void registerDCPasses() { circt::dc::registerPasses(); }
void registerDCPasses() {
circt::registerMapArithToCombPass();
circt::dc::registerPasses();
}
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(DC, dc, circt::dc::DCDialect)