123 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			123 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			C++
		
	
	
	
//===-- LLJITWithOptimizingIRTransform.cpp -- LLJIT with IR optimization --===//
 | 
						|
//
 | 
						|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
						|
// See https://llvm.org/LICENSE.txt for license information.
 | 
						|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
						|
//
 | 
						|
//===----------------------------------------------------------------------===//
 | 
						|
//
 | 
						|
// In this example we will use an IR transform to optimize a module as it
 | 
						|
// passes through LLJIT's IRTransformLayer.
 | 
						|
//
 | 
						|
//===----------------------------------------------------------------------===//
 | 
						|
 | 
						|
#include "llvm/ExecutionEngine/Orc/LLJIT.h"
 | 
						|
#include "llvm/IR/LegacyPassManager.h"
 | 
						|
#include "llvm/Support/InitLLVM.h"
 | 
						|
#include "llvm/Support/TargetSelect.h"
 | 
						|
#include "llvm/Support/raw_ostream.h"
 | 
						|
#include "llvm/Transforms/IPO.h"
 | 
						|
#include "llvm/Transforms/Scalar.h"
 | 
						|
 | 
						|
#include "../ExampleModules.h"
 | 
						|
 | 
						|
using namespace llvm;
 | 
						|
using namespace llvm::orc;
 | 
						|
 | 
						|
ExitOnError ExitOnErr;
 | 
						|
 | 
						|
// Example IR module.
 | 
						|
//
 | 
						|
// This IR contains a recursive definition of the factorial function:
 | 
						|
//
 | 
						|
// fac(n) | n == 0    = 1
 | 
						|
//        | otherwise = n * fac(n - 1)
 | 
						|
//
 | 
						|
// It also contains an entry function which calls the factorial function with
 | 
						|
// an input value of 5.
 | 
						|
//
 | 
						|
// We expect the IR optimization transform that we build below to transform
 | 
						|
// this into a non-recursive factorial function and an entry function that
 | 
						|
// returns a constant value of 5!, or 120.
 | 
						|
 | 
						|
const llvm::StringRef MainMod =
 | 
						|
    R"(
 | 
						|
 | 
						|
  define i32 @fac(i32 %n) {
 | 
						|
  entry:
 | 
						|
    %tobool = icmp eq i32 %n, 0
 | 
						|
    br i1 %tobool, label %return, label %if.then
 | 
						|
 | 
						|
  if.then:                                          ; preds = %entry
 | 
						|
    %arg = add nsw i32 %n, -1
 | 
						|
    %call_result = call i32 @fac(i32 %arg)
 | 
						|
    %result = mul nsw i32 %n, %call_result
 | 
						|
    br label %return
 | 
						|
 | 
						|
  return:                                           ; preds = %entry, %if.then
 | 
						|
    %final_result = phi i32 [ %result, %if.then ], [ 1, %entry ]
 | 
						|
    ret i32 %final_result
 | 
						|
  }
 | 
						|
 | 
						|
  define i32 @entry() {
 | 
						|
  entry:
 | 
						|
    %result = call i32 @fac(i32 5)
 | 
						|
    ret i32 %result
 | 
						|
  }
 | 
						|
 | 
						|
)";
 | 
						|
 | 
						|
// A function object that creates a simple pass pipeline to apply to each
 | 
						|
// module as it passes through the IRTransformLayer.
 | 
						|
class MyOptimizationTransform {
 | 
						|
public:
 | 
						|
  MyOptimizationTransform() : PM(std::make_unique<legacy::PassManager>()) {
 | 
						|
    PM->add(createTailCallEliminationPass());
 | 
						|
    PM->add(createFunctionInliningPass());
 | 
						|
    PM->add(createIndVarSimplifyPass());
 | 
						|
    PM->add(createCFGSimplificationPass());
 | 
						|
  }
 | 
						|
 | 
						|
  Expected<ThreadSafeModule> operator()(ThreadSafeModule TSM,
 | 
						|
                                        MaterializationResponsibility &R) {
 | 
						|
    TSM.withModuleDo([this](Module &M) {
 | 
						|
      dbgs() << "--- BEFORE OPTIMIZATION ---\n" << M << "\n";
 | 
						|
      PM->run(M);
 | 
						|
      dbgs() << "--- AFTER OPTIMIZATION ---\n" << M << "\n";
 | 
						|
    });
 | 
						|
    return std::move(TSM);
 | 
						|
  }
 | 
						|
 | 
						|
private:
 | 
						|
  std::unique_ptr<legacy::PassManager> PM;
 | 
						|
};
 | 
						|
 | 
						|
int main(int argc, char *argv[]) {
 | 
						|
  // Initialize LLVM.
 | 
						|
  InitLLVM X(argc, argv);
 | 
						|
 | 
						|
  InitializeNativeTarget();
 | 
						|
  InitializeNativeTargetAsmPrinter();
 | 
						|
 | 
						|
  ExitOnErr.setBanner(std::string(argv[0]) + ": ");
 | 
						|
 | 
						|
  // (1) Create LLJIT instance.
 | 
						|
  auto J = ExitOnErr(LLJITBuilder().create());
 | 
						|
 | 
						|
  // (2) Install transform to optimize modules when they're materialized.
 | 
						|
  J->getIRTransformLayer().setTransform(MyOptimizationTransform());
 | 
						|
 | 
						|
  // (3) Add modules.
 | 
						|
  ExitOnErr(J->addIRModule(ExitOnErr(parseExampleModule(MainMod, "MainMod"))));
 | 
						|
 | 
						|
  // (4) Look up the JIT'd function and call it.
 | 
						|
  auto EntrySym = ExitOnErr(J->lookup("entry"));
 | 
						|
  auto *Entry = (int (*)())EntrySym.getAddress();
 | 
						|
 | 
						|
  int Result = Entry();
 | 
						|
  outs() << "--- Result ---\n"
 | 
						|
         << "entry() = " << Result << "\n";
 | 
						|
 | 
						|
  return 0;
 | 
						|
}
 |