Make AsyncParallelForRewrite parameterizable with a cost model which drives deciding the parallelization granularity.

Reviewed By: ezhulenev, mehdi_amini

Differential Revision: https://reviews.llvm.org/D115423
This commit is contained in:
bakhtiyar 2021-12-19 08:35:37 -08:00 committed by Eugene Zhulenev
parent 47bd9ebda4
commit ec0e4545ca
5 changed files with 140 additions and 43 deletions

View File

@ -21,7 +21,7 @@ std::unique_ptr<Pass> createAsyncParallelForPass();
std::unique_ptr<Pass> createAsyncParallelForPass(bool asyncDispatch,
int32_t numWorkerThreads,
int32_t targetBlockSize);
int32_t minTaskSize);
std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();

View File

@ -0,0 +1,40 @@
//===- Transforms.h - Async dialect transformation utilities ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines transformations on Async operations.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_ASYNC_TRANSFORMS_H_
#define MLIR_DIALECT_ASYNC_TRANSFORMS_H_
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
namespace mlir {
namespace async {
/// Emit the IR to compute the minimum number of iterations of scf.parallel body
/// that would be viable for a single parallel task. Allows the user to avoid
/// incurring the overheads of spawning costly parallel tasks in absence of
/// sufficient amount of parallelizable work.
///
/// Must return an index type.
using AsyncMinTaskSizeComputationFunction =
std::function<Value(ImplicitLocOpBuilder, scf::ParallelOp)>;
/// Add a pattern to the given pattern list to lower scf.parallel to async
/// operations.
void populateAsyncParallelForPatterns(
RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads,
AsyncMinTaskSizeComputationFunction computeMinTaskSize);
} // namespace async
} // namespace mlir
#endif // MLIR_DIALECT_ASYNC_TRANSFORMS_H_

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/Async/Transforms.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
@ -105,10 +106,12 @@ struct AsyncParallelForPass
struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
public:
AsyncParallelForRewrite(MLIRContext *ctx, bool asyncDispatch,
int32_t numWorkerThreads, int32_t minTaskSize)
AsyncParallelForRewrite(
MLIRContext *ctx, bool asyncDispatch, int32_t numWorkerThreads,
AsyncMinTaskSizeComputationFunction computeMinTaskSize)
: OpRewritePattern(ctx), asyncDispatch(asyncDispatch),
numWorkerThreads(numWorkerThreads), minTaskSize(minTaskSize) {}
numWorkerThreads(numWorkerThreads),
computeMinTaskSize(computeMinTaskSize) {}
LogicalResult matchAndRewrite(scf::ParallelOp op,
PatternRewriter &rewriter) const override;
@ -116,7 +119,7 @@ public:
private:
bool asyncDispatch;
int32_t numWorkerThreads;
int32_t minTaskSize;
AsyncMinTaskSizeComputationFunction computeMinTaskSize;
};
struct ParallelComputeFunctionType {
@ -252,7 +255,11 @@ static ParallelComputeFunction createParallelComputeFunction(
getParallelComputeFunctionType(op, rewriter);
FunctionType type = computeFuncType.type;
FuncOp func = FuncOp::create(op.getLoc(), "parallel_compute_fn", type);
FuncOp func = FuncOp::create(op.getLoc(),
numBlockAlignedInnerLoops > 0
? "parallel_compute_fn_with_aligned_loops"
: "parallel_compute_fn",
type);
func.setPrivate();
// Insert function into the module symbol table and assign it unique name.
@ -702,6 +709,11 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
// Computing minTaskSize emits IR and can be implemented as executing a cost
// model on the body of the scf.parallel. Thus it needs to be computed before
// the body of the scf.parallel has been manipulated.
Value minTaskSize = computeMinTaskSize(b, op);
// Make sure that all constants will be inside the parallel operation body to
// reduce the number of parallel compute function arguments.
cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
@ -752,7 +764,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
};
// Find how many inner iteration dimensions are statically known, and their
// product is smaller than the `512`. We aling the parallel compute block
// product is smaller than the `512`. We align the parallel compute block
// size by the product of statically known dimensions, so that we can
// guarantee that the inner loops executes from 0 to the loop trip counts
// and we can elide dynamic loop boundaries, and give LLVM an opportunity to
@ -793,50 +805,64 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
Value maxComputeBlocks = b.create<arith::ConstantIndexOp>(
std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
// Target block size from the pass parameters.
Value minTaskSizeCst = b.create<arith::ConstantIndexOp>(minTaskSize);
// Compute parallel block size from the parallel problem size:
// blockSize = min(tripCount,
// max(ceil_div(tripCount, maxComputeBlocks),
// ceil_div(minTaskSize, bodySize)))
// minTaskSize))
Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks);
Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSizeCst);
Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize);
Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
// Align the block size to be a multiple of the statically known number
// of iterations in the inner loops.
if (numUnrollableLoops > 0 && minTaskSize >= maxIterations) {
Value numIters = b.create<arith::ConstantIndexOp>(
numIterations[op.getNumLoops() - numUnrollableLoops]);
Value bs2 = b.create<arith::MulIOp>(
b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
blockSize = b.create<arith::MinSIOp>(tripCount, bs2);
} else {
// Reset the number of unrollable loops if we didn't align the block size.
numUnrollableLoops = 0;
}
ParallelComputeFunction notUnrollableParallelComputeFunction =
createParallelComputeFunction(op, staticBounds, 0, rewriter);
// Dispatch parallel compute function using async recursive work splitting,
// or by submitting compute task sequentially from a caller thread.
auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch;
// Create a parallel compute function that takes a block id and computes
// the parallel operation body for a subset of iteration space.
// Compute the number of parallel compute blocks.
Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
// Create a parallel compute function that takes a block id and computes
// the parallel operation body for a subset of iteration space.
ParallelComputeFunction parallelComputeFunction =
createParallelComputeFunction(op, staticBounds, numUnrollableLoops,
rewriter);
// Unroll when numUnrollableLoops > 0 && blockSize >= maxIterations.
bool staticShouldUnroll = numUnrollableLoops > 0;
auto dispatchNotUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
ImplicitLocOpBuilder nb(loc, nestedBuilder);
doDispatch(b, rewriter, notUnrollableParallelComputeFunction, op,
blockSize, blockCount, tripCounts);
nb.create<scf::YieldOp>();
};
// Dispatch parallel compute function using async recursive work splitting,
// or by submitting compute task sequentially from a caller thread.
if (asyncDispatch) {
doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
blockCount, tripCounts);
if (staticShouldUnroll) {
Value dynamicShouldUnroll = b.create<arith::CmpIOp>(
arith::CmpIPredicate::sge, blockSize,
b.create<arith::ConstantIndexOp>(maxIterations));
ParallelComputeFunction unrollableParallelComputeFunction =
createParallelComputeFunction(op, staticBounds, numUnrollableLoops,
rewriter);
auto dispatchUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
ImplicitLocOpBuilder nb(loc, nestedBuilder);
// Align the block size to be a multiple of the statically known
// number of iterations in the inner loops.
Value numIters = nb.create<arith::ConstantIndexOp>(
numIterations[op.getNumLoops() - numUnrollableLoops]);
Value alignedBlockSize = nb.create<arith::MulIOp>(
nb.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
doDispatch(b, rewriter, unrollableParallelComputeFunction, op,
alignedBlockSize, blockCount, tripCounts);
nb.create<scf::YieldOp>();
};
b.create<scf::IfOp>(TypeRange(), dynamicShouldUnroll, dispatchUnrollable,
dispatchNotUnrollable);
nb.create<scf::YieldOp>();
} else {
doSequentialDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
blockCount, tripCounts);
dispatchNotUnrollable(nb, loc);
}
nb.create<scf::YieldOp>();
};
// Replace the `scf.parallel` operation with the parallel compute function.
@ -852,9 +878,11 @@ void AsyncParallelForPass::runOnOperation() {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
minTaskSize);
populateAsyncParallelForPatterns(
patterns, asyncDispatch, numWorkerThreads,
[&](ImplicitLocOpBuilder builder, scf::ParallelOp op) {
return builder.create<arith::ConstantIndexOp>(minTaskSize);
});
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
@ -869,3 +897,11 @@ std::unique_ptr<Pass> mlir::createAsyncParallelForPass(bool asyncDispatch,
return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads,
minTaskSize);
}
void mlir::async::populateAsyncParallelForPatterns(
RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads,
AsyncMinTaskSizeComputationFunction computeMinTaskSize) {
MLIRContext *ctx = patterns.getContext();
patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
computeMinTaskSize);
}

