266 lines
7.5 KiB
C++
266 lines
7.5 KiB
C++
//===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
///
|
|
/// \file This file contains passes and utilities to lower llvm intrinsic call
|
|
/// to DXILOp function call.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "DXILConstants.h"
|
|
#include "DirectX.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/CodeGen/Passes.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/Instruction.h"
|
|
#include "llvm/IR/Intrinsics.h"
|
|
#include "llvm/IR/IntrinsicsDirectX.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/IR/PassManager.h"
|
|
#include "llvm/Pass.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
|
|
#define DEBUG_TYPE "dxil-op-lower"
|
|
|
|
using namespace llvm;
|
|
using namespace llvm::DXIL;
|
|
|
|
constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
|
|
|
|
enum OverloadKind : uint16_t {
|
|
VOID = 1,
|
|
HALF = 1 << 1,
|
|
FLOAT = 1 << 2,
|
|
DOUBLE = 1 << 3,
|
|
I1 = 1 << 4,
|
|
I8 = 1 << 5,
|
|
I16 = 1 << 6,
|
|
I32 = 1 << 7,
|
|
I64 = 1 << 8,
|
|
UserDefineType = 1 << 9,
|
|
ObjectType = 1 << 10,
|
|
};
|
|
|
|
static const char *getOverloadTypeName(OverloadKind Kind) {
|
|
switch (Kind) {
|
|
case OverloadKind::HALF:
|
|
return "f16";
|
|
case OverloadKind::FLOAT:
|
|
return "f32";
|
|
case OverloadKind::DOUBLE:
|
|
return "f64";
|
|
case OverloadKind::I1:
|
|
return "i1";
|
|
case OverloadKind::I8:
|
|
return "i8";
|
|
case OverloadKind::I16:
|
|
return "i16";
|
|
case OverloadKind::I32:
|
|
return "i32";
|
|
case OverloadKind::I64:
|
|
return "i64";
|
|
case OverloadKind::VOID:
|
|
case OverloadKind::ObjectType:
|
|
case OverloadKind::UserDefineType:
|
|
break;
|
|
}
|
|
llvm_unreachable("invalid overload type for name");
|
|
return "void";
|
|
}
|
|
|
|
static OverloadKind getOverloadKind(Type *Ty) {
|
|
Type::TypeID T = Ty->getTypeID();
|
|
switch (T) {
|
|
case Type::VoidTyID:
|
|
return OverloadKind::VOID;
|
|
case Type::HalfTyID:
|
|
return OverloadKind::HALF;
|
|
case Type::FloatTyID:
|
|
return OverloadKind::FLOAT;
|
|
case Type::DoubleTyID:
|
|
return OverloadKind::DOUBLE;
|
|
case Type::IntegerTyID: {
|
|
IntegerType *ITy = cast<IntegerType>(Ty);
|
|
unsigned Bits = ITy->getBitWidth();
|
|
switch (Bits) {
|
|
case 1:
|
|
return OverloadKind::I1;
|
|
case 8:
|
|
return OverloadKind::I8;
|
|
case 16:
|
|
return OverloadKind::I16;
|
|
case 32:
|
|
return OverloadKind::I32;
|
|
case 64:
|
|
return OverloadKind::I64;
|
|
default:
|
|
llvm_unreachable("invalid overload type");
|
|
return OverloadKind::VOID;
|
|
}
|
|
}
|
|
case Type::PointerTyID:
|
|
return OverloadKind::UserDefineType;
|
|
case Type::StructTyID:
|
|
return OverloadKind::ObjectType;
|
|
default:
|
|
llvm_unreachable("invalid overload type");
|
|
return OverloadKind::VOID;
|
|
}
|
|
}
|
|
|
|
static std::string getTypeName(OverloadKind Kind, Type *Ty) {
|
|
if (Kind < OverloadKind::UserDefineType) {
|
|
return getOverloadTypeName(Kind);
|
|
} else if (Kind == OverloadKind::UserDefineType) {
|
|
StructType *ST = cast<StructType>(Ty);
|
|
return ST->getStructName().str();
|
|
} else if (Kind == OverloadKind::ObjectType) {
|
|
StructType *ST = cast<StructType>(Ty);
|
|
return ST->getStructName().str();
|
|
} else {
|
|
std::string Str;
|
|
raw_string_ostream OS(Str);
|
|
Ty->print(OS);
|
|
return OS.str();
|
|
}
|
|
}
|
|
|
|
// Static properties.
|
|
struct OpCodeProperty {
|
|
DXIL::OpCode OpCode;
|
|
// Offset in DXILOpCodeNameTable.
|
|
unsigned OpCodeNameOffset;
|
|
DXIL::OpCodeClass OpCodeClass;
|
|
// Offset in DXILOpCodeClassNameTable.
|
|
unsigned OpCodeClassNameOffset;
|
|
uint16_t OverloadTys;
|
|
llvm::Attribute::AttrKind FuncAttr;
|
|
};
|
|
|
|
// Include getOpCodeClassName getOpCodeProperty and getOpCodeName which
|
|
// generated by tableGen.
|
|
#define DXIL_OP_OPERATION_TABLE
|
|
#include "DXILOperation.inc"
|
|
#undef DXIL_OP_OPERATION_TABLE
|
|
|
|
static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
|
|
const OpCodeProperty &Prop) {
|
|
if (Kind == OverloadKind::VOID) {
|
|
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
|
|
}
|
|
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
|
|
getTypeName(Kind, Ty))
|
|
.str();
|
|
}
|
|
|
|
static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F,
|
|
Module &M) {
|
|
const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
|
|
|
|
// Get return type as overload type for DXILOp.
|
|
// Only simple mapping case here, so return type is good enough.
|
|
Type *OverloadTy = F.getReturnType();
|
|
|
|
OverloadKind Kind = getOverloadKind(OverloadTy);
|
|
// FIXME: find the issue and report error in clang instead of check it in
|
|
// backend.
|
|
if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
|
|
llvm_unreachable("invalid overload");
|
|
}
|
|
|
|
std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
|
|
assert(!M.getFunction(FnName) && "Function already exists");
|
|
|
|
auto &Ctx = M.getContext();
|
|
Type *OpCodeTy = Type::getInt32Ty(Ctx);
|
|
|
|
SmallVector<Type *> ArgTypes;
|
|
// DXIL has i32 opcode as first arg.
|
|
ArgTypes.emplace_back(OpCodeTy);
|
|
FunctionType *FT = F.getFunctionType();
|
|
ArgTypes.append(FT->param_begin(), FT->param_end());
|
|
FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false);
|
|
return M.getOrInsertFunction(FnName, DXILOpFT);
|
|
}
|
|
|
|
static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) {
|
|
auto DXILOpFn = createDXILOpFunction(DXILOp, F, M);
|
|
IRBuilder<> B(M.getContext());
|
|
Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
|
|
for (User *U : make_early_inc_range(F.users())) {
|
|
CallInst *CI = dyn_cast<CallInst>(U);
|
|
if (!CI)
|
|
continue;
|
|
|
|
SmallVector<Value *> Args;
|
|
Args.emplace_back(DXILOpArg);
|
|
Args.append(CI->arg_begin(), CI->arg_end());
|
|
B.SetInsertPoint(CI);
|
|
CallInst *DXILCI = B.CreateCall(DXILOpFn, Args);
|
|
LLVM_DEBUG(DXILCI->setName(getOpCodeName(DXILOp)));
|
|
CI->replaceAllUsesWith(DXILCI);
|
|
CI->eraseFromParent();
|
|
}
|
|
if (F.user_empty())
|
|
F.eraseFromParent();
|
|
}
|
|
|
|
static bool lowerIntrinsics(Module &M) {
|
|
bool Updated = false;
|
|
|
|
#define DXIL_OP_INTRINSIC_MAP
|
|
#include "DXILOperation.inc"
|
|
#undef DXIL_OP_INTRINSIC_MAP
|
|
|
|
for (Function &F : make_early_inc_range(M.functions())) {
|
|
if (!F.isDeclaration())
|
|
continue;
|
|
Intrinsic::ID ID = F.getIntrinsicID();
|
|
if (ID == Intrinsic::not_intrinsic)
|
|
continue;
|
|
auto LowerIt = LowerMap.find(ID);
|
|
if (LowerIt == LowerMap.end())
|
|
continue;
|
|
lowerIntrinsic(LowerIt->second, F, M);
|
|
Updated = true;
|
|
}
|
|
return Updated;
|
|
}
|
|
|
|
namespace {
|
|
/// A pass that transforms external global definitions into declarations.
|
|
class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
|
|
public:
|
|
PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
|
|
if (lowerIntrinsics(M))
|
|
return PreservedAnalyses::none();
|
|
return PreservedAnalyses::all();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
class DXILOpLoweringLegacy : public ModulePass {
|
|
public:
|
|
bool runOnModule(Module &M) override { return lowerIntrinsics(M); }
|
|
StringRef getPassName() const override { return "DXIL Op Lowering"; }
|
|
DXILOpLoweringLegacy() : ModulePass(ID) {}
|
|
|
|
static char ID; // Pass identification.
|
|
};
|
|
char DXILOpLoweringLegacy::ID = 0;
|
|
|
|
} // end anonymous namespace
|
|
|
|
INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
|
|
false, false)
|
|
INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
|
|
false)
|
|
|
|
ModulePass *llvm::createDXILOpLoweringLegacyPass() {
|
|
return new DXILOpLoweringLegacy();
|
|
}
|