mirror of https://github.com/llvm/circt.git
188 lines
6.0 KiB
C++
188 lines
6.0 KiB
C++
//===- MakeTables.cpp -----------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "circt/Dialect/Arc/ArcOps.h"
|
|
#include "circt/Dialect/Arc/ArcPasses.h"
|
|
#include "circt/Dialect/Comb/CombOps.h"
|
|
#include "circt/Dialect/HW/HWOps.h"
|
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
#define DEBUG_TYPE "arc-lookup-tables"
|
|
|
|
namespace circt {
|
|
namespace arc {
|
|
#define GEN_PASS_DEF_MAKETABLES
|
|
#include "circt/Dialect/Arc/ArcPasses.h.inc"
|
|
} // namespace arc
|
|
} // namespace circt
|
|
|
|
using namespace circt;
|
|
using namespace arc;
|
|
using namespace hw;
|
|
|
|
namespace {
|
|
|
|
static constexpr int tableMinOpCount = 20;
|
|
static constexpr int tableMaxSize = 32768; // bits
|
|
|
|
struct MakeTablesPass : public arc::impl::MakeTablesBase<MakeTablesPass> {
|
|
void runOnOperation() override;
|
|
void runOnArc(DefineOp defineOp);
|
|
};
|
|
} // namespace
|
|
|
|
static inline uint32_t bitsMask(uint32_t nbits) {
|
|
if (nbits == 32)
|
|
return ~0;
|
|
return (1 << nbits) - 1;
|
|
}
|
|
|
|
static inline uint32_t bitsGet(uint32_t x, uint32_t lb, uint32_t ub) {
|
|
return (x >> lb) & bitsMask(ub - lb + 1);
|
|
}
|
|
|
|
void MakeTablesPass::runOnOperation() {
|
|
auto module = getOperation();
|
|
for (auto op : module.getOps<DefineOp>())
|
|
runOnArc(op);
|
|
}
|
|
|
|
void MakeTablesPass::runOnArc(DefineOp defineOp) {
|
|
// Determine the number of input bits.
|
|
unsigned numInputBits = 0;
|
|
for (auto &type : defineOp.getArgumentTypes()) {
|
|
auto intType = dyn_cast<IntegerType>(type);
|
|
if (!intType)
|
|
return;
|
|
numInputBits += intType.getWidth();
|
|
}
|
|
if (numInputBits == 0)
|
|
return;
|
|
|
|
// Count the number of non-constant operations in the block.
|
|
unsigned numOps = 0;
|
|
for (auto &op : defineOp.getBodyBlock().without_terminator())
|
|
if (!op.hasTrait<OpTrait::ConstantLike>())
|
|
++numOps;
|
|
|
|
// Determine the number of output bits.
|
|
unsigned numOutputBits = 0;
|
|
auto outputOp = cast<arc::OutputOp>(defineOp.getBodyBlock().getTerminator());
|
|
for (auto type : outputOp.getOperandTypes()) {
|
|
auto intType = dyn_cast<IntegerType>(type);
|
|
if (!intType)
|
|
return;
|
|
numOutputBits += intType.getWidth();
|
|
}
|
|
if (numOutputBits == 0)
|
|
return;
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Making lookup tables in `" << defineOp.getName()
|
|
<< "`\n");
|
|
LLVM_DEBUG(llvm::dbgs() << "- " << numInputBits << " input bits, "
|
|
<< numOutputBits << " output bits, " << numOps
|
|
<< " ops\n");
|
|
|
|
// Check whether the table dimensions are within bounds.
|
|
if (numInputBits >= 31) {
|
|
LLVM_DEBUG(llvm::dbgs() << "- Skip; too many input bits\n");
|
|
return;
|
|
}
|
|
if (numOps < tableMinOpCount) {
|
|
LLVM_DEBUG(llvm::dbgs() << "- Skip; not enough ops\n");
|
|
return;
|
|
}
|
|
|
|
unsigned numTableEntries = 1U << numInputBits;
|
|
if (numTableEntries > tableMaxSize / numOutputBits) {
|
|
LLVM_DEBUG(llvm::dbgs() << "- Skip; table too large\n");
|
|
return;
|
|
}
|
|
LLVM_DEBUG(llvm::dbgs() << "- Creating table of "
|
|
<< numTableEntries * numOutputBits << " bits\n");
|
|
|
|
// Actually build the table.
|
|
SmallVector<Operation *, 64> tabularizedOps;
|
|
for (auto &op : defineOp.getBodyBlock().without_terminator())
|
|
tabularizedOps.push_back(&op);
|
|
|
|
// Concatenate the inputs into a single index value.
|
|
auto builder = ImplicitLocOpBuilder::atBlockBegin(defineOp.getLoc(),
|
|
&defineOp.getBodyBlock());
|
|
SmallVector<Value> inputsToConcat(defineOp.getArguments());
|
|
std::reverse(inputsToConcat.begin(), inputsToConcat.end());
|
|
auto concatInputs = inputsToConcat.size() > 1
|
|
? comb::ConcatOp::create(builder, inputsToConcat)
|
|
: inputsToConcat[0];
|
|
|
|
// Compute a lookup table for every output.
|
|
SmallVector<SmallVector<Attribute, 0>> tables;
|
|
DenseMap<Value, Attribute> values;
|
|
tables.resize(outputOp->getNumOperands());
|
|
|
|
for (int input = (1U << numInputBits) - 1; input >= 0; input--) {
|
|
// Assign the input values.
|
|
values.clear();
|
|
unsigned bits = 0;
|
|
for (auto arg : defineOp.getArguments()) {
|
|
auto w = dyn_cast<IntegerType>(arg.getType()).getWidth();
|
|
values[arg] = builder.getIntegerAttr(arg.getType(),
|
|
bitsGet(input, bits, bits + w - 1));
|
|
bits += w;
|
|
}
|
|
|
|
// Evaluate the operations.
|
|
SmallVector<Attribute> constants;
|
|
for (auto *operation : tabularizedOps) {
|
|
constants.clear();
|
|
for (auto operand : operation->getOperands())
|
|
constants.push_back(values[operand]);
|
|
|
|
SmallVector<OpFoldResult, 8> resultValues;
|
|
if (failed(operation->fold(constants, resultValues))) {
|
|
LLVM_DEBUG(llvm::dbgs() << "- Skip; operation folder failed\n");
|
|
return;
|
|
}
|
|
|
|
for (auto [result, resultValue] :
|
|
llvm::zip(operation->getResults(), resultValues)) {
|
|
auto attr = dyn_cast<Attribute>(resultValue);
|
|
if (!attr)
|
|
attr = values[dyn_cast<Value>(resultValue)];
|
|
values[result] = attr;
|
|
}
|
|
}
|
|
|
|
// Add the evaluated values to the output tables.
|
|
for (auto [table, outputOperand] :
|
|
llvm::zip(tables, outputOp->getOpOperands())) {
|
|
table.push_back(dyn_cast<Attribute>(values[outputOperand.get()]));
|
|
}
|
|
}
|
|
|
|
// Create the table lookup ops.
|
|
for (auto [table, outputOperand] :
|
|
llvm::zip(tables, outputOp->getOpOperands())) {
|
|
auto array = hw::AggregateConstantOp::create(
|
|
builder, ArrayType::get(outputOperand.get().getType(), numTableEntries),
|
|
builder.getArrayAttr(table));
|
|
outputOperand.set(hw::ArrayGetOp::create(builder, array, concatInputs));
|
|
}
|
|
|
|
for (auto *op : tabularizedOps) {
|
|
op->dropAllUses();
|
|
op->erase();
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<Pass> arc::createMakeTablesPass() {
|
|
return std::make_unique<MakeTablesPass>();
|
|
}
|