View File

@ -100,9 +100,27 @@ func @sink_constant_step(%arg0: memref<?x10xf32>, %lb: index, %ub: index) {
// CHECK-SAME: %[[STEP1:arg[0-9]+]]: index,
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
// CHECK-SAME: ) {
// CHECK: scf.for %[[I:arg[0-9]+]]
// CHECK: select
// CHECK: scf.for %[[J:arg[0-9]+]]
// CHECK: memref.store
// CHECK-LABEL: func private @parallel_compute_fn_with_aligned_loops(
// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
// CHECK-SAME: %[[TRIP_COUNT0:arg[0-9]+]]: index,
// CHECK-SAME: %[[TRIP_COUNT1:arg[0-9]+]]: index,
// CHECK-SAME: %[[LB0:arg[0-9]+]]: index,
// CHECK-SAME: %[[LB1:arg[0-9]+]]: index,
// CHECK-SAME: %[[UB0:arg[0-9]+]]: index,
// CHECK-SAME: %[[UB1:arg[0-9]+]]: index,
// CHECK-SAME: %[[STEP0:arg[0-9]+]]: index,
// CHECK-SAME: %[[STEP1:arg[0-9]+]]: index,
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
// CHECK-SAME: ) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C10:.*]] = arith.constant 10 : index
// CHECK: scf.for %[[I:arg[0-9]+]]
// CHECK-NOT: select
// CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1
// CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1

View File

@ -2018,7 +2018,10 @@ cc_library(
"lib/Dialect/Async/Transforms/*.cpp",
"lib/Dialect/Async/Transforms/*.h",
]),
hdrs = ["include/mlir/Dialect/Async/Passes.h"],
hdrs = [
"include/mlir/Dialect/Async/Passes.h",
"include/mlir/Dialect/Async/Transforms.h",
],
includes = ["include"],
deps = [
":Analysis",