mirror of https://github.com/llvm/circt.git
[DC CAPI][PyCDE] Add cmerge and demux handshake ops
- Adds cmerge and demux functions to the handshake pycde module. - Lowering them requires fixes to the conversion pass and the CAPI code.
This commit is contained in:
parent
99826b8499
commit
bcc1e01cdb
|
@ -3,13 +3,14 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import List, Optional, Dict
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
|
||||
from .module import Module, ModuleLikeBuilderBase, PortError
|
||||
from .signals import BitsSignal, ChannelSignal, ClockSignal, Signal
|
||||
from .signals import (BitsSignal, ChannelSignal, ClockSignal, Signal,
|
||||
_FromCirctValue)
|
||||
from .system import System
|
||||
from .support import get_user_loc, obj_to_typed_attribute
|
||||
from .types import Channel
|
||||
from .support import clog2, get_user_loc
|
||||
from .types import Bits, Channel
|
||||
|
||||
from .circt.dialects import handshake as raw_handshake
|
||||
from .circt import ir
|
||||
|
@ -82,7 +83,7 @@ class FuncBuilder(ModuleLikeBuilderBase):
|
|||
# If the input is a channel signal, the types must match.
|
||||
if signal.type.inner_type != port.type:
|
||||
raise ValueError(
|
||||
f"Wrong type on input signal '{name}'. Got '{signal.type}',"
|
||||
f"Wrong type on input signal '{name}'. Got '{signal.type.inner_type}',"
|
||||
f" expected '{port.type}'")
|
||||
assert port.idx is not None
|
||||
circt_inputs[port.idx] = signal.value
|
||||
|
@ -124,3 +125,24 @@ class Func(Module):
|
|||
|
||||
BuilderType: type[ModuleLikeBuilderBase] = FuncBuilder
|
||||
_builder: FuncBuilder
|
||||
|
||||
|
||||
def demux(cond: BitsSignal, data: Signal) -> Tuple[Signal, Signal]:
|
||||
"""Demux a signal based on a condition."""
|
||||
condbr = raw_handshake.ConditionalBranchOp(cond.value, data.value)
|
||||
return (_FromCirctValue(condbr.trueResult),
|
||||
_FromCirctValue(condbr.falseResult))
|
||||
|
||||
|
||||
def cmerge(*args: Signal) -> Tuple[Signal, BitsSignal]:
|
||||
"""Merge multiple signals into one and the index of the signal."""
|
||||
if len(args) == 0:
|
||||
raise ValueError("cmerge must have at least one argument")
|
||||
first = args[0]
|
||||
for a in args[1:]:
|
||||
if a.type != first.type:
|
||||
raise ValueError("All arguments to cmerge must have the same type")
|
||||
idx_type = Bits(clog2(len(args)))
|
||||
cm = raw_handshake.ControlMergeOp(a.type._type, idx_type._type,
|
||||
[a.value for a in args])
|
||||
return (_FromCirctValue(cm.result), BitsSignal(cm.index, idx_type))
|
||||
|
|
|
@ -264,8 +264,8 @@ class System:
|
|||
# Then run all the passes to lower dialects which produce `hw.module`s.
|
||||
"builtin.module(lower-handshake-to-dc)",
|
||||
"builtin.module(dc-materialize-forks-sinks)",
|
||||
"builtin.module(canonicalize)",
|
||||
"builtin.module(lower-dc-to-hw)",
|
||||
"builtin.module(map-arith-to-comb)",
|
||||
|
||||
# Run ESI manifest passes.
|
||||
"builtin.module(esi-appid-hier{{top={tops} }}, esi-build-manifest{{top={tops} }})",
|
||||
|
@ -275,7 +275,6 @@ class System:
|
|||
# Instaniate hlmems, which could produce new esi connections.
|
||||
"builtin.module(hw.module(lower-seq-hlmem))",
|
||||
"builtin.module(lower-esi-to-physical)",
|
||||
# TODO: support more than just cosim.
|
||||
"builtin.module(lower-esi-bundles, lower-esi-ports)",
|
||||
"builtin.module(lower-esi-to-hw{{platform={platform}}})",
|
||||
"builtin.module(convert-fsm-to-sv)",
|
||||
|
|
|
@ -1,42 +1,50 @@
|
|||
# RUN: %PYTHON% %s | FileCheck %s
|
||||
|
||||
from pycde import (Clock, Output, Input, generator, types, Module)
|
||||
from pycde.handshake import Func
|
||||
from pycde.handshake import Func, cmerge, demux
|
||||
from pycde.testing import unittestmodule
|
||||
from pycde.types import Bits, Channel
|
||||
|
||||
# CHECK: hw.module @Top(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<i8>, out x : !esi.channel<i8>)
|
||||
# CHECK: [[R0:%.+]] = handshake.esi_instance @TestFunc "TestFunc" clk %clk rst %rst(%a) : (!esi.channel<i8>) -> !esi.channel<i8>
|
||||
# CHECK: hw.output [[R0]] : !esi.channel<i8>
|
||||
# CHECK: }
|
||||
# CHECK: handshake.func @TestFunc(%arg0: i8, ...) -> i8
|
||||
# CHECK: hw.module @Top(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel<i8>, in %b : !esi.channel<i8>, out x : !esi.channel<i8>)
|
||||
# CHECK: %0:2 = handshake.esi_instance @TestFunc "TestFunc" clk %clk rst %rst(%a, %b) : (!esi.channel<i8>, !esi.channel<i8>) -> (!esi.channel<i8>, !esi.channel<i8>)
|
||||
# CHECK: hw.output %0#0 : !esi.channel<i8>
|
||||
|
||||
# CHECK: handshake.func @TestFunc(%arg0: i8, %arg1: i8, ...) -> (i8, i8)
|
||||
# CHECK: %result, %index = control_merge %arg0, %arg1 : i8, i1
|
||||
# CHECK: %c15_i8 = hw.constant 15 : i8
|
||||
# CHECK: %0 = comb.and bin %arg0, %c15_i8 : i8
|
||||
# CHECK: return %0 : i8
|
||||
# CHECK: }
|
||||
# CHECK: [[R0:%.+]] = comb.and bin %result, %c15_i8 : i8
|
||||
# CHECK: %trueResult, %falseResult = cond_br %index, [[R0]] : i8
|
||||
# CHECK: return %trueResult, %falseResult : i8, i8
|
||||
|
||||
|
||||
class TestFunc(Func):
|
||||
a = Input(Bits(8))
|
||||
b = Input(Bits(8))
|
||||
x = Output(Bits(8))
|
||||
y = Output(Bits(8))
|
||||
|
||||
@generator
|
||||
def build(ports):
|
||||
ports.x = ports.a & Bits(8)(0xF)
|
||||
c, sel = cmerge(ports.a, ports.b)
|
||||
z = c & Bits(8)(0xF)
|
||||
x, y = demux(sel, z)
|
||||
ports.x = x
|
||||
ports.y = y
|
||||
|
||||
|
||||
BarType = types.struct({"foo": types.i12}, "bar")
|
||||
|
||||
|
||||
@unittestmodule(print=True, run_passes=True)
|
||||
@unittestmodule(print=True)
|
||||
class Top(Module):
|
||||
clk = Clock()
|
||||
rst = Input(Bits(1))
|
||||
|
||||
a = Input(Channel(Bits(8)))
|
||||
b = Input(Channel(Bits(8)))
|
||||
x = Output(Channel(Bits(8)))
|
||||
|
||||
@generator
|
||||
def build(ports):
|
||||
test = TestFunc(clk=ports.clk, rst=ports.rst, a=ports.a)
|
||||
test = TestFunc(clk=ports.clk, rst=ports.rst, a=ports.a, b=ports.b)
|
||||
ports.x = test.x
|
||||
|
|
|
@ -471,7 +471,12 @@ def HandshakeToDC : Pass<"lower-handshake-to-dc", "mlir::ModuleOp"> {
|
|||
function with graph region behaviour. Thus, for now, we just use `hw.module`
|
||||
as a container operation.
|
||||
}];
|
||||
let dependentDialects = ["dc::DCDialect", "mlir::func::FuncDialect", "hw::HWDialect"];
|
||||
let dependentDialects = [
|
||||
"dc::DCDialect",
|
||||
"mlir::arith::ArithDialect",
|
||||
"mlir::func::FuncDialect",
|
||||
"hw::HWDialect"
|
||||
];
|
||||
let options = [
|
||||
Option<"clkName", "clk-name", "std::string", "\"clk\"",
|
||||
"Name of the clock signal to use in the generated DC module">,
|
||||
|
|
|
@ -10,9 +10,13 @@
|
|||
#include "circt/Conversion/Passes.h"
|
||||
#include "circt/Dialect/DC/DCDialect.h"
|
||||
#include "circt/Dialect/DC/DCPasses.h"
|
||||
#include "circt/Transforms/Passes.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
|
||||
void registerDCPasses() { circt::dc::registerPasses(); }
|
||||
void registerDCPasses() {
|
||||
circt::registerMapArithToCombPass();
|
||||
circt::dc::registerPasses();
|
||||
}
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(DC, dc, circt::dc::DCDialect)
|
||||
|
|
Loading…
Reference in New Issue