189 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			189 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			C++
		
	
	
	
//===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===//
 | 
						|
//
 | 
						|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
						|
// See https://llvm.org/LICENSE.txt for license information.
 | 
						|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
						|
//
 | 
						|
//===----------------------------------------------------------------------===//
 | 
						|
 | 
						|
#include "mlir/Analysis/DataFlowFramework.h"
 | 
						|
#include "mlir/Dialect/Func/IR/FuncOps.h"
 | 
						|
#include "mlir/Pass/Pass.h"
 | 
						|
 | 
						|
using namespace mlir;
 | 
						|
 | 
						|
namespace {
 | 
						|
/// This analysis state represents an integer that is XOR'd with other states.
 | 
						|
class FooState : public AnalysisState {
 | 
						|
public:
 | 
						|
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState)
 | 
						|
 | 
						|
  using AnalysisState::AnalysisState;
 | 
						|
 | 
						|
  /// Default-initialize the state to zero.
 | 
						|
  ChangeResult defaultInitialize() override { return join(0); }
 | 
						|
 | 
						|
  /// Returns true if the state is uninitialized.
 | 
						|
  bool isUninitialized() const override { return !state; }
 | 
						|
 | 
						|
  /// Print the integer value or "none" if uninitialized.
 | 
						|
  void print(raw_ostream &os) const override {
 | 
						|
    if (state)
 | 
						|
      os << *state;
 | 
						|
    else
 | 
						|
      os << "none";
 | 
						|
  }
 | 
						|
 | 
						|
  /// Join the state with another. If either is unintialized, take the
 | 
						|
  /// initialized value. Otherwise, XOR the integer values.
 | 
						|
  ChangeResult join(const FooState &rhs) {
 | 
						|
    if (rhs.isUninitialized())
 | 
						|
      return ChangeResult::NoChange;
 | 
						|
    return join(*rhs.state);
 | 
						|
  }
 | 
						|
  ChangeResult join(uint64_t value) {
 | 
						|
    if (isUninitialized()) {
 | 
						|
      state = value;
 | 
						|
      return ChangeResult::Change;
 | 
						|
    }
 | 
						|
    uint64_t before = *state;
 | 
						|
    state = before ^ value;
 | 
						|
    return before == *state ? ChangeResult::NoChange : ChangeResult::Change;
 | 
						|
  }
 | 
						|
 | 
						|
  /// Set the value of the state directly.
 | 
						|
  ChangeResult set(const FooState &rhs) {
 | 
						|
    if (state == rhs.state)
 | 
						|
      return ChangeResult::NoChange;
 | 
						|
    state = rhs.state;
 | 
						|
    return ChangeResult::Change;
 | 
						|
  }
 | 
						|
 | 
						|
  /// Returns the integer value of the state.
 | 
						|
  uint64_t getValue() const { return *state; }
 | 
						|
 | 
						|
private:
 | 
						|
  /// An optional integer value.
 | 
						|
  Optional<uint64_t> state;
 | 
						|
};
 | 
						|
 | 
						|
/// This analysis computes `FooState` across operations and control-flow edges.
 | 
						|
/// If an op specifies a `foo` integer attribute, the contained value is XOR'd
 | 
						|
/// with the value before the operation.
 | 
						|
class FooAnalysis : public DataFlowAnalysis {
 | 
						|
public:
 | 
						|
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis)
 | 
						|
 | 
						|
  using DataFlowAnalysis::DataFlowAnalysis;
 | 
						|
 | 
						|
  LogicalResult initialize(Operation *top) override;
 | 
						|
  LogicalResult visit(ProgramPoint point) override;
 | 
						|
 | 
						|
private:
 | 
						|
  void visitBlock(Block *block);
 | 
						|
  void visitOperation(Operation *op);
 | 
						|
};
 | 
						|
 | 
						|
struct TestFooAnalysisPass
 | 
						|
    : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> {
 | 
						|
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass)
 | 
						|
 | 
						|
  StringRef getArgument() const override { return "test-foo-analysis"; }
 | 
						|
 | 
						|
  void runOnOperation() override;
 | 
						|
};
 | 
						|
} // namespace
 | 
						|
 | 
						|
