[VENTUS][fix] modify rounding mode, functions, instructions to follow OpenCL2.0 conversions specifications (#183)

This commit ensures strict compliance with OpenCL 2.0 floating-point
conversion specifications:

* **RISCVISelLowering.cpp**:
  - Map FRINT to RNE (round-to-nearest-even) instead of DYN (dynamic)
  - Add proper FRM save/restore for vector floating-point operations

* **VentusInstrInfoV.td**:
  - Enable VFCVT_RTZ_* instructions for truncation-based conversions
  - Use RTZ (round-to-zero) mode for fp-to-int conversions to match
    OpenCL spec
  - Replace dynamic rounding with explicit RTZ for integer conversions

* **gen_convert.py**:
  - Improve saturation handling in type conversions
  - Add proper edge case handling for integer source saturation
  - Distinguish between integer and float source conversion logic

* **Test Updates**:
  - Update float.ll to expect RTZ instructions for fp-to-int conversions
  - Add fround.ll test cases for ceil/floor/rint operations

These changes ensure that Ventus GPGPU backend produces OpenCL 2.0 compliant
floating-point conversion behavior, particularly for rounding modes and
saturation handling.
This commit is contained in:
wenhu1024 2025-07-05 17:34:04 +08:00 committed by GitHub
parent 81158ed0fa
commit 8dee421bf2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 126 additions and 23 deletions

View File

@ -445,14 +445,20 @@ def generate_float_conversion(src, dst, size, mode, sat):
USRC=unsigned_type[src], N=size
)
)
# For integer sources, check for saturation case
print(
" return select(r, nextafter(r, sign(r) * ({DST}{N})-INFINITY), convert_{BOOL}{N}(abs_y > abs_x || (abs_y == abs_x && abs_y == ({USRC}{N}){SRC_MAX})));".format(
DST=dst, N=size, BOOL=bool_type[dst], USRC=unsigned_type[src], SRC_MAX=limit_max[src]
)
)
else:
print(" {SRC}{N} abs_x = fabs(x);".format(SRC=src, N=size))
print(" {SRC}{N} abs_y = fabs(y);".format(SRC=src, N=size))
print(
" return select(r, nextafter(r, sign(r) * ({DST}{N})-INFINITY), convert_{BOOL}{N}(abs_y > abs_x));".format(
DST=dst, N=size, BOOL=bool_type[dst]
print(
" return select(r, nextafter(r, sign(r) * ({DST}{N})-INFINITY), convert_{BOOL}{N}(abs_y > abs_x));".format(
DST=dst, N=size, BOOL=bool_type[dst]
)
)
)
if mode == "_rtp":
print(
" return select(r, nextafter(r, ({DST}{N})INFINITY), convert_{BOOL}{N}(y < x));".format(
@ -460,11 +466,21 @@ def generate_float_conversion(src, dst, size, mode, sat):
)
)
if mode == "_rtn":
print(
" return select(r, nextafter(r, ({DST}{N})-INFINITY), convert_{BOOL}{N}(y > x));".format(
DST=dst, N=size, BOOL=bool_type[dst]
if src in int_types:
# For integer sources, check for saturation case when converting back
# Cast the constant to source type to match vector element type
print(
" return select(r, nextafter(r, ({DST}{N})-INFINITY), convert_{BOOL}{N}(y > x || (y == x && y == ({SRC}{N}){SRC_MAX})));".format(
DST=dst, N=size, BOOL=bool_type[dst], SRC=src, SRC_MAX=limit_max[src]
)
)
else:
# For float sources, use original logic
print(
" return select(r, nextafter(r, ({DST}{N})-INFINITY), convert_{BOOL}{N}(y > x));".format(
DST=dst, N=size, BOOL=bool_type[dst]
)
)
)
# Footer
print("}")

View File

@ -2048,6 +2048,7 @@ static RISCVFPRndMode::RoundingMode matchRoundingOp(unsigned Opc) {
switch (Opc) {
case ISD::FROUNDEVEN:
case ISD::VP_FROUNDEVEN:
case ISD::FRINT:
return RISCVFPRndMode::RNE;
case ISD::FTRUNC:
case ISD::VP_FROUNDTOZERO:
@ -2061,8 +2062,6 @@ static RISCVFPRndMode::RoundingMode matchRoundingOp(unsigned Opc) {
case ISD::FROUND:
case ISD::VP_FROUND:
return RISCVFPRndMode::RMM;
case ISD::FRINT:
return RISCVFPRndMode::DYN;
}
return RISCVFPRndMode::Invalid;
@ -11262,21 +11261,61 @@ static MachineBasicBlock *emitFROUND(MachineInstr &MI, MachineBasicBlock *MBB,
// Convert to integer.
Register F2IReg = MRI.createVirtualRegister(isDivergent ?
&RISCV::VGPRRegClass : &RISCV::GPRRegClass);
MIB = BuildMI(CvtMBB, DL, TII.get(F2IOpc), F2IReg).addReg(SrcReg);
if(!isDivergent)
Register OldFRMReg;
if (isDivergent) {
// For vector version, set FRM once for both conversions
// Save current FRM value
OldFRMReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
BuildMI(CvtMBB, DL, TII.get(RISCV::CSRRS), OldFRMReg)
.addImm(0x002) // FRM CSR address
.addReg(RISCV::X0);
// Set new rounding mode
BuildMI(CvtMBB, DL, TII.get(RISCV::CSRRWI))
.addReg(RISCV::X0)
.addImm(0x002) // FRM CSR address
.addImm(FRM);
}
// First conversion: float to int
if (isDivergent) {
// Vector version, FRM already set
MIB = BuildMI(CvtMBB, DL, TII.get(F2IOpc), F2IReg).addReg(SrcReg);
} else {
// Scalar version, add rounding mode directly
MIB = BuildMI(CvtMBB, DL, TII.get(F2IOpc), F2IReg).addReg(SrcReg);
MIB.addImm(FRM);
}
if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept))
MIB->setFlag(MachineInstr::MIFlag::NoFPExcept);
// Convert back to FP.
Register I2FReg = MRI.createVirtualRegister(isDivergent ?
&RISCV::VGPRRegClass : &RISCV::GPRRegClass);
MIB = BuildMI(CvtMBB, DL, TII.get(I2FOpc), I2FReg).addReg(F2IReg);
if(!isDivergent)
// Second conversion: int to float
if (isDivergent) {
// Vector version, FRM still effective
MIB = BuildMI(CvtMBB, DL, TII.get(I2FOpc), I2FReg).addReg(F2IReg);
} else {
// Scalar version, add rounding mode directly
MIB = BuildMI(CvtMBB, DL, TII.get(I2FOpc), I2FReg).addReg(F2IReg);
MIB.addImm(FRM);
}
if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept))
MIB->setFlag(MachineInstr::MIFlag::NoFPExcept);
// For vector version, restore original FRM
if (isDivergent) {
BuildMI(CvtMBB, DL, TII.get(RISCV::CSRRW))
.addReg(RISCV::X0)
.addImm(0x002) // FRM CSR address
.addReg(OldFRMReg);
}
// Restore the sign bit.
Register CvtReg = MRI.createVirtualRegister(RC);
BuildMI(CvtMBB, DL, TII.get(FSGNJOpc), CvtReg).addReg(I2FReg).addReg(SrcReg);

View File

@ -1140,10 +1140,10 @@ let Uses = [FRM] in {
defm VFCVT_XU_F_V : VCVTI_FV_VS2<"vfcvt.xu.f.v", 0b010010, 0b00000>;
defm VFCVT_X_F_V : VCVTI_FV_VS2<"vfcvt.x.f.v", 0b010010, 0b00001>;
}
// Follow the way by RISCVInstrInfoF
// TODO: later support
// defm VFCVT_RTZ_XU_F_V : VCVTI_FV_VS2_FRM<"vfcvt.rtz.xu.f.v", 0b010010, 0b00110>;
// defm VFCVT_RTZ_X_F_V : VCVTI_FV_VS2_FRM<"vfcvt.rtz.x.f.v", 0b010010, 0b00111>;
defm VFCVT_RTZ_XU_F_V : VCVTI_FV_VS2_FRM<"vfcvt.rtz.xu.f.v", 0b010010, 0b00110>;
defm VFCVT_RTZ_X_F_V : VCVTI_FV_VS2_FRM<"vfcvt.rtz.x.f.v", 0b010010, 0b00111>;
let Uses = [FRM] in {
defm VFCVT_F_XU_V : VCVTF_IV_VS2<"vfcvt.f.xu.v", 0b010010, 0b00010>;
defm VFCVT_F_X_V : VCVTF_IV_VS2<"vfcvt.f.x.v", 0b010010, 0b00011>;
@ -1342,10 +1342,10 @@ defm : PatFloatSetCC<[VGPR, GPRF32], [SETOGE, SETGE], VMFGE_VF>;
// (VFCVT_RTZ_X_F_V (f32 VGPR:$rs1), $frm)>;
// def : Pat<(i32 (DivergentBinFrag<riscv_fcvt_xu> (f32 VGPR:$rs1), timm:$frm)),
// (VFCVT_RTZ_XU_F_V (f32 VGPR:$rs1), $frm)>;
def : PatFXConvert<DivergentUnaryFrag<any_fp_to_sint>,
[XLenVT, f32], VFCVT_X_F_V>;
def : PatFXConvert<DivergentUnaryFrag<any_fp_to_uint>,
[XLenVT, f32], VFCVT_XU_F_V>;
def : Pat<(i32 (DivergentUnaryFrag<any_fp_to_sint> (f32 VGPR:$rs1))),
(VFCVT_RTZ_X_F_V VGPR:$rs1, 0b001)>;
def : Pat<(i32 (DivergentUnaryFrag<any_fp_to_uint> (f32 VGPR:$rs1))),
(VFCVT_RTZ_XU_F_V VGPR:$rs1, 0b001)>;
def : PatFXConvert<DivergentUnaryFrag<any_sint_to_fp>,
[f32, XLenVT], VFCVT_F_X_V>;
def : PatFXConvert<DivergentUnaryFrag<any_uint_to_fp>,

View File

@ -277,7 +277,7 @@ entry:
define dso_local i32 @fcvt_x_f(float noundef %a) local_unnamed_addr {
; VENTUS-LABEL: fcvt_x_f:
; VENTUS: # %bb.0: # %entry
; VENTUS-NEXT: vfcvt.x.f.v v0, v0
; VENTUS-NEXT: vfcvt.rtz.x.f.v v0, v0
; VENTUS-NEXT: ret
entry:
%conv = fptosi float %a to i32
@ -288,7 +288,7 @@ entry:
define dso_local i32 @fcvtu_xu_f(float noundef %a) local_unnamed_addr {
; VENTUS-LABEL: fcvtu_xu_f:
; VENTUS: # %bb.0: # %entry
; VENTUS-NEXT: vfcvt.xu.f.v v0, v0
; VENTUS-NEXT: vfcvt.rtz.xu.f.v v0, v0
; VENTUS-NEXT: ret
entry:
%conv = fptoui float %a to i32

View File

@ -0,0 +1,48 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mcpu=ventus-gpgpu < %s \
; RUN: | FileCheck -check-prefix=VENTUS %s
define float @test_fceil(float noundef %a) #0 {
; VENTUS-LABEL: test_fceil:
; VENTUS: frrm t0
; VENTUS-NEXT: fsrmi 3
; VENTUS-NEXT: vfcvt.x.f.v v1, v0
; VENTUS-NEXT: vfcvt.f.x.v v1, v1
; VENTUS-NEXT: fsrm t0
entry:
%call = call float @llvm.ceil.f32(float noundef %a)
ret float %call
}
define float @test_ffloor(float noundef %a) #0 {
; VENTUS-LABEL: test_ffloor:
; VENTUS: frrm t0
; VENTUS-NEXT: fsrmi 2
; VENTUS-NEXT: vfcvt.x.f.v v1, v0
; VENTUS-NEXT: vfcvt.f.x.v v1, v1
; VENTUS-NEXT: fsrm t0
entry:
%call = call float @llvm.floor.f32(float noundef %a)
ret float %call
}
define float @test_frint(float noundef %a) #0 {
; VENTUS-LABEL: test_frint:
; VENTUS: frrm t0
; VENTUS-NEXT: fsrmi 0
; VENTUS-NEXT: vfcvt.x.f.v v1, v0
; VENTUS-NEXT: vfcvt.f.x.v v1, v1
; VENTUS-NEXT: fsrm t0
entry:
%call = call float @llvm.rint.f32(float noundef %a)
ret float %call
}
declare float @llvm.ceil.f32(float) #1
declare float @llvm.floor.f32(float) #1
declare float @llvm.rint.f32(float) #1
attributes #0 = { alwaysinline convergent mustprogress nofree norecurse nosync nounwind willreturn memory(none) vscale_range(1,2048) "disable-tail-calls"="true" "frame-pointer"="all" "min-legal-vector-width"="0" "no-builtins" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="ventus-gpgpu" "target-features"="+32bit,+a,+m,+relax,+zdinx,+zfinx,+zhinx,+zve32f,+zve32x,+zvl32b,-64bit,-save-restore" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }