233 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			233 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			C++
		
	
	
	
//===- AMDGPUEmitPrintf.cpp -----------------------------------------------===//
 | 
						|
//
 | 
						|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
						|
// See https://llvm.org/LICENSE.txt for license information.
 | 
						|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
						|
//
 | 
						|
//===----------------------------------------------------------------------===//
 | 
						|
//
 | 
						|
// Utility function to lower a printf call into a series of device
 | 
						|
// library calls on the AMDGPU target.
 | 
						|
//
 | 
						|
// WARNING: This file knows about certain library functions. It recognizes them
 | 
						|
// by name, and hardwires knowledge of their semantics.
 | 
						|
//
 | 
						|
//===----------------------------------------------------------------------===//
 | 
						|
 | 
						|
#include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h"
 | 
						|
#include "llvm/ADT/SparseBitVector.h"
 | 
						|
#include "llvm/Analysis/ValueTracking.h"
 | 
						|
 | 
						|
using namespace llvm;
 | 
						|
 | 
						|
#define DEBUG_TYPE "amdgpu-emit-printf"
 | 
						|
 | 
						|
static Value *fitArgInto64Bits(IRBuilder<> &Builder, Value *Arg) {
 | 
						|
  auto Int64Ty = Builder.getInt64Ty();
 | 
						|
  auto Ty = Arg->getType();
 | 
						|
 | 
						|
  if (auto IntTy = dyn_cast<IntegerType>(Ty)) {
 | 
						|
    switch (IntTy->getBitWidth()) {
 | 
						|
    case 32:
 | 
						|
      return Builder.CreateZExt(Arg, Int64Ty);
 | 
						|
    case 64:
 | 
						|
      return Arg;
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  if (Ty->getTypeID() == Type::DoubleTyID) {
 | 
						|
    return Builder.CreateBitCast(Arg, Int64Ty);
 | 
						|
  }
 | 
						|
 | 
						|
  if (isa<PointerType>(Ty)) {
 | 
						|
    return Builder.CreatePtrToInt(Arg, Int64Ty);
 | 
						|
  }
 | 
						|
 | 
						|
  llvm_unreachable("unexpected type");
 | 
						|
}
 | 
						|
 | 
						|
static Value *callPrintfBegin(IRBuilder<> &Builder, Value *Version) {
 | 
						|
  auto Int64Ty = Builder.getInt64Ty();
 | 
						|
  auto M = Builder.GetInsertBlock()->getModule();
 | 
						|
  auto Fn = M->getOrInsertFunction("__ockl_printf_begin", Int64Ty, Int64Ty);
 | 
						|
  return Builder.CreateCall(Fn, Version);
 | 
						|
}
 | 
						|
 | 
						|
static Value *callAppendArgs(IRBuilder<> &Builder, Value *Desc, int NumArgs,
 | 
						|
                             Value *Arg0, Value *Arg1, Value *Arg2, Value *Arg3,
 | 
						|
                             Value *Arg4, Value *Arg5, Value *Arg6,
 | 
						|
                             bool IsLast) {
 | 
						|
  auto Int64Ty = Builder.getInt64Ty();
 | 
						|
  auto Int32Ty = Builder.getInt32Ty();
 | 
						|
  auto M = Builder.GetInsertBlock()->getModule();
 | 
						|
  auto Fn = M->getOrInsertFunction("__ockl_printf_append_args", Int64Ty,
 | 
						|
                                   Int64Ty, Int32Ty, Int64Ty, Int64Ty, Int64Ty,
 | 
						|
                                   Int64Ty, Int64Ty, Int64Ty, Int64Ty, Int32Ty);
 | 
						|
  auto IsLastValue = Builder.getInt32(IsLast);
 | 
						|
  auto NumArgsValue = Builder.getInt32(NumArgs);
 | 
						|
  return Builder.CreateCall(Fn, {Desc, NumArgsValue, Arg0, Arg1, Arg2, Arg3,
 | 
						|
                                 Arg4, Arg5, Arg6, IsLastValue});
 | 
						|
}
 | 
						|
 | 
						|
static Value *appendArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,
 | 
						|
                        bool IsLast) {
 | 
						|
  auto Arg0 = fitArgInto64Bits(Builder, Arg);
 | 
						|
  auto Zero = Builder.getInt64(0);
 | 
						|
  return callAppendArgs(Builder, Desc, 1, Arg0, Zero, Zero, Zero, Zero, Zero,
 | 
						|
                        Zero, IsLast);
 | 
						|
}
 | 
						|
 | 
						|
// The device library does not provide strlen, so we build our own loop
 | 
						|
// here. While we are at it, we also include the terminating null in the length.
 | 
						|
static Value *getStrlenWithNull(IRBuilder<> &Builder, Value *Str) {
 | 
						|
  auto *Prev = Builder.GetInsertBlock();
 | 
						|
  Module *M = Prev->getModule();
 | 
						|
 | 
						|
  auto CharZero = Builder.getInt8(0);
 | 
						|
  auto One = Builder.getInt64(1);
 | 
						|
  auto Zero = Builder.getInt64(0);
 | 
						|
  auto Int64Ty = Builder.getInt64Ty();
 | 
						|
 | 
						|
  // The length is either zero for a null pointer, or the computed value for an
 | 
						|
  // actual string. We need a join block for a phi that represents the final
 | 
						|
  // value.
 | 
						|
  //
 | 
						|
  //  Strictly speaking, the zero does not matter since
 | 
						|
  // __ockl_printf_append_string_n ignores the length if the pointer is null.
 | 
						|
  BasicBlock *Join = nullptr;
 | 
						|
  if (Prev->getTerminator()) {
 | 
						|
    Join = Prev->splitBasicBlock(Builder.GetInsertPoint(),
 | 
						|
                                 "strlen.join");
 | 
						|
    Prev->getTerminator()->eraseFromParent();
 | 
						|
  } else {
 | 
						|
    Join = BasicBlock::Create(M->getContext(), "strlen.join",
 | 
						|
                              Prev->getParent());
 | 
						|
  }
 | 
						|
  BasicBlock *While =
 | 
						|
      BasicBlock::Create(M->getContext(), "strlen.while",
 | 
						|
                         Prev->getParent(), Join);
 | 
						|
  BasicBlock *WhileDone = BasicBlock::Create(
 | 
						|
      M->getContext(), "strlen.while.done",
 | 
						|
      Prev->getParent(), Join);
 | 
						|
 | 
						|
  // Emit an early return for when the pointer is null.
 | 
						|
  Builder.SetInsertPoint(Prev);
 | 
						|
  auto CmpNull =
 | 
						|
      Builder.CreateICmpEQ(Str, Constant::getNullValue(Str->getType()));
 | 
						|
  BranchInst::Create(Join, While, CmpNull, Prev);
 | 
						|
 | 
						|
  // Entry to the while loop.
 | 
						|
  Builder.SetInsertPoint(While);
 | 
						|
 | 
						|
  auto PtrPhi = Builder.CreatePHI(Str->getType(), 2);
 | 
						|
  PtrPhi->addIncoming(Str, Prev);
 | 
						|
  auto PtrNext = Builder.CreateGEP(Builder.getInt8Ty(), PtrPhi, One);
 | 
						|
  PtrPhi->addIncoming(PtrNext, While);
 | 
						|
 | 
						|
  // Condition for the while loop.
 | 
						|
  auto Data = Builder.CreateLoad(Builder.getInt8Ty(), PtrPhi);
 | 
						|
  auto Cmp = Builder.CreateICmpEQ(Data, CharZero);
 | 
						|
  Builder.CreateCondBr(Cmp, WhileDone, While);
 | 
						|
 | 
						|
  // Add one to the computed length.
 | 
						|
  Builder.SetInsertPoint(WhileDone, WhileDone->begin());
 | 
						|
  auto Begin = Builder.CreatePtrToInt(Str, Int64Ty);
 | 
						|
  auto End = Builder.CreatePtrToInt(PtrPhi, Int64Ty);
 | 
						|
  auto Len = Builder.CreateSub(End, Begin);
 | 
						|
  Len = Builder.CreateAdd(Len, One);
 | 
						|
 | 
						|
  // Final join.
 | 
						|
  BranchInst::Create(Join, WhileDone);
 | 
						|
  Builder.SetInsertPoint(Join, Join->begin());
 | 
						|
  auto LenPhi = Builder.CreatePHI(Len->getType(), 2);
 | 
						|
  LenPhi->addIncoming(Len, WhileDone);
 | 
						|
  LenPhi->addIncoming(Zero, Prev);
 | 
						|
 | 
						|
  return LenPhi;
 | 
						|
}
 | 
						|
 | 
						|
static Value *callAppendStringN(IRBuilder<> &Builder, Value *Desc, Value *Str,
 | 
						|
                                Value *Length, bool isLast) {
 | 
						|
  auto Int64Ty = Builder.getInt64Ty();
 | 
						|
  auto CharPtrTy = Builder.getInt8PtrTy();
 | 
						|
  auto Int32Ty = Builder.getInt32Ty();
 | 
						|
  auto M = Builder.GetInsertBlock()->getModule();
 | 
						|
  auto Fn = M->getOrInsertFunction("__ockl_printf_append_string_n", Int64Ty,
 | 
						|
                                   Int64Ty, CharPtrTy, Int64Ty, Int32Ty);
 | 
						|
  auto IsLastInt32 = Builder.getInt32(isLast);
 | 
						|
  return Builder.CreateCall(Fn, {Desc, Str, Length, IsLastInt32});
 | 
						|
}
 | 
						|
 | 
						|
static Value *appendString(IRBuilder<> &Builder, Value *Desc, Value *Arg,
 | 
						|
                           bool IsLast) {
 | 
						|
  Arg = Builder.CreateBitCast(
 | 
						|
      Arg, Builder.getInt8PtrTy(Arg->getType()->getPointerAddressSpace()));
 | 
						|
  auto Length = getStrlenWithNull(Builder, Arg);
 | 
						|
  return callAppendStringN(Builder, Desc, Arg, Length, IsLast);
 | 
						|
}
 | 
						|
 | 
						|
static Value *processArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,
 | 
						|
                         bool SpecIsCString, bool IsLast) {
 | 
						|
  if (SpecIsCString && isa<PointerType>(Arg->getType())) {
 | 
						|
    return appendString(Builder, Desc, Arg, IsLast);
 | 
						|
  }
 | 
						|
  // If the format specifies a string but the argument is not, the frontend will
 | 
						|
  // have printed a warning. We just rely on undefined behaviour and send the
 | 
						|
  // argument anyway.
 | 
						|
  return appendArg(Builder, Desc, Arg, IsLast);
 | 
						|
}
 | 
						|
 | 
						|