LogicalResult FooAnalysis::initialize(Operation *top) {
 | 
						|
  if (top->getNumRegions() != 1)
 | 
						|
    return top->emitError("expected a single region top-level op");
 | 
						|
 | 
						|
  // Initialize the top-level state.
 | 
						|
  getOrCreate<FooState>(&top->getRegion(0).front())->join(0);
 | 
						|
 | 
						|
  // Visit all nested blocks and operations.
 | 
						|
  for (Block &block : top->getRegion(0)) {
 | 
						|
    visitBlock(&block);
 | 
						|
    for (Operation &op : block) {
 | 
						|
      if (op.getNumRegions())
 | 
						|
        return op.emitError("unexpected op with regions");
 | 
						|
      visitOperation(&op);
 | 
						|
    }
 | 
						|
  }
 | 
						|
  return success();
 | 
						|
}
 | 
						|
 | 
						|
LogicalResult FooAnalysis::visit(ProgramPoint point) {
 | 
						|
  if (auto *op = point.dyn_cast<Operation *>()) {
 | 
						|
    visitOperation(op);
 | 
						|
    return success();
 | 
						|
  }
 | 
						|
  if (auto *block = point.dyn_cast<Block *>()) {
 | 
						|
    visitBlock(block);
 | 
						|
    return success();
 | 
						|
  }
 | 
						|
  return emitError(point.getLoc(), "unknown point kind");
 | 
						|
}
 | 
						|
 | 
						|
void FooAnalysis::visitBlock(Block *block) {
 | 
						|
  if (block->isEntryBlock()) {
 | 
						|
    // This is the initial state. Let the framework default-initialize it.
 | 
						|
    return;
 | 
						|
  }
 | 
						|
  FooState *state = getOrCreate<FooState>(block);
 | 
						|
  ChangeResult result = ChangeResult::NoChange;
 | 
						|
  for (Block *pred : block->getPredecessors()) {
 | 
						|
    // Join the state at the terminators of all predecessors.
 | 
						|
    const FooState *predState =
 | 
						|
        getOrCreateFor<FooState>(block, pred->getTerminator());
 | 
						|
    result |= state->join(*predState);
 | 
						|
  }
 | 
						|
  propagateIfChanged(state, result);
 | 
						|
}
 | 
						|
 | 
						|
void FooAnalysis::visitOperation(Operation *op) {
 | 
						|
  FooState *state = getOrCreate<FooState>(op);
 | 
						|
  ChangeResult result = ChangeResult::NoChange;
 | 
						|
 | 
						|
  // Copy the state across the operation.
 | 
						|
  const FooState *prevState;
 | 
						|
  if (Operation *prev = op->getPrevNode())
 | 
						|
    prevState = getOrCreateFor<FooState>(op, prev);
 | 
						|
  else
 | 
						|
    prevState = getOrCreateFor<FooState>(op, op->getBlock());
 | 
						|
  result |= state->set(*prevState);
 | 
						|
 | 
						|
  // Modify the state with the attribute, if specified.
 | 
						|
  if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
 | 
						|
    uint64_t value = attr.getUInt();
 | 
						|
    result |= state->join(value);
 | 
						|
  }
 | 
						|
  propagateIfChanged(state, result);
 | 
						|
}
 | 
						|
 | 
						|
void TestFooAnalysisPass::runOnOperation() {
 | 
						|
  func::FuncOp func = getOperation();
 | 
						|
  DataFlowSolver solver;
 | 
						|
  solver.load<FooAnalysis>();
 | 
						|
  if (failed(solver.initializeAndRun(func)))
 | 
						|
    return signalPassFailure();
 | 
						|
 | 
						|
  raw_ostream &os = llvm::errs();
 | 
						|
  os << "function: @" << func.getSymName() << "\n";
 | 
						|
 | 
						|
  func.walk([&](Operation *op) {
 | 
						|
    auto tag = op->getAttrOfType<StringAttr>("tag");
 | 
						|
    if (!tag)
 | 
						|
      return;
 | 
						|
    const FooState *state = solver.lookupState<FooState>(op);
 | 
						|
    assert(state && !state->isUninitialized());
 | 
						|
    os << tag.getValue() << " -> " << state->getValue() << "\n";
 | 
						|
  });
 | 
						|
}
 | 
						|
 | 
						|
namespace mlir {
 | 
						|
namespace test {
 | 
						|
void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
 | 
						|
} // namespace test
 | 
						|
} // namespace mlir
 |