mirror of https://github.com/llvm/circt.git
1460 lines
50 KiB
C++
1460 lines
50 KiB
C++
//===- HandshakeOps.cpp - Handshake MLIR Operations -----------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file contains the declaration of the Handshake operations struct.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "circt/Dialect/Handshake/HandshakeOps.h"
|
|
#include "circt/Support/LLVM.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/IntegerSet.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/SymbolTable.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Interfaces/FunctionImplementation.h"
|
|
#include "mlir/Transforms/InliningUtils.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/SmallBitVector.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
#include <set>
|
|
|
|
using namespace circt;
|
|
using namespace circt::handshake;
|
|
|
|
namespace circt {
|
|
namespace handshake {
|
|
#include "circt/Dialect/Handshake/HandshakeCanonicalization.h.inc"
|
|
|
|
bool isControlOpImpl(Operation *op) {
|
|
if (auto sostInterface = dyn_cast<SOSTInterface>(op); sostInterface)
|
|
return sostInterface.sostIsControl();
|
|
|
|
return false;
|
|
}
|
|
|
|
} // namespace handshake
|
|
} // namespace circt
|
|
|
|
static std::string defaultOperandName(unsigned int idx) {
|
|
return "in" + std::to_string(idx);
|
|
}
|
|
|
|
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, int &v) {
|
|
if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
static ParseResult
|
|
parseSostOperation(OpAsmParser &parser,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
|
|
OperationState &result, int &size, Type &type,
|
|
bool explicitSize) {
|
|
if (explicitSize)
|
|
if (parseIntInSquareBrackets(parser, size))
|
|
return failure();
|
|
|
|
if (parser.parseOperandList(operands) ||
|
|
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
|
parser.parseType(type))
|
|
return failure();
|
|
|
|
if (!explicitSize)
|
|
size = operands.size();
|
|
return success();
|
|
}
|
|
|
|
/// Verifies whether an indexing value is wide enough to index into a provided
|
|
/// number of operands.
|
|
static LogicalResult verifyIndexWideEnough(Operation *op, Value indexVal,
|
|
uint64_t numOperands) {
|
|
auto indexType = indexVal.getType();
|
|
unsigned indexWidth;
|
|
|
|
// Determine the bitwidth of the indexing value
|
|
if (auto integerType = indexType.dyn_cast<IntegerType>())
|
|
indexWidth = integerType.getWidth();
|
|
else if (indexType.isIndex())
|
|
indexWidth = IndexType::kInternalStorageBitWidth;
|
|
else
|
|
return op->emitError("unsupported type for indexing value: ") << indexType;
|
|
|
|
// Check whether the bitwidth can support the provided number of operands
|
|
if (indexWidth < 64) {
|
|
uint64_t maxNumOperands = (uint64_t)1 << indexWidth;
|
|
if (numOperands > maxNumOperands)
|
|
return op->emitError("bitwidth of indexing value is ")
|
|
<< indexWidth << ", which can index into " << maxNumOperands
|
|
<< " operands, but found " << numOperands << " operands";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static bool isControlCheckTypeAndOperand(Type dataType, Value operand) {
|
|
// The operation is a control operation if its operand data type is a
|
|
// NoneType.
|
|
if (dataType.isa<NoneType>())
|
|
return true;
|
|
|
|
// Otherwise, the operation is a control operation if the operation's
|
|
// operand originates from the control network
|
|
auto *defOp = operand.getDefiningOp();
|
|
return isa_and_nonnull<ControlMergeOp>(defOp) &&
|
|
operand == defOp->getResult(0);
|
|
}
|
|
|
|
template <typename TMemOp>
|
|
llvm::SmallVector<handshake::MemLoadInterface> getLoadPorts(TMemOp op) {
|
|
llvm::SmallVector<MemLoadInterface> ports;
|
|
// Memory interface refresher:
|
|
// Operands:
|
|
// all stores (stdata1, staddr1, stdata2, staddr2, ...)
|
|
// then all loads (ldaddr1, ldaddr2,...)
|
|
// Outputs: load addresses (lddata1, lddata2, ...), followed by all none
|
|
// outputs, ordered as operands(stnone1, stnone2, ... ldnone1, ldnone2, ...)
|
|
unsigned stCount = op.getStCount();
|
|
unsigned ldCount = op.getLdCount();
|
|
for (unsigned i = 0, e = ldCount; i != e; ++i) {
|
|
MemLoadInterface ldif;
|
|
ldif.index = i;
|
|
ldif.addressIn = op.getInputs()[stCount * 2 + i];
|
|
ldif.dataOut = op.getResult(i);
|
|
ldif.doneOut = op.getResult(ldCount + stCount + i);
|
|
ports.push_back(ldif);
|
|
}
|
|
return ports;
|
|
}
|
|
|
|
template <typename TMemOp>
|
|
llvm::SmallVector<handshake::MemStoreInterface> getStorePorts(TMemOp op) {
|
|
llvm::SmallVector<MemStoreInterface> ports;
|
|
// Memory interface refresher:
|
|
// Operands:
|
|
// all stores (stdata1, staddr1, stdata2, staddr2, ...)
|
|
// then all loads (ldaddr1, ldaddr2,...)
|
|
// Outputs: load data (lddata1, lddata2, ...), followed by all none
|
|
// outputs, ordered as operands(stnone1, stnone2, ... ldnone1, ldnone2, ...)
|
|
unsigned ldCount = op.getLdCount();
|
|
for (unsigned i = 0, e = op.getStCount(); i != e; ++i) {
|
|
MemStoreInterface stif;
|
|
stif.index = i;
|
|
stif.dataIn = op.getInputs()[i * 2];
|
|
stif.addressIn = op.getInputs()[i * 2 + 1];
|
|
stif.doneOut = op.getResult(ldCount + i);
|
|
ports.push_back(stif);
|
|
}
|
|
return ports;
|
|
}
|
|
|
|
unsigned ForkOp::getSize() { return getResults().size(); }
|
|
|
|
static ParseResult parseForkOp(OpAsmParser &parser, OperationState &result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
|
|
Type type;
|
|
ArrayRef<Type> operandTypes(type);
|
|
SmallVector<Type, 1> resultTypes;
|
|
llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
int size;
|
|
if (parseSostOperation(parser, allOperands, result, size, type, true))
|
|
return failure();
|
|
|
|
resultTypes.assign(size, type);
|
|
result.addTypes(resultTypes);
|
|
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
|
result.operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
return parseForkOp(parser, result);
|
|
}
|
|
|
|
void ForkOp::print(OpAsmPrinter &p) { sostPrint(p, true); }
|
|
|
|
namespace {
|
|
|
|
struct EliminateUnusedForkResultsPattern : mlir::OpRewritePattern<ForkOp> {
|
|
using mlir::OpRewritePattern<ForkOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ForkOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
std::set<unsigned> unusedIndexes;
|
|
|
|
for (auto res : llvm::enumerate(op.getResults()))
|
|
if (res.value().getUses().empty())
|
|
unusedIndexes.insert(res.index());
|
|
|
|
if (unusedIndexes.empty())
|
|
return failure();
|
|
|
|
// Create a new fork op, dropping the unused results.
|
|
rewriter.setInsertionPoint(op);
|
|
auto operand = op.getOperand();
|
|
auto newFork = rewriter.create<ForkOp>(
|
|
op.getLoc(), operand, op.getNumResults() - unusedIndexes.size());
|
|
rewriter.modifyOpInPlace(op, [&] {
|
|
unsigned i = 0;
|
|
for (auto oldRes : llvm::enumerate(op.getResults()))
|
|
if (unusedIndexes.count(oldRes.index()) == 0)
|
|
oldRes.value().replaceAllUsesWith(newFork.getResults()[i++]);
|
|
});
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct EliminateForkToForkPattern : mlir::OpRewritePattern<ForkOp> {
|
|
using mlir::OpRewritePattern<ForkOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ForkOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto parentForkOp = op.getOperand().getDefiningOp<ForkOp>();
|
|
if (!parentForkOp)
|
|
return failure();
|
|
|
|
/// Create the fork with as many outputs as the two source forks.
|
|
/// Keeping the op.operand() output may or may not be redundant (dependning
|
|
/// on if op is the single user of the value), but we'll let
|
|
/// EliminateUnusedForkResultsPattern apply in that case.
|
|
unsigned totalNumOuts = op.getSize() + parentForkOp.getSize();
|
|
rewriter.modifyOpInPlace(parentForkOp, [&] {
|
|
/// Create a new parent fork op which produces all of the fork outputs and
|
|
/// replace all of the uses of the old results.
|
|
auto newParentForkOp = rewriter.create<ForkOp>(
|
|
parentForkOp.getLoc(), parentForkOp.getOperand(), totalNumOuts);
|
|
|
|
for (auto it :
|
|
llvm::zip(parentForkOp->getResults(), newParentForkOp.getResults()))
|
|
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
|
|
|
|
/// Replace the results of the matches fork op with the corresponding
|
|
/// results of the new parent fork op.
|
|
rewriter.replaceOp(op,
|
|
newParentForkOp.getResults().take_back(op.getSize()));
|
|
});
|
|
rewriter.eraseOp(parentForkOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void handshake::ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.insert<circt::handshake::EliminateSimpleForksPattern,
|
|
EliminateForkToForkPattern, EliminateUnusedForkResultsPattern>(
|
|
context);
|
|
}
|
|
|
|
unsigned LazyForkOp::getSize() { return getResults().size(); }
|
|
|
|
bool LazyForkOp::sostIsControl() {
|
|
return isControlCheckTypeAndOperand(getDataType(), getOperand());
|
|
}
|
|
|
|
ParseResult LazyForkOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
return parseForkOp(parser, result);
|
|
}
|
|
|
|
void LazyForkOp::print(OpAsmPrinter &p) { sostPrint(p, true); }
|
|
|
|
ParseResult MergeOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
|
|
Type type;
|
|
ArrayRef<Type> operandTypes(type);
|
|
SmallVector<Type, 1> resultTypes, dataOperandsTypes;
|
|
llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
int size;
|
|
if (parseSostOperation(parser, allOperands, result, size, type, false))
|
|
return failure();
|
|
|
|
dataOperandsTypes.assign(size, type);
|
|
resultTypes.push_back(type);
|
|
result.addTypes(resultTypes);
|
|
if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
|
|
result.operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void MergeOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
|
|
|
|
void MergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.insert<circt::handshake::EliminateSimpleMergesPattern>(context);
|
|
}
|
|
|
|
/// Returns a dematerialized version of the value 'v', defined as the source of
|
|
/// the value before passing through a buffer or fork operation.
|
|
static Value getDematerialized(Value v) {
|
|
Operation *parentOp = v.getDefiningOp();
|
|
if (!parentOp)
|
|
return v;
|
|
|
|
return llvm::TypeSwitch<Operation *, Value>(parentOp)
|
|
.Case<ForkOp>(
|
|
[&](ForkOp op) { return getDematerialized(op.getOperand()); })
|
|
.Case<BufferOp>(
|
|
[&](BufferOp op) { return getDematerialized(op.getOperand()); })
|
|
.Default([&](auto) { return v; });
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// Eliminates muxes with identical data inputs. Data inputs are inspected as
|
|
/// their dematerialized versions. This has the side effect of any subsequently
|
|
/// unused buffers are DCE'd and forks are optimized to be narrower.
|
|
struct EliminateSimpleMuxesPattern : mlir::OpRewritePattern<MuxOp> {
|
|
using mlir::OpRewritePattern<MuxOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(MuxOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Value firstDataOperand = getDematerialized(op.getDataOperands()[0]);
|
|
if (!llvm::all_of(op.getDataOperands(), [&](Value operand) {
|
|
return getDematerialized(operand) == firstDataOperand;
|
|
}))
|
|
return failure();
|
|
rewriter.replaceOp(op, firstDataOperand);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct EliminateUnaryMuxesPattern : OpRewritePattern<MuxOp> {
|
|
using mlir::OpRewritePattern<MuxOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(MuxOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (op.getDataOperands().size() != 1)
|
|
return failure();
|
|
|
|
rewriter.replaceOp(op, op.getDataOperands()[0]);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct EliminateCBranchIntoMuxPattern : OpRewritePattern<MuxOp> {
|
|
using mlir::OpRewritePattern<MuxOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(MuxOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
auto dataOperands = op.getDataOperands();
|
|
if (dataOperands.size() != 2)
|
|
return failure();
|
|
|
|
// Both data operands must originate from the same cbranch
|
|
ConditionalBranchOp firstParentCBranch =
|
|
dataOperands[0].getDefiningOp<ConditionalBranchOp>();
|
|
if (!firstParentCBranch)
|
|
return failure();
|
|
auto secondParentCBranch =
|
|
dataOperands[1].getDefiningOp<ConditionalBranchOp>();
|
|
if (!secondParentCBranch || firstParentCBranch != secondParentCBranch)
|
|
return failure();
|
|
|
|
rewriter.modifyOpInPlace(firstParentCBranch, [&] {
|
|
// Replace uses of the mux's output with cbranch's data input
|
|
rewriter.replaceOp(op, firstParentCBranch.getDataOperand());
|
|
});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.insert<EliminateSimpleMuxesPattern, EliminateUnaryMuxesPattern,
|
|
EliminateCBranchIntoMuxPattern>(context);
|
|
}
|
|
|
|
LogicalResult
|
|
MuxOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
|
|
ValueRange operands, DictionaryAttr attributes,
|
|
mlir::OpaqueProperties properties,
|
|
mlir::RegionRange regions,
|
|
SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
|
|
// MuxOp must have at least one data operand (in addition to the select
|
|
// operand)
|
|
if (operands.size() < 2)
|
|
return failure();
|
|
// Result type is type of any data operand
|
|
inferredReturnTypes.push_back(operands[1].getType());
|
|
return success();
|
|
}
|
|
|
|
bool MuxOp::isControl() { return getResult().getType().isa<NoneType>(); }
|
|
|
|
std::string handshake::MuxOp::getOperandName(unsigned int idx) {
|
|
return idx == 0 ? "select" : defaultOperandName(idx - 1);
|
|
}
|
|
|
|
ParseResult MuxOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
OpAsmParser::UnresolvedOperand selectOperand;
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
|
|
Type selectType, dataType;
|
|
SmallVector<Type, 1> dataOperandsTypes;
|
|
llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
if (parser.parseOperand(selectOperand) || parser.parseLSquare() ||
|
|
parser.parseOperandList(allOperands) || parser.parseRSquare() ||
|
|
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
|
parser.parseType(selectType) || parser.parseComma() ||
|
|
parser.parseType(dataType))
|
|
return failure();
|
|
|
|
int size = allOperands.size();
|
|
dataOperandsTypes.assign(size, dataType);
|
|
result.addTypes(dataType);
|
|
allOperands.insert(allOperands.begin(), selectOperand);
|
|
if (parser.resolveOperands(
|
|
allOperands,
|
|
llvm::concat<const Type>(ArrayRef<Type>(selectType),
|
|
ArrayRef<Type>(dataOperandsTypes)),
|
|
allOperandLoc, result.operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void MuxOp::print(OpAsmPrinter &p) {
|
|
Type selectType = getSelectOperand().getType();
|
|
auto ops = getOperands();
|
|
p << ' ' << ops.front();
|
|
p << " [";
|
|
p.printOperands(ops.drop_front());
|
|
p << "]";
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
p << " : " << selectType << ", " << getResult().getType();
|
|
}
|
|
|
|
LogicalResult MuxOp::verify() {
|
|
return verifyIndexWideEnough(*this, getSelectOperand(),
|
|
getDataOperands().size());
|
|
}
|
|
|
|
std::string handshake::ControlMergeOp::getResultName(unsigned int idx) {
|
|
assert(idx == 0 || idx == 1);
|
|
return idx == 0 ? "dataOut" : "index";
|
|
}
|
|
|
|
ParseResult ControlMergeOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
|
|
Type resultType, indexType;
|
|
SmallVector<Type> resultTypes, dataOperandsTypes;
|
|
llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
int size;
|
|
if (parseSostOperation(parser, allOperands, result, size, resultType, false))
|
|
return failure();
|
|
// Parse type of index result
|
|
if (parser.parseComma() || parser.parseType(indexType))
|
|
return failure();
|
|
|
|
dataOperandsTypes.assign(size, resultType);
|
|
resultTypes.push_back(resultType);
|
|
resultTypes.push_back(indexType);
|
|
result.addTypes(resultTypes);
|
|
if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
|
|
result.operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void ControlMergeOp::print(OpAsmPrinter &p) {
|
|
sostPrint(p, false);
|
|
// Print type of index result
|
|
p << ", " << getIndex().getType();
|
|
}
|
|
|
|
LogicalResult ControlMergeOp::verify() {
|
|
auto operands = getOperands();
|
|
if (operands.empty())
|
|
return emitOpError("operation must have at least one operand");
|
|
if (operands[0].getType() != getResult().getType())
|
|
return emitOpError("type of first result should match type of operands");
|
|
return verifyIndexWideEnough(*this, getIndex(), getNumOperands());
|
|
}
|
|
|
|
LogicalResult FuncOp::verify() {
|
|
// If this function is external there is nothing to do.
|
|
if (isExternal())
|
|
return success();
|
|
|
|
// Verify that the argument list of the function and the arg list of the
|
|
// entry block line up. The trait already verified that the number of
|
|
// arguments is the same between the signature and the block.
|
|
auto fnInputTypes = getArgumentTypes();
|
|
Block &entryBlock = front();
|
|
|
|
for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
|
|
if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
|
|
return emitOpError("type of entry block argument #")
|
|
<< i << '(' << entryBlock.getArgument(i).getType()
|
|
<< ") must match the type of the corresponding argument in "
|
|
<< "function signature(" << fnInputTypes[i] << ')';
|
|
|
|
// Verify that we have a name for each argument and result of this function.
|
|
auto verifyPortNameAttr = [&](StringRef attrName,
|
|
unsigned numIOs) -> LogicalResult {
|
|
auto portNamesAttr = (*this)->getAttrOfType<ArrayAttr>(attrName);
|
|
|
|
if (!portNamesAttr)
|
|
return emitOpError() << "expected attribute '" << attrName << "'.";
|
|
|
|
auto portNames = portNamesAttr.getValue();
|
|
if (portNames.size() != numIOs)
|
|
return emitOpError() << "attribute '" << attrName << "' has "
|
|
<< portNames.size()
|
|
<< " entries but is expected to have " << numIOs
|
|
<< ".";
|
|
|
|
if (llvm::any_of(portNames,
|
|
[&](Attribute attr) { return !attr.isa<StringAttr>(); }))
|
|
return emitOpError() << "expected all entries in attribute '" << attrName
|
|
<< "' to be strings.";
|
|
|
|
return success();
|
|
};
|
|
if (failed(verifyPortNameAttr("argNames", getNumArguments())))
|
|
return failure();
|
|
if (failed(verifyPortNameAttr("resNames", getNumResults())))
|
|
return failure();
|
|
|
|
// Verify that all memrefs have a corresponding extmemory operation
|
|
for (auto arg : entryBlock.getArguments()) {
|
|
if (!arg.getType().isa<MemRefType>())
|
|
continue;
|
|
if (arg.getUsers().empty() ||
|
|
!isa<ExternalMemoryOp>(*arg.getUsers().begin()))
|
|
return emitOpError("expected that block argument #")
|
|
<< arg.getArgNumber() << " is used by an 'extmemory' operation";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Parses a FuncOp signature using
|
|
/// mlir::function_interface_impl::parseFunctionSignature while getting access
|
|
/// to the parsed SSA names to store as attributes.
|
|
static ParseResult
|
|
parseFuncOpArgs(OpAsmParser &parser,
|
|
SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
|
|
SmallVectorImpl<Type> &resTypes,
|
|
SmallVectorImpl<DictionaryAttr> &resAttrs) {
|
|
bool isVariadic;
|
|
if (mlir::function_interface_impl::parseFunctionSignature(
|
|
parser, /*allowVariadic=*/true, entryArgs, isVariadic, resTypes,
|
|
resAttrs)
|
|
.failed())
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Generates names for a handshake.func input and output arguments, based on
|
|
/// the number of args as well as a prefix.
|
|
static SmallVector<Attribute> getFuncOpNames(Builder &builder, unsigned cnt,
|
|
StringRef prefix) {
|
|
SmallVector<Attribute> resNames;
|
|
for (unsigned i = 0; i < cnt; ++i)
|
|
resNames.push_back(builder.getStringAttr(prefix + std::to_string(i)));
|
|
return resNames;
|
|
}
|
|
|
|
void handshake::FuncOp::build(OpBuilder &builder, OperationState &state,
|
|
StringRef name, FunctionType type,
|
|
ArrayRef<NamedAttribute> attrs) {
|
|
state.addAttribute(SymbolTable::getSymbolAttrName(),
|
|
builder.getStringAttr(name));
|
|
state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
|
|
TypeAttr::get(type));
|
|
state.attributes.append(attrs.begin(), attrs.end());
|
|
|
|
if (const auto *argNamesAttrIt = llvm::find_if(
|
|
attrs, [&](auto attr) { return attr.getName() == "argNames"; });
|
|
argNamesAttrIt == attrs.end())
|
|
state.addAttribute("argNames", builder.getArrayAttr({}));
|
|
|
|
if (llvm::find_if(attrs, [&](auto attr) {
|
|
return attr.getName() == "resNames";
|
|
}) == attrs.end())
|
|
state.addAttribute("resNames", builder.getArrayAttr({}));
|
|
|
|
state.addRegion();
|
|
}
|
|
|
|
/// Helper function for appending a string to an array attribute, and
|
|
/// rewriting the attribute back to the operation.
|
|
static void addStringToStringArrayAttr(Builder &builder, Operation *op,
|
|
StringRef attrName, StringAttr str) {
|
|
llvm::SmallVector<Attribute> attrs;
|
|
llvm::copy(op->getAttrOfType<ArrayAttr>(attrName).getValue(),
|
|
std::back_inserter(attrs));
|
|
attrs.push_back(str);
|
|
op->setAttr(attrName, builder.getArrayAttr(attrs));
|
|
}
|
|
|
|
void handshake::FuncOp::resolveArgAndResNames() {
|
|
Builder builder(getContext());
|
|
|
|
/// Generate a set of fallback names. These are used in case names are
|
|
/// missing from the currently set arg- and res name attributes.
|
|
auto fallbackArgNames = getFuncOpNames(builder, getNumArguments(), "in");
|
|
auto fallbackResNames = getFuncOpNames(builder, getNumResults(), "out");
|
|
auto argNames = getArgNames().getValue();
|
|
auto resNames = getResNames().getValue();
|
|
|
|
/// Use fallback names where actual names are missing.
|
|
auto resolveNames = [&](auto &fallbackNames, auto &actualNames,
|
|
StringRef attrName) {
|
|
for (auto fallbackName : llvm::enumerate(fallbackNames)) {
|
|
if (actualNames.size() <= fallbackName.index())
|
|
addStringToStringArrayAttr(
|
|
builder, this->getOperation(), attrName,
|
|
fallbackName.value().template cast<StringAttr>());
|
|
}
|
|
};
|
|
resolveNames(fallbackArgNames, argNames, "argNames");
|
|
resolveNames(fallbackResNames, resNames, "resNames");
|
|
}
|
|
|
|
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
auto &builder = parser.getBuilder();
|
|
StringAttr nameAttr;
|
|
SmallVector<OpAsmParser::Argument> args;
|
|
SmallVector<Type> resTypes;
|
|
SmallVector<DictionaryAttr> resAttributes;
|
|
SmallVector<Attribute> argNames;
|
|
|
|
// Parse visibility.
|
|
(void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
|
|
|
|
// Parse signature
|
|
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
|
result.attributes) ||
|
|
parseFuncOpArgs(parser, args, resTypes, resAttributes))
|
|
return failure();
|
|
mlir::function_interface_impl::addArgAndResultAttrs(
|
|
builder, result, args, resAttributes,
|
|
handshake::FuncOp::getArgAttrsAttrName(result.name),
|
|
handshake::FuncOp::getResAttrsAttrName(result.name));
|
|
|
|
// Set function type
|
|
SmallVector<Type> argTypes;
|
|
for (auto arg : args)
|
|
argTypes.push_back(arg.type);
|
|
|
|
result.addAttribute(
|
|
handshake::FuncOp::getFunctionTypeAttrName(result.name),
|
|
TypeAttr::get(builder.getFunctionType(argTypes, resTypes)));
|
|
|
|
// Determine the names of the arguments. If no SSA values are present, use
|
|
// fallback names.
|
|
bool noSSANames =
|
|
llvm::any_of(args, [](auto arg) { return arg.ssaName.name.empty(); });
|
|
if (noSSANames) {
|
|
argNames = getFuncOpNames(builder, args.size(), "in");
|
|
} else {
|
|
llvm::transform(args, std::back_inserter(argNames), [&](auto arg) {
|
|
return builder.getStringAttr(arg.ssaName.name.drop_front());
|
|
});
|
|
}
|
|
|
|
// Parse attributes
|
|
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
|
|
return failure();
|
|
|
|
// If argNames and resNames wasn't provided manually, infer argNames attribute
|
|
// from the parsed SSA names and resNames from our naming convention.
|
|
if (!result.attributes.get("argNames"))
|
|
result.addAttribute("argNames", builder.getArrayAttr(argNames));
|
|
if (!result.attributes.get("resNames")) {
|
|
auto resNames = getFuncOpNames(builder, resTypes.size(), "out");
|
|
result.addAttribute("resNames", builder.getArrayAttr(resNames));
|
|
}
|
|
|
|
// Parse the optional function body. The printer will not print the body if
|
|
// its empty, so disallow parsing of empty body in the parser.
|
|
auto *body = result.addRegion();
|
|
llvm::SMLoc loc = parser.getCurrentLocation();
|
|
auto parseResult = parser.parseOptionalRegion(*body, args,
|
|
/*enableNameShadowing=*/false);
|
|
if (!parseResult.has_value())
|
|
return success();
|
|
|
|
if (failed(*parseResult))
|
|
return failure();
|
|
// Function body was parsed, make sure its not empty.
|
|
if (body->empty())
|
|
return parser.emitError(loc, "expected non-empty function body");
|
|
|
|
// If a body was parsed, the arg and res names need to be resolved
|
|
return success();
|
|
}
|
|
|
|
void FuncOp::print(OpAsmPrinter &p) {
|
|
mlir::function_interface_impl::printFunctionOp(
|
|
p, *this, /*isVariadic=*/true, getFunctionTypeAttrName(),
|
|
getArgAttrsAttrName(), getResAttrsAttrName());
|
|
}
|
|
|
|
namespace {
|
|
struct EliminateSimpleControlMergesPattern
|
|
: mlir::OpRewritePattern<ControlMergeOp> {
|
|
using mlir::OpRewritePattern<ControlMergeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ControlMergeOp op,
|
|
PatternRewriter &rewriter) const override;
|
|
};
|
|
} // namespace
|
|
|
|
LogicalResult EliminateSimpleControlMergesPattern::matchAndRewrite(
|
|
ControlMergeOp op, PatternRewriter &rewriter) const {
|
|
auto dataResult = op.getResult();
|
|
auto choiceResult = op.getIndex();
|
|
auto choiceUnused = choiceResult.use_empty();
|
|
if (!choiceUnused && !choiceResult.hasOneUse())
|
|
return failure();
|
|
|
|
Operation *choiceUser = nullptr;
|
|
if (choiceResult.hasOneUse()) {
|
|
choiceUser = choiceResult.getUses().begin().getUser();
|
|
if (!isa<SinkOp>(choiceUser))
|
|
return failure();
|
|
}
|
|
|
|
auto merge = rewriter.create<MergeOp>(op.getLoc(), op.getDataOperands());
|
|
|
|
for (auto &use : llvm::make_early_inc_range(dataResult.getUses())) {
|
|
auto *user = use.getOwner();
|
|
rewriter.modifyOpInPlace(
|
|
user, [&]() { user->setOperand(use.getOperandNumber(), merge); });
|
|
}
|
|
|
|
if (choiceUnused) {
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
|
|
rewriter.eraseOp(choiceUser);
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
|
|
void ControlMergeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.insert<EliminateSimpleControlMergesPattern>(context);
|
|
}
|
|
|
|
bool BranchOp::sostIsControl() {
|
|
return isControlCheckTypeAndOperand(getDataType(), getOperand());
|
|
}
|
|
|
|
void BranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.insert<circt::handshake::EliminateSimpleBranchesPattern>(context);
|
|
}
|
|
|
|
ParseResult BranchOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
|
|
Type type;
|
|
ArrayRef<Type> operandTypes(type);
|
|
SmallVector<Type, 1> dataOperandsTypes;
|
|
llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
int size;
|
|
if (parseSostOperation(parser, allOperands, result, size, type, false))
|
|
return failure();
|
|
|
|
dataOperandsTypes.assign(size, type);
|
|
result.addTypes({type});
|
|
if (parser.resolveOperands(allOperands, dataOperandsTypes, allOperandLoc,
|
|
result.operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void BranchOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
|
|
|
|
ParseResult ConditionalBranchOp::parse(OpAsmParser &parser,
|
|
OperationState &result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
|
|
Type dataType;
|
|
SmallVector<Type> operandTypes;
|
|
llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
if (parser.parseOperandList(allOperands) ||
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
parser.parseColonType(dataType))
|
|
return failure();
|
|
|
|
if (allOperands.size() != 2)
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"Expected exactly 2 operands");
|
|
|
|
result.addTypes({dataType, dataType});
|
|
operandTypes.push_back(IntegerType::get(parser.getContext(), 1));
|
|
operandTypes.push_back(dataType);
|
|
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
|
result.operands))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
void ConditionalBranchOp::print(OpAsmPrinter &p) {
|
|
Type type = getDataOperand().getType();
|
|
p << " " << getOperands();
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
p << " : " << type;
|
|
}
|
|
|
|
std::string handshake::ConditionalBranchOp::getOperandName(unsigned int idx) {
|
|
assert(idx == 0 || idx == 1);
|
|
return idx == 0 ? "cond" : "data";
|
|
}
|
|
|
|
std::string handshake::ConditionalBranchOp::getResultName(unsigned int idx) {
|
|
assert(idx == 0 || idx == 1);
|
|
return idx == ConditionalBranchOp::falseIndex ? "outFalse" : "outTrue";
|
|
}
|
|
|
|
bool ConditionalBranchOp::isControl() {
|
|
return isControlCheckTypeAndOperand(getDataOperand().getType(),
|
|
getDataOperand());
|
|
}
|
|
|
|
ParseResult SinkOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
|
|
Type type;
|
|
ArrayRef<Type> operandTypes(type);
|
|
llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
int size;
|
|
if (parseSostOperation(parser, allOperands, result, size, type, false))
|
|
return failure();
|
|
|
|
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
|
result.operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void SinkOp::print(OpAsmPrinter &p) { sostPrint(p, false); }
|
|
|
|
std::string handshake::ConstantOp::getOperandName(unsigned int idx) {
|
|
assert(idx == 0);
|
|
return "ctrl";
|
|
}
|
|
|
|
Type SourceOp::getDataType() { return getResult().getType(); }
|
|
unsigned SourceOp::getSize() { return 1; }
|
|
|
|
ParseResult SourceOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
result.addTypes(NoneType::get(result.getContext()));
|
|
return success();
|
|
}
|
|
|
|
void SourceOp::print(OpAsmPrinter &p) {
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
}
|
|
|
|
LogicalResult ConstantOp::verify() {
|
|
// Verify that the type of the provided value is equal to the result type.
|
|
auto typedValue = getValue().dyn_cast<mlir::TypedAttr>();
|
|
if (!typedValue)
|
|
return emitOpError("constant value must be a typed attribute; value is ")
|
|
<< getValue();
|
|
if (typedValue.getType() != getResult().getType())
|
|
return emitOpError() << "constant value type " << typedValue.getType()
|
|
<< " differs from operation result type "
|
|
<< getResult().getType();
|
|
|
|
return success();
|
|
}
|
|
|
|
void handshake::ConstantOp::getCanonicalizationPatterns(
|
|
RewritePatternSet &results, MLIRContext *context) {
|
|
results.insert<circt::handshake::EliminateSunkConstantsPattern>(context);
|
|
}
|
|
|
|
LogicalResult BufferOp::verify() {
|
|
// Verify that exactly 'size' number of initial values have been provided, if
|
|
// an initializer list have been provided.
|
|
if (auto initVals = getInitValues()) {
|
|
if (!isSequential())
|
|
return emitOpError()
|
|
<< "only bufferType buffers are allowed to have initial values.";
|
|
|
|
auto nInits = initVals->size();
|
|
if (nInits != getSize())
|
|
return emitOpError() << "expected " << getSize()
|
|
<< " init values but got " << nInits << ".";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
void handshake::BufferOp::getCanonicalizationPatterns(
|
|
RewritePatternSet &results, MLIRContext *context) {
|
|
results.insert<circt::handshake::EliminateSunkBuffersPattern>(context);
|
|
}
|
|
|
|
unsigned BufferOp::getSize() {
|
|
return (*this)->getAttrOfType<IntegerAttr>("slots").getValue().getZExtValue();
|
|
}
|
|
|
|
ParseResult BufferOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
|
|
Type type;
|
|
ArrayRef<Type> operandTypes(type);
|
|
llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
int slots;
|
|
if (parseIntInSquareBrackets(parser, slots))
|
|
return failure();
|
|
|
|
auto bufferTypeAttr = BufferTypeEnumAttr::parse(parser, {});
|
|
if (!bufferTypeAttr)
|
|
return failure();
|
|
|
|
result.addAttribute(
|
|
"slots",
|
|
IntegerAttr::get(IntegerType::get(result.getContext(), 32), slots));
|
|
result.addAttribute("bufferType", bufferTypeAttr);
|
|
|
|
if (parser.parseOperandList(allOperands) ||
|
|
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
|
parser.parseType(type))
|
|
return failure();
|
|
|
|
result.addTypes({type});
|
|
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
|
result.operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void BufferOp::print(OpAsmPrinter &p) {
|
|
int size =
|
|
(*this)->getAttrOfType<IntegerAttr>("slots").getValue().getZExtValue();
|
|
p << " [" << size << "]";
|
|
p << " " << stringifyEnum(getBufferType());
|
|
p << " " << (*this)->getOperands();
|
|
p.printOptionalAttrDict((*this)->getAttrs(), {"slots", "bufferType"});
|
|
p << " : " << (*this).getDataType();
|
|
}
|
|
|
|
static std::string getMemoryOperandName(unsigned nStores, unsigned idx) {
|
|
std::string name;
|
|
if (idx < nStores * 2) {
|
|
bool isData = idx % 2 == 0;
|
|
name = isData ? "stData" + std::to_string(idx / 2)
|
|
: "stAddr" + std::to_string(idx / 2);
|
|
} else {
|
|
idx -= 2 * nStores;
|
|
name = "ldAddr" + std::to_string(idx);
|
|
}
|
|
return name;
|
|
}
|
|
|
|
std::string handshake::MemoryOp::getOperandName(unsigned int idx) {
|
|
return getMemoryOperandName(getStCount(), idx);
|
|
}
|
|
|
|
static std::string getMemoryResultName(unsigned nLoads, unsigned nStores,
|
|
unsigned idx) {
|
|
std::string name;
|
|
if (idx < nLoads)
|
|
name = "ldData" + std::to_string(idx);
|
|
else if (idx < nLoads + nStores)
|
|
name = "stDone" + std::to_string(idx - nLoads);
|
|
else
|
|
name = "ldDone" + std::to_string(idx - nLoads - nStores);
|
|
return name;
|
|
}
|
|
|
|
std::string handshake::MemoryOp::getResultName(unsigned int idx) {
|
|
return getMemoryResultName(getLdCount(), getStCount(), idx);
|
|
}
|
|
|
|
LogicalResult MemoryOp::verify() {
|
|
auto memrefType = getMemRefType();
|
|
|
|
if (memrefType.getNumDynamicDims() != 0)
|
|
return emitOpError()
|
|
<< "memref dimensions for handshake.memory must be static.";
|
|
if (memrefType.getShape().size() != 1)
|
|
return emitOpError() << "memref must have only a single dimension.";
|
|
|
|
unsigned opStCount = getStCount();
|
|
unsigned opLdCount = getLdCount();
|
|
int addressCount = memrefType.getShape().size();
|
|
|
|
auto inputType = getInputs().getType();
|
|
auto outputType = getOutputs().getType();
|
|
Type dataType = memrefType.getElementType();
|
|
|
|
unsigned numOperands = static_cast<int>(getInputs().size());
|
|
unsigned numResults = static_cast<int>(getOutputs().size());
|
|
if (numOperands != (1 + addressCount) * opStCount + addressCount * opLdCount)
|
|
return emitOpError("number of operands ")
|
|
<< numOperands << " does not match number expected of "
|
|
<< 2 * opStCount + opLdCount << " with " << addressCount
|
|
<< " address inputs per port";
|
|
|
|
if (numResults != opStCount + 2 * opLdCount)
|
|
return emitOpError("number of results ")
|
|
<< numResults << " does not match number expected of "
|
|
<< opStCount + 2 * opLdCount << " with " << addressCount
|
|
<< " address inputs per port";
|
|
|
|
Type addressType = opStCount > 0 ? inputType[1] : inputType[0];
|
|
|
|
for (unsigned i = 0; i < opStCount; i++) {
|
|
if (inputType[2 * i] != dataType)
|
|
return emitOpError("data type for store port ")
|
|
<< i << ":" << inputType[2 * i] << " doesn't match memory type "
|
|
<< dataType;
|
|
if (inputType[2 * i + 1] != addressType)
|
|
return emitOpError("address type for store port ")
|
|
<< i << ":" << inputType[2 * i + 1]
|
|
<< " doesn't match address type " << addressType;
|
|
}
|
|
for (unsigned i = 0; i < opLdCount; i++) {
|
|
Type ldAddressType = inputType[2 * opStCount + i];
|
|
if (ldAddressType != addressType)
|
|
return emitOpError("address type for load port ")
|
|
<< i << ":" << ldAddressType << " doesn't match address type "
|
|
<< addressType;
|
|
}
|
|
for (unsigned i = 0; i < opLdCount; i++) {
|
|
if (outputType[i] != dataType)
|
|
return emitOpError("data type for load port ")
|
|
<< i << ":" << outputType[i] << " doesn't match memory type "
|
|
<< dataType;
|
|
}
|
|
for (unsigned i = 0; i < opStCount; i++) {
|
|
Type syncType = outputType[opLdCount + i];
|
|
if (!syncType.isa<NoneType>())
|
|
return emitOpError("data type for sync port for store port ")
|
|
<< i << ":" << syncType << " is not 'none'";
|
|
}
|
|
for (unsigned i = 0; i < opLdCount; i++) {
|
|
Type syncType = outputType[opLdCount + opStCount + i];
|
|
if (!syncType.isa<NoneType>())
|
|
return emitOpError("data type for sync port for load port ")
|
|
<< i << ":" << syncType << " is not 'none'";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
std::string handshake::ExternalMemoryOp::getOperandName(unsigned int idx) {
|
|
if (idx == 0)
|
|
return "extmem";
|
|
|
|
return getMemoryOperandName(getStCount(), idx - 1);
|
|
}
|
|
|
|
std::string handshake::ExternalMemoryOp::getResultName(unsigned int idx) {
|
|
return getMemoryResultName(getLdCount(), getStCount(), idx);
|
|
}
|
|
|
|
void ExternalMemoryOp::build(OpBuilder &builder, OperationState &result,
|
|
Value memref, ValueRange inputs, int ldCount,
|
|
int stCount, int id) {
|
|
SmallVector<Value> ops;
|
|
ops.push_back(memref);
|
|
llvm::append_range(ops, inputs);
|
|
result.addOperands(ops);
|
|
|
|
auto memrefType = memref.getType().cast<MemRefType>();
|
|
|
|
// Data outputs (get their type from memref)
|
|
result.types.append(ldCount, memrefType.getElementType());
|
|
|
|
// Control outputs
|
|
result.types.append(stCount + ldCount, builder.getNoneType());
|
|
|
|
// Memory ID (individual ID for each MemoryOp)
|
|
Type i32Type = builder.getIntegerType(32);
|
|
result.addAttribute("id", builder.getIntegerAttr(i32Type, id));
|
|
result.addAttribute("ldCount", builder.getIntegerAttr(i32Type, ldCount));
|
|
result.addAttribute("stCount", builder.getIntegerAttr(i32Type, stCount));
|
|
}
|
|
|
|
llvm::SmallVector<handshake::MemLoadInterface>
|
|
ExternalMemoryOp::getLoadPorts() {
|
|
return ::getLoadPorts(*this);
|
|
}
|
|
|
|
llvm::SmallVector<handshake::MemStoreInterface>
|
|
ExternalMemoryOp::getStorePorts() {
|
|
return ::getStorePorts(*this);
|
|
}
|
|
|
|
void MemoryOp::build(OpBuilder &builder, OperationState &result,
|
|
ValueRange operands, int outputs, int controlOutputs,
|
|
bool lsq, int id, Value memref) {
|
|
result.addOperands(operands);
|
|
|
|
auto memrefType = memref.getType().cast<MemRefType>();
|
|
|
|
// Data outputs (get their type from memref)
|
|
result.types.append(outputs, memrefType.getElementType());
|
|
|
|
// Control outputs
|
|
result.types.append(controlOutputs, builder.getNoneType());
|
|
result.addAttribute("lsq", builder.getBoolAttr(lsq));
|
|
result.addAttribute("memRefType", TypeAttr::get(memrefType));
|
|
|
|
// Memory ID (individual ID for each MemoryOp)
|
|
Type i32Type = builder.getIntegerType(32);
|
|
result.addAttribute("id", builder.getIntegerAttr(i32Type, id));
|
|
|
|
if (!lsq) {
|
|
result.addAttribute("ldCount", builder.getIntegerAttr(i32Type, outputs));
|
|
result.addAttribute(
|
|
"stCount", builder.getIntegerAttr(i32Type, controlOutputs - outputs));
|
|
}
|
|
}
|
|
|
|
llvm::SmallVector<handshake::MemLoadInterface> MemoryOp::getLoadPorts() {
|
|
return ::getLoadPorts(*this);
|
|
}
|
|
|
|
llvm::SmallVector<handshake::MemStoreInterface> MemoryOp::getStorePorts() {
|
|
return ::getStorePorts(*this);
|
|
}
|
|
|
|
bool handshake::MemoryOp::allocateMemory(
|
|
llvm::DenseMap<unsigned, unsigned> &memoryMap,
|
|
std::vector<std::vector<llvm::Any>> &store,
|
|
std::vector<double> &storeTimes) {
|
|
if (memoryMap.count(getId()))
|
|
return false;
|
|
|
|
auto type = getMemRefType();
|
|
std::vector<llvm::Any> in;
|
|
|
|
ArrayRef<int64_t> shape = type.getShape();
|
|
int allocationSize = 1;
|
|
unsigned count = 0;
|
|
for (int64_t dim : shape) {
|
|
if (dim > 0)
|
|
allocationSize *= dim;
|
|
else {
|
|
assert(count < in.size());
|
|
allocationSize *= llvm::any_cast<APInt>(in[count++]).getSExtValue();
|
|
}
|
|
}
|
|
unsigned ptr = store.size();
|
|
store.resize(ptr + 1);
|
|
storeTimes.resize(ptr + 1);
|
|
store[ptr].resize(allocationSize);
|
|
storeTimes[ptr] = 0.0;
|
|
mlir::Type elementType = type.getElementType();
|
|
int width = elementType.getIntOrFloatBitWidth();
|
|
for (int i = 0; i < allocationSize; i++) {
|
|
if (elementType.isa<mlir::IntegerType>()) {
|
|
store[ptr][i] = APInt(width, 0);
|
|
} else if (elementType.isa<mlir::FloatType>()) {
|
|
store[ptr][i] = APFloat(0.0);
|
|
} else {
|
|
llvm_unreachable("Unknown result type!\n");
|
|
}
|
|
}
|
|
|
|
memoryMap[getId()] = ptr;
|
|
return true;
|
|
}
|
|
|
|
std::string handshake::LoadOp::getOperandName(unsigned int idx) {
|
|
unsigned nAddresses = getAddresses().size();
|
|
std::string opName;
|
|
if (idx < nAddresses)
|
|
opName = "addrIn" + std::to_string(idx);
|
|
else if (idx == nAddresses)
|
|
opName = "dataFromMem";
|
|
else
|
|
opName = "ctrl";
|
|
return opName;
|
|
}
|
|
|
|
std::string handshake::LoadOp::getResultName(unsigned int idx) {
|
|
std::string resName;
|
|
if (idx == 0)
|
|
resName = "dataOut";
|
|
else
|
|
resName = "addrOut" + std::to_string(idx - 1);
|
|
return resName;
|
|
}
|
|
|
|
void handshake::LoadOp::build(OpBuilder &builder, OperationState &result,
|
|
Value memref, ValueRange indices) {
|
|
// Address indices
|
|
// result.addOperands(memref);
|
|
result.addOperands(indices);
|
|
|
|
// Data type
|
|
auto memrefType = memref.getType().cast<MemRefType>();
|
|
|
|
// Data output (from load to successor ops)
|
|
result.types.push_back(memrefType.getElementType());
|
|
|
|
// Address outputs (to lsq)
|
|
result.types.append(indices.size(), builder.getIndexType());
|
|
}
|
|
|
|
static ParseResult parseMemoryAccessOp(OpAsmParser &parser,
|
|
OperationState &result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> addressOperands,
|
|
remainingOperands, allOperands;
|
|
SmallVector<Type, 1> parsedTypes, allTypes;
|
|
llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseLSquare() || parser.parseOperandList(addressOperands) ||
|
|
parser.parseRSquare() || parser.parseOperandList(remainingOperands) ||
|
|
parser.parseColon() || parser.parseTypeList(parsedTypes))
|
|
return failure();
|
|
|
|
// The last type will be the data type of the operation; the prior will be the
|
|
// address types.
|
|
Type dataType = parsedTypes.back();
|
|
auto parsedTypesRef = ArrayRef(parsedTypes);
|
|
result.addTypes(dataType);
|
|
result.addTypes(parsedTypesRef.drop_back());
|
|
allOperands.append(addressOperands);
|
|
allOperands.append(remainingOperands);
|
|
allTypes.append(parsedTypes);
|
|
allTypes.push_back(NoneType::get(result.getContext()));
|
|
if (parser.resolveOperands(allOperands, allTypes, allOperandLoc,
|
|
result.operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
template <typename MemOp>
|
|
static void printMemoryAccessOp(OpAsmPrinter &p, MemOp op) {
|
|
p << " [";
|
|
p << op.getAddresses();
|
|
p << "] " << op.getData() << ", " << op.getCtrl() << " : ";
|
|
llvm::interleaveComma(op.getAddresses(), p,
|
|
[&](Value v) { p << v.getType(); });
|
|
p << ", " << op.getData().getType();
|
|
}
|
|
|
|
ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
return parseMemoryAccessOp(parser, result);
|
|
}
|
|
|
|
void LoadOp::print(OpAsmPrinter &p) { printMemoryAccessOp(p, *this); }
|
|
|
|
std::string handshake::StoreOp::getOperandName(unsigned int idx) {
|
|
unsigned nAddresses = getAddresses().size();
|
|
std::string opName;
|
|
if (idx < nAddresses)
|
|
opName = "addrIn" + std::to_string(idx);
|
|
else if (idx == nAddresses)
|
|
opName = "dataIn";
|
|
else
|
|
opName = "ctrl";
|
|
return opName;
|
|
}
|
|
|
|
template <typename TMemoryOp>
|
|
static LogicalResult verifyMemoryAccessOp(TMemoryOp op) {
|
|
if (op.getAddresses().size() == 0)
|
|
return op.emitOpError() << "No addresses were specified";
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult LoadOp::verify() { return verifyMemoryAccessOp(*this); }
|
|
|
|
std::string handshake::StoreOp::getResultName(unsigned int idx) {
|
|
std::string resName;
|
|
if (idx == 0)
|
|
resName = "dataToMem";
|
|
else
|
|
resName = "addrOut" + std::to_string(idx - 1);
|
|
return resName;
|
|
}
|
|
|
|
void handshake::StoreOp::build(OpBuilder &builder, OperationState &result,
|
|
Value valueToStore, ValueRange indices) {
|
|
|
|
// Address indices
|
|
result.addOperands(indices);
|
|
|
|
// Data
|
|
result.addOperands(valueToStore);
|
|
|
|
// Data output (from store to LSQ)
|
|
result.types.push_back(valueToStore.getType());
|
|
|
|
// Address outputs (from store to lsq)
|
|
result.types.append(indices.size(), builder.getIndexType());
|
|
}
|
|
|
|
LogicalResult StoreOp::verify() { return verifyMemoryAccessOp(*this); }
|
|
|
|
ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
return parseMemoryAccessOp(parser, result);
|
|
}
|
|
|
|
void StoreOp::print(OpAsmPrinter &p) { return printMemoryAccessOp(p, *this); }
|
|
|
|
bool JoinOp::isControl() { return true; }
|
|
|
|
ParseResult JoinOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
|
|
SmallVector<Type> types;
|
|
|
|
llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
if (parser.parseOperandList(operands) ||
|
|
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
|
parser.parseTypeList(types))
|
|
return failure();
|
|
|
|
if (parser.resolveOperands(operands, types, allOperandLoc, result.operands))
|
|
return failure();
|
|
|
|
result.addTypes(NoneType::get(result.getContext()));
|
|
return success();
|
|
}
|
|
|
|
void JoinOp::print(OpAsmPrinter &p) {
|
|
p << " " << getData();
|
|
p.printOptionalAttrDict((*this)->getAttrs(), {"control"});
|
|
p << " : " << getData().getTypes();
|
|
}
|
|
|
|
/// Based on mlir::func::CallOp::verifySymbolUses
|
|
LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
// Check that the module attribute was specified.
|
|
auto fnAttr = this->getModuleAttr();
|
|
assert(fnAttr && "requires a 'module' symbol reference attribute");
|
|
|
|
FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
|
|
if (!fn)
|
|
return emitOpError() << "'" << fnAttr.getValue()
|
|
<< "' does not reference a valid handshake function";
|
|
|
|
// Verify that the operand and result types match the callee.
|
|
auto fnType = fn.getFunctionType();
|
|
if (fnType.getNumInputs() != getNumOperands())
|
|
return emitOpError(
|
|
"incorrect number of operands for the referenced handshake function");
|
|
|
|
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
|
|
if (getOperand(i).getType() != fnType.getInput(i))
|
|
return emitOpError("operand type mismatch: expected operand type ")
|
|
<< fnType.getInput(i) << ", but provided "
|
|
<< getOperand(i).getType() << " for operand number " << i;
|
|
|
|
if (fnType.getNumResults() != getNumResults())
|
|
return emitOpError(
|
|
"incorrect number of results for the referenced handshake function");
|
|
|
|
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
|
|
if (getResult(i).getType() != fnType.getResult(i))
|
|
return emitOpError("result type mismatch: expected result type ")
|
|
<< fnType.getResult(i) << ", but provided "
|
|
<< getResult(i).getType() << " for result number " << i;
|
|
|
|
return success();
|
|
}
|
|
|
|
FunctionType InstanceOp::getModuleType() {
|
|
return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
|
|
}
|
|
|
|
ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
OpAsmParser::UnresolvedOperand tuple;
|
|
TupleType type;
|
|
|
|
if (parser.parseOperand(tuple) ||
|
|
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
|
parser.parseType(type))
|
|
return failure();
|
|
|
|
if (parser.resolveOperand(tuple, type, result.operands))
|
|
return failure();
|
|
|
|
result.addTypes(type.getTypes());
|
|
|
|
return success();
|
|
}
|
|
|
|
void UnpackOp::print(OpAsmPrinter &p) {
|
|
p << " " << getInput();
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
p << " : " << getInput().getType();
|
|
}
|
|
|
|
ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
|
|
llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
TupleType type;
|
|
|
|
if (parser.parseOperandList(operands) ||
|
|
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
|
parser.parseType(type))
|
|
return failure();
|
|
|
|
if (parser.resolveOperands(operands, type.getTypes(), allOperandLoc,
|
|
result.operands))
|
|
return failure();
|
|
|
|
result.addTypes(type);
|
|
|
|
return success();
|
|
}
|
|
|
|
void PackOp::print(OpAsmPrinter &p) {
|
|
p << " " << getInputs();
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
p << " : " << getResult().getType();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd op method definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult ReturnOp::verify() {
|
|
auto *parent = (*this)->getParentOp();
|
|
auto function = dyn_cast<handshake::FuncOp>(parent);
|
|
if (!function)
|
|
return emitOpError("must have a handshake.func parent");
|
|
|
|
// The operand number and types must match the function signature.
|
|
const auto &results = function.getResultTypes();
|
|
if (getNumOperands() != results.size())
|
|
return emitOpError("has ")
|
|
<< getNumOperands() << " operands, but enclosing function returns "
|
|
<< results.size();
|
|
|
|
for (unsigned i = 0, e = results.size(); i != e; ++i)
|
|
if (getOperand(i).getType() != results[i])
|
|
return emitError() << "type of return operand " << i << " ("
|
|
<< getOperand(i).getType()
|
|
<< ") doesn't match function result type ("
|
|
<< results[i] << ")";
|
|
|
|
return success();
|
|
}
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "circt/Dialect/Handshake/Handshake.cpp.inc"
|