mirror of https://github.com/llvm/circt.git
3322 lines
118 KiB
C++
3322 lines
118 KiB
C++
//===- HWOps.cpp - Implement the HW operations ----------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implement the HW ops.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "circt/Dialect/HW/HWOps.h"
|
|
#include "circt/Dialect/Comb/CombOps.h"
|
|
#include "circt/Dialect/HW/CustomDirectiveImpl.h"
|
|
#include "circt/Dialect/HW/HWAttributes.h"
|
|
#include "circt/Dialect/HW/HWInstanceImplementation.h"
|
|
#include "circt/Dialect/HW/HWSymCache.h"
|
|
#include "circt/Dialect/HW/HWVisitors.h"
|
|
#include "circt/Dialect/HW/ModuleImplementation.h"
|
|
#include "circt/Support/CustomDirectiveImpl.h"
|
|
#include "circt/Support/Namespace.h"
|
|
#include "circt/Support/Naming.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Interfaces/FunctionImplementation.h"
|
|
#include "llvm/ADT/BitVector.h"
|
|
#include "llvm/ADT/SmallPtrSet.h"
|
|
#include "llvm/ADT/StringSet.h"
|
|
|
|
using namespace circt;
|
|
using namespace hw;
|
|
using mlir::TypedAttr;
|
|
|
|
/// Flip a port direction.
|
|
ModulePort::Direction hw::flip(ModulePort::Direction direction) {
|
|
switch (direction) {
|
|
case ModulePort::Direction::Input:
|
|
return ModulePort::Direction::Output;
|
|
case ModulePort::Direction::Output:
|
|
return ModulePort::Direction::Input;
|
|
case ModulePort::Direction::InOut:
|
|
return ModulePort::Direction::InOut;
|
|
}
|
|
llvm_unreachable("unknown PortDirection");
|
|
}
|
|
|
|
bool hw::isValidIndexBitWidth(Value index, Value array) {
|
|
hw::ArrayType arrayType =
|
|
hw::getCanonicalType(array.getType()).dyn_cast<hw::ArrayType>();
|
|
assert(arrayType && "expected array type");
|
|
unsigned indexWidth = index.getType().getIntOrFloatBitWidth();
|
|
auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements());
|
|
return requiredWidth == 0 ? (indexWidth == 0 || indexWidth == 1)
|
|
: indexWidth == requiredWidth;
|
|
}
|
|
|
|
/// Return true if the specified operation is a combinational logic op.
|
|
bool hw::isCombinational(Operation *op) {
|
|
struct IsCombClassifier : public TypeOpVisitor<IsCombClassifier, bool> {
|
|
bool visitInvalidTypeOp(Operation *op) { return false; }
|
|
bool visitUnhandledTypeOp(Operation *op) { return true; }
|
|
};
|
|
|
|
return (op->getDialect() && op->getDialect()->getNamespace() == "comb") ||
|
|
IsCombClassifier().dispatchTypeOpVisitor(op);
|
|
}
|
|
|
|
static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex) {
|
|
// A struct extract of a struct create -> corresponding struct create operand.
|
|
if (auto structCreate = dyn_cast_or_null<StructCreateOp>(inputOp)) {
|
|
return structCreate.getOperand(fieldIndex);
|
|
}
|
|
|
|
// Extracting injected field -> corresponding field
|
|
if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
|
|
if (structInject.getFieldIndex() != fieldIndex)
|
|
return {};
|
|
return structInject.getNewValue();
|
|
}
|
|
return {};
|
|
}
|
|
|
|
static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context,
|
|
ArrayRef<Attribute> attrs) {
|
|
if (attrs.empty())
|
|
return ArrayAttr::get(context, {});
|
|
bool empty = true;
|
|
for (auto a : attrs)
|
|
if (a && !cast<DictionaryAttr>(a).empty()) {
|
|
empty = false;
|
|
break;
|
|
}
|
|
if (empty)
|
|
return ArrayAttr::get(context, {});
|
|
return ArrayAttr::get(context, attrs);
|
|
}
|
|
|
|
/// Get a special name to use when printing the entry block arguments of the
|
|
/// region contained by an operation in this dialect.
|
|
static void getAsmBlockArgumentNamesImpl(mlir::Region ®ion,
|
|
OpAsmSetValueNameFn setNameFn) {
|
|
if (region.empty())
|
|
return;
|
|
// Assign port names to the bbargs.
|
|
auto module = cast<HWModuleOp>(region.getParentOp());
|
|
|
|
auto *block = ®ion.front();
|
|
for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
|
|
auto name = module.getInputName(i);
|
|
// Let mlir deterministically convert names to valid identifiers
|
|
setNameFn(block->getArgument(i), name);
|
|
}
|
|
}
|
|
|
|
enum class Delimiter {
|
|
None,
|
|
Paren, // () enclosed list
|
|
OptionalLessGreater, // <> enclosed list or absent
|
|
};
|
|
|
|
/// Check parameter specified by `value` to see if it is valid according to the
|
|
/// module's parameters. If not, emit an error to the diagnostic provided as an
|
|
/// argument to the lambda 'instanceError' and return failure, otherwise return
|
|
/// success.
|
|
///
|
|
/// If `disallowParamRefs` is true, then parameter references are not allowed.
|
|
LogicalResult hw::checkParameterInContext(
|
|
Attribute value, ArrayAttr moduleParameters,
|
|
const instance_like_impl::EmitErrorFn &instanceError,
|
|
bool disallowParamRefs) {
|
|
// Literals are always ok. Their types are already known to match
|
|
// expectations.
|
|
if (value.isa<IntegerAttr>() || value.isa<FloatAttr>() ||
|
|
value.isa<StringAttr>() || value.isa<ParamVerbatimAttr>())
|
|
return success();
|
|
|
|
// Check both subexpressions of an expression.
|
|
if (auto expr = value.dyn_cast<ParamExprAttr>()) {
|
|
for (auto op : expr.getOperands())
|
|
if (failed(checkParameterInContext(op, moduleParameters, instanceError,
|
|
disallowParamRefs)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
// Parameter references need more analysis to make sure they are valid within
|
|
// this module.
|
|
if (auto parameterRef = value.dyn_cast<ParamDeclRefAttr>()) {
|
|
auto nameAttr = parameterRef.getName();
|
|
|
|
// Don't allow references to parameters from the default values of a
|
|
// parameter list.
|
|
if (disallowParamRefs) {
|
|
instanceError([&](auto &diag) {
|
|
diag << "parameter " << nameAttr
|
|
<< " cannot be used as a default value for a parameter";
|
|
return false;
|
|
});
|
|
return failure();
|
|
}
|
|
|
|
// Find the corresponding attribute in the module.
|
|
for (auto param : moduleParameters) {
|
|
auto paramAttr = param.cast<ParamDeclAttr>();
|
|
if (paramAttr.getName() != nameAttr)
|
|
continue;
|
|
|
|
// If the types match then the reference is ok.
|
|
if (paramAttr.getType() == parameterRef.getType())
|
|
return success();
|
|
|
|
instanceError([&](auto &diag) {
|
|
diag << "parameter " << nameAttr << " used with type "
|
|
<< parameterRef.getType() << "; should have type "
|
|
<< paramAttr.getType();
|
|
return true;
|
|
});
|
|
return failure();
|
|
}
|
|
|
|
instanceError([&](auto &diag) {
|
|
diag << "use of unknown parameter " << nameAttr;
|
|
return true;
|
|
});
|
|
return failure();
|
|
}
|
|
|
|
instanceError([&](auto &diag) {
|
|
diag << "invalid parameter value " << value;
|
|
return false;
|
|
});
|
|
return failure();
|
|
}
|
|
|
|
/// Check parameter specified by `value` to see if it is valid within the scope
|
|
/// of the specified module `module`. If not, emit an error at the location of
|
|
/// `usingOp` and return failure, otherwise return success. If `usingOp` is
|
|
/// null, then no diagnostic is generated.
|
|
///
|
|
/// If `disallowParamRefs` is true, then parameter references are not allowed.
|
|
LogicalResult hw::checkParameterInContext(Attribute value, Operation *module,
|
|
Operation *usingOp,
|
|
bool disallowParamRefs) {
|
|
instance_like_impl::EmitErrorFn emitError =
|
|
[&](const std::function<bool(InFlightDiagnostic &)> &fn) {
|
|
if (usingOp) {
|
|
auto diag = usingOp->emitOpError();
|
|
if (fn(diag))
|
|
diag.attachNote(module->getLoc()) << "module declared here";
|
|
}
|
|
};
|
|
|
|
return checkParameterInContext(value,
|
|
module->getAttrOfType<ArrayAttr>("parameters"),
|
|
emitError, disallowParamRefs);
|
|
}
|
|
|
|
/// Return true if the specified attribute tree is made up of nodes that are
|
|
/// valid in a parameter expression.
|
|
bool hw::isValidParameterExpression(Attribute attr, Operation *module) {
|
|
return succeeded(checkParameterInContext(attr, module, nullptr, false));
|
|
}
|
|
|
|
HWModulePortAccessor::HWModulePortAccessor(Location loc,
|
|
const ModulePortInfo &info,
|
|
Region &bodyRegion)
|
|
: info(info) {
|
|
inputArgs.resize(info.sizeInputs());
|
|
for (auto [i, barg] : llvm::enumerate(bodyRegion.getArguments())) {
|
|
inputIdx[info.at(i).name.str()] = i;
|
|
inputArgs[i] = barg;
|
|
}
|
|
|
|
outputOperands.resize(info.sizeOutputs());
|
|
for (auto [i, outputInfo] : llvm::enumerate(info.getOutputs())) {
|
|
outputIdx[outputInfo.name.str()] = i;
|
|
}
|
|
}
|
|
|
|
void HWModulePortAccessor::setOutput(unsigned i, Value v) {
|
|
assert(outputOperands.size() > i && "invalid output index");
|
|
assert(outputOperands[i] == Value() && "output already set");
|
|
outputOperands[i] = v;
|
|
}
|
|
|
|
Value HWModulePortAccessor::getInput(unsigned i) {
|
|
assert(inputArgs.size() > i && "invalid input index");
|
|
return inputArgs[i];
|
|
}
|
|
Value HWModulePortAccessor::getInput(StringRef name) {
|
|
return getInput(inputIdx.find(name.str())->second);
|
|
}
|
|
void HWModulePortAccessor::setOutput(StringRef name, Value v) {
|
|
setOutput(outputIdx.find(name.str())->second, v);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ConstantOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ConstantOp::print(OpAsmPrinter &p) {
|
|
p << " ";
|
|
p.printAttribute(getValueAttr());
|
|
p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
|
|
}
|
|
|
|
ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
IntegerAttr valueAttr;
|
|
|
|
if (parser.parseAttribute(valueAttr, "value", result.attributes) ||
|
|
parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
|
|
result.addTypes(valueAttr.getType());
|
|
return success();
|
|
}
|
|
|
|
LogicalResult ConstantOp::verify() {
|
|
// If the result type has a bitwidth, then the attribute must match its width.
|
|
if (getValue().getBitWidth() != getType().cast<IntegerType>().getWidth())
|
|
return emitError(
|
|
"hw.constant attribute bitwidth doesn't match return type");
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Build a ConstantOp from an APInt, infering the result type from the
|
|
/// width of the APInt.
|
|
void ConstantOp::build(OpBuilder &builder, OperationState &result,
|
|
const APInt &value) {
|
|
|
|
auto type = IntegerType::get(builder.getContext(), value.getBitWidth());
|
|
auto attr = builder.getIntegerAttr(type, value);
|
|
return build(builder, result, type, attr);
|
|
}
|
|
|
|
/// Build a ConstantOp from an APInt, infering the result type from the
|
|
/// width of the APInt.
|
|
void ConstantOp::build(OpBuilder &builder, OperationState &result,
|
|
IntegerAttr value) {
|
|
return build(builder, result, value.getType(), value);
|
|
}
|
|
|
|
/// This builder allows construction of small signed integers like 0, 1, -1
|
|
/// matching a specified MLIR IntegerType. This shouldn't be used for general
|
|
/// constant folding because it only works with values that can be expressed in
|
|
/// an int64_t. Use APInt's instead.
|
|
void ConstantOp::build(OpBuilder &builder, OperationState &result, Type type,
|
|
int64_t value) {
|
|
auto numBits = type.cast<IntegerType>().getWidth();
|
|
build(builder, result, APInt(numBits, (uint64_t)value, /*isSigned=*/true));
|
|
}
|
|
|
|
void ConstantOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
auto intTy = getType();
|
|
auto intCst = getValue();
|
|
|
|
// Sugar i1 constants with 'true' and 'false'.
|
|
if (intTy.cast<IntegerType>().getWidth() == 1)
|
|
return setNameFn(getResult(), intCst.isZero() ? "false" : "true");
|
|
|
|
// Otherwise, build a complex name with the value and type.
|
|
SmallVector<char, 32> specialNameBuffer;
|
|
llvm::raw_svector_ostream specialName(specialNameBuffer);
|
|
specialName << 'c' << intCst << '_' << intTy;
|
|
setNameFn(getResult(), specialName.str());
|
|
}
|
|
|
|
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
|
|
assert(adaptor.getOperands().empty() && "constant has no operands");
|
|
return getValueAttr();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// WireOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Check whether an operation has any additional attributes set beyond its
|
|
/// standard list of attributes returned by `getAttributeNames`.
|
|
template <class Op>
|
|
static bool hasAdditionalAttributes(Op op,
|
|
ArrayRef<StringRef> ignoredAttrs = {}) {
|
|
auto names = op.getAttributeNames();
|
|
llvm::SmallDenseSet<StringRef> nameSet;
|
|
nameSet.reserve(names.size() + ignoredAttrs.size());
|
|
nameSet.insert(names.begin(), names.end());
|
|
nameSet.insert(ignoredAttrs.begin(), ignoredAttrs.end());
|
|
return llvm::any_of(op->getAttrs(), [&](auto namedAttr) {
|
|
return !nameSet.contains(namedAttr.getName());
|
|
});
|
|
}
|
|
|
|
void WireOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
|
|
// If the wire has an optional 'name' attribute, use it.
|
|
auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
|
|
if (nameAttr && !nameAttr.getValue().empty())
|
|
setNameFn(getResult(), nameAttr.getValue());
|
|
}
|
|
|
|
std::optional<size_t> WireOp::getTargetResultIndex() { return 0; }
|
|
|
|
OpFoldResult WireOp::fold(FoldAdaptor adaptor) {
|
|
// If the wire has no additional attributes, no name, and no symbol, just
|
|
// forward its input.
|
|
if (!hasAdditionalAttributes(*this, {"sv.namehint"}) && !getNameAttr() &&
|
|
!getInnerSymAttr())
|
|
return getInput();
|
|
return {};
|
|
}
|
|
|
|
LogicalResult WireOp::canonicalize(WireOp wire, PatternRewriter &rewriter) {
|
|
// Block if the wire has any attributes.
|
|
if (hasAdditionalAttributes(wire, {"sv.namehint"}))
|
|
return failure();
|
|
|
|
// If the wire has a symbol, then we can't delete it.
|
|
if (wire.getInnerSymAttr())
|
|
return failure();
|
|
|
|
// If the wire has a name or an `sv.namehint` attribute, propagate it as an
|
|
// `sv.namehint` to the expression.
|
|
if (auto *inputOp = wire.getInput().getDefiningOp())
|
|
if (auto name = chooseName(wire, inputOp))
|
|
rewriter.modifyOpInPlace(inputOp,
|
|
[&] { inputOp->setAttr("sv.namehint", name); });
|
|
|
|
rewriter.replaceOp(wire, wire.getInput());
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AggregateConstantOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type) {
|
|
// If this is a type alias, get the underlying type.
|
|
if (auto typeAlias = type.dyn_cast<TypeAliasType>())
|
|
type = typeAlias.getCanonicalType();
|
|
|
|
if (auto structType = type.dyn_cast<StructType>()) {
|
|
auto arrayAttr = attr.dyn_cast<ArrayAttr>();
|
|
if (!arrayAttr)
|
|
return op->emitOpError("expected array attribute for constant of type ")
|
|
<< type;
|
|
if (structType.getElements().size() != arrayAttr.size())
|
|
return op->emitOpError("array attribute (")
|
|
<< arrayAttr.size() << ") has wrong size for struct constant ("
|
|
<< structType.getElements().size() << ")";
|
|
|
|
for (auto [attr, fieldInfo] :
|
|
llvm::zip(arrayAttr.getValue(), structType.getElements())) {
|
|
if (failed(checkAttributes(op, attr, fieldInfo.type)))
|
|
return failure();
|
|
}
|
|
} else if (auto arrayType = type.dyn_cast<ArrayType>()) {
|
|
auto arrayAttr = attr.dyn_cast<ArrayAttr>();
|
|
if (!arrayAttr)
|
|
return op->emitOpError("expected array attribute for constant of type ")
|
|
<< type;
|
|
if (arrayType.getNumElements() != arrayAttr.size())
|
|
return op->emitOpError("array attribute (")
|
|
<< arrayAttr.size() << ") has wrong size for array constant ("
|
|
<< arrayType.getNumElements() << ")";
|
|
|
|
auto elementType = arrayType.getElementType();
|
|
for (auto attr : arrayAttr.getValue()) {
|
|
if (failed(checkAttributes(op, attr, elementType)))
|
|
return failure();
|
|
}
|
|
} else if (auto arrayType = type.dyn_cast<UnpackedArrayType>()) {
|
|
auto arrayAttr = attr.dyn_cast<ArrayAttr>();
|
|
if (!arrayAttr)
|
|
return op->emitOpError("expected array attribute for constant of type ")
|
|
<< type;
|
|
auto elementType = arrayType.getElementType();
|
|
if (arrayType.getNumElements() != arrayAttr.size())
|
|
return op->emitOpError("array attribute (")
|
|
<< arrayAttr.size()
|
|
<< ") has wrong size for unpacked array constant ("
|
|
<< arrayType.getNumElements() << ")";
|
|
|
|
for (auto attr : arrayAttr.getValue()) {
|
|
if (failed(checkAttributes(op, attr, elementType)))
|
|
return failure();
|
|
}
|
|
} else if (auto enumType = type.dyn_cast<EnumType>()) {
|
|
auto stringAttr = attr.dyn_cast<StringAttr>();
|
|
if (!stringAttr)
|
|
return op->emitOpError("expected string attribute for constant of type ")
|
|
<< type;
|
|
} else if (auto intType = type.dyn_cast<IntegerType>()) {
|
|
// Check the attribute kind is correct.
|
|
auto intAttr = attr.dyn_cast<IntegerAttr>();
|
|
if (!intAttr)
|
|
return op->emitOpError("expected integer attribute for constant of type ")
|
|
<< type;
|
|
// Check the bitwidth is correct.
|
|
if (intAttr.getValue().getBitWidth() != intType.getWidth())
|
|
return op->emitOpError("hw.constant attribute bitwidth "
|
|
"doesn't match return type");
|
|
} else {
|
|
return op->emitOpError("unknown element type") << type;
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult AggregateConstantOp::verify() {
|
|
return checkAttributes(*this, getFieldsAttr(), getType());
|
|
}
|
|
|
|
OpFoldResult AggregateConstantOp::fold(FoldAdaptor) { return getFieldsAttr(); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ParamValueOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult parseParamValue(OpAsmParser &p, Attribute &value,
|
|
Type &resultType) {
|
|
if (p.parseType(resultType) || p.parseEqual() ||
|
|
p.parseAttribute(value, resultType))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value,
|
|
Type resultType) {
|
|
p << resultType << " = ";
|
|
p.printAttributeWithoutType(value);
|
|
}
|
|
|
|
LogicalResult ParamValueOp::verify() {
|
|
// Check that the attribute expression is valid in this module.
|
|
return checkParameterInContext(
|
|
getValue(), (*this)->getParentOfType<hw::HWModuleOp>(), *this);
|
|
}
|
|
|
|
OpFoldResult ParamValueOp::fold(FoldAdaptor adaptor) {
|
|
assert(adaptor.getOperands().empty() && "hw.param.value has no operands");
|
|
return getValueAttr();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// HWModuleOp
|
|
//===----------------------------------------------------------------------===/
|
|
|
|
/// Return true if isAnyModule or instance.
|
|
bool hw::isAnyModuleOrInstance(Operation *moduleOrInstance) {
|
|
return isa<HWModuleLike, InstanceOp>(moduleOrInstance);
|
|
}
|
|
|
|
/// Return the signature for a module as a function type from the module itself
|
|
/// or from an hw::InstanceOp.
|
|
FunctionType hw::getModuleType(Operation *moduleOrInstance) {
|
|
return TypeSwitch<Operation *, FunctionType>(moduleOrInstance)
|
|
.Case<InstanceOp, InstanceChoiceOp>([](auto instance) {
|
|
SmallVector<Type> inputs(instance->getOperandTypes());
|
|
SmallVector<Type> results(instance->getResultTypes());
|
|
return FunctionType::get(instance->getContext(), inputs, results);
|
|
})
|
|
.Case<HWModuleLike>(
|
|
[](auto mod) { return mod.getHWModuleType().getFuncType(); })
|
|
.Default([](Operation *op) {
|
|
return cast<mlir::FunctionOpInterface>(op)
|
|
.getFunctionType()
|
|
.cast<FunctionType>();
|
|
});
|
|
}
|
|
|
|
/// Return the name to use for the Verilog module that we're referencing
|
|
/// here. This is typically the symbol, but can be overridden with the
|
|
/// verilogName attribute.
|
|
StringAttr hw::getVerilogModuleNameAttr(Operation *module) {
|
|
auto nameAttr = module->getAttrOfType<StringAttr>("verilogName");
|
|
if (nameAttr)
|
|
return nameAttr;
|
|
|
|
return module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
|
|
}
|
|
|
|
// Flag for parsing different module types
|
|
enum ExternModKind { PlainMod, ExternMod, GenMod };
|
|
|
|
template <typename ModuleTy>
|
|
static void
|
|
buildModule(OpBuilder &builder, OperationState &result, StringAttr name,
|
|
const ModulePortInfo &ports, ArrayAttr parameters,
|
|
ArrayRef<NamedAttribute> attributes, StringAttr comment) {
|
|
using namespace mlir::function_interface_impl;
|
|
LocationAttr unknownLoc = builder.getUnknownLoc();
|
|
|
|
// Add an attribute for the name.
|
|
result.addAttribute(SymbolTable::getSymbolAttrName(), name);
|
|
|
|
SmallVector<Attribute> perPortAttrs;
|
|
SmallVector<Attribute> portLocs;
|
|
SmallVector<ModulePort> portTypes;
|
|
|
|
for (auto elt : ports) {
|
|
portTypes.push_back(elt);
|
|
portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
|
|
llvm::SmallVector<NamedAttribute> portAttrs;
|
|
if (elt.attrs)
|
|
llvm::copy(elt.attrs, std::back_inserter(portAttrs));
|
|
perPortAttrs.push_back(builder.getDictionaryAttr(portAttrs));
|
|
}
|
|
|
|
// Allow clients to pass in null for the parameters list.
|
|
if (!parameters)
|
|
parameters = builder.getArrayAttr({});
|
|
|
|
// Record the argument and result types as an attribute.
|
|
auto type = ModuleType::get(builder.getContext(), portTypes);
|
|
result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name),
|
|
TypeAttr::get(type));
|
|
result.addAttribute("port_locs", builder.getArrayAttr(portLocs));
|
|
result.addAttribute("per_port_attrs",
|
|
arrayOrEmpty(builder.getContext(), perPortAttrs));
|
|
result.addAttribute("parameters", parameters);
|
|
if (!comment)
|
|
comment = builder.getStringAttr("");
|
|
result.addAttribute("comment", comment);
|
|
result.addAttributes(attributes);
|
|
result.addRegion();
|
|
}
|
|
|
|
/// Internal implementation of argument/result insertion and removal on modules.
|
|
static void modifyModuleArgs(
|
|
MLIRContext *context, ArrayRef<std::pair<unsigned, PortInfo>> insertArgs,
|
|
ArrayRef<unsigned> removeArgs, ArrayRef<Attribute> oldArgNames,
|
|
ArrayRef<Type> oldArgTypes, ArrayRef<Attribute> oldArgAttrs,
|
|
ArrayRef<Location> oldArgLocs, SmallVector<Attribute> &newArgNames,
|
|
SmallVector<Type> &newArgTypes, SmallVector<Attribute> &newArgAttrs,
|
|
SmallVector<Location> &newArgLocs, Block *body = nullptr) {
|
|
|
|
#ifndef NDEBUG
|
|
// Check that the `insertArgs` and `removeArgs` indices are in ascending
|
|
// order.
|
|
assert(llvm::is_sorted(insertArgs,
|
|
[](auto &a, auto &b) { return a.first < b.first; }) &&
|
|
"insertArgs must be in ascending order");
|
|
assert(llvm::is_sorted(removeArgs, [](auto &a, auto &b) { return a < b; }) &&
|
|
"removeArgs must be in ascending order");
|
|
#endif
|
|
|
|
auto oldArgCount = oldArgTypes.size();
|
|
auto newArgCount = oldArgCount + insertArgs.size() - removeArgs.size();
|
|
assert((int)newArgCount >= 0);
|
|
|
|
newArgNames.reserve(newArgCount);
|
|
newArgTypes.reserve(newArgCount);
|
|
newArgAttrs.reserve(newArgCount);
|
|
newArgLocs.reserve(newArgCount);
|
|
|
|
auto exportPortAttrName = StringAttr::get(context, "hw.exportPort");
|
|
auto emptyDictAttr = DictionaryAttr::get(context, {});
|
|
auto unknownLoc = UnknownLoc::get(context);
|
|
|
|
BitVector erasedIndices;
|
|
if (body)
|
|
erasedIndices.resize(oldArgCount + insertArgs.size());
|
|
|
|
for (unsigned argIdx = 0, idx = 0; argIdx <= oldArgCount; ++argIdx, ++idx) {
|
|
// Insert new ports at this position.
|
|
while (!insertArgs.empty() && insertArgs[0].first == argIdx) {
|
|
auto port = insertArgs[0].second;
|
|
if (port.dir == ModulePort::Direction::InOut &&
|
|
!port.type.isa<InOutType>())
|
|
port.type = InOutType::get(port.type);
|
|
auto sym = port.getSym();
|
|
Attribute attr =
|
|
(sym && !sym.empty())
|
|
? DictionaryAttr::get(context, {{exportPortAttrName, sym}})
|
|
: emptyDictAttr;
|
|
newArgNames.push_back(port.name);
|
|
newArgTypes.push_back(port.type);
|
|
newArgAttrs.push_back(attr);
|
|
insertArgs = insertArgs.drop_front();
|
|
LocationAttr loc = port.loc ? port.loc : unknownLoc;
|
|
newArgLocs.push_back(loc);
|
|
if (body)
|
|
body->insertArgument(idx++, port.type, loc);
|
|
}
|
|
if (argIdx == oldArgCount)
|
|
break;
|
|
|
|
// Migrate the old port at this position.
|
|
bool removed = false;
|
|
while (!removeArgs.empty() && removeArgs[0] == argIdx) {
|
|
removeArgs = removeArgs.drop_front();
|
|
removed = true;
|
|
}
|
|
|
|
if (removed) {
|
|
if (body)
|
|
erasedIndices.set(idx);
|
|
} else {
|
|
newArgNames.push_back(oldArgNames[argIdx]);
|
|
newArgTypes.push_back(oldArgTypes[argIdx]);
|
|
newArgAttrs.push_back(oldArgAttrs.empty() ? emptyDictAttr
|
|
: oldArgAttrs[argIdx]);
|
|
newArgLocs.push_back(oldArgLocs[argIdx]);
|
|
}
|
|
}
|
|
|
|
if (body)
|
|
body->eraseArguments(erasedIndices);
|
|
|
|
assert(newArgNames.size() == newArgCount);
|
|
assert(newArgTypes.size() == newArgCount);
|
|
assert(newArgAttrs.size() == newArgCount);
|
|
assert(newArgLocs.size() == newArgCount);
|
|
}
|
|
|
|
/// Insert and remove ports of a module. The insertion and removal indices must
|
|
/// be in ascending order. The indices refer to the port positions before any
|
|
/// insertion or removal occurs. Ports inserted at the same index will appear in
|
|
/// the module in the same order as they were listed in the `insert*` array.
|
|
///
|
|
/// The operation must be any of the module-like operations.
|
|
void hw::modifyModulePorts(
|
|
Operation *op, ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
|
|
ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
|
|
ArrayRef<unsigned> removeInputs, ArrayRef<unsigned> removeOutputs,
|
|
Block *body) {
|
|
auto moduleOp = cast<HWModuleLike>(op);
|
|
auto *context = moduleOp.getContext();
|
|
|
|
// Dig up the old argument and result data.
|
|
auto oldArgNames = moduleOp.getInputNames();
|
|
auto oldArgTypes = moduleOp.getInputTypes();
|
|
auto oldArgAttrs = moduleOp.getAllInputAttrs();
|
|
auto oldArgLocs = moduleOp.getInputLocs();
|
|
|
|
auto oldResultNames = moduleOp.getOutputNames();
|
|
auto oldResultTypes = moduleOp.getOutputTypes();
|
|
auto oldResultAttrs = moduleOp.getAllOutputAttrs();
|
|
auto oldResultLocs = moduleOp.getOutputLocs();
|
|
|
|
// Modify the ports.
|
|
SmallVector<Attribute> newArgNames, newResultNames;
|
|
SmallVector<Type> newArgTypes, newResultTypes;
|
|
SmallVector<Attribute> newArgAttrs, newResultAttrs;
|
|
SmallVector<Location> newArgLocs, newResultLocs;
|
|
|
|
modifyModuleArgs(context, insertInputs, removeInputs, oldArgNames,
|
|
oldArgTypes, oldArgAttrs, oldArgLocs, newArgNames,
|
|
newArgTypes, newArgAttrs, newArgLocs, body);
|
|
|
|
modifyModuleArgs(context, insertOutputs, removeOutputs, oldResultNames,
|
|
oldResultTypes, oldResultAttrs, oldResultLocs,
|
|
newResultNames, newResultTypes, newResultAttrs,
|
|
newResultLocs);
|
|
|
|
// Update the module operation types and attributes.
|
|
auto fnty = FunctionType::get(context, newArgTypes, newResultTypes);
|
|
auto modty = detail::fnToMod(fnty, newArgNames, newResultNames);
|
|
moduleOp.setHWModuleType(modty);
|
|
moduleOp.setAllInputAttrs(newArgAttrs);
|
|
moduleOp.setAllOutputAttrs(newResultAttrs);
|
|
|
|
newArgLocs.append(newResultLocs.begin(), newResultLocs.end());
|
|
moduleOp.setAllPortLocs(newArgLocs);
|
|
}
|
|
|
|
void HWModuleOp::build(OpBuilder &builder, OperationState &result,
|
|
StringAttr name, const ModulePortInfo &ports,
|
|
ArrayAttr parameters,
|
|
ArrayRef<NamedAttribute> attributes, StringAttr comment,
|
|
bool shouldEnsureTerminator) {
|
|
buildModule<HWModuleOp>(builder, result, name, ports, parameters, attributes,
|
|
comment);
|
|
|
|
// Create a region and a block for the body.
|
|
auto *bodyRegion = result.regions[0].get();
|
|
Block *body = new Block();
|
|
bodyRegion->push_back(body);
|
|
|
|
// Add arguments to the body block.
|
|
auto unknownLoc = builder.getUnknownLoc();
|
|
for (auto port : ports.getInputs()) {
|
|
auto loc = port.loc ? Location(port.loc) : unknownLoc;
|
|
auto type = port.type;
|
|
if (port.isInOut() && !type.isa<InOutType>())
|
|
type = InOutType::get(type);
|
|
body->addArgument(type, loc);
|
|
}
|
|
|
|
if (shouldEnsureTerminator)
|
|
HWModuleOp::ensureTerminator(*bodyRegion, builder, result.location);
|
|
}
|
|
|
|
void HWModuleOp::build(OpBuilder &builder, OperationState &result,
|
|
StringAttr name, ArrayRef<PortInfo> ports,
|
|
ArrayAttr parameters,
|
|
ArrayRef<NamedAttribute> attributes,
|
|
StringAttr comment) {
|
|
build(builder, result, name, ModulePortInfo(ports), parameters, attributes,
|
|
comment);
|
|
}
|
|
|
|
void HWModuleOp::build(OpBuilder &builder, OperationState &odsState,
|
|
StringAttr name, const ModulePortInfo &ports,
|
|
HWModuleBuilder modBuilder, ArrayAttr parameters,
|
|
ArrayRef<NamedAttribute> attributes,
|
|
StringAttr comment) {
|
|
build(builder, odsState, name, ports, parameters, attributes, comment,
|
|
/*shouldEnsureTerminator=*/false);
|
|
auto *bodyRegion = odsState.regions[0].get();
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
auto accessor = HWModulePortAccessor(odsState.location, ports, *bodyRegion);
|
|
builder.setInsertionPointToEnd(&bodyRegion->front());
|
|
modBuilder(builder, accessor);
|
|
// Create output operands.
|
|
llvm::SmallVector<Value> outputOperands = accessor.getOutputOperands();
|
|
builder.create<hw::OutputOp>(odsState.location, outputOperands);
|
|
}
|
|
|
|
void HWModuleOp::modifyPorts(
|
|
ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
|
|
ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
|
|
ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
|
|
hw::modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
|
|
eraseOutputs);
|
|
}
|
|
|
|
/// Return the name to use for the Verilog module that we're referencing
|
|
/// here. This is typically the symbol, but can be overridden with the
|
|
/// verilogName attribute.
|
|
StringAttr HWModuleExternOp::getVerilogModuleNameAttr() {
|
|
if (auto vName = getVerilogNameAttr())
|
|
return vName;
|
|
|
|
return (*this)->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
|
|
}
|
|
|
|
StringAttr HWModuleGeneratedOp::getVerilogModuleNameAttr() {
|
|
if (auto vName = getVerilogNameAttr()) {
|
|
return vName;
|
|
}
|
|
return (*this)->getAttrOfType<StringAttr>(
|
|
::mlir::SymbolTable::getSymbolAttrName());
|
|
}
|
|
|
|
void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
|
|
StringAttr name, const ModulePortInfo &ports,
|
|
StringRef verilogName, ArrayAttr parameters,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
buildModule<HWModuleExternOp>(builder, result, name, ports, parameters,
|
|
attributes, {});
|
|
|
|
if (!verilogName.empty())
|
|
result.addAttribute("verilogName", builder.getStringAttr(verilogName));
|
|
}
|
|
|
|
void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
|
|
StringAttr name, ArrayRef<PortInfo> ports,
|
|
StringRef verilogName, ArrayAttr parameters,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, name, ModulePortInfo(ports), verilogName, parameters,
|
|
attributes);
|
|
}
|
|
|
|
void HWModuleExternOp::modifyPorts(
|
|
ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
|
|
ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
|
|
ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
|
|
hw::modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
|
|
eraseOutputs);
|
|
}
|
|
|
|
void HWModuleExternOp::appendOutputs(
|
|
ArrayRef<std::pair<StringAttr, Value>> outputs) {}
|
|
|
|
void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
|
|
FlatSymbolRefAttr genKind, StringAttr name,
|
|
const ModulePortInfo &ports,
|
|
StringRef verilogName, ArrayAttr parameters,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
buildModule<HWModuleGeneratedOp>(builder, result, name, ports, parameters,
|
|
attributes, {});
|
|
result.addAttribute("generatorKind", genKind);
|
|
if (!verilogName.empty())
|
|
result.addAttribute("verilogName", builder.getStringAttr(verilogName));
|
|
}
|
|
|
|
void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
|
|
FlatSymbolRefAttr genKind, StringAttr name,
|
|
ArrayRef<PortInfo> ports, StringRef verilogName,
|
|
ArrayAttr parameters,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, genKind, name, ModulePortInfo(ports), verilogName,
|
|
parameters, attributes);
|
|
}
|
|
|
|
void HWModuleGeneratedOp::modifyPorts(
|
|
ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
|
|
ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
|
|
ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
|
|
hw::modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
|
|
eraseOutputs);
|
|
}
|
|
|
|
void HWModuleGeneratedOp::appendOutputs(
|
|
ArrayRef<std::pair<StringAttr, Value>> outputs) {}
|
|
|
|
static bool hasAttribute(StringRef name, ArrayRef<NamedAttribute> attrs) {
|
|
for (auto &argAttr : attrs)
|
|
if (argAttr.getName() == name)
|
|
return true;
|
|
return false;
|
|
}
|
|
|
|
static void
|
|
addPortAttrsAndLocs(Builder &builder, OperationState &result,
|
|
SmallVectorImpl<module_like_impl::PortParse> &ports,
|
|
StringAttr portAttrsName, StringAttr portLocsName) {
|
|
auto unknownLoc = builder.getUnknownLoc();
|
|
auto nonEmptyAttrsFn = [](Attribute attr) {
|
|
return attr && !cast<DictionaryAttr>(attr).empty();
|
|
};
|
|
auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
|
|
return attr && cast<Location>(attr) != unknownLoc;
|
|
};
|
|
|
|
// Convert the specified array of dictionary attrs (which may have null
|
|
// entries) to an ArrayAttr of dictionaries.
|
|
SmallVector<Attribute> attrs;
|
|
SmallVector<Attribute> locs;
|
|
for (auto &port : ports) {
|
|
attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
|
|
locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
|
|
}
|
|
|
|
// Add the attributes to the ports.
|
|
if (llvm::any_of(attrs, nonEmptyAttrsFn))
|
|
result.addAttribute(portAttrsName, builder.getArrayAttr(attrs));
|
|
|
|
if (llvm::any_of(locs, nonEmptyLocsFn))
|
|
result.addAttribute(portLocsName, builder.getArrayAttr(locs));
|
|
}
|
|
|
|
template <typename ModuleTy>
|
|
static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result,
|
|
ExternModKind modKind = PlainMod) {
|
|
|
|
using namespace mlir::function_interface_impl;
|
|
auto loc = parser.getCurrentLocation();
|
|
|
|
// Parse the visibility attribute.
|
|
(void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
|
|
|
|
// Parse the name as a symbol.
|
|
StringAttr nameAttr;
|
|
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
|
result.attributes))
|
|
return failure();
|
|
|
|
// Parse the generator information.
|
|
FlatSymbolRefAttr kindAttr;
|
|
if (modKind == GenMod) {
|
|
if (parser.parseComma() ||
|
|
parser.parseAttribute(kindAttr, "generatorKind", result.attributes)) {
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
// Parse the parameters.
|
|
ArrayAttr parameters;
|
|
if (parseOptionalParameterList(parser, parameters))
|
|
return failure();
|
|
|
|
SmallVector<module_like_impl::PortParse> ports;
|
|
TypeAttr modType;
|
|
if (failed(module_like_impl::parseModuleSignature(parser, ports, modType)))
|
|
return failure();
|
|
|
|
// Parse the attribute dict.
|
|
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
|
|
return failure();
|
|
|
|
if (hasAttribute("parameters", result.attributes)) {
|
|
parser.emitError(loc, "explicit `parameters` attributes not allowed");
|
|
return failure();
|
|
}
|
|
|
|
result.addAttribute("parameters", parameters);
|
|
result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), modType);
|
|
addPortAttrsAndLocs(parser.getBuilder(), result, ports,
|
|
ModuleTy::getPerPortAttrsAttrName(result.name),
|
|
ModuleTy::getPortLocsAttrName(result.name));
|
|
|
|
SmallVector<OpAsmParser::Argument, 4> entryArgs;
|
|
for (auto &port : ports)
|
|
if (port.direction != ModulePort::Direction::Output)
|
|
entryArgs.push_back(port);
|
|
|
|
// Parse the optional function body.
|
|
auto *body = result.addRegion();
|
|
if (modKind == PlainMod) {
|
|
if (parser.parseRegion(*body, entryArgs))
|
|
return failure();
|
|
|
|
HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
|
|
}
|
|
return success();
|
|
}
|
|
|
|
ParseResult HWModuleOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
return parseHWModuleOp<HWModuleOp>(parser, result);
|
|
}
|
|
|
|
ParseResult HWModuleExternOp::parse(OpAsmParser &parser,
|
|
OperationState &result) {
|
|
return parseHWModuleOp<HWModuleExternOp>(parser, result, ExternMod);
|
|
}
|
|
|
|
ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser,
|
|
OperationState &result) {
|
|
return parseHWModuleOp<HWModuleGeneratedOp>(parser, result, GenMod);
|
|
}
|
|
|
|
FunctionType getHWModuleOpType(Operation *op) {
|
|
if (auto mod = dyn_cast<HWModuleLike>(op))
|
|
return mod.getHWModuleType().getFuncType();
|
|
return cast<mlir::FunctionOpInterface>(op)
|
|
.getFunctionType()
|
|
.cast<FunctionType>();
|
|
}
|
|
|
|
template <typename ModuleTy>
|
|
static void printModuleOp(OpAsmPrinter &p, ModuleTy mod) {
|
|
p << ' ';
|
|
// Print the visibility of the module.
|
|
StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
|
|
if (auto visibility = mod.getOperation()->template getAttrOfType<StringAttr>(
|
|
visibilityAttrName))
|
|
p << visibility.getValue() << ' ';
|
|
|
|
// Print the operation and the function name.
|
|
p.printSymbolName(SymbolTable::getSymbolName(mod.getOperation()).getValue());
|
|
if (auto gen = dyn_cast<HWModuleGeneratedOp>(mod.getOperation())) {
|
|
p << ", ";
|
|
p.printSymbolName(gen.getGeneratorKind());
|
|
}
|
|
|
|
// Print the parameter list if present.
|
|
printOptionalParameterList(p, mod.getOperation(), mod.getParameters());
|
|
|
|
module_like_impl::printModuleSignatureNew(p, mod.getOperation());
|
|
|
|
SmallVector<StringRef, 3> omittedAttrs;
|
|
if (isa<HWModuleGeneratedOp>(mod.getOperation()))
|
|
omittedAttrs.push_back("generatorKind");
|
|
omittedAttrs.push_back(mod.getPortLocsAttrName());
|
|
omittedAttrs.push_back(mod.getModuleTypeAttrName());
|
|
omittedAttrs.push_back(mod.getPerPortAttrsAttrName());
|
|
omittedAttrs.push_back(mod.getParametersAttrName());
|
|
omittedAttrs.push_back(visibilityAttrName);
|
|
if (auto cmt =
|
|
mod.getOperation()->template getAttrOfType<StringAttr>("comment"))
|
|
if (cmt.getValue().empty())
|
|
omittedAttrs.push_back("comment");
|
|
|
|
mlir::function_interface_impl::printFunctionAttributes(p, mod.getOperation(),
|
|
omittedAttrs);
|
|
}
|
|
|
|
void HWModuleExternOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
|
|
void HWModuleGeneratedOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
|
|
|
|
void HWModuleOp::print(OpAsmPrinter &p) {
|
|
printModuleOp(p, *this);
|
|
|
|
// Print the body if this is not an external function.
|
|
Region &body = getBody();
|
|
if (!body.empty()) {
|
|
p << " ";
|
|
p.printRegion(body, /*printEntryBlockArgs=*/false,
|
|
/*printBlockTerminators=*/true);
|
|
}
|
|
}
|
|
|
|
static LogicalResult verifyModuleCommon(HWModuleLike module) {
|
|
assert(isa<HWModuleLike>(module) &&
|
|
"verifier hook should only be called on modules");
|
|
|
|
auto moduleType = module.getHWModuleType();
|
|
|
|
auto argLocs = module.getInputLocs();
|
|
if (argLocs.size() != moduleType.getNumInputs())
|
|
return module->emitOpError("incorrect number of argument locations");
|
|
|
|
auto resultLocs = module.getOutputLocs();
|
|
if (resultLocs.size() != moduleType.getNumOutputs())
|
|
return module->emitOpError("incorrect number of result locations");
|
|
|
|
SmallPtrSet<Attribute, 4> paramNames;
|
|
|
|
// Check parameter default values are sensible.
|
|
for (auto param : module->getAttrOfType<ArrayAttr>("parameters")) {
|
|
auto paramAttr = param.cast<ParamDeclAttr>();
|
|
|
|
// Check that we don't have any redundant parameter names. These are
|
|
// resolved by string name: reuse of the same name would cause ambiguities.
|
|
if (!paramNames.insert(paramAttr.getName()).second)
|
|
return module->emitOpError("parameter ")
|
|
<< paramAttr << " has the same name as a previous parameter";
|
|
|
|
// Default values are allowed to be missing, check them if present.
|
|
auto value = paramAttr.getValue();
|
|
if (!value)
|
|
continue;
|
|
|
|
auto typedValue = value.dyn_cast<TypedAttr>();
|
|
if (!typedValue)
|
|
return module->emitOpError("parameter ")
|
|
<< paramAttr << " should have a typed value; has value " << value;
|
|
|
|
if (typedValue.getType() != paramAttr.getType())
|
|
return module->emitOpError("parameter ")
|
|
<< paramAttr << " should have type " << paramAttr.getType()
|
|
<< "; has type " << typedValue.getType();
|
|
|
|
// Verify that this is a valid parameter value, disallowing parameter
|
|
// references. We could allow parameters to refer to each other in the
|
|
// future with lexical ordering if there is a need.
|
|
if (failed(checkParameterInContext(value, module, module,
|
|
/*disallowParamRefs=*/true)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult HWModuleOp::verify() {
|
|
if (failed(verifyModuleCommon(*this)))
|
|
return failure();
|
|
|
|
auto type = getModuleType();
|
|
auto *body = getBodyBlock();
|
|
|
|
// Verify the number of block arguments.
|
|
auto numInputs = type.getNumInputs();
|
|
if (body->getNumArguments() != numInputs)
|
|
return emitOpError("entry block must have")
|
|
<< numInputs << " arguments to match module signature";
|
|
|
|
// Verify that the block arguments match the op's attributes.
|
|
for (auto [arg, type, loc] : llvm::zip(getBodyBlock()->getArguments(),
|
|
getInputTypes(), getInputLocs())) {
|
|
if (arg.getType() != type)
|
|
return emitOpError("block argument types should match signature types");
|
|
if (arg.getLoc() != loc.cast<LocationAttr>())
|
|
return emitOpError(
|
|
"block argument locations should match signature locations");
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult HWModuleExternOp::verify() { return verifyModuleCommon(*this); }
|
|
|
|
std::pair<StringAttr, BlockArgument>
|
|
HWModuleOp::insertInput(unsigned index, StringAttr name, Type ty) {
|
|
// Find a unique name for the wire.
|
|
Namespace ns;
|
|
auto ports = getPortList();
|
|
for (auto port : ports)
|
|
ns.newName(port.name.getValue());
|
|
auto nameAttr = StringAttr::get(getContext(), ns.newName(name.getValue()));
|
|
|
|
Block *body = getBodyBlock();
|
|
|
|
// Create a new port for the host clock.
|
|
PortInfo port;
|
|
port.name = nameAttr;
|
|
port.dir = ModulePort::Direction::Input;
|
|
port.type = ty;
|
|
hw::modifyModulePorts(getOperation(), {std::make_pair(index, port)}, {}, {},
|
|
{}, body);
|
|
|
|
// Add a new argument.
|
|
return {nameAttr, body->getArgument(index)};
|
|
}
|
|
|
|
void HWModuleOp::insertOutputs(unsigned index,
|
|
ArrayRef<std::pair<StringAttr, Value>> outputs) {
|
|
|
|
auto output = cast<OutputOp>(getBodyBlock()->getTerminator());
|
|
assert(index <= output->getNumOperands() && "invalid output index");
|
|
|
|
// Rewrite the port list of the module.
|
|
SmallVector<std::pair<unsigned, PortInfo>> indexedNewPorts;
|
|
for (auto &[name, value] : outputs) {
|
|
PortInfo port;
|
|
port.name = name;
|
|
port.dir = ModulePort::Direction::Output;
|
|
port.type = value.getType();
|
|
indexedNewPorts.emplace_back(index, port);
|
|
}
|
|
hw::modifyModulePorts(getOperation(), {}, indexedNewPorts, {}, {},
|
|
getBodyBlock());
|
|
|
|
// Rewrite the output op.
|
|
for (auto &[name, value] : outputs)
|
|
output->insertOperands(index++, value);
|
|
}
|
|
|
|
void HWModuleOp::appendOutputs(ArrayRef<std::pair<StringAttr, Value>> outputs) {
|
|
return insertOutputs(getNumOutputPorts(), outputs);
|
|
}
|
|
|
|
void HWModuleOp::getAsmBlockArgumentNames(mlir::Region ®ion,
|
|
mlir::OpAsmSetValueNameFn setNameFn) {
|
|
getAsmBlockArgumentNamesImpl(region, setNameFn);
|
|
}
|
|
|
|
void HWModuleExternOp::getAsmBlockArgumentNames(
|
|
mlir::Region ®ion, mlir::OpAsmSetValueNameFn setNameFn) {
|
|
getAsmBlockArgumentNamesImpl(region, setNameFn);
|
|
}
|
|
|
|
template <typename ModTy>
|
|
static SmallVector<Location> getAllPortLocs(ModTy module) {
|
|
auto locs = module.getPortLocs();
|
|
if (locs) {
|
|
SmallVector<Location> retval;
|
|
for (auto l : *locs)
|
|
retval.push_back(cast<Location>(l));
|
|
// Either we have a length of 0 or the correct length
|
|
assert(!locs->size() || locs->size() == module.getNumPorts());
|
|
return retval;
|
|
}
|
|
return SmallVector<Location>(module.getNumPorts(),
|
|
UnknownLoc::get(module.getContext()));
|
|
}
|
|
|
|
SmallVector<Location> HWModuleOp::getAllPortLocs() {
|
|
return ::getAllPortLocs(*this);
|
|
}
|
|
|
|
SmallVector<Location> HWModuleExternOp::getAllPortLocs() {
|
|
return ::getAllPortLocs(*this);
|
|
}
|
|
|
|
SmallVector<Location> HWModuleGeneratedOp::getAllPortLocs() {
|
|
return ::getAllPortLocs(*this);
|
|
}
|
|
|
|
template <typename ModTy>
|
|
static void setAllPortLocs(ArrayRef<Location> locs, ModTy module) {
|
|
std::vector<Attribute> nLocs(locs.begin(), locs.end());
|
|
module.setPortLocsAttr(ArrayAttr::get(module.getContext(), nLocs));
|
|
}
|
|
|
|
void HWModuleOp::setAllPortLocs(ArrayRef<Location> locs) {
|
|
::setAllPortLocs(locs, *this);
|
|
}
|
|
|
|
void HWModuleExternOp::setAllPortLocs(ArrayRef<Location> locs) {
|
|
::setAllPortLocs(locs, *this);
|
|
}
|
|
|
|
void HWModuleGeneratedOp::setAllPortLocs(ArrayRef<Location> locs) {
|
|
::setAllPortLocs(locs, *this);
|
|
}
|
|
|
|
template <typename ModTy>
|
|
static void setAllPortNames(ArrayRef<Attribute> names, ModTy module) {
|
|
auto numInputs = module.getNumInputPorts();
|
|
SmallVector<Attribute> argNames(names.begin(), names.begin() + numInputs);
|
|
SmallVector<Attribute> resNames(names.begin() + numInputs, names.end());
|
|
auto oldType = module.getModuleType();
|
|
SmallVector<ModulePort> newPorts(oldType.getPorts().begin(),
|
|
oldType.getPorts().end());
|
|
for (size_t i = 0UL, e = newPorts.size(); i != e; ++i)
|
|
newPorts[i].name = cast<StringAttr>(names[i]);
|
|
auto newType = ModuleType::get(module.getContext(), newPorts);
|
|
module.setModuleType(newType);
|
|
}
|
|
|
|
void HWModuleOp::setAllPortNames(ArrayRef<Attribute> names) {
|
|
::setAllPortNames(names, *this);
|
|
}
|
|
|
|
void HWModuleExternOp::setAllPortNames(ArrayRef<Attribute> names) {
|
|
::setAllPortNames(names, *this);
|
|
}
|
|
|
|
void HWModuleGeneratedOp::setAllPortNames(ArrayRef<Attribute> names) {
|
|
::setAllPortNames(names, *this);
|
|
}
|
|
|
|
ArrayRef<Attribute> HWModuleOp::getAllPortAttrs() {
|
|
auto attrs = getPerPortAttrs();
|
|
if (attrs && !attrs->empty())
|
|
return attrs->getValue();
|
|
return {};
|
|
}
|
|
|
|
ArrayRef<Attribute> HWModuleExternOp::getAllPortAttrs() {
|
|
auto attrs = getPerPortAttrs();
|
|
if (attrs && !attrs->empty())
|
|
return attrs->getValue();
|
|
return {};
|
|
}
|
|
|
|
ArrayRef<Attribute> HWModuleGeneratedOp::getAllPortAttrs() {
|
|
auto attrs = getPerPortAttrs();
|
|
if (attrs && !attrs->empty())
|
|
return attrs->getValue();
|
|
return {};
|
|
}
|
|
|
|
void HWModuleOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
|
|
setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
|
|
}
|
|
|
|
void HWModuleExternOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
|
|
setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
|
|
}
|
|
|
|
void HWModuleGeneratedOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
|
|
setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
|
|
}
|
|
|
|
void HWModuleOp::removeAllPortAttrs() {
|
|
setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
|
|
}
|
|
|
|
void HWModuleExternOp::removeAllPortAttrs() {
|
|
setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
|
|
}
|
|
|
|
void HWModuleGeneratedOp::removeAllPortAttrs() {
|
|
setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
|
|
}
|
|
|
|
// This probably does really unexpected stuff when you change the number of
|
|
|
|
template <typename ModTy>
|
|
static void setHWModuleType(ModTy &mod, ModuleType type) {
|
|
auto argAttrs = mod.getAllInputAttrs();
|
|
auto resAttrs = mod.getAllOutputAttrs();
|
|
mod.setModuleTypeAttr(TypeAttr::get(type));
|
|
unsigned newNumArgs = type.getNumInputs();
|
|
unsigned newNumResults = type.getNumOutputs();
|
|
|
|
auto emptyDict = DictionaryAttr::get(mod.getContext());
|
|
argAttrs.resize(newNumArgs, emptyDict);
|
|
resAttrs.resize(newNumResults, emptyDict);
|
|
|
|
SmallVector<Attribute> attrs;
|
|
attrs.append(argAttrs.begin(), argAttrs.end());
|
|
attrs.append(resAttrs.begin(), resAttrs.end());
|
|
|
|
if (attrs.empty())
|
|
return mod.removeAllPortAttrs();
|
|
mod.setAllPortAttrs(attrs);
|
|
}
|
|
|
|
void HWModuleOp::setHWModuleType(ModuleType type) {
|
|
return ::setHWModuleType(*this, type);
|
|
}
|
|
|
|
void HWModuleExternOp::setHWModuleType(ModuleType type) {
|
|
return ::setHWModuleType(*this, type);
|
|
}
|
|
|
|
void HWModuleGeneratedOp::setHWModuleType(ModuleType type) {
|
|
return ::setHWModuleType(*this, type);
|
|
}
|
|
|
|
/// Lookup the generator for the symbol. This returns null on
|
|
/// invalid IR.
|
|
Operation *HWModuleGeneratedOp::getGeneratorKindOp() {
|
|
auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
|
|
return topLevelModuleOp.lookupSymbol(getGeneratorKind());
|
|
}
|
|
|
|
LogicalResult
|
|
HWModuleGeneratedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
auto *referencedKind =
|
|
symbolTable.lookupNearestSymbolFrom(*this, getGeneratorKindAttr());
|
|
|
|
if (referencedKind == nullptr)
|
|
return emitError("Cannot find generator definition '")
|
|
<< getGeneratorKind() << "'";
|
|
|
|
if (!isa<HWGeneratorSchemaOp>(referencedKind))
|
|
return emitError("Symbol resolved to '")
|
|
<< referencedKind->getName()
|
|
<< "' which is not a HWGeneratorSchemaOp";
|
|
|
|
auto referencedKindOp = dyn_cast<HWGeneratorSchemaOp>(referencedKind);
|
|
auto paramRef = referencedKindOp.getRequiredAttrs();
|
|
auto dict = (*this)->getAttrDictionary();
|
|
for (auto str : paramRef) {
|
|
auto strAttr = str.dyn_cast<StringAttr>();
|
|
if (!strAttr)
|
|
return emitError("Unknown attribute type, expected a string");
|
|
if (!dict.get(strAttr.getValue()))
|
|
return emitError("Missing attribute '") << strAttr.getValue() << "'";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult HWModuleGeneratedOp::verify() {
|
|
return verifyModuleCommon(*this);
|
|
}
|
|
|
|
void HWModuleGeneratedOp::getAsmBlockArgumentNames(
|
|
mlir::Region ®ion, mlir::OpAsmSetValueNameFn setNameFn) {
|
|
getAsmBlockArgumentNamesImpl(region, setNameFn);
|
|
}
|
|
|
|
LogicalResult HWModuleOp::verifyBody() { return success(); }
|
|
|
|
template <typename ModuleTy>
|
|
static SmallVector<PortInfo> getPortList(ModuleTy &mod) {
|
|
auto modTy = mod.getHWModuleType();
|
|
auto emptyDict = DictionaryAttr::get(mod.getContext());
|
|
SmallVector<PortInfo> retval;
|
|
auto locs = mod.getAllPortLocs();
|
|
for (unsigned i = 0, e = modTy.getNumPorts(); i < e; ++i) {
|
|
LocationAttr loc = locs[i];
|
|
DictionaryAttr attrs =
|
|
dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(i));
|
|
if (!attrs)
|
|
attrs = emptyDict;
|
|
retval.push_back({modTy.getPorts()[i],
|
|
modTy.isOutput(i) ? modTy.getOutputIdForPortId(i)
|
|
: modTy.getInputIdForPortId(i),
|
|
attrs, loc});
|
|
}
|
|
return retval;
|
|
}
|
|
|
|
template <typename ModuleTy>
|
|
static PortInfo getPort(ModuleTy &mod, size_t idx) {
|
|
auto modTy = mod.getHWModuleType();
|
|
auto emptyDict = DictionaryAttr::get(mod.getContext());
|
|
LocationAttr loc = mod.getPortLoc(idx);
|
|
DictionaryAttr attrs =
|
|
dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(idx));
|
|
if (!attrs)
|
|
attrs = emptyDict;
|
|
return {modTy.getPorts()[idx],
|
|
modTy.isOutput(idx) ? modTy.getOutputIdForPortId(idx)
|
|
: modTy.getInputIdForPortId(idx),
|
|
attrs, loc};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// InstanceOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Create a instance that refers to a known module.
|
|
void InstanceOp::build(OpBuilder &builder, OperationState &result,
|
|
Operation *module, StringAttr name,
|
|
ArrayRef<Value> inputs, ArrayAttr parameters,
|
|
InnerSymAttr innerSym) {
|
|
if (!parameters)
|
|
parameters = builder.getArrayAttr({});
|
|
|
|
auto mod = cast<hw::HWModuleLike>(module);
|
|
auto argNames = builder.getArrayAttr(mod.getInputNames());
|
|
auto resultNames = builder.getArrayAttr(mod.getOutputNames());
|
|
|
|
// Try to resolve the parameterized module type. If failed, use the module's
|
|
// parmeterized type. If the client doesn't fix this error, the verifier will
|
|
// fail.
|
|
ModuleType modType = mod.getHWModuleType();
|
|
FailureOr<ModuleType> resolvedModType = modType.resolveParametricTypes(
|
|
parameters, result.location, /*emitErrors=*/false);
|
|
if (succeeded(resolvedModType))
|
|
modType = *resolvedModType;
|
|
FunctionType funcType = resolvedModType->getFuncType();
|
|
build(builder, result, funcType.getResults(), name,
|
|
FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), inputs,
|
|
argNames, resultNames, parameters, innerSym);
|
|
}
|
|
|
|
std::optional<size_t> InstanceOp::getTargetResultIndex() {
|
|
// Inner symbols on instance operations target the op not any result.
|
|
return std::nullopt;
|
|
}
|
|
|
|
LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
return instance_like_impl::verifyInstanceOfHWModule(
|
|
*this, getModuleNameAttr(), getInputs(), getResultTypes(), getArgNames(),
|
|
getResultNames(), getParameters(), symbolTable);
|
|
}
|
|
|
|
LogicalResult InstanceOp::verify() {
|
|
auto module = (*this)->getParentOfType<HWModuleOp>();
|
|
if (!module)
|
|
return success();
|
|
|
|
auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
|
|
instance_like_impl::EmitErrorFn emitError =
|
|
[&](const std::function<bool(InFlightDiagnostic &)> &fn) {
|
|
auto diag = emitOpError();
|
|
if (fn(diag))
|
|
diag.attachNote(module->getLoc()) << "module declared here";
|
|
};
|
|
return instance_like_impl::verifyParameterStructure(
|
|
getParameters(), moduleParameters, emitError);
|
|
}
|
|
|
|
ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
StringAttr instanceNameAttr;
|
|
InnerSymAttr innerSym;
|
|
FlatSymbolRefAttr moduleNameAttr;
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
|
|
SmallVector<Type, 1> inputsTypes, allResultTypes;
|
|
ArrayAttr argNames, resultNames, parameters;
|
|
auto noneType = parser.getBuilder().getType<NoneType>();
|
|
|
|
if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
|
|
result.attributes))
|
|
return failure();
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("sym"))) {
|
|
// Parsing an optional symbol name doesn't fail, so no need to check the
|
|
// result.
|
|
if (parser.parseCustomAttributeWithFallback(innerSym))
|
|
return failure();
|
|
result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
|
|
}
|
|
|
|
llvm::SMLoc parametersLoc, inputsOperandsLoc;
|
|
if (parser.parseAttribute(moduleNameAttr, noneType, "moduleName",
|
|
result.attributes) ||
|
|
parser.getCurrentLocation(¶metersLoc) ||
|
|
parseOptionalParameterList(parser, parameters) ||
|
|
parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
|
|
parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
|
|
result.operands) ||
|
|
parser.parseArrow() ||
|
|
parseOutputPortList(parser, allResultTypes, resultNames) ||
|
|
parser.parseOptionalAttrDict(result.attributes)) {
|
|
return failure();
|
|
}
|
|
|
|
result.addAttribute("argNames", argNames);
|
|
result.addAttribute("resultNames", resultNames);
|
|
result.addAttribute("parameters", parameters);
|
|
result.addTypes(allResultTypes);
|
|
return success();
|
|
}
|
|
|
|
void InstanceOp::print(OpAsmPrinter &p) {
|
|
p << ' ';
|
|
p.printAttributeWithoutType(getInstanceNameAttr());
|
|
if (auto attr = getInnerSymAttr()) {
|
|
p << " sym ";
|
|
attr.print(p);
|
|
}
|
|
p << ' ';
|
|
p.printAttributeWithoutType(getModuleNameAttr());
|
|
printOptionalParameterList(p, *this, getParameters());
|
|
printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
|
|
getArgNames());
|
|
p << " -> ";
|
|
printOutputPortList(p, *this, getResultTypes(), getResultNames());
|
|
|
|
p.printOptionalAttrDict(
|
|
(*this)->getAttrs(),
|
|
/*elidedAttrs=*/{"instanceName",
|
|
InnerSymbolTable::getInnerSymbolAttrName(), "moduleName",
|
|
"argNames", "resultNames", "parameters"});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// InstanceChoiceOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
|
|
// Inner symbols on instance operations target the op not any result.
|
|
return std::nullopt;
|
|
}
|
|
|
|
LogicalResult
|
|
InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
for (Attribute name : getModuleNamesAttr()) {
|
|
if (failed(instance_like_impl::verifyInstanceOfHWModule(
|
|
*this, name.cast<FlatSymbolRefAttr>(), getInputs(),
|
|
getResultTypes(), getArgNames(), getResultNames(), getParameters(),
|
|
symbolTable))) {
|
|
return failure();
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult InstanceChoiceOp::verify() {
|
|
auto module = (*this)->getParentOfType<HWModuleOp>();
|
|
if (!module)
|
|
return success();
|
|
|
|
auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
|
|
instance_like_impl::EmitErrorFn emitError =
|
|
[&](const std::function<bool(InFlightDiagnostic &)> &fn) {
|
|
auto diag = emitOpError();
|
|
if (fn(diag))
|
|
diag.attachNote(module->getLoc()) << "module declared here";
|
|
};
|
|
return instance_like_impl::verifyParameterStructure(
|
|
getParameters(), moduleParameters, emitError);
|
|
}
|
|
|
|
ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
|
|
OperationState &result) {
|
|
StringAttr instanceNameAttr;
|
|
InnerSymAttr innerSym;
|
|
SmallVector<Attribute> moduleNames;
|
|
SmallVector<Attribute> targetNames;
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
|
|
SmallVector<Type, 1> inputsTypes, allResultTypes;
|
|
ArrayAttr argNames, resultNames, parameters;
|
|
auto noneType = parser.getBuilder().getType<NoneType>();
|
|
|
|
if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
|
|
result.attributes))
|
|
return failure();
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("sym"))) {
|
|
// Parsing an optional symbol name doesn't fail, so no need to check the
|
|
// result.
|
|
if (parser.parseCustomAttributeWithFallback(innerSym))
|
|
return failure();
|
|
result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
|
|
}
|
|
|
|
FlatSymbolRefAttr defaultModuleName;
|
|
if (parser.parseAttribute(defaultModuleName))
|
|
return failure();
|
|
moduleNames.push_back(defaultModuleName);
|
|
|
|
while (succeeded(parser.parseOptionalKeyword("or"))) {
|
|
FlatSymbolRefAttr moduleName;
|
|
StringAttr targetName;
|
|
if (parser.parseAttribute(moduleName) ||
|
|
parser.parseOptionalKeyword("if") || parser.parseAttribute(targetName))
|
|
return failure();
|
|
moduleNames.push_back(moduleName);
|
|
targetNames.push_back(targetName);
|
|
}
|
|
|
|
llvm::SMLoc parametersLoc, inputsOperandsLoc;
|
|
if (parser.getCurrentLocation(¶metersLoc) ||
|
|
parseOptionalParameterList(parser, parameters) ||
|
|
parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
|
|
parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
|
|
result.operands) ||
|
|
parser.parseArrow() ||
|
|
parseOutputPortList(parser, allResultTypes, resultNames) ||
|
|
parser.parseOptionalAttrDict(result.attributes)) {
|
|
return failure();
|
|
}
|
|
|
|
result.addAttribute("moduleNames",
|
|
ArrayAttr::get(parser.getContext(), moduleNames));
|
|
result.addAttribute("targetNames",
|
|
ArrayAttr::get(parser.getContext(), targetNames));
|
|
result.addAttribute("argNames", argNames);
|
|
result.addAttribute("resultNames", resultNames);
|
|
result.addAttribute("parameters", parameters);
|
|
result.addTypes(allResultTypes);
|
|
return success();
|
|
}
|
|
|
|
void InstanceChoiceOp::print(OpAsmPrinter &p) {
|
|
p << ' ';
|
|
p.printAttributeWithoutType(getInstanceNameAttr());
|
|
if (auto attr = getInnerSymAttr()) {
|
|
p << " sym ";
|
|
attr.print(p);
|
|
}
|
|
p << ' ';
|
|
|
|
auto moduleNames = getModuleNamesAttr();
|
|
auto targetNames = getTargetNamesAttr();
|
|
assert(moduleNames.size() == targetNames.size() + 1);
|
|
|
|
p.printAttributeWithoutType(moduleNames[0]);
|
|
for (size_t i = 0, n = targetNames.size(); i < n; ++i) {
|
|
p << " or ";
|
|
p.printAttributeWithoutType(moduleNames[i + 1]);
|
|
p << " if ";
|
|
p.printAttributeWithoutType(targetNames[i]);
|
|
}
|
|
|
|
printOptionalParameterList(p, *this, getParameters());
|
|
printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
|
|
getArgNames());
|
|
p << " -> ";
|
|
printOutputPortList(p, *this, getResultTypes(), getResultNames());
|
|
|
|
p.printOptionalAttrDict(
|
|
(*this)->getAttrs(),
|
|
/*elidedAttrs=*/{"instanceName",
|
|
InnerSymbolTable::getInnerSymbolAttrName(),
|
|
"moduleNames", "targetNames", "argNames", "resultNames",
|
|
"parameters"});
|
|
}
|
|
|
|
ArrayAttr InstanceChoiceOp::getReferencedModuleNamesAttr() {
|
|
SmallVector<Attribute> moduleNames;
|
|
for (Attribute attr : getModuleNamesAttr()) {
|
|
moduleNames.push_back(attr.cast<FlatSymbolRefAttr>().getAttr());
|
|
}
|
|
return ArrayAttr::get(getContext(), moduleNames);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// HWOutputOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Verify that the num of operands and types fit the declared results.
|
|
LogicalResult OutputOp::verify() {
|
|
// Check that the we (hw.output) have the same number of operands as our
|
|
// region has results.
|
|
ModuleType modType;
|
|
if (auto mod = dyn_cast<HWModuleOp>((*this)->getParentOp()))
|
|
modType = mod.getHWModuleType();
|
|
else {
|
|
emitOpError("must have a module parent");
|
|
return failure();
|
|
}
|
|
auto modResults = modType.getOutputTypes();
|
|
OperandRange outputValues = getOperands();
|
|
if (modResults.size() != outputValues.size()) {
|
|
emitOpError("must have same number of operands as region results.");
|
|
return failure();
|
|
}
|
|
|
|
// Check that the types of our operands and the region's results match.
|
|
for (size_t i = 0, e = modResults.size(); i < e; ++i) {
|
|
if (modResults[i] != outputValues[i].getType()) {
|
|
emitOpError("output types must match module. In "
|
|
"operand ")
|
|
<< i << ", expected " << modResults[i] << ", but got "
|
|
<< outputValues[i].getType() << ".";
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Other Operations
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType,
|
|
Type &idxType) {
|
|
Type type;
|
|
if (p.parseType(type))
|
|
return p.emitError(p.getCurrentLocation(), "Expected type");
|
|
auto arrType = type_dyn_cast<ArrayType>(type);
|
|
if (!arrType)
|
|
return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
|
|
srcType = type;
|
|
unsigned idxWidth = llvm::Log2_64_Ceil(arrType.getNumElements());
|
|
idxType = IntegerType::get(p.getBuilder().getContext(), idxWidth);
|
|
return success();
|
|
}
|
|
|
|
static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType,
|
|
Type idxType) {
|
|
p.printType(srcType);
|
|
}
|
|
|
|
ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
|
|
llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
|
|
Type elemType;
|
|
|
|
if (parser.parseOperandList(operands) ||
|
|
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
|
parser.parseType(elemType))
|
|
return failure();
|
|
|
|
if (operands.size() == 0)
|
|
return parser.emitError(inputOperandsLoc,
|
|
"Cannot construct an array of length 0");
|
|
result.addTypes({ArrayType::get(elemType, operands.size())});
|
|
|
|
for (auto operand : operands)
|
|
if (parser.resolveOperand(operand, elemType, result.operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void ArrayCreateOp::print(OpAsmPrinter &p) {
|
|
p << " ";
|
|
p.printOperands(getInputs());
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
p << " : " << getInputs()[0].getType();
|
|
}
|
|
|
|
void ArrayCreateOp::build(OpBuilder &b, OperationState &state,
|
|
ValueRange values) {
|
|
assert(values.size() > 0 && "Cannot build array of zero elements");
|
|
Type elemType = values[0].getType();
|
|
assert(llvm::all_of(
|
|
values,
|
|
[elemType](Value v) -> bool { return v.getType() == elemType; }) &&
|
|
"All values must have same type.");
|
|
build(b, state, ArrayType::get(elemType, values.size()), values);
|
|
}
|
|
|
|
LogicalResult ArrayCreateOp::verify() {
|
|
unsigned returnSize = getType().cast<ArrayType>().getNumElements();
|
|
if (getInputs().size() != returnSize)
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
OpFoldResult ArrayCreateOp::fold(FoldAdaptor adaptor) {
|
|
if (llvm::any_of(adaptor.getInputs(), [](Attribute attr) { return !attr; }))
|
|
return {};
|
|
return ArrayAttr::get(getContext(), adaptor.getInputs());
|
|
}
|
|
|
|
// Check whether an integer value is an offset from a base.
|
|
bool hw::isOffset(Value base, Value index, uint64_t offset) {
|
|
if (auto constBase = base.getDefiningOp<hw::ConstantOp>()) {
|
|
if (auto constIndex = index.getDefiningOp<hw::ConstantOp>()) {
|
|
// If both values are a constant, check if index == base + offset.
|
|
// To account for overflow, the addition is performed with an extra bit
|
|
// and the offset is asserted to fit in the bit width of the base.
|
|
auto baseValue = constBase.getValue();
|
|
auto indexValue = constIndex.getValue();
|
|
|
|
unsigned bits = baseValue.getBitWidth();
|
|
assert(bits == indexValue.getBitWidth() && "mismatched widths");
|
|
|
|
if (bits < 64 && offset >= (1ull << bits))
|
|
return false;
|
|
|
|
APInt baseExt = baseValue.zextOrTrunc(bits + 1);
|
|
APInt indexExt = indexValue.zextOrTrunc(bits + 1);
|
|
return baseExt + offset == indexExt;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Canonicalize a create of consecutive elements to a slice.
|
|
static LogicalResult foldCreateToSlice(ArrayCreateOp op,
|
|
PatternRewriter &rewriter) {
|
|
// Do not canonicalize create of get into a slice.
|
|
auto arrayTy = hw::type_cast<ArrayType>(op.getType());
|
|
if (arrayTy.getNumElements() <= 1)
|
|
return failure();
|
|
auto elemTy = arrayTy.getElementType();
|
|
|
|
// Check if create arguments are consecutive elements of the same array.
|
|
// Attempt to break a create of gets into a sequence of consecutive intervals.
|
|
struct Chunk {
|
|
Value input;
|
|
Value index;
|
|
size_t size;
|
|
};
|
|
SmallVector<Chunk> chunks;
|
|
for (Value value : llvm::reverse(op.getInputs())) {
|
|
auto get = value.getDefiningOp<ArrayGetOp>();
|
|
if (!get)
|
|
return failure();
|
|
|
|
Value input = get.getInput();
|
|
Value index = get.getIndex();
|
|
if (!chunks.empty()) {
|
|
auto &c = *chunks.rbegin();
|
|
if (c.input == get.getInput() && isOffset(c.index, index, c.size)) {
|
|
c.size++;
|
|
continue;
|
|
}
|
|
}
|
|
|
|
chunks.push_back(Chunk{input, index, 1});
|
|
}
|
|
|
|
// If there is a single slice, eliminate the create.
|
|
if (chunks.size() == 1) {
|
|
auto &chunk = chunks[0];
|
|
rewriter.replaceOp(op, rewriter.createOrFold<ArraySliceOp>(
|
|
op.getLoc(), arrayTy, chunk.input, chunk.index));
|
|
return success();
|
|
}
|
|
|
|
// If the number of chunks is significantly less than the number of
|
|
// elements, replace the create with a concat of the identified slices.
|
|
if (chunks.size() * 2 < arrayTy.getNumElements()) {
|
|
SmallVector<Value> slices;
|
|
for (auto &chunk : llvm::reverse(chunks)) {
|
|
auto sliceTy = ArrayType::get(elemTy, chunk.size);
|
|
slices.push_back(rewriter.createOrFold<ArraySliceOp>(
|
|
op.getLoc(), sliceTy, chunk.input, chunk.index));
|
|
}
|
|
rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, arrayTy, slices);
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
LogicalResult ArrayCreateOp::canonicalize(ArrayCreateOp op,
|
|
PatternRewriter &rewriter) {
|
|
if (succeeded(foldCreateToSlice(op, rewriter)))
|
|
return success();
|
|
return failure();
|
|
}
|
|
|
|
Value ArrayCreateOp::getUniformElement() {
|
|
if (!getInputs().empty() && llvm::all_equal(getInputs()))
|
|
return getInputs()[0];
|
|
return {};
|
|
}
|
|
|
|
static std::optional<uint64_t> getUIntFromValue(Value value) {
|
|
auto idxOp = dyn_cast_or_null<ConstantOp>(value.getDefiningOp());
|
|
if (!idxOp)
|
|
return std::nullopt;
|
|
APInt idxAttr = idxOp.getValue();
|
|
if (idxAttr.getBitWidth() > 64)
|
|
return std::nullopt;
|
|
return idxAttr.getLimitedValue();
|
|
}
|
|
|
|
LogicalResult ArraySliceOp::verify() {
|
|
unsigned inputSize =
|
|
type_cast<ArrayType>(getInput().getType()).getNumElements();
|
|
if (llvm::Log2_64_Ceil(inputSize) !=
|
|
getLowIndex().getType().getIntOrFloatBitWidth())
|
|
return emitOpError(
|
|
"ArraySlice: index width must match clog2 of array size");
|
|
return success();
|
|
}
|
|
|
|
OpFoldResult ArraySliceOp::fold(FoldAdaptor adaptor) {
|
|
// If we are slicing the entire input, then return it.
|
|
if (getType() == getInput().getType())
|
|
return getInput();
|
|
return {};
|
|
}
|
|
|
|
LogicalResult ArraySliceOp::canonicalize(ArraySliceOp op,
|
|
PatternRewriter &rewriter) {
|
|
auto sliceTy = hw::type_cast<ArrayType>(op.getType());
|
|
auto elemTy = sliceTy.getElementType();
|
|
uint64_t sliceSize = sliceTy.getNumElements();
|
|
if (sliceSize == 0)
|
|
return failure();
|
|
|
|
if (sliceSize == 1) {
|
|
// slice(a, n) -> create(a[n])
|
|
auto get = rewriter.create<ArrayGetOp>(op.getLoc(), op.getInput(),
|
|
op.getLowIndex());
|
|
rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
|
|
get.getResult());
|
|
return success();
|
|
}
|
|
|
|
auto offsetOpt = getUIntFromValue(op.getLowIndex());
|
|
if (!offsetOpt)
|
|
return failure();
|
|
|
|
auto inputOp = op.getInput().getDefiningOp();
|
|
if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
|
|
// slice(slice(a, n), m) -> slice(a, n + m)
|
|
if (inputSlice == op)
|
|
return failure();
|
|
|
|
auto inputIndex = inputSlice.getLowIndex();
|
|
auto inputOffsetOpt = getUIntFromValue(inputIndex);
|
|
if (!inputOffsetOpt)
|
|
return failure();
|
|
|
|
uint64_t offset = *offsetOpt + *inputOffsetOpt;
|
|
auto lowIndex =
|
|
rewriter.create<ConstantOp>(op.getLoc(), inputIndex.getType(), offset);
|
|
rewriter.replaceOpWithNewOp<ArraySliceOp>(op, op.getType(),
|
|
inputSlice.getInput(), lowIndex);
|
|
return success();
|
|
}
|
|
|
|
if (auto inputCreate = dyn_cast_or_null<ArrayCreateOp>(inputOp)) {
|
|
// slice(create(a0, a1, ..., an), m) -> create(am, ...)
|
|
auto inputs = inputCreate.getInputs();
|
|
|
|
uint64_t begin = inputs.size() - *offsetOpt - sliceSize;
|
|
rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
|
|
inputs.slice(begin, sliceSize));
|
|
return success();
|
|
}
|
|
|
|
if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
|
|
// slice(concat(a1, a2, ...)) -> concat(a2, slice(a3, ..), ...)
|
|
SmallVector<Value> chunks;
|
|
uint64_t sliceStart = *offsetOpt;
|
|
for (auto input : llvm::reverse(inputConcat.getInputs())) {
|
|
// Check whether the input intersects with the slice.
|
|
uint64_t inputSize =
|
|
hw::type_cast<ArrayType>(input.getType()).getNumElements();
|
|
if (inputSize == 0 || inputSize <= sliceStart) {
|
|
sliceStart -= inputSize;
|
|
continue;
|
|
}
|
|
|
|
// Find the indices to slice from this input by intersection.
|
|
uint64_t cutEnd = std::min(inputSize, sliceStart + sliceSize);
|
|
uint64_t cutSize = cutEnd - sliceStart;
|
|
assert(cutSize != 0 && "slice cannot be empty");
|
|
|
|
if (cutSize == inputSize) {
|
|
// The whole input fits in the slice, add it.
|
|
assert(sliceStart == 0 && "invalid cut size");
|
|
chunks.push_back(input);
|
|
} else {
|
|
// Slice the required bits from the input.
|
|
unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize);
|
|
auto lowIndex = rewriter.create<ConstantOp>(
|
|
op.getLoc(), rewriter.getIntegerType(width), sliceStart);
|
|
chunks.push_back(rewriter.create<ArraySliceOp>(
|
|
op.getLoc(), hw::ArrayType::get(elemTy, cutSize), input, lowIndex));
|
|
}
|
|
|
|
sliceStart = 0;
|
|
sliceSize -= cutSize;
|
|
if (sliceSize == 0)
|
|
break;
|
|
}
|
|
|
|
assert(chunks.size() > 0 && "missing sliced items");
|
|
if (chunks.size() == 1)
|
|
rewriter.replaceOp(op, chunks[0]);
|
|
else
|
|
rewriter.replaceOpWithNewOp<ArrayConcatOp>(
|
|
op, llvm::to_vector(llvm::reverse(chunks)));
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ArrayConcatOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult parseArrayConcatTypes(OpAsmParser &p,
|
|
SmallVectorImpl<Type> &inputTypes,
|
|
Type &resultType) {
|
|
Type elemType;
|
|
uint64_t resultSize = 0;
|
|
|
|
auto parseElement = [&]() -> ParseResult {
|
|
Type ty;
|
|
if (p.parseType(ty))
|
|
return failure();
|
|
auto arrTy = type_dyn_cast<ArrayType>(ty);
|
|
if (!arrTy)
|
|
return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
|
|
if (elemType && elemType != arrTy.getElementType())
|
|
return p.emitError(p.getCurrentLocation(), "Expected array element type ")
|
|
<< elemType;
|
|
|
|
elemType = arrTy.getElementType();
|
|
inputTypes.push_back(ty);
|
|
resultSize += arrTy.getNumElements();
|
|
return success();
|
|
};
|
|
|
|
if (p.parseCommaSeparatedList(parseElement))
|
|
return failure();
|
|
|
|
resultType = ArrayType::get(elemType, resultSize);
|
|
return success();
|
|
}
|
|
|
|
static void printArrayConcatTypes(OpAsmPrinter &p, Operation *,
|
|
TypeRange inputTypes, Type resultType) {
|
|
llvm::interleaveComma(inputTypes, p, [&p](Type t) { p << t; });
|
|
}
|
|
|
|
void ArrayConcatOp::build(OpBuilder &b, OperationState &state,
|
|
ValueRange values) {
|
|
assert(!values.empty() && "Cannot build array of zero elements");
|
|
ArrayType arrayTy = values[0].getType().cast<ArrayType>();
|
|
Type elemTy = arrayTy.getElementType();
|
|
assert(llvm::all_of(values,
|
|
[elemTy](Value v) -> bool {
|
|
return v.getType().isa<ArrayType>() &&
|
|
v.getType().cast<ArrayType>().getElementType() ==
|
|
elemTy;
|
|
}) &&
|
|
"All values must be of ArrayType with the same element type.");
|
|
|
|
uint64_t resultSize = 0;
|
|
for (Value val : values)
|
|
resultSize += val.getType().cast<ArrayType>().getNumElements();
|
|
build(b, state, ArrayType::get(elemTy, resultSize), values);
|
|
}
|
|
|
|
OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) {
|
|
auto inputs = adaptor.getInputs();
|
|
SmallVector<Attribute> array;
|
|
for (size_t i = 0, e = getNumOperands(); i < e; ++i) {
|
|
if (!inputs[i])
|
|
return {};
|
|
llvm::copy(inputs[i].cast<ArrayAttr>(), std::back_inserter(array));
|
|
}
|
|
return ArrayAttr::get(getContext(), array);
|
|
}
|
|
|
|
// Flatten a concatenation of array creates into a single create.
|
|
static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter) {
|
|
for (auto input : op.getInputs())
|
|
if (!input.getDefiningOp<ArrayCreateOp>())
|
|
return false;
|
|
|
|
SmallVector<Value> items;
|
|
for (auto input : op.getInputs()) {
|
|
auto create = cast<ArrayCreateOp>(input.getDefiningOp());
|
|
for (auto item : create.getInputs())
|
|
items.push_back(item);
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, items);
|
|
return true;
|
|
}
|
|
|
|
// Merge consecutive slice expressions in a concatenation.
|
|
static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter) {
|
|
struct Slice {
|
|
Value input;
|
|
Value index;
|
|
size_t size;
|
|
Value op;
|
|
SmallVector<Location> locs;
|
|
};
|
|
|
|
SmallVector<Value> items;
|
|
std::optional<Slice> last;
|
|
bool changed = false;
|
|
|
|
auto concatenate = [&] {
|
|
// If there is only one op in the slice, place it to the items list.
|
|
if (!last)
|
|
return;
|
|
if (last->op) {
|
|
items.push_back(last->op);
|
|
last.reset();
|
|
return;
|
|
}
|
|
|
|
// Otherwise, create a new slice of with the given size and place it.
|
|
// In this case, the concat op is replaced, using the new argument.
|
|
changed = true;
|
|
auto loc = FusedLoc::get(op.getContext(), last->locs);
|
|
auto origTy = hw::type_cast<ArrayType>(last->input.getType());
|
|
auto arrayTy = ArrayType::get(origTy.getElementType(), last->size);
|
|
items.push_back(rewriter.createOrFold<ArraySliceOp>(
|
|
loc, arrayTy, last->input, last->index));
|
|
|
|
last.reset();
|
|
};
|
|
|
|
auto append = [&](Value op, Value input, Value index, size_t size) {
|
|
// If this slice is an extension of the previous one, extend the size
|
|
// saved. In this case, a new slice of is created and the concatenation
|
|
// operator is rewritten. Otherwise, flush the last slice.
|
|
if (last) {
|
|
if (last->input == input && isOffset(last->index, index, last->size)) {
|
|
last->size += size;
|
|
last->op = {};
|
|
last->locs.push_back(op.getLoc());
|
|
return;
|
|
}
|
|
concatenate();
|
|
}
|
|
last.emplace(Slice{input, index, size, op, {op.getLoc()}});
|
|
};
|
|
|
|
for (auto item : llvm::reverse(op.getInputs())) {
|
|
if (auto slice = item.getDefiningOp<ArraySliceOp>()) {
|
|
auto size = hw::type_cast<ArrayType>(slice.getType()).getNumElements();
|
|
append(item, slice.getInput(), slice.getLowIndex(), size);
|
|
continue;
|
|
}
|
|
|
|
if (auto create = item.getDefiningOp<ArrayCreateOp>()) {
|
|
if (create.getInputs().size() == 1) {
|
|
if (auto get = create.getInputs()[0].getDefiningOp<ArrayGetOp>()) {
|
|
append(item, get.getInput(), get.getIndex(), 1);
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
|
|
concatenate();
|
|
items.push_back(item);
|
|
}
|
|
concatenate();
|
|
|
|
if (!changed)
|
|
return false;
|
|
|
|
if (items.size() == 1) {
|
|
rewriter.replaceOp(op, items[0]);
|
|
} else {
|
|
std::reverse(items.begin(), items.end());
|
|
rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, items);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
LogicalResult ArrayConcatOp::canonicalize(ArrayConcatOp op,
|
|
PatternRewriter &rewriter) {
|
|
// concat(create(a1, ...), create(a3, ...), ...) -> create(a1, ..., a3, ...)
|
|
if (flattenConcatOp(op, rewriter))
|
|
return success();
|
|
|
|
// concat(slice(a, n, m), slice(a, n + m, p)) -> concat(slice(a, n, m + p))
|
|
if (mergeConcatSlices(op, rewriter))
|
|
return success();
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// EnumConstantOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult EnumConstantOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
// Parse a Type instead of an EnumType since the type might be a type alias.
|
|
// The validity of the canonical type is checked during construction of the
|
|
// EnumFieldAttr.
|
|
Type type;
|
|
StringRef field;
|
|
|
|
auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
|
|
if (parser.parseKeyword(&field) || parser.parseColonType(type))
|
|
return failure();
|
|
|
|
auto fieldAttr = EnumFieldAttr::get(
|
|
loc, StringAttr::get(parser.getContext(), field), type);
|
|
|
|
if (!fieldAttr)
|
|
return failure();
|
|
|
|
result.addAttribute("field", fieldAttr);
|
|
result.addTypes(type);
|
|
|
|
return success();
|
|
}
|
|
|
|
void EnumConstantOp::print(OpAsmPrinter &p) {
|
|
p << " " << getField().getField().getValue() << " : "
|
|
<< getField().getType().getValue();
|
|
}
|
|
|
|
void EnumConstantOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
setNameFn(getResult(), getField().getField().str());
|
|
}
|
|
|
|
void EnumConstantOp::build(OpBuilder &builder, OperationState &odsState,
|
|
EnumFieldAttr field) {
|
|
return build(builder, odsState, field.getType().getValue(), field);
|
|
}
|
|
|
|
OpFoldResult EnumConstantOp::fold(FoldAdaptor adaptor) {
|
|
assert(adaptor.getOperands().empty() && "constant has no operands");
|
|
return getFieldAttr();
|
|
}
|
|
|
|
LogicalResult EnumConstantOp::verify() {
|
|
auto fieldAttr = getFieldAttr();
|
|
auto fieldType = fieldAttr.getType().getValue();
|
|
// This check ensures that we are using the exact same type, without looking
|
|
// through type aliases.
|
|
if (fieldType != getType())
|
|
emitOpError("return type ")
|
|
<< getType() << " does not match attribute type " << fieldAttr;
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// EnumCmpOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult EnumCmpOp::verify() {
|
|
// Compare the canonical types.
|
|
auto lhsType = type_cast<EnumType>(getLhs().getType());
|
|
auto rhsType = type_cast<EnumType>(getRhs().getType());
|
|
if (rhsType != lhsType)
|
|
emitOpError("types do not match");
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// StructCreateOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult StructCreateOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
|
|
llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
|
|
Type declOrAliasType;
|
|
|
|
if (parser.parseLParen() || parser.parseOperandList(operands) ||
|
|
parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
|
|
parser.parseColonType(declOrAliasType))
|
|
return failure();
|
|
|
|
auto declType = type_dyn_cast<StructType>(declOrAliasType);
|
|
if (!declType)
|
|
return parser.emitError(parser.getNameLoc(),
|
|
"expected !hw.struct type or alias");
|
|
|
|
llvm::SmallVector<Type, 4> structInnerTypes;
|
|
declType.getInnerTypes(structInnerTypes);
|
|
result.addTypes(declOrAliasType);
|
|
|
|
if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc,
|
|
result.operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void StructCreateOp::print(OpAsmPrinter &printer) {
|
|
printer << " (";
|
|
printer.printOperands(getInput());
|
|
printer << ")";
|
|
printer.printOptionalAttrDict((*this)->getAttrs());
|
|
printer << " : " << getType();
|
|
}
|
|
|
|
LogicalResult StructCreateOp::verify() {
|
|
auto elements = hw::type_cast<StructType>(getType()).getElements();
|
|
|
|
if (elements.size() != getInput().size())
|
|
return emitOpError("structure field count mismatch");
|
|
|
|
for (const auto &[field, value] : llvm::zip(elements, getInput()))
|
|
if (field.type != value.getType())
|
|
return emitOpError("structure field `")
|
|
<< field.name << "` type does not match";
|
|
|
|
return success();
|
|
}
|
|
|
|
OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
|
|
// struct_create(struct_explode(x)) => x
|
|
if (!getInput().empty())
|
|
if (auto explodeOp = getInput()[0].getDefiningOp<StructExplodeOp>();
|
|
explodeOp && getInput() == explodeOp.getResults() &&
|
|
getResult().getType() == explodeOp.getInput().getType())
|
|
return explodeOp.getInput();
|
|
|
|
auto inputs = adaptor.getInput();
|
|
if (llvm::any_of(inputs, [](Attribute attr) { return !attr; }))
|
|
return {};
|
|
return ArrayAttr::get(getContext(), inputs);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// StructExplodeOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult StructExplodeOp::parse(OpAsmParser &parser,
|
|
OperationState &result) {
|
|
OpAsmParser::UnresolvedOperand operand;
|
|
Type declType;
|
|
|
|
if (parser.parseOperand(operand) ||
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
parser.parseColonType(declType))
|
|
return failure();
|
|
auto structType = type_dyn_cast<StructType>(declType);
|
|
if (!structType)
|
|
return parser.emitError(parser.getNameLoc(),
|
|
"invalid kind of type specified");
|
|
|
|
llvm::SmallVector<Type, 4> structInnerTypes;
|
|
structType.getInnerTypes(structInnerTypes);
|
|
result.addTypes(structInnerTypes);
|
|
|
|
if (parser.resolveOperand(operand, declType, result.operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void StructExplodeOp::print(OpAsmPrinter &printer) {
|
|
printer << " ";
|
|
printer.printOperand(getInput());
|
|
printer.printOptionalAttrDict((*this)->getAttrs());
|
|
printer << " : " << getInput().getType();
|
|
}
|
|
|
|
LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor,
|
|
SmallVectorImpl<OpFoldResult> &results) {
|
|
auto input = adaptor.getInput();
|
|
if (!input)
|
|
return failure();
|
|
llvm::copy(input.cast<ArrayAttr>(), std::back_inserter(results));
|
|
return success();
|
|
}
|
|
|
|
LogicalResult StructExplodeOp::canonicalize(StructExplodeOp op,
|
|
PatternRewriter &rewriter) {
|
|
auto *inputOp = op.getInput().getDefiningOp();
|
|
auto elements = type_cast<StructType>(op.getInput().getType()).getElements();
|
|
auto result = failure();
|
|
auto opResults = op.getResults();
|
|
for (uint32_t index = 0; index < elements.size(); index++) {
|
|
if (auto foldResult = foldStructExtract(inputOp, index)) {
|
|
rewriter.replaceAllUsesWith(opResults[index], foldResult);
|
|
result = success();
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void StructExplodeOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
auto structType = type_cast<StructType>(getInput().getType());
|
|
for (auto [res, field] : llvm::zip(getResults(), structType.getElements()))
|
|
setNameFn(res, field.name.str());
|
|
}
|
|
|
|
void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
|
|
Value input) {
|
|
StructType inputType = input.getType().dyn_cast<StructType>();
|
|
assert(inputType);
|
|
SmallVector<Type, 16> fieldTypes;
|
|
for (auto field : inputType.getElements())
|
|
fieldTypes.push_back(field.type);
|
|
build(odsBuilder, odsState, fieldTypes, input);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// StructExtractOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Ensure an aggregate op's field index is within the bounds of
|
|
/// the aggregate type and the accessed field is of 'elementType'.
|
|
template <typename AggregateOp, typename AggregateType>
|
|
static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op,
|
|
AggregateType aggType,
|
|
Type elementType) {
|
|
auto index = op.getFieldIndex();
|
|
if (index >= aggType.getElements().size())
|
|
return op.emitOpError() << "field index " << index
|
|
<< " exceeds element count of aggregate type";
|
|
|
|
if (getCanonicalType(elementType) !=
|
|
getCanonicalType(aggType.getElements()[index].type))
|
|
return op.emitOpError()
|
|
<< "type " << aggType.getElements()[index].type
|
|
<< " of accessed field in aggregate at index " << index
|
|
<< " does not match expected type " << elementType;
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult StructExtractOp::verify() {
|
|
return verifyAggregateFieldIndexAndType<StructExtractOp, StructType>(
|
|
*this, getInput().getType(), getType());
|
|
}
|
|
|
|
/// Use the same parser for both struct_extract and union_extract since the
|
|
/// syntax is identical.
|
|
template <typename AggregateType>
|
|
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
|
|
OpAsmParser::UnresolvedOperand operand;
|
|
StringAttr fieldName;
|
|
Type declType;
|
|
|
|
if (parser.parseOperand(operand) || parser.parseLSquare() ||
|
|
parser.parseAttribute(fieldName) || parser.parseRSquare() ||
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
parser.parseColonType(declType))
|
|
return failure();
|
|
auto aggType = type_dyn_cast<AggregateType>(declType);
|
|
if (!aggType)
|
|
return parser.emitError(parser.getNameLoc(),
|
|
"invalid kind of type specified");
|
|
|
|
auto fieldIndex = aggType.getFieldIndex(fieldName);
|
|
if (!fieldIndex) {
|
|
parser.emitError(parser.getNameLoc(), "field name '" +
|
|
fieldName.getValue() +
|
|
"' not found in aggregate type");
|
|
return failure();
|
|
}
|
|
|
|
auto indexAttr =
|
|
IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
|
|
result.addAttribute("fieldIndex", indexAttr);
|
|
Type resultType = aggType.getElements()[*fieldIndex].type;
|
|
result.addTypes(resultType);
|
|
|
|
if (parser.resolveOperand(operand, declType, result.operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
/// Use the same printer for both struct_extract and union_extract since the
|
|
/// syntax is identical.
|
|
template <typename AggType>
|
|
static void printExtractOp(OpAsmPrinter &printer, AggType op) {
|
|
printer << " ";
|
|
printer.printOperand(op.getInput());
|
|
printer << "[\"" << op.getFieldName() << "\"]";
|
|
printer.printOptionalAttrDict(op->getAttrs(), {"fieldIndex"});
|
|
printer << " : " << op.getInput().getType();
|
|
}
|
|
|
|
ParseResult StructExtractOp::parse(OpAsmParser &parser,
|
|
OperationState &result) {
|
|
return parseExtractOp<StructType>(parser, result);
|
|
}
|
|
|
|
void StructExtractOp::print(OpAsmPrinter &printer) {
|
|
printExtractOp(printer, *this);
|
|
}
|
|
|
|
void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
|
|
Value input, StructType::FieldInfo field) {
|
|
auto fieldIndex =
|
|
type_cast<StructType>(input.getType()).getFieldIndex(field.name);
|
|
assert(fieldIndex.has_value() && "field name not found in aggregate type");
|
|
build(builder, odsState, field.type, input, *fieldIndex);
|
|
}
|
|
|
|
void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
|
|
Value input, StringAttr fieldName) {
|
|
auto structType = type_cast<StructType>(input.getType());
|
|
auto fieldIndex = structType.getFieldIndex(fieldName);
|
|
assert(fieldIndex.has_value() && "field name not found in aggregate type");
|
|
auto resultType = structType.getElements()[*fieldIndex].type;
|
|
build(builder, odsState, resultType, input, *fieldIndex);
|
|
}
|
|
|
|
OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
|
|
if (auto constOperand = adaptor.getInput()) {
|
|
// Fold extract from aggregate constant
|
|
auto operandAttr = llvm::cast<ArrayAttr>(constOperand);
|
|
return operandAttr.getValue()[getFieldIndex()];
|
|
}
|
|
|
|
if (auto foldResult =
|
|
foldStructExtract(getInput().getDefiningOp(), getFieldIndex()))
|
|
return foldResult;
|
|
return {};
|
|
}
|
|
|
|
LogicalResult StructExtractOp::canonicalize(StructExtractOp op,
|
|
PatternRewriter &rewriter) {
|
|
auto inputOp = op.getInput().getDefiningOp();
|
|
|
|
// b = extract(inject(x["a"], v0)["b"]) => extract(x, "b")
|
|
if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
|
|
if (structInject.getFieldIndex() != op.getFieldIndex()) {
|
|
rewriter.replaceOpWithNewOp<StructExtractOp>(
|
|
op, op.getType(), structInject.getInput(), op.getFieldIndexAttr());
|
|
return success();
|
|
}
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
void StructExtractOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
setNameFn(getResult(), getFieldName());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// StructInjectOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void StructInjectOp::build(OpBuilder &builder, OperationState &odsState,
|
|
Value input, StringAttr fieldName, Value newValue) {
|
|
auto structType = type_cast<StructType>(input.getType());
|
|
auto fieldIndex = structType.getFieldIndex(fieldName);
|
|
assert(fieldIndex.has_value() && "field name not found in aggregate type");
|
|
build(builder, odsState, input, *fieldIndex, newValue);
|
|
}
|
|
|
|
LogicalResult StructInjectOp::verify() {
|
|
return verifyAggregateFieldIndexAndType<StructInjectOp, StructType>(
|
|
*this, getInput().getType(), getNewValue().getType());
|
|
}
|
|
|
|
ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
|
|
OpAsmParser::UnresolvedOperand operand, val;
|
|
StringAttr fieldName;
|
|
Type declType;
|
|
|
|
if (parser.parseOperand(operand) || parser.parseLSquare() ||
|
|
parser.parseAttribute(fieldName) || parser.parseRSquare() ||
|
|
parser.parseComma() || parser.parseOperand(val) ||
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
parser.parseColonType(declType))
|
|
return failure();
|
|
auto structType = type_dyn_cast<StructType>(declType);
|
|
if (!structType)
|
|
return parser.emitError(inputOperandsLoc, "invalid kind of type specified");
|
|
|
|
auto fieldIndex = structType.getFieldIndex(fieldName);
|
|
if (!fieldIndex) {
|
|
parser.emitError(parser.getNameLoc(), "field name '" +
|
|
fieldName.getValue() +
|
|
"' not found in aggregate type");
|
|
return failure();
|
|
}
|
|
|
|
auto indexAttr =
|
|
IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
|
|
result.addAttribute("fieldIndex", indexAttr);
|
|
result.addTypes(declType);
|
|
|
|
Type resultType = structType.getElements()[*fieldIndex].type;
|
|
if (parser.resolveOperands({operand, val}, {declType, resultType},
|
|
inputOperandsLoc, result.operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void StructInjectOp::print(OpAsmPrinter &printer) {
|
|
printer << " ";
|
|
printer.printOperand(getInput());
|
|
printer << "[\"" << getFieldName() << "\"], ";
|
|
printer.printOperand(getNewValue());
|
|
printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
|
|
printer << " : " << getInput().getType();
|
|
}
|
|
|
|
OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
|
|
auto input = adaptor.getInput();
|
|
auto newValue = adaptor.getNewValue();
|
|
if (!input || !newValue)
|
|
return {};
|
|
SmallVector<Attribute> array;
|
|
llvm::copy(input.cast<ArrayAttr>(), std::back_inserter(array));
|
|
array[getFieldIndex()] = newValue;
|
|
return ArrayAttr::get(getContext(), array);
|
|
}
|
|
|
|
LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
|
|
PatternRewriter &rewriter) {
|
|
// Canonicalize multiple injects into a create op and eliminate overwrites.
|
|
SmallPtrSet<Operation *, 4> injects;
|
|
DenseMap<StringAttr, Value> fields;
|
|
|
|
// Chase a chain of injects. Bail out if cycles are present.
|
|
StructInjectOp inject = op;
|
|
Value input;
|
|
do {
|
|
if (!injects.insert(inject).second)
|
|
return failure();
|
|
|
|
fields.try_emplace(inject.getFieldNameAttr(), inject.getNewValue());
|
|
input = inject.getInput();
|
|
inject = dyn_cast_or_null<StructInjectOp>(input.getDefiningOp());
|
|
} while (inject);
|
|
assert(input && "missing input to inject chain");
|
|
|
|
auto ty = hw::type_cast<StructType>(op.getType());
|
|
auto elements = ty.getElements();
|
|
|
|
// If the inject chain sets all fields, canonicalize to create.
|
|
if (fields.size() == elements.size()) {
|
|
SmallVector<Value> createFields;
|
|
for (const auto &field : elements) {
|
|
auto it = fields.find(field.name);
|
|
assert(it != fields.end() && "missing field");
|
|
createFields.push_back(it->second);
|
|
}
|
|
rewriter.replaceOpWithNewOp<StructCreateOp>(op, ty, createFields);
|
|
return success();
|
|
}
|
|
|
|
// Nothing to canonicalize, only the original inject in the chain.
|
|
if (injects.size() == fields.size())
|
|
return failure();
|
|
|
|
// Eliminate overwrites. The hash map contains the last write to each field.
|
|
for (uint32_t fieldIndex = 0; fieldIndex < elements.size(); fieldIndex++) {
|
|
auto it = fields.find(elements[fieldIndex].name);
|
|
if (it == fields.end())
|
|
continue;
|
|
input = rewriter.create<StructInjectOp>(op.getLoc(), ty, input, fieldIndex,
|
|
it->second);
|
|
}
|
|
|
|
rewriter.replaceOp(op, input);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UnionCreateOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult UnionCreateOp::verify() {
|
|
return verifyAggregateFieldIndexAndType<UnionCreateOp, UnionType>(
|
|
*this, getType(), getInput().getType());
|
|
}
|
|
|
|
void UnionCreateOp::build(OpBuilder &builder, OperationState &odsState,
|
|
Type unionType, StringAttr fieldName, Value input) {
|
|
auto fieldIndex = type_cast<UnionType>(unionType).getFieldIndex(fieldName);
|
|
assert(fieldIndex.has_value() && "field name not found in aggregate type");
|
|
build(builder, odsState, unionType, *fieldIndex, input);
|
|
}
|
|
|
|
ParseResult UnionCreateOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
Type declOrAliasType;
|
|
StringAttr fieldName;
|
|
OpAsmParser::UnresolvedOperand input;
|
|
llvm::SMLoc fieldLoc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseAttribute(fieldName) || parser.parseComma() ||
|
|
parser.parseOperand(input) ||
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
parser.parseColonType(declOrAliasType))
|
|
return failure();
|
|
|
|
auto declType = type_dyn_cast<UnionType>(declOrAliasType);
|
|
if (!declType)
|
|
return parser.emitError(parser.getNameLoc(),
|
|
"expected !hw.union type or alias");
|
|
|
|
auto fieldIndex = declType.getFieldIndex(fieldName);
|
|
if (!fieldIndex) {
|
|
parser.emitError(fieldLoc, "cannot find union field '")
|
|
<< fieldName.getValue() << '\'';
|
|
return failure();
|
|
}
|
|
|
|
auto indexAttr =
|
|
IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
|
|
result.addAttribute("fieldIndex", indexAttr);
|
|
Type inputType = declType.getElements()[*fieldIndex].type;
|
|
|
|
if (parser.resolveOperand(input, inputType, result.operands))
|
|
return failure();
|
|
result.addTypes({declOrAliasType});
|
|
return success();
|
|
}
|
|
|
|
void UnionCreateOp::print(OpAsmPrinter &printer) {
|
|
printer << " \"" << getFieldName() << "\", ";
|
|
printer.printOperand(getInput());
|
|
printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
|
|
printer << " : " << getType();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UnionExtractOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult UnionExtractOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
return parseExtractOp<UnionType>(parser, result);
|
|
}
|
|
|
|
void UnionExtractOp::print(OpAsmPrinter &printer) {
|
|
printExtractOp(printer, *this);
|
|
}
|
|
|
|
LogicalResult UnionExtractOp::inferReturnTypes(
|
|
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
|
|
DictionaryAttr attrs, mlir::OpaqueProperties properties,
|
|
mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
|
|
auto unionElements =
|
|
hw::type_cast<UnionType>((operands[0].getType())).getElements();
|
|
unsigned fieldIndex =
|
|
attrs.getAs<IntegerAttr>("fieldIndex").getValue().getZExtValue();
|
|
if (fieldIndex >= unionElements.size()) {
|
|
if (loc)
|
|
mlir::emitError(*loc, "field index " + Twine(fieldIndex) +
|
|
" exceeds element count of aggregate type");
|
|
return failure();
|
|
}
|
|
results.push_back(unionElements[fieldIndex].type);
|
|
return success();
|
|
}
|
|
|
|
void UnionExtractOp::build(OpBuilder &odsBuilder, OperationState &odsState,
|
|
Value input, StringAttr fieldName) {
|
|
auto unionType = type_cast<UnionType>(input.getType());
|
|
auto fieldIndex = unionType.getFieldIndex(fieldName);
|
|
assert(fieldIndex.has_value() && "field name not found in aggregate type");
|
|
auto resultType = unionType.getElements()[*fieldIndex].type;
|
|
build(odsBuilder, odsState, resultType, input, *fieldIndex);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ArrayGetOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// An array_get of an array_create with a constant index can just be the
|
|
// array_create operand at the constant index. If the array_create has a
|
|
// single uniform value for each element, just return that value regardless of
|
|
// the index. If the array is constructed from a constant by a bitcast
|
|
// operation, we can fold into a constant.
|
|
OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) {
|
|
auto inputCst = adaptor.getInput().dyn_cast_or_null<ArrayAttr>();
|
|
auto indexCst = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
|
|
|
|
if (inputCst) {
|
|
// Constant array index.
|
|
if (indexCst) {
|
|
auto indexVal = indexCst.getValue();
|
|
if (indexVal.getBitWidth() < 64) {
|
|
auto index = indexVal.getZExtValue();
|
|
return inputCst[inputCst.size() - 1 - index];
|
|
}
|
|
}
|
|
// If all elements of the array are the same, we can return any element of
|
|
// array.
|
|
if (!inputCst.empty() && llvm::all_equal(inputCst))
|
|
return inputCst[0];
|
|
}
|
|
|
|
// array_get(bitcast(c), i) -> c[i*w+w-1:i*w]
|
|
if (auto bitcast = getInput().getDefiningOp<hw::BitcastOp>()) {
|
|
auto intTy = getType().dyn_cast<IntegerType>();
|
|
if (!intTy)
|
|
return {};
|
|
auto bitcastInputOp = bitcast.getInput().getDefiningOp<hw::ConstantOp>();
|
|
if (!bitcastInputOp)
|
|
return {};
|
|
if (!indexCst)
|
|
return {};
|
|
auto bitcastInputCst = bitcastInputOp.getValue();
|
|
// Calculate the index. Make sure to zero-extend the index value before
|
|
// multiplying the element width.
|
|
auto startIdx = indexCst.getValue().zext(bitcastInputCst.getBitWidth()) *
|
|
getType().getIntOrFloatBitWidth();
|
|
// Extract [startIdx + width - 1: startIdx].
|
|
return IntegerAttr::get(intTy, bitcastInputCst.lshr(startIdx).trunc(
|
|
intTy.getIntOrFloatBitWidth()));
|
|
}
|
|
|
|
auto inputCreate = getInput().getDefiningOp<ArrayCreateOp>();
|
|
if (!inputCreate)
|
|
return {};
|
|
|
|
if (auto uniformValue = inputCreate.getUniformElement())
|
|
return uniformValue;
|
|
|
|
if (!indexCst || indexCst.getValue().getBitWidth() > 64)
|
|
return {};
|
|
|
|
uint64_t index = indexCst.getValue().getLimitedValue();
|
|
auto createInputs = inputCreate.getInputs();
|
|
if (index >= createInputs.size())
|
|
return {};
|
|
return createInputs[createInputs.size() - index - 1];
|
|
}
|
|
|
|
LogicalResult ArrayGetOp::canonicalize(ArrayGetOp op,
|
|
PatternRewriter &rewriter) {
|
|
auto idxOpt = getUIntFromValue(op.getIndex());
|
|
if (!idxOpt)
|
|
return failure();
|
|
|
|
auto *inputOp = op.getInput().getDefiningOp();
|
|
if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
|
|
// get(slice(a, n), m) -> get(a, n + m)
|
|
auto offsetOp = inputSlice.getLowIndex();
|
|
auto offsetOpt = getUIntFromValue(offsetOp);
|
|
if (!offsetOpt)
|
|
return failure();
|
|
|
|
uint64_t offset = *offsetOpt + *idxOpt;
|
|
auto newOffset =
|
|
rewriter.create<ConstantOp>(op.getLoc(), offsetOp.getType(), offset);
|
|
rewriter.replaceOpWithNewOp<ArrayGetOp>(op, inputSlice.getInput(),
|
|
newOffset);
|
|
return success();
|
|
}
|
|
|
|
if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
|
|
// get(concat(a0, a1, ...), m) -> get(an, m - s0 - s1 - ...)
|
|
uint64_t elemIndex = *idxOpt;
|
|
for (auto input : llvm::reverse(inputConcat.getInputs())) {
|
|
size_t size = hw::type_cast<ArrayType>(input.getType()).getNumElements();
|
|
if (elemIndex >= size) {
|
|
elemIndex -= size;
|
|
continue;
|
|
}
|
|
|
|
unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size);
|
|
auto newIdxOp = rewriter.create<ConstantOp>(
|
|
op.getLoc(), rewriter.getIntegerType(indexWidth), elemIndex);
|
|
|
|
rewriter.replaceOpWithNewOp<ArrayGetOp>(op, input, newIdxOp);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
// array_get const, (array_get sel, (array_create a, b, c, d)) -->
|
|
// array_get sel, (array_create (array_get const a), (array_get const b),
|
|
// (array_get const, c), (array_get const, d))
|
|
if (auto innerGet = dyn_cast_or_null<hw::ArrayGetOp>(inputOp)) {
|
|
if (!innerGet.getIndex().getDefiningOp<hw::ConstantOp>()) {
|
|
if (auto create =
|
|
innerGet.getInput().getDefiningOp<hw::ArrayCreateOp>()) {
|
|
|
|
SmallVector<Value> newValues;
|
|
for (auto operand : create.getOperands())
|
|
newValues.push_back(rewriter.createOrFold<hw::ArrayGetOp>(
|
|
op.getLoc(), operand, op.getIndex()));
|
|
|
|
rewriter.replaceOpWithNewOp<hw::ArrayGetOp>(
|
|
op,
|
|
rewriter.createOrFold<hw::ArrayCreateOp>(op.getLoc(), newValues),
|
|
innerGet.getIndex());
|
|
return success();
|
|
}
|
|
}
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TypedeclOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
StringRef TypedeclOp::getPreferredName() {
|
|
return getVerilogName().value_or(getName());
|
|
}
|
|
|
|
Type TypedeclOp::getAliasType() {
|
|
auto parentScope = cast<hw::TypeScopeOp>(getOperation()->getParentOp());
|
|
return hw::TypeAliasType::get(
|
|
SymbolRefAttr::get(parentScope.getSymNameAttr(),
|
|
{FlatSymbolRefAttr::get(*this)}),
|
|
getType());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BitcastOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult BitcastOp::fold(FoldAdaptor) {
|
|
// Identity.
|
|
// bitcast(%a) : A -> A ==> %a
|
|
if (getOperand().getType() == getType())
|
|
return getOperand();
|
|
|
|
return {};
|
|
}
|
|
|
|
LogicalResult BitcastOp::canonicalize(BitcastOp op, PatternRewriter &rewriter) {
|
|
// Composition.
|
|
// %b = bitcast(%a) : A -> B
|
|
// bitcast(%b) : B -> C
|
|
// ===> bitcast(%a) : A -> C
|
|
auto inputBitcast =
|
|
dyn_cast_or_null<BitcastOp>(op.getInput().getDefiningOp());
|
|
if (!inputBitcast)
|
|
return failure();
|
|
auto bitcast = rewriter.createOrFold<BitcastOp>(op.getLoc(), op.getType(),
|
|
inputBitcast.getInput());
|
|
rewriter.replaceOp(op, bitcast);
|
|
return success();
|
|
}
|
|
|
|
LogicalResult BitcastOp::verify() {
|
|
if (getBitWidth(getInput().getType()) != getBitWidth(getResult().getType()))
|
|
return this->emitOpError("Bitwidth of input must match result");
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// HierPathOp helpers.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool HierPathOp::dropModule(StringAttr moduleToDrop) {
|
|
SmallVector<Attribute, 4> newPath;
|
|
bool updateMade = false;
|
|
for (auto nameRef : getNamepath()) {
|
|
// nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
|
|
if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
|
|
if (ref.getModule() == moduleToDrop)
|
|
updateMade = true;
|
|
else
|
|
newPath.push_back(ref);
|
|
} else {
|
|
if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == moduleToDrop)
|
|
updateMade = true;
|
|
else
|
|
newPath.push_back(nameRef);
|
|
}
|
|
}
|
|
if (updateMade)
|
|
setNamepathAttr(ArrayAttr::get(getContext(), newPath));
|
|
return updateMade;
|
|
}
|
|
|
|
bool HierPathOp::inlineModule(StringAttr moduleToDrop) {
|
|
SmallVector<Attribute, 4> newPath;
|
|
bool updateMade = false;
|
|
StringRef inlinedInstanceName = "";
|
|
for (auto nameRef : getNamepath()) {
|
|
// nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
|
|
if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
|
|
if (ref.getModule() == moduleToDrop) {
|
|
inlinedInstanceName = ref.getName().getValue();
|
|
updateMade = true;
|
|
} else if (!inlinedInstanceName.empty()) {
|
|
newPath.push_back(hw::InnerRefAttr::get(
|
|
ref.getModule(),
|
|
StringAttr::get(getContext(), inlinedInstanceName + "_" +
|
|
ref.getName().getValue())));
|
|
inlinedInstanceName = "";
|
|
} else
|
|
newPath.push_back(ref);
|
|
} else {
|
|
if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == moduleToDrop)
|
|
updateMade = true;
|
|
else
|
|
newPath.push_back(nameRef);
|
|
}
|
|
}
|
|
if (updateMade)
|
|
setNamepathAttr(ArrayAttr::get(getContext(), newPath));
|
|
return updateMade;
|
|
}
|
|
|
|
bool HierPathOp::updateModule(StringAttr oldMod, StringAttr newMod) {
|
|
SmallVector<Attribute, 4> newPath;
|
|
bool updateMade = false;
|
|
for (auto nameRef : getNamepath()) {
|
|
// nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
|
|
if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
|
|
if (ref.getModule() == oldMod) {
|
|
newPath.push_back(hw::InnerRefAttr::get(newMod, ref.getName()));
|
|
updateMade = true;
|
|
} else
|
|
newPath.push_back(ref);
|
|
} else {
|
|
if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == oldMod) {
|
|
newPath.push_back(FlatSymbolRefAttr::get(newMod));
|
|
updateMade = true;
|
|
} else
|
|
newPath.push_back(nameRef);
|
|
}
|
|
}
|
|
if (updateMade)
|
|
setNamepathAttr(ArrayAttr::get(getContext(), newPath));
|
|
return updateMade;
|
|
}
|
|
|
|
bool HierPathOp::updateModuleAndInnerRef(
|
|
StringAttr oldMod, StringAttr newMod,
|
|
const llvm::DenseMap<StringAttr, StringAttr> &innerSymRenameMap) {
|
|
auto fromRef = FlatSymbolRefAttr::get(oldMod);
|
|
if (oldMod == newMod)
|
|
return false;
|
|
|
|
auto namepathNew = getNamepath().getValue().vec();
|
|
bool updateMade = false;
|
|
// Break from the loop if the module is found, since it can occur only once.
|
|
for (auto &element : namepathNew) {
|
|
if (auto innerRef = element.dyn_cast<hw::InnerRefAttr>()) {
|
|
if (innerRef.getModule() != oldMod)
|
|
continue;
|
|
auto symName = innerRef.getName();
|
|
// Since the module got updated, the old innerRef symbol inside oldMod
|
|
// should also be updated to the new symbol inside the newMod.
|
|
auto to = innerSymRenameMap.find(symName);
|
|
if (to != innerSymRenameMap.end())
|
|
symName = to->second;
|
|
updateMade = true;
|
|
element = hw::InnerRefAttr::get(newMod, symName);
|
|
break;
|
|
}
|
|
if (element != fromRef)
|
|
continue;
|
|
|
|
updateMade = true;
|
|
element = FlatSymbolRefAttr::get(newMod);
|
|
break;
|
|
}
|
|
if (updateMade)
|
|
setNamepathAttr(ArrayAttr::get(getContext(), namepathNew));
|
|
return updateMade;
|
|
}
|
|
|
|
bool HierPathOp::truncateAtModule(StringAttr atMod, bool includeMod) {
|
|
SmallVector<Attribute, 4> newPath;
|
|
bool updateMade = false;
|
|
for (auto nameRef : getNamepath()) {
|
|
// nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
|
|
if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
|
|
if (ref.getModule() == atMod) {
|
|
updateMade = true;
|
|
if (includeMod)
|
|
newPath.push_back(ref);
|
|
} else
|
|
newPath.push_back(ref);
|
|
} else {
|
|
if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == atMod && !includeMod)
|
|
updateMade = true;
|
|
else
|
|
newPath.push_back(nameRef);
|
|
}
|
|
if (updateMade)
|
|
break;
|
|
}
|
|
if (updateMade)
|
|
setNamepathAttr(ArrayAttr::get(getContext(), newPath));
|
|
return updateMade;
|
|
}
|
|
|
|
/// Return just the module part of the namepath at a specific index.
|
|
StringAttr HierPathOp::modPart(unsigned i) {
|
|
return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
|
|
.Case<FlatSymbolRefAttr>([](auto a) { return a.getAttr(); })
|
|
.Case<hw::InnerRefAttr>([](auto a) { return a.getModule(); });
|
|
}
|
|
|
|
/// Return the root module.
|
|
StringAttr HierPathOp::root() {
|
|
assert(!getNamepath().empty());
|
|
return modPart(0);
|
|
}
|
|
|
|
/// Return true if the NLA has the module in its path.
|
|
bool HierPathOp::hasModule(StringAttr modName) {
|
|
for (auto nameRef : getNamepath()) {
|
|
// nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
|
|
if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
|
|
if (ref.getModule() == modName)
|
|
return true;
|
|
} else {
|
|
if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == modName)
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Return true if the NLA has the InnerSym .
|
|
bool HierPathOp::hasInnerSym(StringAttr modName, StringAttr symName) const {
|
|
for (auto nameRef : const_cast<HierPathOp *>(this)->getNamepath())
|
|
if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>())
|
|
if (ref.getName() == symName && ref.getModule() == modName)
|
|
return true;
|
|
|
|
return false;
|
|
}
|
|
|
|
/// Return just the reference part of the namepath at a specific index. This
|
|
/// will return an empty attribute if this is the leaf and the leaf is a module.
|
|
StringAttr HierPathOp::refPart(unsigned i) {
|
|
return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
|
|
.Case<FlatSymbolRefAttr>([](auto a) { return StringAttr({}); })
|
|
.Case<hw::InnerRefAttr>([](auto a) { return a.getName(); });
|
|
}
|
|
|
|
/// Return the leaf reference. This returns an empty attribute if the leaf
|
|
/// reference is a module.
|
|
StringAttr HierPathOp::ref() {
|
|
assert(!getNamepath().empty());
|
|
return refPart(getNamepath().size() - 1);
|
|
}
|
|
|
|
/// Return the leaf module.
|
|
StringAttr HierPathOp::leafMod() {
|
|
assert(!getNamepath().empty());
|
|
return modPart(getNamepath().size() - 1);
|
|
}
|
|
|
|
/// Returns true if this NLA targets an instance of a module (as opposed to
|
|
/// an instance's port or something inside an instance).
|
|
bool HierPathOp::isModule() { return !ref(); }
|
|
|
|
/// Returns true if this NLA targets something inside a module (as opposed
|
|
/// to a module or an instance of a module);
|
|
bool HierPathOp::isComponent() { return (bool)ref(); }
|
|
|
|
// Verify the HierPathOp.
|
|
// 1. Iterate over the namepath.
|
|
// 2. The namepath should be a valid instance path, specified either on a
|
|
// module or a declaration inside a module.
|
|
// 3. Each element in the namepath is an InnerRefAttr except possibly the
|
|
// last element.
|
|
// 4. Make sure that the InnerRefAttr is legal, by verifying the module name
|
|
// and the corresponding inner_sym on the instance.
|
|
// 5. Make sure that the instance path is legal, by verifying the sequence of
|
|
// instance and the expected module occurs as the next element in the path.
|
|
// 6. The last element of the namepath, can be an InnerRefAttr on either a
|
|
// module port or a declaration inside the module.
|
|
// 7. The last element of the namepath can also be a module symbol.
|
|
LogicalResult HierPathOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
|
|
ArrayAttr expectedModuleNames = {};
|
|
auto checkExpectedModule = [&](Attribute name) -> LogicalResult {
|
|
if (!expectedModuleNames)
|
|
return success();
|
|
if (llvm::any_of(expectedModuleNames,
|
|
[name](Attribute attr) { return attr == name; }))
|
|
return success();
|
|
auto diag = emitOpError() << "instance path is incorrect. Expected ";
|
|
size_t n = expectedModuleNames.size();
|
|
if (n != 1) {
|
|
diag << "one of ";
|
|
}
|
|
for (size_t i = 0; i < n; ++i) {
|
|
if (i != 0)
|
|
diag << ((i + 1 == n) ? " or " : ", ");
|
|
diag << expectedModuleNames[i].cast<StringAttr>();
|
|
}
|
|
diag << ". Instead found: " << name;
|
|
return diag;
|
|
};
|
|
|
|
if (!getNamepath() || getNamepath().empty())
|
|
return emitOpError() << "the instance path cannot be empty";
|
|
for (unsigned i = 0, s = getNamepath().size() - 1; i < s; ++i) {
|
|
hw::InnerRefAttr innerRef = getNamepath()[i].dyn_cast<hw::InnerRefAttr>();
|
|
if (!innerRef)
|
|
return emitOpError()
|
|
<< "the instance path can only contain inner sym reference"
|
|
<< ", only the leaf can refer to a module symbol";
|
|
|
|
if (failed(checkExpectedModule(innerRef.getModule())))
|
|
return failure();
|
|
|
|
auto instOp = ns.lookupOp<igraph::InstanceOpInterface>(innerRef);
|
|
if (!instOp)
|
|
return emitOpError() << " module: " << innerRef.getModule()
|
|
<< " does not contain any instance with symbol: "
|
|
<< innerRef.getName();
|
|
expectedModuleNames = instOp.getReferencedModuleNamesAttr();
|
|
}
|
|
|
|
// The instance path has been verified. Now verify the last element.
|
|
auto leafRef = getNamepath()[getNamepath().size() - 1];
|
|
if (auto innerRef = leafRef.dyn_cast<hw::InnerRefAttr>()) {
|
|
if (!ns.lookup(innerRef)) {
|
|
return emitOpError() << " operation with symbol: " << innerRef
|
|
<< " was not found ";
|
|
}
|
|
if (failed(checkExpectedModule(innerRef.getModule())))
|
|
return failure();
|
|
} else if (failed(checkExpectedModule(
|
|
leafRef.cast<FlatSymbolRefAttr>().getAttr()))) {
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void HierPathOp::print(OpAsmPrinter &p) {
|
|
p << " ";
|
|
|
|
// Print visibility if present.
|
|
StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
|
|
if (auto visibility =
|
|
getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
|
|
p << visibility.getValue() << ' ';
|
|
|
|
p.printSymbolName(getSymName());
|
|
p << " [";
|
|
llvm::interleaveComma(getNamepath().getValue(), p, [&](Attribute attr) {
|
|
if (auto ref = attr.dyn_cast<hw::InnerRefAttr>()) {
|
|
p.printSymbolName(ref.getModule().getValue());
|
|
p << "::";
|
|
p.printSymbolName(ref.getName().getValue());
|
|
} else {
|
|
p.printSymbolName(attr.cast<FlatSymbolRefAttr>().getValue());
|
|
}
|
|
});
|
|
p << "]";
|
|
p.printOptionalAttrDict(
|
|
(*this)->getAttrs(),
|
|
{SymbolTable::getSymbolAttrName(), "namepath", visibilityAttrName});
|
|
}
|
|
|
|
ParseResult HierPathOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
// Parse the visibility attribute.
|
|
(void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
|
|
|
|
// Parse the symbol name.
|
|
StringAttr symName;
|
|
if (parser.parseSymbolName(symName, SymbolTable::getSymbolAttrName(),
|
|
result.attributes))
|
|
return failure();
|
|
|
|
// Parse the namepath.
|
|
SmallVector<Attribute> namepath;
|
|
if (parser.parseCommaSeparatedList(
|
|
OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
|
|
auto loc = parser.getCurrentLocation();
|
|
SymbolRefAttr ref;
|
|
if (parser.parseAttribute(ref))
|
|
return failure();
|
|
|
|
// "A" is a Ref, "A::b" is a InnerRef, "A::B::c" is an error.
|
|
auto pathLength = ref.getNestedReferences().size();
|
|
if (pathLength == 0)
|
|
namepath.push_back(
|
|
FlatSymbolRefAttr::get(ref.getRootReference()));
|
|
else if (pathLength == 1)
|
|
namepath.push_back(hw::InnerRefAttr::get(ref.getRootReference(),
|
|
ref.getLeafReference()));
|
|
else
|
|
return parser.emitError(loc,
|
|
"only one nested reference is allowed");
|
|
return success();
|
|
}))
|
|
return failure();
|
|
result.addAttribute("namepath",
|
|
ArrayAttr::get(parser.getContext(), namepath));
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TriggeredOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void TriggeredOp::build(OpBuilder &builder, OperationState &odsState,
|
|
EventControlAttr event, Value trigger,
|
|
ValueRange inputs) {
|
|
odsState.addOperands(trigger);
|
|
odsState.addOperands(inputs);
|
|
odsState.addAttribute(getEventAttrName(odsState.name), event);
|
|
auto *r = odsState.addRegion();
|
|
Block *b = new Block();
|
|
r->push_back(b);
|
|
|
|
llvm::SmallVector<Location> argLocs;
|
|
llvm::transform(inputs, std::back_inserter(argLocs),
|
|
[&](Value v) { return v.getLoc(); });
|
|
b->addArguments(inputs.getTypes(), argLocs);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen generated logic.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Provide the autogenerated implementation guts for the Op classes.
|
|
#define GET_OP_CLASSES
|
|
#include "circt/Dialect/HW/HW.cpp.inc"
|