// Scan the format string to locate all specifiers, and mark the ones that
 | 
						|
// specify a string, i.e, the "%s" specifier with optional '*' characters.
 | 
						|
static void locateCStrings(SparseBitVector<8> &BV, Value *Fmt) {
 | 
						|
  StringRef Str;
 | 
						|
  if (!getConstantStringInfo(Fmt, Str) || Str.empty())
 | 
						|
    return;
 | 
						|
 | 
						|
  static const char ConvSpecifiers[] = "diouxXfFeEgGaAcspn";
 | 
						|
  size_t SpecPos = 0;
 | 
						|
  // Skip the first argument, the format string.
 | 
						|
  unsigned ArgIdx = 1;
 | 
						|
 | 
						|
  while ((SpecPos = Str.find_first_of('%', SpecPos)) != StringRef::npos) {
 | 
						|
    if (Str[SpecPos + 1] == '%') {
 | 
						|
      SpecPos += 2;
 | 
						|
      continue;
 | 
						|
    }
 | 
						|
    auto SpecEnd = Str.find_first_of(ConvSpecifiers, SpecPos);
 | 
						|
    if (SpecEnd == StringRef::npos)
 | 
						|
      return;
 | 
						|
    auto Spec = Str.slice(SpecPos, SpecEnd + 1);
 | 
						|
    ArgIdx += Spec.count('*');
 | 
						|
    if (Str[SpecEnd] == 's') {
 | 
						|
      BV.set(ArgIdx);
 | 
						|
    }
 | 
						|
    SpecPos = SpecEnd + 1;
 | 
						|
    ++ArgIdx;
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
Value *llvm::emitAMDGPUPrintfCall(IRBuilder<> &Builder,
 | 
						|
                                  ArrayRef<Value *> Args) {
 | 
						|
  auto NumOps = Args.size();
 | 
						|
  assert(NumOps >= 1);
 | 
						|
 | 
						|
  auto Fmt = Args[0];
 | 
						|
  SparseBitVector<8> SpecIsCString;
 | 
						|
  locateCStrings(SpecIsCString, Fmt);
 | 
						|
 | 
						|
  auto Desc = callPrintfBegin(Builder, Builder.getIntN(64, 0));
 | 
						|
  Desc = appendString(Builder, Desc, Fmt, NumOps == 1);
 | 
						|
 | 
						|
  // FIXME: This invokes hostcall once for each argument. We can pack up to
 | 
						|
  // seven scalar printf arguments in a single hostcall. See the signature of
 | 
						|
  // callAppendArgs().
 | 
						|
  for (unsigned int i = 1; i != NumOps; ++i) {
 | 
						|
    bool IsLast = i == NumOps - 1;
 | 
						|
    bool IsCString = SpecIsCString.test(i);
 | 
						|
    Desc = processArg(Builder, Desc, Args[i], IsCString, IsLast);
 | 
						|
  }
 | 
						|
 | 
						|
  return Builder.CreateTrunc(Desc, Builder.getInt32Ty());
 | 
						|
}
 |