mirror of https://github.com/llvm/circt.git
[Handshake] Add memref legalization pass (#4191)
This commit is contained in:
parent
d6ca0e6ff9
commit
1db9d8499d
|
@ -31,6 +31,7 @@ std::unique_ptr<mlir::Pass> createHandshakeDematerializeForksSinksPass();
|
|||
std::unique_ptr<mlir::Pass> createHandshakeRemoveBuffersPass();
|
||||
std::unique_ptr<mlir::Pass> createHandshakeAddIDsPass();
|
||||
std::unique_ptr<mlir::Pass> createHandshakeLowerExtmemToHWPass();
|
||||
std::unique_ptr<mlir::Pass> createHandshakeLegalizeMemrefsPass();
|
||||
std::unique_ptr<mlir::OperationPass<handshake::FuncOp>>
|
||||
createHandshakeInsertBuffersPass(const std::string &strategy = "all",
|
||||
unsigned bufferSize = 2);
|
||||
|
|
|
@ -113,4 +113,14 @@ def HandshakeLockFunctions : Pass<"handshake-lock-functions", "handshake::FuncOp
|
|||
let constructor = "circt::handshake::createHandshakeLockFunctionsPass()";
|
||||
}
|
||||
|
||||
def HandshakeLegalizeMemrefs : Pass<"handshake-legalize-memrefs", "mlir::func::FuncOp"> {
|
||||
let summary = "Memref legalization and lowering pass.";
|
||||
let description = [{
|
||||
Lowers various memref operations to a state suitable for passing to the
|
||||
StandardToHandshake lowering.
|
||||
}];
|
||||
let constructor = "circt::handshake::createHandshakeLegalizeMemrefsPass()";
|
||||
let dependentDialects = ["mlir::scf::SCFDialect"];
|
||||
}
|
||||
|
||||
#endif // CIRCT_DIALECT_HANDSHAKE_HANDSHAKEPASSES_TD
|
||||
|
|
|
@ -5,6 +5,7 @@ add_circt_dialect_library(CIRCTHandshakeTransforms
|
|||
Buffers.cpp
|
||||
LockFunctions.cpp
|
||||
LowerExtmemToHW.cpp
|
||||
LegalizeMemrefs.cpp
|
||||
|
||||
DEPENDS
|
||||
CIRCTHandshakeTransformsIncGen
|
||||
|
@ -14,7 +15,9 @@ add_circt_dialect_library(CIRCTHandshakeTransforms
|
|||
CIRCTESI
|
||||
CIRCTHandshake
|
||||
CIRCTSupport
|
||||
CIRCTTransforms
|
||||
MLIRIR
|
||||
MLIRSCFDialect
|
||||
MLIRPass
|
||||
MLIRTransformUtils
|
||||
MLIRMemRefDialect
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
//===- LegalizeMemrefs.cpp - handshake memref legalization pass -*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Contains the definitions of the memref legalization pass.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetails.h"
|
||||
#include "circt/Dialect/Handshake/HandshakeOps.h"
|
||||
#include "circt/Dialect/Handshake/HandshakePasses.h"
|
||||
#include "circt/Support/BackedgeBuilder.h"
|
||||
#include "circt/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace circt;
|
||||
using namespace handshake;
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
struct HandshakeLegalizeMemrefsPass
|
||||
: public HandshakeLegalizeMemrefsBase<HandshakeLegalizeMemrefsPass> {
|
||||
void runOnOperation() override {
|
||||
func::FuncOp op = getOperation();
|
||||
if (op.isExternal())
|
||||
return;
|
||||
|
||||
// Erase all memref.dealloc operations - this implies that we consider all
|
||||
// memref.alloc's in the IR to be "static", in the C sense. It is then up to
|
||||
// callers of the handshake module to determine whether a call to said
|
||||
// module implies a _call_ (shared semantics) or an _instance_.
|
||||
for (auto dealloc :
|
||||
llvm::make_early_inc_range(op.getOps<memref::DeallocOp>()))
|
||||
dealloc.erase();
|
||||
|
||||
auto b = OpBuilder(op);
|
||||
|
||||
// Convert any memref.copy to explicit store operations (scf loop in case of
|
||||
// an array).
|
||||
for (auto copy : llvm::make_early_inc_range(op.getOps<memref::CopyOp>())) {
|
||||
b.setInsertionPoint(copy);
|
||||
auto loc = copy.getLoc();
|
||||
auto src = copy.getSource();
|
||||
auto dst = copy.getTarget();
|
||||
auto memrefType = src.getType().cast<MemRefType>();
|
||||
if (!isUniDimensional(memrefType)) {
|
||||
llvm::errs() << "Cannot legalize multi-dimensional memref operation "
|
||||
<< copy
|
||||
<< ". Please run the memref flattening pass before this "
|
||||
"pass.";
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
auto emitLoadStore = [&](Value index) {
|
||||
llvm::SmallVector<Value> indices = {index};
|
||||
auto loadValue = b.create<memref::LoadOp>(loc, src, indices);
|
||||
b.create<memref::StoreOp>(loc, loadValue, dst, indices);
|
||||
};
|
||||
|
||||
auto n = memrefType.getShape()[0];
|
||||
|
||||
if (n > 1) {
|
||||
auto lb = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
|
||||
auto ub = b.create<arith::ConstantIndexOp>(loc, n).getResult();
|
||||
auto step = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
|
||||
|
||||
b.create<scf::ForOp>(
|
||||
loc, lb, ub, step, llvm::SmallVector<Value>(),
|
||||
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
|
||||
emitLoadStore(iv);
|
||||
b.create<scf::YieldOp>(loc);
|
||||
});
|
||||
} else
|
||||
emitLoadStore(b.create<arith::ConstantIndexOp>(loc, 0));
|
||||
|
||||
copy.erase();
|
||||
}
|
||||
};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass>
|
||||
circt::handshake::createHandshakeLegalizeMemrefsPass() {
|
||||
return std::make_unique<HandshakeLegalizeMemrefsPass>();
|
||||
}
|
|
@ -21,6 +21,8 @@
|
|||
#include "circt/Dialect/HW/HWDialect.h"
|
||||
#include "circt/Dialect/Handshake/HandshakeOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace circt {
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
// RUN: circt-opt --handshake-legalize-memrefs %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @dealloc_copy(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: memref<4xi32>) -> memref<4xi32> {
|
||||
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<4xi32>
|
||||
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[VAL_3:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
|
||||
// CHECK: scf.for %[[VAL_5:.*]] = %[[VAL_2]] to %[[VAL_3]] step %[[VAL_4]] {
|
||||
// CHECK: %[[VAL_6:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_5]]] : memref<4xi32>
|
||||
// CHECK: memref.store %[[VAL_6]], %[[VAL_1]]{{\[}}%[[VAL_5]]] : memref<4xi32>
|
||||
// CHECK: }
|
||||
// CHECK: return %[[VAL_1]] : memref<4xi32>
|
||||
// CHECK: }
|
||||
|
||||
func.func @dealloc_copy(%arg : memref<4xi32>) -> memref<4xi32> {
|
||||
%0 = memref.alloc() : memref<4xi32>
|
||||
memref.copy %arg, %0 : memref<4xi32> to memref<4xi32>
|
||||
memref.dealloc %0 : memref<4xi32>
|
||||
return %0 : memref<4xi32>
|
||||
}
|
|
@ -192,7 +192,13 @@ static void loadDHLSPipeline(OpPassManager &pm) {
|
|||
// Software lowering
|
||||
pm.addPass(mlir::createLowerAffinePass());
|
||||
pm.addPass(mlir::createConvertSCFToCFPass());
|
||||
|
||||
// Memref legalization.
|
||||
pm.addPass(circt::createFlattenMemRefPass());
|
||||
pm.nest<func::FuncOp>().addPass(
|
||||
circt::handshake::createHandshakeLegalizeMemrefsPass());
|
||||
pm.addPass(mlir::createConvertSCFToCFPass());
|
||||
pm.nest<handshake::FuncOp>().addPass(createSimpleCanonicalizerPass());
|
||||
|
||||
// DHLS conversion
|
||||
pm.addPass(circt::createStandardToHandshakePass(
|
||||
|
|
Loading…
Reference in New Issue