[Handshake] Add memref legalization pass (#4191)

This commit is contained in:
Morten Borup Petersen 2022-10-26 21:35:51 +02:00 committed by GitHub
parent d6ca0e6ff9
commit 1db9d8499d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 138 additions and 0 deletions

View File

@ -31,6 +31,7 @@ std::unique_ptr<mlir::Pass> createHandshakeDematerializeForksSinksPass();
std::unique_ptr<mlir::Pass> createHandshakeRemoveBuffersPass();
std::unique_ptr<mlir::Pass> createHandshakeAddIDsPass();
std::unique_ptr<mlir::Pass> createHandshakeLowerExtmemToHWPass();
std::unique_ptr<mlir::Pass> createHandshakeLegalizeMemrefsPass();
std::unique_ptr<mlir::OperationPass<handshake::FuncOp>>
createHandshakeInsertBuffersPass(const std::string &strategy = "all",
unsigned bufferSize = 2);

View File

@ -113,4 +113,14 @@ def HandshakeLockFunctions : Pass<"handshake-lock-functions", "handshake::FuncOp
let constructor = "circt::handshake::createHandshakeLockFunctionsPass()";
}
def HandshakeLegalizeMemrefs : Pass<"handshake-legalize-memrefs", "mlir::func::FuncOp"> {
let summary = "Memref legalization and lowering pass.";
let description = [{
Lowers various memref operations to a state suitable for passing to the
StandardToHandshake lowering.
}];
let constructor = "circt::handshake::createHandshakeLegalizeMemrefsPass()";
let dependentDialects = ["mlir::scf::SCFDialect"];
}
#endif // CIRCT_DIALECT_HANDSHAKE_HANDSHAKEPASSES_TD

View File

@ -5,6 +5,7 @@ add_circt_dialect_library(CIRCTHandshakeTransforms
Buffers.cpp
LockFunctions.cpp
LowerExtmemToHW.cpp
LegalizeMemrefs.cpp
DEPENDS
CIRCTHandshakeTransformsIncGen
@ -14,7 +15,9 @@ add_circt_dialect_library(CIRCTHandshakeTransforms
CIRCTESI
CIRCTHandshake
CIRCTSupport
CIRCTTransforms
MLIRIR
MLIRSCFDialect
MLIRPass
MLIRTransformUtils
MLIRMemRefDialect

View File

@ -0,0 +1,95 @@
//===- LegalizeMemrefs.cpp - handshake memref legalization pass -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Contains the definitions of the memref legalization pass.
//
//===----------------------------------------------------------------------===//
#include "PassDetails.h"
#include "circt/Dialect/Handshake/HandshakeOps.h"
#include "circt/Dialect/Handshake/HandshakePasses.h"
#include "circt/Support/BackedgeBuilder.h"
#include "circt/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace circt;
using namespace handshake;
using namespace mlir;
namespace {
struct HandshakeLegalizeMemrefsPass
: public HandshakeLegalizeMemrefsBase<HandshakeLegalizeMemrefsPass> {
void runOnOperation() override {
func::FuncOp op = getOperation();
if (op.isExternal())
return;
// Erase all memref.dealloc operations - this implies that we consider all
// memref.alloc's in the IR to be "static", in the C sense. It is then up to
// callers of the handshake module to determine whether a call to said
// module implies a _call_ (shared semantics) or an _instance_.
for (auto dealloc :
llvm::make_early_inc_range(op.getOps<memref::DeallocOp>()))
dealloc.erase();
auto b = OpBuilder(op);
// Convert any memref.copy to explicit store operations (scf loop in case of
// an array).
for (auto copy : llvm::make_early_inc_range(op.getOps<memref::CopyOp>())) {
b.setInsertionPoint(copy);
auto loc = copy.getLoc();
auto src = copy.getSource();
auto dst = copy.getTarget();
auto memrefType = src.getType().cast<MemRefType>();
if (!isUniDimensional(memrefType)) {
llvm::errs() << "Cannot legalize multi-dimensional memref operation "
<< copy
<< ". Please run the memref flattening pass before this "
"pass.";
signalPassFailure();
return;
}
auto emitLoadStore = [&](Value index) {
llvm::SmallVector<Value> indices = {index};
auto loadValue = b.create<memref::LoadOp>(loc, src, indices);
b.create<memref::StoreOp>(loc, loadValue, dst, indices);
};
auto n = memrefType.getShape()[0];
if (n > 1) {
auto lb = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
auto ub = b.create<arith::ConstantIndexOp>(loc, n).getResult();
auto step = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
b.create<scf::ForOp>(
loc, lb, ub, step, llvm::SmallVector<Value>(),
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
emitLoadStore(iv);
b.create<scf::YieldOp>(loc);
});
} else
emitLoadStore(b.create<arith::ConstantIndexOp>(loc, 0));
copy.erase();
}
};
};
} // namespace
std::unique_ptr<mlir::Pass>
circt::handshake::createHandshakeLegalizeMemrefsPass() {
return std::make_unique<HandshakeLegalizeMemrefsPass>();
}

View File

@ -21,6 +21,8 @@
#include "circt/Dialect/HW/HWDialect.h"
#include "circt/Dialect/Handshake/HandshakeOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Pass/Pass.h"
namespace circt {

View File

@ -0,0 +1,21 @@
// RUN: circt-opt --handshake-legalize-memrefs %s | FileCheck %s
// CHECK-LABEL: func.func @dealloc_copy(
// CHECK-SAME: %[[VAL_0:.*]]: memref<4xi32>) -> memref<4xi32> {
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<4xi32>
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 4 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK: scf.for %[[VAL_5:.*]] = %[[VAL_2]] to %[[VAL_3]] step %[[VAL_4]] {
// CHECK: %[[VAL_6:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_5]]] : memref<4xi32>
// CHECK: memref.store %[[VAL_6]], %[[VAL_1]]{{\[}}%[[VAL_5]]] : memref<4xi32>
// CHECK: }
// CHECK: return %[[VAL_1]] : memref<4xi32>
// CHECK: }
func.func @dealloc_copy(%arg : memref<4xi32>) -> memref<4xi32> {
%0 = memref.alloc() : memref<4xi32>
memref.copy %arg, %0 : memref<4xi32> to memref<4xi32>
memref.dealloc %0 : memref<4xi32>
return %0 : memref<4xi32>
}

View File

@ -192,7 +192,13 @@ static void loadDHLSPipeline(OpPassManager &pm) {
// Software lowering
pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createConvertSCFToCFPass());
// Memref legalization.
pm.addPass(circt::createFlattenMemRefPass());
pm.nest<func::FuncOp>().addPass(
circt::handshake::createHandshakeLegalizeMemrefsPass());
pm.addPass(mlir::createConvertSCFToCFPass());
pm.nest<handshake::FuncOp>().addPass(createSimpleCanonicalizerPass());
// DHLS conversion
pm.addPass(circt::createStandardToHandshakePass(