mirror of https://github.com/llvm/circt.git
349 lines
12 KiB
C++
349 lines
12 KiB
C++
//===- CombOps.cpp - Implement the Comb operations ------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements combinational ops.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "circt/Dialect/Comb/CombOps.h"
|
|
#include "circt/Dialect/HW/HWOps.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
|
|
using namespace circt;
|
|
using namespace comb;
|
|
|
|
/// Create a sign extension operation from a value of integer type to an equal
|
|
/// or larger integer type.
|
|
Value comb::createOrFoldSExt(Location loc, Value value, Type destTy,
|
|
OpBuilder &builder) {
|
|
IntegerType valueType = dyn_cast<IntegerType>(value.getType());
|
|
assert(valueType && isa<IntegerType>(destTy) &&
|
|
valueType.getWidth() <= destTy.getIntOrFloatBitWidth() &&
|
|
valueType.getWidth() != 0 && "invalid sext operands");
|
|
// If already the right size, we are done.
|
|
if (valueType == destTy)
|
|
return value;
|
|
|
|
// sext is concat with a replicate of the sign bits and the bottom part.
|
|
auto signBit =
|
|
builder.createOrFold<ExtractOp>(loc, value, valueType.getWidth() - 1, 1);
|
|
auto signBits = builder.createOrFold<ReplicateOp>(
|
|
loc, signBit, destTy.getIntOrFloatBitWidth() - valueType.getWidth());
|
|
return builder.createOrFold<ConcatOp>(loc, signBits, value);
|
|
}
|
|
|
|
Value comb::createOrFoldSExt(Value value, Type destTy,
|
|
ImplicitLocOpBuilder &builder) {
|
|
return createOrFoldSExt(builder.getLoc(), value, destTy, builder);
|
|
}
|
|
|
|
Value comb::createOrFoldNot(Location loc, Value value, OpBuilder &builder,
|
|
bool twoState) {
|
|
auto allOnes = builder.create<hw::ConstantOp>(loc, value.getType(), -1);
|
|
return builder.createOrFold<XorOp>(loc, value, allOnes, twoState);
|
|
}
|
|
|
|
Value comb::createOrFoldNot(Value value, ImplicitLocOpBuilder &builder,
|
|
bool twoState) {
|
|
return createOrFoldNot(builder.getLoc(), value, builder, twoState);
|
|
}
|
|
|
|
// Extract individual bits from a value
|
|
void comb::extractBits(OpBuilder &builder, Value val,
|
|
SmallVectorImpl<Value> &bits) {
|
|
assert(val.getType().isInteger() && "expected integer");
|
|
auto width = val.getType().getIntOrFloatBitWidth();
|
|
bits.reserve(width);
|
|
|
|
// Check if we can reuse concat operands
|
|
if (auto concat = val.getDefiningOp<comb::ConcatOp>()) {
|
|
if (concat.getNumOperands() == width &&
|
|
llvm::all_of(concat.getOperandTypes(), [](Type type) {
|
|
return type.getIntOrFloatBitWidth() == 1;
|
|
})) {
|
|
// Reverse the operands to match the bit order
|
|
bits.append(std::make_reverse_iterator(concat.getOperands().end()),
|
|
std::make_reverse_iterator(concat.getOperands().begin()));
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Extract individual bits
|
|
for (int64_t i = 0; i < width; ++i)
|
|
bits.push_back(
|
|
builder.createOrFold<comb::ExtractOp>(val.getLoc(), val, i, 1));
|
|
}
|
|
|
|
// Construct a mux tree for given leaf nodes. `selectors` is the selector for
|
|
// each level of the tree. Currently the selector is tested from MSB to LSB.
|
|
Value comb::constructMuxTree(OpBuilder &builder, Location loc,
|
|
ArrayRef<Value> selectors,
|
|
ArrayRef<Value> leafNodes,
|
|
Value outOfBoundsValue) {
|
|
// Recursive helper function to construct the mux tree
|
|
std::function<Value(size_t, size_t)> constructTreeHelper =
|
|
[&](size_t id, size_t level) -> Value {
|
|
// Base case: at the lowest level, return the result
|
|
if (level == 0) {
|
|
// Return the result for the given index. If the index is out of bounds,
|
|
// return the out-of-bound value.
|
|
return id < leafNodes.size() ? leafNodes[id] : outOfBoundsValue;
|
|
}
|
|
|
|
auto selector = selectors[level - 1];
|
|
|
|
// Recursive case: create muxes for true and false branches
|
|
auto trueVal = constructTreeHelper(2 * id + 1, level - 1);
|
|
auto falseVal = constructTreeHelper(2 * id, level - 1);
|
|
|
|
// Combine the results with a mux
|
|
return builder.createOrFold<comb::MuxOp>(loc, selector, trueVal, falseVal);
|
|
};
|
|
|
|
return constructTreeHelper(0, llvm::Log2_64_Ceil(leafNodes.size()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ICmpOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ICmpPredicate ICmpOp::getFlippedPredicate(ICmpPredicate predicate) {
|
|
switch (predicate) {
|
|
case ICmpPredicate::eq:
|
|
return ICmpPredicate::eq;
|
|
case ICmpPredicate::ne:
|
|
return ICmpPredicate::ne;
|
|
case ICmpPredicate::slt:
|
|
return ICmpPredicate::sgt;
|
|
case ICmpPredicate::sle:
|
|
return ICmpPredicate::sge;
|
|
case ICmpPredicate::sgt:
|
|
return ICmpPredicate::slt;
|
|
case ICmpPredicate::sge:
|
|
return ICmpPredicate::sle;
|
|
case ICmpPredicate::ult:
|
|
return ICmpPredicate::ugt;
|
|
case ICmpPredicate::ule:
|
|
return ICmpPredicate::uge;
|
|
case ICmpPredicate::ugt:
|
|
return ICmpPredicate::ult;
|
|
case ICmpPredicate::uge:
|
|
return ICmpPredicate::ule;
|
|
case ICmpPredicate::ceq:
|
|
return ICmpPredicate::ceq;
|
|
case ICmpPredicate::cne:
|
|
return ICmpPredicate::cne;
|
|
case ICmpPredicate::weq:
|
|
return ICmpPredicate::weq;
|
|
case ICmpPredicate::wne:
|
|
return ICmpPredicate::wne;
|
|
}
|
|
llvm_unreachable("unknown comparison predicate");
|
|
}
|
|
|
|
bool ICmpOp::isPredicateSigned(ICmpPredicate predicate) {
|
|
switch (predicate) {
|
|
case ICmpPredicate::ult:
|
|
case ICmpPredicate::ugt:
|
|
case ICmpPredicate::ule:
|
|
case ICmpPredicate::uge:
|
|
case ICmpPredicate::ne:
|
|
case ICmpPredicate::eq:
|
|
case ICmpPredicate::cne:
|
|
case ICmpPredicate::ceq:
|
|
case ICmpPredicate::wne:
|
|
case ICmpPredicate::weq:
|
|
return false;
|
|
case ICmpPredicate::slt:
|
|
case ICmpPredicate::sgt:
|
|
case ICmpPredicate::sle:
|
|
case ICmpPredicate::sge:
|
|
return true;
|
|
}
|
|
llvm_unreachable("unknown comparison predicate");
|
|
}
|
|
|
|
/// Returns the predicate for a logically negated comparison, e.g. mapping
|
|
/// EQ => NE and SLE => SGT.
|
|
ICmpPredicate ICmpOp::getNegatedPredicate(ICmpPredicate predicate) {
|
|
switch (predicate) {
|
|
case ICmpPredicate::eq:
|
|
return ICmpPredicate::ne;
|
|
case ICmpPredicate::ne:
|
|
return ICmpPredicate::eq;
|
|
case ICmpPredicate::slt:
|
|
return ICmpPredicate::sge;
|
|
case ICmpPredicate::sle:
|
|
return ICmpPredicate::sgt;
|
|
case ICmpPredicate::sgt:
|
|
return ICmpPredicate::sle;
|
|
case ICmpPredicate::sge:
|
|
return ICmpPredicate::slt;
|
|
case ICmpPredicate::ult:
|
|
return ICmpPredicate::uge;
|
|
case ICmpPredicate::ule:
|
|
return ICmpPredicate::ugt;
|
|
case ICmpPredicate::ugt:
|
|
return ICmpPredicate::ule;
|
|
case ICmpPredicate::uge:
|
|
return ICmpPredicate::ult;
|
|
case ICmpPredicate::ceq:
|
|
return ICmpPredicate::cne;
|
|
case ICmpPredicate::cne:
|
|
return ICmpPredicate::ceq;
|
|
case ICmpPredicate::weq:
|
|
return ICmpPredicate::wne;
|
|
case ICmpPredicate::wne:
|
|
return ICmpPredicate::weq;
|
|
}
|
|
llvm_unreachable("unknown comparison predicate");
|
|
}
|
|
|
|
/// Return true if this is an equality test with -1, which is a "reduction
|
|
/// and" operation in Verilog.
|
|
bool ICmpOp::isEqualAllOnes() {
|
|
if (getPredicate() != ICmpPredicate::eq)
|
|
return false;
|
|
|
|
if (auto op1 =
|
|
dyn_cast_or_null<hw::ConstantOp>(getOperand(1).getDefiningOp()))
|
|
return op1.getValue().isAllOnes();
|
|
return false;
|
|
}
|
|
|
|
/// Return true if this is a not equal test with 0, which is a "reduction
|
|
/// or" operation in Verilog.
|
|
bool ICmpOp::isNotEqualZero() {
|
|
if (getPredicate() != ICmpPredicate::ne)
|
|
return false;
|
|
|
|
if (auto op1 =
|
|
dyn_cast_or_null<hw::ConstantOp>(getOperand(1).getDefiningOp()))
|
|
return op1.getValue().isZero();
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Unary Operations
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult ReplicateOp::verify() {
|
|
// The source must be equal or smaller than the dest type, and an even
|
|
// multiple of it. Both are already known to be signless integers.
|
|
auto srcWidth = cast<IntegerType>(getOperand().getType()).getWidth();
|
|
auto dstWidth = cast<IntegerType>(getType()).getWidth();
|
|
if (srcWidth == 0)
|
|
return emitOpError("replicate does not take zero bit integer");
|
|
|
|
if (srcWidth > dstWidth)
|
|
return emitOpError("replicate cannot shrink bitwidth of operand"),
|
|
failure();
|
|
|
|
if (dstWidth % srcWidth)
|
|
return emitOpError("replicate must produce integer multiple of operand"),
|
|
failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Variadic operations
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static LogicalResult verifyUTBinOp(Operation *op) {
|
|
if (op->getOperands().empty())
|
|
return op->emitOpError("requires 1 or more args");
|
|
return success();
|
|
}
|
|
|
|
LogicalResult AddOp::verify() { return verifyUTBinOp(*this); }
|
|
|
|
LogicalResult MulOp::verify() { return verifyUTBinOp(*this); }
|
|
|
|
LogicalResult AndOp::verify() { return verifyUTBinOp(*this); }
|
|
|
|
LogicalResult OrOp::verify() { return verifyUTBinOp(*this); }
|
|
|
|
LogicalResult XorOp::verify() { return verifyUTBinOp(*this); }
|
|
|
|
/// Return true if this is a two operand xor with an all ones constant as its
|
|
/// RHS operand.
|
|
bool XorOp::isBinaryNot() {
|
|
if (getNumOperands() != 2)
|
|
return false;
|
|
if (auto cst = getOperand(1).getDefiningOp<hw::ConstantOp>())
|
|
if (cst.getValue().isAllOnes())
|
|
return true;
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ConcatOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static unsigned getTotalWidth(ValueRange inputs) {
|
|
unsigned resultWidth = 0;
|
|
for (auto input : inputs) {
|
|
resultWidth += hw::type_cast<IntegerType>(input.getType()).getWidth();
|
|
}
|
|
return resultWidth;
|
|
}
|
|
|
|
void ConcatOp::build(OpBuilder &builder, OperationState &result, Value hd,
|
|
ValueRange tl) {
|
|
result.addOperands(ValueRange{hd});
|
|
result.addOperands(tl);
|
|
unsigned hdWidth = cast<IntegerType>(hd.getType()).getWidth();
|
|
result.addTypes(builder.getIntegerType(getTotalWidth(tl) + hdWidth));
|
|
}
|
|
|
|
LogicalResult ConcatOp::inferReturnTypes(
|
|
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
|
|
DictionaryAttr attrs, mlir::OpaqueProperties properties,
|
|
mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
|
|
unsigned resultWidth = getTotalWidth(operands);
|
|
results.push_back(IntegerType::get(context, resultWidth));
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Other Operations
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult ExtractOp::verify() {
|
|
unsigned srcWidth = cast<IntegerType>(getInput().getType()).getWidth();
|
|
unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
|
|
if (getLowBit() >= srcWidth || srcWidth - getLowBit() < dstWidth)
|
|
return emitOpError("from bit too large for input"), failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult TruthTableOp::verify() {
|
|
size_t numInputs = getInputs().size();
|
|
if (numInputs >= sizeof(size_t) * 8)
|
|
return emitOpError("Truth tables support a maximum of ")
|
|
<< sizeof(size_t) * 8 - 1 << " inputs on your platform";
|
|
|
|
ArrayAttr table = getLookupTable();
|
|
if (table.size() != (1ull << numInputs))
|
|
return emitOpError("Expected lookup table of 2^n length");
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen generated logic.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Provide the autogenerated implementation guts for the Op classes.
|
|
#define GET_OP_CLASSES
|
|
#include "circt/Dialect/Comb/Comb.cpp.inc"
|