[MLIR][GPU] Replace fdiv on fp16 with promoted (fp32) multiplication with reciprocal plus one (conditional) Newton iteration.
This is correct for all values, i.e. the same as promoting the division to fp32 in the NVPTX backend. But it is faster (~10% in average, sometimes more) because:
- it performs less Newton iterations
- it avoids the slow path for e.g. denormals
- it allows reuse of the reciprocal for multiple divisions by the same divisor
Test program:
```
#include <stdio.h>
#include "cuda_fp16.h"
// This is a variant of CUDA's own __hdiv which is fast than hdiv_promote below
// and doesn't suffer from the perf cliff of div.rn.fp32 with 'special' values.
__device__ half hdiv_newton(half a, half b) {
float fa = __half2float(a);
float fb = __half2float(b);
float rcp;
asm("{rcp.approx.ftz.f32 %0, %1;\n}" : "=f"(rcp) : "f"(fb));
float result = fa * rcp;
auto exponent = reinterpret_cast<const unsigned&>(result) & 0x7f800000;
if (exponent != 0 && exponent != 0x7f800000) {
float err = __fmaf_rn(-fb, result, fa);
result = __fmaf_rn(rcp, err, result);
}
return __float2half(result);
}
// Surprisingly, this is faster than CUDA's own __hdiv.
__device__ half hdiv_promote(half a, half b) {
return __float2half(__half2float(a) / __half2float(b));
}
// This is an approximation that is accurate up to 1 ulp.
__device__ half hdiv_approx(half a, half b) {
float fa = __half2float(a);
float fb = __half2float(b);
float result;
asm("{div.approx.ftz.f32 %0, %1, %2;\n}" : "=f"(result) : "f"(fa), "f"(fb));
return __float2half(result);
}
__global__ void CheckCorrectness() {
int i = threadIdx.x + blockIdx.x * blockDim.x;
half x = reinterpret_cast<const half&>(i);
for (int j = 0; j < 65536; ++j) {
half y = reinterpret_cast<const half&>(j);
half d1 = hdiv_newton(x, y);
half d2 = hdiv_promote(x, y);
auto s1 = reinterpret_cast<const short&>(d1);
auto s2 = reinterpret_cast<const short&>(d2);
if (s1 != s2) {
printf("%f (%u) / %f (%u), got %f (%hu), expected: %f (%hu)\n",
__half2float(x), i, __half2float(y), j, __half2float(d1), s1,
__half2float(d2), s2);
//__trap();
}
}
}
__device__ half dst;
__global__ void ProfileBuiltin(half x) {
#pragma unroll 1
for (int i = 0; i < 10000000; ++i) {
x = x / x;
}
dst = x;
}
__global__ void ProfilePromote(half x) {
#pragma unroll 1
for (int i = 0; i < 10000000; ++i) {
x = hdiv_promote(x, x);
}
dst = x;
}
__global__ void ProfileNewton(half x) {
#pragma unroll 1
for (int i = 0; i < 10000000; ++i) {
x = hdiv_newton(x, x);
}
dst = x;
}
__global__ void ProfileApprox(half x) {
#pragma unroll 1
for (int i = 0; i < 10000000; ++i) {
x = hdiv_approx(x, x);
}
dst = x;
}
int main() {
CheckCorrectness<<<256, 256>>>();
half one = __float2half(1.0f);
ProfileBuiltin<<<1, 1>>>(one); // 1.001s
ProfilePromote<<<1, 1>>>(one); // 0.560s
ProfileNewton<<<1, 1>>>(one); // 0.508s
ProfileApprox<<<1, 1>>>(one); // 0.304s
auto status = cudaDeviceSynchronize();
printf("%s\n", cudaGetErrorString(status));
}
```
Reviewed By: herhut
Differential Revision: https://reviews.llvm.org/D126158
This commit is contained in:
parent
9c54d76251
commit
bcfc0a9051
|
|
@ -51,21 +51,21 @@ class NVVM_Op<string mnemonic, list<Trait> traits = []> :
|
|||
// NVVM intrinsic operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class NVVM_IntrOp<string mnem, list<int> overloadedResults,
|
||||
list<int> overloadedOperands, list<Trait> traits,
|
||||
class NVVM_IntrOp<string mnem, list<Trait> traits,
|
||||
int numResults>
|
||||
: LLVM_IntrOpBase<NVVM_Dialect, mnem, "nvvm_" # !subst(".", "_", mnem),
|
||||
overloadedResults, overloadedOperands, traits, numResults>;
|
||||
/*list<int> overloadedResults=*/[],
|
||||
/*list<int> overloadedOperands=*/[],
|
||||
traits, numResults>;
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NVVM special register op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class NVVM_SpecialRegisterOp<string mnemonic,
|
||||
list<Trait> traits = []> :
|
||||
NVVM_IntrOp<mnemonic, [], [], !listconcat(traits, [NoSideEffect]), 1>,
|
||||
Arguments<(ins)> {
|
||||
class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
|
||||
NVVM_IntrOp<mnemonic, !listconcat(traits, [NoSideEffect]), 1> {
|
||||
let arguments = (ins);
|
||||
let assemblyFormat = "attr-dict `:` type($res)";
|
||||
}
|
||||
|
||||
|
|
@ -92,6 +92,16 @@ def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">;
|
|||
def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
|
||||
def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NVVM approximate op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [NoSideEffect], 1> {
|
||||
let arguments = (ins F32:$arg);
|
||||
let results = (outs F32:$res);
|
||||
let assemblyFormat = "$arg attr-dict `:` type($res)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NVVM synchronization op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
|||
|
|
@ -0,0 +1,25 @@
|
|||
//===- OptimizeForNVVM.h - Optimize LLVM IR for NVVM -*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H
|
||||
#define MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class Pass;
|
||||
|
||||
namespace NVVM {
|
||||
|
||||
/// Creates a pass that optimizes LLVM IR for the NVVM target.
|
||||
std::unique_ptr<Pass> createOptimizeForTargetPass();
|
||||
|
||||
} // namespace NVVM
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H
|
||||
|
|
@ -10,6 +10,7 @@
|
|||
#define MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
|
||||
#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
|||
|
|
@ -16,4 +16,9 @@ def LLVMLegalizeForExport : Pass<"llvm-legalize-for-export"> {
|
|||
let constructor = "mlir::LLVM::createLegalizeForExportPass()";
|
||||
}
|
||||
|
||||
def NVVMOptimizeForTarget : Pass<"llvm-optimize-for-nvvm-target"> {
|
||||
let summary = "Optimize NVVM IR";
|
||||
let constructor = "mlir::NVVM::createOptimizeForTargetPass()";
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
add_mlir_dialect_library(MLIRLLVMIRTransforms
|
||||
LegalizeForExport.cpp
|
||||
OptimizeForNVVM.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRLLVMPassIncGen
|
||||
|
|
|
|||
|
|
@ -0,0 +1,97 @@
|
|||
//===- OptimizeForNVVM.cpp - Optimize LLVM IR for NVVM ---------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
// Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one
|
||||
// (conditional) Newton iteration.
|
||||
//
|
||||
// This as accurate as promoting the division to fp32 in the NVPTX backend, but
|
||||
// faster because it performs less Newton iterations, avoids the slow path
|
||||
// for e.g. denormals, and allows reuse of the reciprocal for multiple divisions
|
||||
// by the same divisor.
|
||||
struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> {
|
||||
using OpRewritePattern<LLVM::FDivOp>::OpRewritePattern;
|
||||
|
||||
private:
|
||||
LogicalResult matchAndRewrite(LLVM::FDivOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct NVVMOptimizeForTarget
|
||||
: public NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> {
|
||||
void runOnOperation() override;
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<NVVM::NVVMDialect>();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
if (!op.getType().isF16())
|
||||
return rewriter.notifyMatchFailure(op, "not f16");
|
||||
Location loc = op.getLoc();
|
||||
|
||||
Type f32Type = rewriter.getF32Type();
|
||||
Type i32Type = rewriter.getI32Type();
|
||||
|
||||
// Extend lhs and rhs to fp32.
|
||||
Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getLhs());
|
||||
Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getRhs());
|
||||
|
||||
// float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp.
|
||||
Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs);
|
||||
Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp);
|
||||
|
||||
// Refine the approximation with one Newton iteration:
|
||||
// float refined = approx + (lhs - approx * rhs) * rcp;
|
||||
Value err = rewriter.create<LLVM::FMAOp>(
|
||||
loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs);
|
||||
Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx);
|
||||
|
||||
// Use refined value if approx is normal (exponent neither all 0 or all 1).
|
||||
Value mask = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000));
|
||||
Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx);
|
||||
Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask);
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, i32Type, rewriter.getUI32IntegerAttr(0));
|
||||
Value pred = rewriter.create<LLVM::OrOp>(
|
||||
loc,
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero),
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask));
|
||||
Value result =
|
||||
rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined);
|
||||
|
||||
// Replace with trucation back to fp16.
|
||||
rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void NVVMOptimizeForTarget::runOnOperation() {
|
||||
MLIRContext *ctx = getOperation()->getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<ExpandDivF16>(ctx);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> NVVM::createOptimizeForTargetPass() {
|
||||
return std::make_unique<NVVMOptimizeForTarget>();
|
||||
}
|
||||
|
|
@ -29,6 +29,13 @@ func.func @nvvm_special_regs() -> i32 {
|
|||
llvm.return %0 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @nvvm_rcp
|
||||
func.func @nvvm_rcp(%arg0: f32) -> f32 {
|
||||
// CHECK: nvvm.rcp.approx.ftz.f %arg0 : f32
|
||||
%0 = nvvm.rcp.approx.ftz.f %arg0 : f32
|
||||
llvm.return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @llvm_nvvm_barrier0
|
||||
func.func @llvm_nvvm_barrier0() {
|
||||
// CHECK: nvvm.barrier0
|
||||
|
|
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
// RUN: mlir-opt %s -llvm-optimize-for-nvvm-target | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: llvm.func @fdiv_fp16
|
||||
llvm.func @fdiv_fp16(%arg0 : f16, %arg1 : f16) -> f16 {
|
||||
// CHECK-DAG: %[[c0:.*]] = llvm.mlir.constant(0 : ui32) : i32
|
||||
// CHECK-DAG: %[[mask:.*]] = llvm.mlir.constant(2139095040 : ui32) : i32
|
||||
// CHECK-DAG: %[[lhs:.*]] = llvm.fpext %arg0 : f16 to f32
|
||||
// CHECK-DAG: %[[rhs:.*]] = llvm.fpext %arg1 : f16 to f32
|
||||
// CHECK-DAG: %[[rcp:.*]] = nvvm.rcp.approx.ftz.f %[[rhs]] : f32
|
||||
// CHECK-DAG: %[[approx:.*]] = llvm.fmul %[[lhs]], %[[rcp]] : f32
|
||||
// CHECK-DAG: %[[neg:.*]] = llvm.fneg %[[rhs]] : f32
|
||||
// CHECK-DAG: %[[err:.*]] = "llvm.intr.fma"(%[[approx]], %[[neg]], %[[lhs]]) : (f32, f32, f32) -> f32
|
||||
// CHECK-DAG: %[[refined:.*]] = "llvm.intr.fma"(%[[err]], %[[rcp]], %[[approx]]) : (f32, f32, f32) -> f32
|
||||
// CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[approx]] : f32 to i32
|
||||
// CHECK-DAG: %[[exp:.*]] = llvm.and %[[cast]], %[[mask]] : i32
|
||||
// CHECK-DAG: %[[is_zero:.*]] = llvm.icmp "eq" %[[exp]], %[[c0]] : i32
|
||||
// CHECK-DAG: %[[is_mask:.*]] = llvm.icmp "eq" %[[exp]], %[[mask]] : i32
|
||||
// CHECK-DAG: %[[pred:.*]] = llvm.or %[[is_zero]], %[[is_mask]] : i1
|
||||
// CHECK-DAG: %[[select:.*]] = llvm.select %[[pred]], %[[approx]], %[[refined]] : i1, f32
|
||||
// CHECK-DAG: %[[result:.*]] = llvm.fptrunc %[[select]] : f32 to f16
|
||||
%result = llvm.fdiv %arg0, %arg1 : f16
|
||||
// CHECK: llvm.return %[[result]] : f16
|
||||
llvm.return %result : f16
|
||||
}
|
||||
|
|
@ -33,6 +33,13 @@ llvm.func @nvvm_special_regs() -> i32 {
|
|||
llvm.return %1 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @nvvm_rcp
|
||||
llvm.func @nvvm_rcp(%0: f32) -> f32 {
|
||||
// CHECK: call float @llvm.nvvm.rcp.approx.ftz.f
|
||||
%1 = nvvm.rcp.approx.ftz.f %0 : f32
|
||||
llvm.return %1 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @llvm_nvvm_barrier0
|
||||
llvm.func @llvm_nvvm_barrier0() {
|
||||
// CHECK: call void @llvm.nvvm.barrier0()
|
||||
|
|
|
|||
|
|
@ -3386,7 +3386,9 @@ cc_library(
|
|||
":IR",
|
||||
":LLVMDialect",
|
||||
":LLVMPassIncGen",
|
||||
":NVVMDialect",
|
||||
":Pass",
|
||||
":Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue