[VENTUS][fix] modify rounding mode, functions, instructions to follow OpenCL2.0 conversions specifications (#183)
This commit ensures strict compliance with OpenCL 2.0 floating-point conversion specifications: * **RISCVISelLowering.cpp**: - Map FRINT to RNE (round-to-nearest-even) instead of DYN (dynamic) - Add proper FRM save/restore for vector floating-point operations * **VentusInstrInfoV.td**: - Enable VFCVT_RTZ_* instructions for truncation-based conversions - Use RTZ (round-to-zero) mode for fp-to-int conversions to match OpenCL spec - Replace dynamic rounding with explicit RTZ for integer conversions * **gen_convert.py**: - Improve saturation handling in type conversions - Add proper edge case handling for integer source saturation - Distinguish between integer and float source conversion logic * **Test Updates**: - Update float.ll to expect RTZ instructions for fp-to-int conversions - Add fround.ll test cases for ceil/floor/rint operations These changes ensure that Ventus GPGPU backend produces OpenCL 2.0 compliant floating-point conversion behavior, particularly for rounding modes and saturation handling.
This commit is contained in:
parent
81158ed0fa
commit
8dee421bf2
|
@ -445,14 +445,20 @@ def generate_float_conversion(src, dst, size, mode, sat):
|
|||
USRC=unsigned_type[src], N=size
|
||||
)
|
||||
)
|
||||
# For integer sources, check for saturation case
|
||||
print(
|
||||
" return select(r, nextafter(r, sign(r) * ({DST}{N})-INFINITY), convert_{BOOL}{N}(abs_y > abs_x || (abs_y == abs_x && abs_y == ({USRC}{N}){SRC_MAX})));".format(
|
||||
DST=dst, N=size, BOOL=bool_type[dst], USRC=unsigned_type[src], SRC_MAX=limit_max[src]
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(" {SRC}{N} abs_x = fabs(x);".format(SRC=src, N=size))
|
||||
print(" {SRC}{N} abs_y = fabs(y);".format(SRC=src, N=size))
|
||||
print(
|
||||
" return select(r, nextafter(r, sign(r) * ({DST}{N})-INFINITY), convert_{BOOL}{N}(abs_y > abs_x));".format(
|
||||
DST=dst, N=size, BOOL=bool_type[dst]
|
||||
print(
|
||||
" return select(r, nextafter(r, sign(r) * ({DST}{N})-INFINITY), convert_{BOOL}{N}(abs_y > abs_x));".format(
|
||||
DST=dst, N=size, BOOL=bool_type[dst]
|
||||
)
|
||||
)
|
||||
)
|
||||
if mode == "_rtp":
|
||||
print(
|
||||
" return select(r, nextafter(r, ({DST}{N})INFINITY), convert_{BOOL}{N}(y < x));".format(
|
||||
|
@ -460,11 +466,21 @@ def generate_float_conversion(src, dst, size, mode, sat):
|
|||
)
|
||||
)
|
||||
if mode == "_rtn":
|
||||
print(
|
||||
" return select(r, nextafter(r, ({DST}{N})-INFINITY), convert_{BOOL}{N}(y > x));".format(
|
||||
DST=dst, N=size, BOOL=bool_type[dst]
|
||||
if src in int_types:
|
||||
# For integer sources, check for saturation case when converting back
|
||||
# Cast the constant to source type to match vector element type
|
||||
print(
|
||||
" return select(r, nextafter(r, ({DST}{N})-INFINITY), convert_{BOOL}{N}(y > x || (y == x && y == ({SRC}{N}){SRC_MAX})));".format(
|
||||
DST=dst, N=size, BOOL=bool_type[dst], SRC=src, SRC_MAX=limit_max[src]
|
||||
)
|
||||
)
|
||||
else:
|
||||
# For float sources, use original logic
|
||||
print(
|
||||
" return select(r, nextafter(r, ({DST}{N})-INFINITY), convert_{BOOL}{N}(y > x));".format(
|
||||
DST=dst, N=size, BOOL=bool_type[dst]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Footer
|
||||
print("}")
|
||||
|
|
|
@ -2048,6 +2048,7 @@ static RISCVFPRndMode::RoundingMode matchRoundingOp(unsigned Opc) {
|
|||
switch (Opc) {
|
||||
case ISD::FROUNDEVEN:
|
||||
case ISD::VP_FROUNDEVEN:
|
||||
case ISD::FRINT:
|
||||
return RISCVFPRndMode::RNE;
|
||||
case ISD::FTRUNC:
|
||||
case ISD::VP_FROUNDTOZERO:
|
||||
|
@ -2061,8 +2062,6 @@ static RISCVFPRndMode::RoundingMode matchRoundingOp(unsigned Opc) {
|
|||
case ISD::FROUND:
|
||||
case ISD::VP_FROUND:
|
||||
return RISCVFPRndMode::RMM;
|
||||
case ISD::FRINT:
|
||||
return RISCVFPRndMode::DYN;
|
||||
}
|
||||
|
||||
return RISCVFPRndMode::Invalid;
|
||||
|
@ -11262,21 +11261,61 @@ static MachineBasicBlock *emitFROUND(MachineInstr &MI, MachineBasicBlock *MBB,
|
|||
// Convert to integer.
|
||||
Register F2IReg = MRI.createVirtualRegister(isDivergent ?
|
||||
&RISCV::VGPRRegClass : &RISCV::GPRRegClass);
|
||||
MIB = BuildMI(CvtMBB, DL, TII.get(F2IOpc), F2IReg).addReg(SrcReg);
|
||||
if(!isDivergent)
|
||||
|
||||
Register OldFRMReg;
|
||||
if (isDivergent) {
|
||||
// For vector version, set FRM once for both conversions
|
||||
// Save current FRM value
|
||||
OldFRMReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
|
||||
BuildMI(CvtMBB, DL, TII.get(RISCV::CSRRS), OldFRMReg)
|
||||
.addImm(0x002) // FRM CSR address
|
||||
.addReg(RISCV::X0);
|
||||
|
||||
// Set new rounding mode
|
||||
BuildMI(CvtMBB, DL, TII.get(RISCV::CSRRWI))
|
||||
.addReg(RISCV::X0)
|
||||
.addImm(0x002) // FRM CSR address
|
||||
.addImm(FRM);
|
||||
}
|
||||
|
||||
// First conversion: float to int
|
||||
if (isDivergent) {
|
||||
// Vector version, FRM already set
|
||||
MIB = BuildMI(CvtMBB, DL, TII.get(F2IOpc), F2IReg).addReg(SrcReg);
|
||||
} else {
|
||||
// Scalar version, add rounding mode directly
|
||||
MIB = BuildMI(CvtMBB, DL, TII.get(F2IOpc), F2IReg).addReg(SrcReg);
|
||||
MIB.addImm(FRM);
|
||||
}
|
||||
|
||||
if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept))
|
||||
MIB->setFlag(MachineInstr::MIFlag::NoFPExcept);
|
||||
|
||||
// Convert back to FP.
|
||||
Register I2FReg = MRI.createVirtualRegister(isDivergent ?
|
||||
&RISCV::VGPRRegClass : &RISCV::GPRRegClass);
|
||||
MIB = BuildMI(CvtMBB, DL, TII.get(I2FOpc), I2FReg).addReg(F2IReg);
|
||||
if(!isDivergent)
|
||||
|
||||
// Second conversion: int to float
|
||||
if (isDivergent) {
|
||||
// Vector version, FRM still effective
|
||||
MIB = BuildMI(CvtMBB, DL, TII.get(I2FOpc), I2FReg).addReg(F2IReg);
|
||||
} else {
|
||||
// Scalar version, add rounding mode directly
|
||||
MIB = BuildMI(CvtMBB, DL, TII.get(I2FOpc), I2FReg).addReg(F2IReg);
|
||||
MIB.addImm(FRM);
|
||||
}
|
||||
|
||||
if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept))
|
||||
MIB->setFlag(MachineInstr::MIFlag::NoFPExcept);
|
||||
|
||||
// For vector version, restore original FRM
|
||||
if (isDivergent) {
|
||||
BuildMI(CvtMBB, DL, TII.get(RISCV::CSRRW))
|
||||
.addReg(RISCV::X0)
|
||||
.addImm(0x002) // FRM CSR address
|
||||
.addReg(OldFRMReg);
|
||||
}
|
||||
|
||||
// Restore the sign bit.
|
||||
Register CvtReg = MRI.createVirtualRegister(RC);
|
||||
BuildMI(CvtMBB, DL, TII.get(FSGNJOpc), CvtReg).addReg(I2FReg).addReg(SrcReg);
|
||||
|
|
|
@ -1140,10 +1140,10 @@ let Uses = [FRM] in {
|
|||
defm VFCVT_XU_F_V : VCVTI_FV_VS2<"vfcvt.xu.f.v", 0b010010, 0b00000>;
|
||||
defm VFCVT_X_F_V : VCVTI_FV_VS2<"vfcvt.x.f.v", 0b010010, 0b00001>;
|
||||
}
|
||||
|
||||
// Follow the way by RISCVInstrInfoF
|
||||
// TODO: later support
|
||||
// defm VFCVT_RTZ_XU_F_V : VCVTI_FV_VS2_FRM<"vfcvt.rtz.xu.f.v", 0b010010, 0b00110>;
|
||||
// defm VFCVT_RTZ_X_F_V : VCVTI_FV_VS2_FRM<"vfcvt.rtz.x.f.v", 0b010010, 0b00111>;
|
||||
defm VFCVT_RTZ_XU_F_V : VCVTI_FV_VS2_FRM<"vfcvt.rtz.xu.f.v", 0b010010, 0b00110>;
|
||||
defm VFCVT_RTZ_X_F_V : VCVTI_FV_VS2_FRM<"vfcvt.rtz.x.f.v", 0b010010, 0b00111>;
|
||||
let Uses = [FRM] in {
|
||||
defm VFCVT_F_XU_V : VCVTF_IV_VS2<"vfcvt.f.xu.v", 0b010010, 0b00010>;
|
||||
defm VFCVT_F_X_V : VCVTF_IV_VS2<"vfcvt.f.x.v", 0b010010, 0b00011>;
|
||||
|
@ -1342,10 +1342,10 @@ defm : PatFloatSetCC<[VGPR, GPRF32], [SETOGE, SETGE], VMFGE_VF>;
|
|||
// (VFCVT_RTZ_X_F_V (f32 VGPR:$rs1), $frm)>;
|
||||
// def : Pat<(i32 (DivergentBinFrag<riscv_fcvt_xu> (f32 VGPR:$rs1), timm:$frm)),
|
||||
// (VFCVT_RTZ_XU_F_V (f32 VGPR:$rs1), $frm)>;
|
||||
def : PatFXConvert<DivergentUnaryFrag<any_fp_to_sint>,
|
||||
[XLenVT, f32], VFCVT_X_F_V>;
|
||||
def : PatFXConvert<DivergentUnaryFrag<any_fp_to_uint>,
|
||||
[XLenVT, f32], VFCVT_XU_F_V>;
|
||||
def : Pat<(i32 (DivergentUnaryFrag<any_fp_to_sint> (f32 VGPR:$rs1))),
|
||||
(VFCVT_RTZ_X_F_V VGPR:$rs1, 0b001)>;
|
||||
def : Pat<(i32 (DivergentUnaryFrag<any_fp_to_uint> (f32 VGPR:$rs1))),
|
||||
(VFCVT_RTZ_XU_F_V VGPR:$rs1, 0b001)>;
|
||||
def : PatFXConvert<DivergentUnaryFrag<any_sint_to_fp>,
|
||||
[f32, XLenVT], VFCVT_F_X_V>;
|
||||
def : PatFXConvert<DivergentUnaryFrag<any_uint_to_fp>,
|
||||
|
|
|
@ -277,7 +277,7 @@ entry:
|
|||
define dso_local i32 @fcvt_x_f(float noundef %a) local_unnamed_addr {
|
||||
; VENTUS-LABEL: fcvt_x_f:
|
||||
; VENTUS: # %bb.0: # %entry
|
||||
; VENTUS-NEXT: vfcvt.x.f.v v0, v0
|
||||
; VENTUS-NEXT: vfcvt.rtz.x.f.v v0, v0
|
||||
; VENTUS-NEXT: ret
|
||||
entry:
|
||||
%conv = fptosi float %a to i32
|
||||
|
@ -288,7 +288,7 @@ entry:
|
|||
define dso_local i32 @fcvtu_xu_f(float noundef %a) local_unnamed_addr {
|
||||
; VENTUS-LABEL: fcvtu_xu_f:
|
||||
; VENTUS: # %bb.0: # %entry
|
||||
; VENTUS-NEXT: vfcvt.xu.f.v v0, v0
|
||||
; VENTUS-NEXT: vfcvt.rtz.xu.f.v v0, v0
|
||||
; VENTUS-NEXT: ret
|
||||
entry:
|
||||
%conv = fptoui float %a to i32
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
|
||||
; RUN: llc -mtriple=riscv32 -mcpu=ventus-gpgpu < %s \
|
||||
; RUN: | FileCheck -check-prefix=VENTUS %s
|
||||
|
||||
define float @test_fceil(float noundef %a) #0 {
|
||||
; VENTUS-LABEL: test_fceil:
|
||||
; VENTUS: frrm t0
|
||||
; VENTUS-NEXT: fsrmi 3
|
||||
; VENTUS-NEXT: vfcvt.x.f.v v1, v0
|
||||
; VENTUS-NEXT: vfcvt.f.x.v v1, v1
|
||||
; VENTUS-NEXT: fsrm t0
|
||||
entry:
|
||||
%call = call float @llvm.ceil.f32(float noundef %a)
|
||||
ret float %call
|
||||
}
|
||||
|
||||
define float @test_ffloor(float noundef %a) #0 {
|
||||
; VENTUS-LABEL: test_ffloor:
|
||||
; VENTUS: frrm t0
|
||||
; VENTUS-NEXT: fsrmi 2
|
||||
; VENTUS-NEXT: vfcvt.x.f.v v1, v0
|
||||
; VENTUS-NEXT: vfcvt.f.x.v v1, v1
|
||||
; VENTUS-NEXT: fsrm t0
|
||||
entry:
|
||||
%call = call float @llvm.floor.f32(float noundef %a)
|
||||
ret float %call
|
||||
}
|
||||
|
||||
define float @test_frint(float noundef %a) #0 {
|
||||
; VENTUS-LABEL: test_frint:
|
||||
; VENTUS: frrm t0
|
||||
; VENTUS-NEXT: fsrmi 0
|
||||
; VENTUS-NEXT: vfcvt.x.f.v v1, v0
|
||||
; VENTUS-NEXT: vfcvt.f.x.v v1, v1
|
||||
; VENTUS-NEXT: fsrm t0
|
||||
entry:
|
||||
%call = call float @llvm.rint.f32(float noundef %a)
|
||||
ret float %call
|
||||
}
|
||||
|
||||
|
||||
declare float @llvm.ceil.f32(float) #1
|
||||
declare float @llvm.floor.f32(float) #1
|
||||
declare float @llvm.rint.f32(float) #1
|
||||
|
||||
|
||||
attributes #0 = { alwaysinline convergent mustprogress nofree norecurse nosync nounwind willreturn memory(none) vscale_range(1,2048) "disable-tail-calls"="true" "frame-pointer"="all" "min-legal-vector-width"="0" "no-builtins" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="ventus-gpgpu" "target-features"="+32bit,+a,+m,+relax,+zdinx,+zfinx,+zhinx,+zve32f,+zve32x,+zvl32b,-64bit,-save-restore" }
|
||||
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
|
Loading…
Reference in New Issue