86 lines
3.2 KiB
C++
86 lines
3.2 KiB
C++
//===- Distibution.cpp - linalg named ops to generic ops --------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the Linalg distibution pass. It updates `tiled_loop`
|
|
// control variables depending on the distribution type.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#define DEBUG_TYPE "linalg-distribution"
|
|
|
|
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
namespace {
|
|
|
|
struct DistributeTiledLoopPattern
|
|
: public OpRewritePattern<linalg::TiledLoopOp> {
|
|
DistributeTiledLoopPattern(MLIRContext *context,
|
|
LinalgLoopDistributionOptions options,
|
|
LinalgTransformationFilter marker)
|
|
: OpRewritePattern<linalg::TiledLoopOp>(context), options(options),
|
|
marker(marker) {}
|
|
LogicalResult matchAndRewrite(linalg::TiledLoopOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (failed(marker.checkAndNotify(rewriter, op)))
|
|
return failure();
|
|
if (!op.distribution_types().hasValue())
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
SmallVector<Value, 2> newLowerBounds = op.lowerBound();
|
|
SmallVector<Value, 2> newUpperBounds = op.upperBound();
|
|
SmallVector<Value, 2> newSteps = op.step();
|
|
|
|
// Update bounds and steps.
|
|
auto distributionTypes = op.distribution_types().getValue();
|
|
for (int i = 0, e = op.getNumLoops(); i < e; ++i) {
|
|
StringRef type = distributionTypes[i].cast<StringAttr>().getValue();
|
|
auto procInfoCallback = options.procInfoMap.find(type);
|
|
if (procInfoCallback == options.procInfoMap.end())
|
|
continue;
|
|
|
|
if (!isParallelIterator(op.iterator_types()[i])) {
|
|
op.emitOpError("only support for parallel loops is implemented");
|
|
return failure();
|
|
}
|
|
ProcInfo info = procInfoCallback->second(rewriter, loc);
|
|
updateBoundsForCyclicDistribution(rewriter, loc, info.procId, info.nprocs,
|
|
newLowerBounds[i], newUpperBounds[i],
|
|
newSteps[i]);
|
|
}
|
|
rewriter.updateRootInPlace(op, [&] {
|
|
op.setLowerBounds(newLowerBounds);
|
|
op.setUpperBounds(newUpperBounds);
|
|
op.setSteps(newSteps);
|
|
});
|
|
marker.replaceLinalgTransformationFilter(rewriter, op);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
LinalgLoopDistributionOptions options;
|
|
LinalgTransformationFilter marker;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::linalg::populateLinalgDistributeTiledLoopPattern(
|
|
RewritePatternSet &patterns, const LinalgLoopDistributionOptions &opts,
|
|
const LinalgTransformationFilter &marker) {
|
|
patterns.add<DistributeTiledLoopPattern>(patterns.getContext(), opts, marker);
|
|
}
|