* Correct the function prototypes for some of the functions to match the
actual spec (int -> uint)
* Add the ability to get/cache the strlen function prototype.
* Make sure generated values are appropriately named for debugging purposes
* Add the SPrintFOptimiation for 4 casts of sprintf optimization:
    sprintf(str,cstr) -> llvm.memcpy(str,cstr) (if cstr has no %)
    sprintf(str,"")   -> store sbyte 0, str
    sprintf(str,"%s",src) -> llvm.memcpy(str,src) (if src is constant)
    sprintf(str,"%c",chr) -> store chr, str   ; store sbyte 0, str+1
The sprintf optimization didn't fire as much as I had hoped:
  2 MultiSource/Applications/SPASS
  5 MultiSource/Benchmarks/McCat/18-imp
 22 MultiSource/Benchmarks/Prolangs-C/TimberWolfMC
  1 MultiSource/Benchmarks/Prolangs-C/assembler
  6 MultiSource/Benchmarks/Prolangs-C/unix-smail
  2 MultiSource/Benchmarks/mediabench/mpeg2/mpeg2dec
llvm-svn: 21679
			
			
This commit is contained in:
		
							parent
							
								
									23e9f163ad
								
							
						
					
					
						commit
						1e520fd661
					
				| 
						 | 
				
			
			@ -256,6 +256,21 @@ public:
 | 
			
		|||
    return sqrt_func;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// @brief Return a Function* for the strlen libcall
 | 
			
		||||
  Function* get_strcpy()
 | 
			
		||||
  {
 | 
			
		||||
    if (!strcpy_func)
 | 
			
		||||
    {
 | 
			
		||||
      std::vector<const Type*> args;
 | 
			
		||||
      args.push_back(PointerType::get(Type::SByteTy));
 | 
			
		||||
      args.push_back(PointerType::get(Type::SByteTy));
 | 
			
		||||
      FunctionType* strcpy_type = 
 | 
			
		||||
        FunctionType::get(PointerType::get(Type::SByteTy), args, false);
 | 
			
		||||
      strcpy_func = M->getOrInsertFunction("strcpy",strcpy_type);
 | 
			
		||||
    }
 | 
			
		||||
    return strcpy_func;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// @brief Return a Function* for the strlen libcall
 | 
			
		||||
  Function* get_strlen()
 | 
			
		||||
  {
 | 
			
		||||
| 
						 | 
				
			
			@ -295,8 +310,8 @@ public:
 | 
			
		|||
      std::vector<const Type*> args;
 | 
			
		||||
      args.push_back(PointerType::get(Type::SByteTy));
 | 
			
		||||
      args.push_back(PointerType::get(Type::SByteTy));
 | 
			
		||||
      args.push_back(Type::IntTy);
 | 
			
		||||
      args.push_back(Type::IntTy);
 | 
			
		||||
      args.push_back(Type::UIntTy);
 | 
			
		||||
      args.push_back(Type::UIntTy);
 | 
			
		||||
      FunctionType* memcpy_type = FunctionType::get(Type::VoidTy, args, false);
 | 
			
		||||
      memcpy_func = M->getOrInsertFunction("llvm.memcpy",memcpy_type);
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			@ -314,6 +329,7 @@ private:
 | 
			
		|||
    memcpy_func = 0;
 | 
			
		||||
    memchr_func = 0;
 | 
			
		||||
    sqrt_func   = 0;
 | 
			
		||||
    strcpy_func = 0;
 | 
			
		||||
    strlen_func = 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -323,6 +339,7 @@ private:
 | 
			
		|||
  Function* memcpy_func; ///< Cached llvm.memcpy function
 | 
			
		||||
  Function* memchr_func; ///< Cached memchr function
 | 
			
		||||
  Function* sqrt_func;   ///< Cached sqrt function
 | 
			
		||||
  Function* strcpy_func; ///< Cached strcpy function
 | 
			
		||||
  Function* strlen_func; ///< Cached strlen function
 | 
			
		||||
  Module* M;             ///< Cached Module
 | 
			
		||||
  TargetData* TD;        ///< Cached TargetData
 | 
			
		||||
| 
						 | 
				
			
			@ -493,8 +510,8 @@ public:
 | 
			
		|||
    std::vector<Value*> vals;
 | 
			
		||||
    vals.push_back(gep); // destination
 | 
			
		||||
    vals.push_back(ci->getOperand(2)); // source
 | 
			
		||||
    vals.push_back(ConstantSInt::get(Type::IntTy,len)); // length
 | 
			
		||||
    vals.push_back(ConstantSInt::get(Type::IntTy,1)); // alignment
 | 
			
		||||
    vals.push_back(ConstantUInt::get(Type::UIntTy,len)); // length
 | 
			
		||||
    vals.push_back(ConstantUInt::get(Type::UIntTy,1)); // alignment
 | 
			
		||||
    new CallInst(SLC.get_memcpy(), vals, "", ci);
 | 
			
		||||
 | 
			
		||||
    // Finally, substitute the first operand of the strcat call for the 
 | 
			
		||||
| 
						 | 
				
			
			@ -862,8 +879,8 @@ public:
 | 
			
		|||
    std::vector<Value*> vals;
 | 
			
		||||
    vals.push_back(dest); // destination
 | 
			
		||||
    vals.push_back(src); // source
 | 
			
		||||
    vals.push_back(ConstantSInt::get(Type::IntTy,len)); // length
 | 
			
		||||
    vals.push_back(ConstantSInt::get(Type::IntTy,1)); // alignment
 | 
			
		||||
    vals.push_back(ConstantUInt::get(Type::UIntTy,len)); // length
 | 
			
		||||
    vals.push_back(ConstantUInt::get(Type::UIntTy,1)); // alignment
 | 
			
		||||
    new CallInst(SLC.get_memcpy(), vals, "", ci);
 | 
			
		||||
 | 
			
		||||
    // Finally, substitute the first operand of the strcat call for the 
 | 
			
		||||
| 
						 | 
				
			
			@ -1255,7 +1272,8 @@ public:
 | 
			
		|||
      args.push_back(ConstantUInt::get(SLC.getIntPtrType(),len));
 | 
			
		||||
      args.push_back(ConstantUInt::get(SLC.getIntPtrType(),1));
 | 
			
		||||
      args.push_back(ci->getOperand(1));
 | 
			
		||||
      new CallInst(fwrite_func,args,"",ci);
 | 
			
		||||
      new CallInst(fwrite_func,args,ci->getName(),ci);
 | 
			
		||||
      ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,len));
 | 
			
		||||
      ci->eraseFromParent();
 | 
			
		||||
      return true;
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			@ -1281,7 +1299,7 @@ public:
 | 
			
		|||
        if (!getConstantStringLength(ci->getOperand(3), len, &CA))
 | 
			
		||||
          return false;
 | 
			
		||||
 | 
			
		||||
        // fprintf(file,fmt) -> fwrite(fmt,strlen(fmt),1,file) 
 | 
			
		||||
        // fprintf(file,"%s",str) -> fwrite(fmt,strlen(fmt),1,file) 
 | 
			
		||||
        const Type* FILEptr_type = ci->getOperand(1)->getType();
 | 
			
		||||
        Function* fwrite_func = SLC.get_fwrite(FILEptr_type);
 | 
			
		||||
        if (!fwrite_func)
 | 
			
		||||
| 
						 | 
				
			
			@ -1291,7 +1309,8 @@ public:
 | 
			
		|||
        args.push_back(ConstantUInt::get(SLC.getIntPtrType(),len));
 | 
			
		||||
        args.push_back(ConstantUInt::get(SLC.getIntPtrType(),1));
 | 
			
		||||
        args.push_back(ci->getOperand(1));
 | 
			
		||||
        new CallInst(fwrite_func,args,"",ci);
 | 
			
		||||
        new CallInst(fwrite_func,args,ci->getName(),ci);
 | 
			
		||||
        ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,len));
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      case 'c':
 | 
			
		||||
| 
						 | 
				
			
			@ -1306,6 +1325,7 @@ public:
 | 
			
		|||
          return false;
 | 
			
		||||
        CastInst* cast = new CastInst(CI,Type::IntTy,CI->getName()+".int",ci);
 | 
			
		||||
        new CallInst(fputc_func,cast,ci->getOperand(1),"",ci);
 | 
			
		||||
        ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,1));
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      default:
 | 
			
		||||
| 
						 | 
				
			
			@ -1317,6 +1337,149 @@ public:
 | 
			
		|||
} FPrintFOptimizer;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
/// This LibCallOptimization will simplify calls to the "sprintf" library 
 | 
			
		||||
