Make AsyncParallelForRewrite parameterizable with a cost model which drives deciding the parallelization granularity.

Reviewed By: ezhulenev, mehdi_amini

Differential Revision: https://reviews.llvm.org/D115423
This commit is contained in:
bakhtiyar 2021-12-19 08:35:37 -08:00 committed by Eugene Zhulenev
parent 47bd9ebda4
commit ec0e4545ca
5 changed files with 140 additions and 43 deletions

View File

@ -21,7 +21,7 @@ std::unique_ptr<Pass> createAsyncParallelForPass();
std::unique_ptr<Pass> createAsyncParallelForPass(bool asyncDispatch, std::unique_ptr<Pass> createAsyncParallelForPass(bool asyncDispatch,
int32_t numWorkerThreads, int32_t numWorkerThreads,
int32_t targetBlockSize); int32_t minTaskSize);
std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass(); std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();

View File

@ -0,0 +1,40 @@
//===- Transforms.h - Async dialect transformation utilities ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines transformations on Async operations.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_ASYNC_TRANSFORMS_H_
#define MLIR_DIALECT_ASYNC_TRANSFORMS_H_
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
namespace mlir {
namespace async {
/// Emit the IR to compute the minimum number of iterations of scf.parallel body
/// that would be viable for a single parallel task. Allows the user to avoid
/// incurring the overheads of spawning costly parallel tasks in absence of
/// sufficient amount of parallelizable work.
///
/// Must return an index type.
using AsyncMinTaskSizeComputationFunction =
std::function<Value(ImplicitLocOpBuilder, scf::ParallelOp)>;
/// Add a pattern to the given pattern list to lower scf.parallel to async
/// operations.
void populateAsyncParallelForPatterns(
RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads,
AsyncMinTaskSizeComputationFunction computeMinTaskSize);
} // namespace async
} // namespace mlir
#endif // MLIR_DIALECT_ASYNC_TRANSFORMS_H_

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h" #include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/Async/Transforms.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
@ -105,10 +106,12 @@ struct AsyncParallelForPass
struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> { struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
public: public:
AsyncParallelForRewrite(MLIRContext *ctx, bool asyncDispatch, AsyncParallelForRewrite(
int32_t numWorkerThreads, int32_t minTaskSize) MLIRContext *ctx, bool asyncDispatch, int32_t numWorkerThreads,
AsyncMinTaskSizeComputationFunction computeMinTaskSize)
: OpRewritePattern(ctx), asyncDispatch(asyncDispatch), : OpRewritePattern(ctx), asyncDispatch(asyncDispatch),
numWorkerThreads(numWorkerThreads), minTaskSize(minTaskSize) {} numWorkerThreads(numWorkerThreads),
computeMinTaskSize(computeMinTaskSize) {}
LogicalResult matchAndRewrite(scf::ParallelOp op, LogicalResult matchAndRewrite(scf::ParallelOp op,
PatternRewriter &rewriter) const override; PatternRewriter &rewriter) const override;
@ -116,7 +119,7 @@ public:
private: private:
bool asyncDispatch; bool asyncDispatch;
int32_t numWorkerThreads; int32_t numWorkerThreads;
int32_t minTaskSize; AsyncMinTaskSizeComputationFunction computeMinTaskSize;
}; };
struct ParallelComputeFunctionType { struct ParallelComputeFunctionType {
@ -252,7 +255,11 @@ static ParallelComputeFunction createParallelComputeFunction(
getParallelComputeFunctionType(op, rewriter); getParallelComputeFunctionType(op, rewriter);
FunctionType type = computeFuncType.type; FunctionType type = computeFuncType.type;
FuncOp func = FuncOp::create(op.getLoc(), "parallel_compute_fn", type); FuncOp func = FuncOp::create(op.getLoc(),
numBlockAlignedInnerLoops > 0
? "parallel_compute_fn_with_aligned_loops"
: "parallel_compute_fn",
type);
func.setPrivate(); func.setPrivate();
// Insert function into the module symbol table and assign it unique name. // Insert function into the module symbol table and assign it unique name.
@ -702,6 +709,11 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
ImplicitLocOpBuilder b(op.getLoc(), rewriter); ImplicitLocOpBuilder b(op.getLoc(), rewriter);
// Computing minTaskSize emits IR and can be implemented as executing a cost
// model on the body of the scf.parallel. Thus it needs to be computed before
// the body of the scf.parallel has been manipulated.
Value minTaskSize = computeMinTaskSize(b, op);
// Make sure that all constants will be inside the parallel operation body to // Make sure that all constants will be inside the parallel operation body to
// reduce the number of parallel compute function arguments. // reduce the number of parallel compute function arguments.
cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter); cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
@ -752,7 +764,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
}; };
// Find how many inner iteration dimensions are statically known, and their // Find how many inner iteration dimensions are statically known, and their
// product is smaller than the `512`. We aling the parallel compute block // product is smaller than the `512`. We align the parallel compute block
// size by the product of statically known dimensions, so that we can // size by the product of statically known dimensions, so that we can
// guarantee that the inner loops executes from 0 to the loop trip counts // guarantee that the inner loops executes from 0 to the loop trip counts
// and we can elide dynamic loop boundaries, and give LLVM an opportunity to // and we can elide dynamic loop boundaries, and give LLVM an opportunity to
@ -793,50 +805,64 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
Value maxComputeBlocks = b.create<arith::ConstantIndexOp>( Value maxComputeBlocks = b.create<arith::ConstantIndexOp>(
std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor))); std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
// Target block size from the pass parameters.
Value minTaskSizeCst = b.create<arith::ConstantIndexOp>(minTaskSize);
// Compute parallel block size from the parallel problem size: // Compute parallel block size from the parallel problem size:
// blockSize = min(tripCount, // blockSize = min(tripCount,
// max(ceil_div(tripCount, maxComputeBlocks), // max(ceil_div(tripCount, maxComputeBlocks),
// ceil_div(minTaskSize, bodySize))) // minTaskSize))
Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks); Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks);
Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSizeCst); Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize);
Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1); Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
// Align the block size to be a multiple of the statically known number ParallelComputeFunction notUnrollableParallelComputeFunction =
// of iterations in the inner loops. createParallelComputeFunction(op, staticBounds, 0, rewriter);
if (numUnrollableLoops > 0 && minTaskSize >= maxIterations) {
Value numIters = b.create<arith::ConstantIndexOp>( // Dispatch parallel compute function using async recursive work splitting,
numIterations[op.getNumLoops() - numUnrollableLoops]); // or by submitting compute task sequentially from a caller thread.
Value bs2 = b.create<arith::MulIOp>( auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch;
b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
blockSize = b.create<arith::MinSIOp>(tripCount, bs2); // Create a parallel compute function that takes a block id and computes
} else { // the parallel operation body for a subset of iteration space.
// Reset the number of unrollable loops if we didn't align the block size.
numUnrollableLoops = 0;
}
// Compute the number of parallel compute blocks. // Compute the number of parallel compute blocks.
Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize); Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
// Create a parallel compute function that takes a block id and computes // Unroll when numUnrollableLoops > 0 && blockSize >= maxIterations.
// the parallel operation body for a subset of iteration space. bool staticShouldUnroll = numUnrollableLoops > 0;
ParallelComputeFunction parallelComputeFunction = auto dispatchNotUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
createParallelComputeFunction(op, staticBounds, numUnrollableLoops, ImplicitLocOpBuilder nb(loc, nestedBuilder);
rewriter); doDispatch(b, rewriter, notUnrollableParallelComputeFunction, op,
blockSize, blockCount, tripCounts);
nb.create<scf::YieldOp>();
};
// Dispatch parallel compute function using async recursive work splitting, if (staticShouldUnroll) {
// or by submitting compute task sequentially from a caller thread. Value dynamicShouldUnroll = b.create<arith::CmpIOp>(
if (asyncDispatch) { arith::CmpIPredicate::sge, blockSize,
doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize, b.create<arith::ConstantIndexOp>(maxIterations));
blockCount, tripCounts);
ParallelComputeFunction unrollableParallelComputeFunction =
createParallelComputeFunction(op, staticBounds, numUnrollableLoops,
rewriter);
auto dispatchUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
ImplicitLocOpBuilder nb(loc, nestedBuilder);
// Align the block size to be a multiple of the statically known
// number of iterations in the inner loops.
Value numIters = nb.create<arith::ConstantIndexOp>(
numIterations[op.getNumLoops() - numUnrollableLoops]);
Value alignedBlockSize = nb.create<arith::MulIOp>(
nb.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
doDispatch(b, rewriter, unrollableParallelComputeFunction, op,
alignedBlockSize, blockCount, tripCounts);
nb.create<scf::YieldOp>();
};
b.create<scf::IfOp>(TypeRange(), dynamicShouldUnroll, dispatchUnrollable,
dispatchNotUnrollable);
nb.create<scf::YieldOp>();
} else { } else {
doSequentialDispatch(b, rewriter, parallelComputeFunction, op, blockSize, dispatchNotUnrollable(nb, loc);
blockCount, tripCounts);
} }
nb.create<scf::YieldOp>();
}; };
// Replace the `scf.parallel` operation with the parallel compute function. // Replace the `scf.parallel` operation with the parallel compute function.
@ -852,9 +878,11 @@ void AsyncParallelForPass::runOnOperation() {
MLIRContext *ctx = &getContext(); MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx); RewritePatternSet patterns(ctx);
patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads, populateAsyncParallelForPatterns(
minTaskSize); patterns, asyncDispatch, numWorkerThreads,
[&](ImplicitLocOpBuilder builder, scf::ParallelOp op) {
return builder.create<arith::ConstantIndexOp>(minTaskSize);
});
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure(); signalPassFailure();
} }
@ -869,3 +897,11 @@ std::unique_ptr<Pass> mlir::createAsyncParallelForPass(bool asyncDispatch,
return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads, return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads,
minTaskSize); minTaskSize);
} }
void mlir::async::populateAsyncParallelForPatterns(
RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads,
AsyncMinTaskSizeComputationFunction computeMinTaskSize) {
MLIRContext *ctx = patterns.getContext();
patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
computeMinTaskSize);
}

View File

@ -100,9 +100,27 @@ func @sink_constant_step(%arg0: memref<?x10xf32>, %lb: index, %ub: index) {
// CHECK-SAME: %[[STEP1:arg[0-9]+]]: index, // CHECK-SAME: %[[STEP1:arg[0-9]+]]: index,
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?x10xf32> // CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
// CHECK-SAME: ) { // CHECK-SAME: ) {
// CHECK: scf.for %[[I:arg[0-9]+]]
// CHECK: select
// CHECK: scf.for %[[J:arg[0-9]+]]
// CHECK: memref.store
// CHECK-LABEL: func private @parallel_compute_fn_with_aligned_loops(
// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
// CHECK-SAME: %[[TRIP_COUNT0:arg[0-9]+]]: index,
// CHECK-SAME: %[[TRIP_COUNT1:arg[0-9]+]]: index,
// CHECK-SAME: %[[LB0:arg[0-9]+]]: index,
// CHECK-SAME: %[[LB1:arg[0-9]+]]: index,
// CHECK-SAME: %[[UB0:arg[0-9]+]]: index,
// CHECK-SAME: %[[UB1:arg[0-9]+]]: index,
// CHECK-SAME: %[[STEP0:arg[0-9]+]]: index,
// CHECK-SAME: %[[STEP1:arg[0-9]+]]: index,
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
// CHECK-SAME: ) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C10:.*]] = arith.constant 10 : index // CHECK: %[[C10:.*]] = arith.constant 10 : index
// CHECK: scf.for %[[I:arg[0-9]+]] // CHECK: scf.for %[[I:arg[0-9]+]]
// CHECK-NOT: select // CHECK-NOT: select
// CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1 // CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1

View File

@ -2018,7 +2018,10 @@ cc_library(
"lib/Dialect/Async/Transforms/*.cpp", "lib/Dialect/Async/Transforms/*.cpp",
"lib/Dialect/Async/Transforms/*.h", "lib/Dialect/Async/Transforms/*.h",
]), ]),
hdrs = ["include/mlir/Dialect/Async/Passes.h"], hdrs = [
"include/mlir/Dialect/Async/Passes.h",
"include/mlir/Dialect/Async/Transforms.h",
],
includes = ["include"], includes = ["include"],
deps = [ deps = [
":Analysis", ":Analysis",