Make AsyncParallelForRewrite parameterizable with a cost model which drives deciding the parallelization granularity.
Reviewed By: ezhulenev, mehdi_amini Differential Revision: https://reviews.llvm.org/D115423
This commit is contained in:
parent
47bd9ebda4
commit
ec0e4545ca
|
|
@ -21,7 +21,7 @@ std::unique_ptr<Pass> createAsyncParallelForPass();
|
|||
|
||||
std::unique_ptr<Pass> createAsyncParallelForPass(bool asyncDispatch,
|
||||
int32_t numWorkerThreads,
|
||||
int32_t targetBlockSize);
|
||||
int32_t minTaskSize);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
//===- Transforms.h - Async dialect transformation utilities ----*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header file defines transformations on Async operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_ASYNC_TRANSFORMS_H_
|
||||
#define MLIR_DIALECT_ASYNC_TRANSFORMS_H_
|
||||
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace async {
|
||||
|
||||
/// Emit the IR to compute the minimum number of iterations of scf.parallel body
|
||||
/// that would be viable for a single parallel task. Allows the user to avoid
|
||||
/// incurring the overheads of spawning costly parallel tasks in absence of
|
||||
/// sufficient amount of parallelizable work.
|
||||
///
|
||||
/// Must return an index type.
|
||||
using AsyncMinTaskSizeComputationFunction =
|
||||
std::function<Value(ImplicitLocOpBuilder, scf::ParallelOp)>;
|
||||
|
||||
/// Add a pattern to the given pattern list to lower scf.parallel to async
|
||||
/// operations.
|
||||
void populateAsyncParallelForPatterns(
|
||||
RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads,
|
||||
AsyncMinTaskSizeComputationFunction computeMinTaskSize);
|
||||
|
||||
} // namespace async
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_ASYNC_TRANSFORMS_H_
|
||||
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Async/IR/Async.h"
|
||||
#include "mlir/Dialect/Async/Passes.h"
|
||||
#include "mlir/Dialect/Async/Transforms.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
|
|
@ -105,10 +106,12 @@ struct AsyncParallelForPass
|
|||
|
||||
struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
|
||||
public:
|
||||
AsyncParallelForRewrite(MLIRContext *ctx, bool asyncDispatch,
|
||||
int32_t numWorkerThreads, int32_t minTaskSize)
|
||||
AsyncParallelForRewrite(
|
||||
MLIRContext *ctx, bool asyncDispatch, int32_t numWorkerThreads,
|
||||
AsyncMinTaskSizeComputationFunction computeMinTaskSize)
|
||||
: OpRewritePattern(ctx), asyncDispatch(asyncDispatch),
|
||||
numWorkerThreads(numWorkerThreads), minTaskSize(minTaskSize) {}
|
||||
numWorkerThreads(numWorkerThreads),
|
||||
computeMinTaskSize(computeMinTaskSize) {}
|
||||
|
||||
LogicalResult matchAndRewrite(scf::ParallelOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
|
@ -116,7 +119,7 @@ public:
|
|||
private:
|
||||
bool asyncDispatch;
|
||||
int32_t numWorkerThreads;
|
||||
int32_t minTaskSize;
|
||||
AsyncMinTaskSizeComputationFunction computeMinTaskSize;
|
||||
};
|
||||
|
||||
struct ParallelComputeFunctionType {
|
||||
|
|
@ -252,7 +255,11 @@ static ParallelComputeFunction createParallelComputeFunction(
|
|||
getParallelComputeFunctionType(op, rewriter);
|
||||
|
||||
FunctionType type = computeFuncType.type;
|
||||
FuncOp func = FuncOp::create(op.getLoc(), "parallel_compute_fn", type);
|
||||
FuncOp func = FuncOp::create(op.getLoc(),
|
||||
numBlockAlignedInnerLoops > 0
|
||||
? "parallel_compute_fn_with_aligned_loops"
|
||||
: "parallel_compute_fn",
|
||||
type);
|
||||
func.setPrivate();
|
||||
|
||||
// Insert function into the module symbol table and assign it unique name.
|
||||
|
|
@ -702,6 +709,11 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
|
|||
|
||||
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
|
||||
// Computing minTaskSize emits IR and can be implemented as executing a cost
|
||||
// model on the body of the scf.parallel. Thus it needs to be computed before
|
||||
// the body of the scf.parallel has been manipulated.
|
||||
Value minTaskSize = computeMinTaskSize(b, op);
|
||||
|
||||
// Make sure that all constants will be inside the parallel operation body to
|
||||
// reduce the number of parallel compute function arguments.
|
||||
cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
|
||||
|
|
@ -752,7 +764,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
|
|||
};
|
||||
|
||||
// Find how many inner iteration dimensions are statically known, and their
|
||||
// product is smaller than the `512`. We aling the parallel compute block
|
||||
// product is smaller than the `512`. We align the parallel compute block
|
||||
// size by the product of statically known dimensions, so that we can
|
||||
// guarantee that the inner loops executes from 0 to the loop trip counts
|
||||
// and we can elide dynamic loop boundaries, and give LLVM an opportunity to
|
||||
|
|
@ -793,52 +805,66 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
|
|||
Value maxComputeBlocks = b.create<arith::ConstantIndexOp>(
|
||||
std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
|
||||
|
||||
// Target block size from the pass parameters.
|
||||
Value minTaskSizeCst = b.create<arith::ConstantIndexOp>(minTaskSize);
|
||||
|
||||
// Compute parallel block size from the parallel problem size:
|
||||
// blockSize = min(tripCount,
|
||||
// max(ceil_div(tripCount, maxComputeBlocks),
|
||||
// ceil_div(minTaskSize, bodySize)))
|
||||
// minTaskSize))
|
||||
Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks);
|
||||
Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSizeCst);
|
||||
Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize);
|
||||
Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
|
||||
|
||||
// Align the block size to be a multiple of the statically known number
|
||||
// of iterations in the inner loops.
|
||||
if (numUnrollableLoops > 0 && minTaskSize >= maxIterations) {
|
||||
Value numIters = b.create<arith::ConstantIndexOp>(
|
||||
numIterations[op.getNumLoops() - numUnrollableLoops]);
|
||||
Value bs2 = b.create<arith::MulIOp>(
|
||||
b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
|
||||
blockSize = b.create<arith::MinSIOp>(tripCount, bs2);
|
||||
} else {
|
||||
// Reset the number of unrollable loops if we didn't align the block size.
|
||||
numUnrollableLoops = 0;
|
||||
}
|
||||
ParallelComputeFunction notUnrollableParallelComputeFunction =
|
||||
createParallelComputeFunction(op, staticBounds, 0, rewriter);
|
||||
|
||||
// Dispatch parallel compute function using async recursive work splitting,
|
||||
// or by submitting compute task sequentially from a caller thread.
|
||||
auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch;
|
||||
|
||||
// Create a parallel compute function that takes a block id and computes
|
||||
// the parallel operation body for a subset of iteration space.
|
||||
|
||||
// Compute the number of parallel compute blocks.
|
||||
Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
|
||||
|
||||
// Create a parallel compute function that takes a block id and computes
|
||||
// the parallel operation body for a subset of iteration space.
|
||||
ParallelComputeFunction parallelComputeFunction =
|
||||
// Unroll when numUnrollableLoops > 0 && blockSize >= maxIterations.
|
||||
bool staticShouldUnroll = numUnrollableLoops > 0;
|
||||
auto dispatchNotUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
|
||||
ImplicitLocOpBuilder nb(loc, nestedBuilder);
|
||||
doDispatch(b, rewriter, notUnrollableParallelComputeFunction, op,
|
||||
blockSize, blockCount, tripCounts);
|
||||
nb.create<scf::YieldOp>();
|
||||
};
|
||||
|
||||
if (staticShouldUnroll) {
|
||||
Value dynamicShouldUnroll = b.create<arith::CmpIOp>(
|
||||
arith::CmpIPredicate::sge, blockSize,
|
||||
b.create<arith::ConstantIndexOp>(maxIterations));
|
||||
|
||||
ParallelComputeFunction unrollableParallelComputeFunction =
|
||||
createParallelComputeFunction(op, staticBounds, numUnrollableLoops,
|
||||
rewriter);
|
||||
|
||||
// Dispatch parallel compute function using async recursive work splitting,
|
||||
// or by submitting compute task sequentially from a caller thread.
|
||||
if (asyncDispatch) {
|
||||
doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
|
||||
blockCount, tripCounts);
|
||||
} else {
|
||||
doSequentialDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
|
||||
blockCount, tripCounts);
|
||||
}
|
||||
|
||||
auto dispatchUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
|
||||
ImplicitLocOpBuilder nb(loc, nestedBuilder);
|
||||
// Align the block size to be a multiple of the statically known
|
||||
// number of iterations in the inner loops.
|
||||
Value numIters = nb.create<arith::ConstantIndexOp>(
|
||||
numIterations[op.getNumLoops() - numUnrollableLoops]);
|
||||
Value alignedBlockSize = nb.create<arith::MulIOp>(
|
||||
nb.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
|
||||
doDispatch(b, rewriter, unrollableParallelComputeFunction, op,
|
||||
alignedBlockSize, blockCount, tripCounts);
|
||||
nb.create<scf::YieldOp>();
|
||||
};
|
||||
|
||||
b.create<scf::IfOp>(TypeRange(), dynamicShouldUnroll, dispatchUnrollable,
|
||||
dispatchNotUnrollable);
|
||||
nb.create<scf::YieldOp>();
|
||||
} else {
|
||||
dispatchNotUnrollable(nb, loc);
|
||||
}
|
||||
};
|
||||
|
||||
// Replace the `scf.parallel` operation with the parallel compute function.
|
||||
b.create<scf::IfOp>(TypeRange(), isZeroIterations, noOp, dispatch);
|
||||
|
||||
|
|
@ -852,9 +878,11 @@ void AsyncParallelForPass::runOnOperation() {
|
|||
MLIRContext *ctx = &getContext();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
|
||||
minTaskSize);
|
||||
|
||||
populateAsyncParallelForPatterns(
|
||||
patterns, asyncDispatch, numWorkerThreads,
|
||||
[&](ImplicitLocOpBuilder builder, scf::ParallelOp op) {
|
||||
return builder.create<arith::ConstantIndexOp>(minTaskSize);
|
||||
});
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
|
@ -869,3 +897,11 @@ std::unique_ptr<Pass> mlir::createAsyncParallelForPass(bool asyncDispatch,
|
|||
return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads,
|
||||
minTaskSize);
|
||||
}
|
||||
|
||||
void mlir::async::populateAsyncParallelForPatterns(
|
||||
RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads,
|
||||
AsyncMinTaskSizeComputationFunction computeMinTaskSize) {
|
||||
MLIRContext *ctx = patterns.getContext();
|
||||
patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
|
||||
computeMinTaskSize);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -100,6 +100,24 @@ func @sink_constant_step(%arg0: memref<?x10xf32>, %lb: index, %ub: index) {
|
|||
// CHECK-SAME: %[[STEP1:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
|
||||
// CHECK-SAME: ) {
|
||||
// CHECK: scf.for %[[I:arg[0-9]+]]
|
||||
// CHECK: select
|
||||
// CHECK: scf.for %[[J:arg[0-9]+]]
|
||||
// CHECK: memref.store
|
||||
|
||||
// CHECK-LABEL: func private @parallel_compute_fn_with_aligned_loops(
|
||||
// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[TRIP_COUNT0:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[TRIP_COUNT1:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[LB0:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[LB1:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[UB0:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[UB1:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[STEP0:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[STEP1:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
|
||||
// CHECK-SAME: ) {
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[C10:.*]] = arith.constant 10 : index
|
||||
|
|
|
|||
|
|
@ -2018,7 +2018,10 @@ cc_library(
|
|||
"lib/Dialect/Async/Transforms/*.cpp",
|
||||
"lib/Dialect/Async/Transforms/*.h",
|
||||
]),
|
||||
hdrs = ["include/mlir/Dialect/Async/Passes.h"],
|
||||
hdrs = [
|
||||
"include/mlir/Dialect/Async/Passes.h",
|
||||
"include/mlir/Dialect/Async/Transforms.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":Analysis",
|
||||
|
|
|
|||
Loading…
Reference in New Issue