/// function. It looks for cases where the result of sprintf is not used and the
 | 
			
		||||
/// operation can be reduced to something simpler.
 | 
			
		||||
/// @brief Simplify the pow library function.
 | 
			
		||||
struct SPrintFOptimization : public LibCallOptimization
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
  /// @brief Default Constructor
 | 
			
		||||
  SPrintFOptimization() : LibCallOptimization("sprintf",
 | 
			
		||||
      "simplify-libcalls:sprintf", "Number of 'sprintf' calls simplified") {}
 | 
			
		||||
 | 
			
		||||
  /// @brief Destructor
 | 
			
		||||
  virtual ~SPrintFOptimization() {}
 | 
			
		||||
 | 
			
		||||
  /// @brief Make sure that the "fprintf" function has the right prototype
 | 
			
		||||
  virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC)
 | 
			
		||||
  {
 | 
			
		||||
    // Just make sure this has at least 2 arguments
 | 
			
		||||
    return (f->getReturnType() == Type::IntTy && f->arg_size() >= 2);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// @brief Perform the sprintf optimization.
 | 
			
		||||
  virtual bool OptimizeCall(CallInst* ci, SimplifyLibCalls& SLC)
 | 
			
		||||
  {
 | 
			
		||||
    // If the call has more than 3 operands, we can't optimize it
 | 
			
		||||
    if (ci->getNumOperands() > 4 || ci->getNumOperands() < 3)
 | 
			
		||||
      return false;
 | 
			
		||||
 | 
			
		||||
    // All the optimizations depend on the length of the second argument and the
 | 
			
		||||
    // fact that it is a constant string array. Check that now
 | 
			
		||||
    uint64_t len = 0; 
 | 
			
		||||
    ConstantArray* CA = 0;
 | 
			
		||||
    if (!getConstantStringLength(ci->getOperand(2), len, &CA))
 | 
			
		||||
      return false;
 | 
			
		||||
 | 
			
		||||
    if (ci->getNumOperands() == 3)
 | 
			
		||||
    {
 | 
			
		||||
      if (len == 0)
 | 
			
		||||
      {
 | 
			
		||||
        // If the length is 0, we just need to store a null byte
 | 
			
		||||
        new StoreInst(ConstantInt::get(Type::SByteTy,0),ci->getOperand(1),ci);
 | 
			
		||||
        ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,0));
 | 
			
		||||
        ci->eraseFromParent();
 | 
			
		||||
        return true;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Make sure there's no % in the constant array
 | 
			
		||||
      for (unsigned i = 0; i < len; ++i)
 | 
			
		||||
      {
 | 
			
		||||
        if (ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(i)))
 | 
			
		||||
        {
 | 
			
		||||
          // Check for the null terminator
 | 
			
		||||
          if (CI->getRawValue() == '%')
 | 
			
		||||
            return false; // we found a %, can't optimize
 | 
			
		||||
        }
 | 
			
		||||
        else 
 | 
			
		||||
          return false; // initializer is not constant int, can't optimize
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Increment length because we want to copy the null byte too
 | 
			
		||||
      len++;
 | 
			
		||||
 | 
			
		||||
      // sprintf(str,fmt) -> llvm.memcpy(str,fmt,strlen(fmt),1) 
 | 
			
		||||
      Function* memcpy_func = SLC.get_memcpy();
 | 
			
		||||
      if (!memcpy_func)
 | 
			
		||||
        return false;
 | 
			
		||||
      std::vector<Value*> args;
 | 
			
		||||
      args.push_back(ci->getOperand(1));
 | 
			
		||||
      args.push_back(ci->getOperand(2));
 | 
			
		||||
      args.push_back(ConstantUInt::get(Type::UIntTy,len));
 | 
			
		||||
      args.push_back(ConstantUInt::get(Type::UIntTy,1));
 | 
			
		||||
      new CallInst(memcpy_func,args,"",ci);
 | 
			
		||||
      ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,len));
 | 
			
		||||
      ci->eraseFromParent();
 | 
			
		||||
      return true;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // The remaining optimizations require the format string to be length 2
 | 
			
		||||
    // "%s" or "%c".
 | 
			
		||||
    if (len != 2)
 | 
			
		||||
      return false;
 | 
			
		||||
 | 
			
		||||
    // The first character has to be a %
 | 
			
		||||
    if (ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(0)))
 | 
			
		||||
      if (CI->getRawValue() != '%')
 | 
			
		||||
        return false;
 | 
			
		||||
 | 
			
		||||
    // Get the second character and switch on its value
 | 
			
		||||
    ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(1));
 | 
			
		||||
    switch (CI->getRawValue())
 | 
			
		||||
    {
 | 
			
		||||
      case 's':
 | 
			
		||||
      {
 | 
			
		||||
        uint64_t len = 0;
 | 
			
		||||
        if (ci->hasNUses(0))
 | 
			
		||||
        {
 | 
			
		||||
          // sprintf(dest,"%s",str) -> strcpy(dest,str) 
 | 
			
		||||
          Function* strcpy_func = SLC.get_strcpy();
 | 
			
		||||
          if (!strcpy_func)
 | 
			
		||||
            return false;
 | 
			
		||||
          std::vector<Value*> args;
 | 
			
		||||
          args.push_back(ci->getOperand(1));
 | 
			
		||||
          args.push_back(ci->getOperand(3));
 | 
			
		||||
          new CallInst(strcpy_func,args,"",ci);
 | 
			
		||||
        }
 | 
			
		||||
        else if (getConstantStringLength(ci->getOperand(3),len))
 | 
			
		||||
        {
 | 
			
		||||
          // sprintf(dest,"%s",cstr) -> llvm.memcpy(dest,str,strlen(str),1)
 | 
			
		||||
          len++; // get the null-terminator
 | 
			
		||||
          Function* memcpy_func = SLC.get_memcpy();
 | 
			
		||||
          if (!memcpy_func)
 | 
			
		||||
            return false;
 | 
			
		||||
          std::vector<Value*> args;
 | 
			
		||||
          args.push_back(ci->getOperand(1));
 | 
			
		||||
          args.push_back(ci->getOperand(3));
 | 
			
		||||
          args.push_back(ConstantUInt::get(Type::UIntTy,len));
 | 
			
		||||
          args.push_back(ConstantUInt::get(Type::UIntTy,1));
 | 
			
		||||
          new CallInst(memcpy_func,args,"",ci);
 | 
			
		||||
          ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,len));
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      case 'c':
 | 
			
		||||
      {
 | 
			
		||||
        // sprintf(dest,"%c",chr) -> store chr, dest
 | 
			
		||||
        CastInst* cast = 
 | 
			
		||||
          new CastInst(ci->getOperand(3),Type::SByteTy,"char",ci);
 | 
			
		||||
        new StoreInst(cast, ci->getOperand(1), ci);
 | 
			
		||||
        GetElementPtrInst* gep = new GetElementPtrInst(ci->getOperand(1),
 | 
			
		||||
          ConstantUInt::get(Type::UIntTy,1),ci->getOperand(1)->getName()+".end",
 | 
			
		||||
          ci);
 | 
			
		||||
        new StoreInst(ConstantInt::get(Type::SByteTy,0),gep,ci);
 | 
			
		||||
        ci->replaceAllUsesWith(ConstantSInt::get(Type::IntTy,1));
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      default:
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
    ci->eraseFromParent();
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
} SPrintFOptimizer;
 | 
			
		||||
 | 
			
		||||
/// This LibCallOptimization will simplify calls to the "fputs" library 
 | 
			
		||||
/// function. It looks for cases where the result of fputs is not used and the
 | 
			
		||||
/// operation can be reduced to something simpler.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue