diff --git a/libclc/generic/lib/gen_convert.py b/libclc/generic/lib/gen_convert.py index 95ec9fa7ba58..fec6ee5f8887 100644 --- a/libclc/generic/lib/gen_convert.py +++ b/libclc/generic/lib/gen_convert.py @@ -445,14 +445,20 @@ def generate_float_conversion(src, dst, size, mode, sat): USRC=unsigned_type[src], N=size ) ) + # For integer sources, check for saturation case + print( + " return select(r, nextafter(r, sign(r) * ({DST}{N})-INFINITY), convert_{BOOL}{N}(abs_y > abs_x || (abs_y == abs_x && abs_y == ({USRC}{N}){SRC_MAX})));".format( + DST=dst, N=size, BOOL=bool_type[dst], USRC=unsigned_type[src], SRC_MAX=limit_max[src] + ) + ) else: print(" {SRC}{N} abs_x = fabs(x);".format(SRC=src, N=size)) print(" {SRC}{N} abs_y = fabs(y);".format(SRC=src, N=size)) - print( - " return select(r, nextafter(r, sign(r) * ({DST}{N})-INFINITY), convert_{BOOL}{N}(abs_y > abs_x));".format( - DST=dst, N=size, BOOL=bool_type[dst] + print( + " return select(r, nextafter(r, sign(r) * ({DST}{N})-INFINITY), convert_{BOOL}{N}(abs_y > abs_x));".format( + DST=dst, N=size, BOOL=bool_type[dst] + ) ) - ) if mode == "_rtp": print( " return select(r, nextafter(r, ({DST}{N})INFINITY), convert_{BOOL}{N}(y < x));".format( @@ -460,11 +466,21 @@ def generate_float_conversion(src, dst, size, mode, sat): ) ) if mode == "_rtn": - print( - " return select(r, nextafter(r, ({DST}{N})-INFINITY), convert_{BOOL}{N}(y > x));".format( - DST=dst, N=size, BOOL=bool_type[dst] + if src in int_types: + # For integer sources, check for saturation case when converting back + # Cast the constant to source type to match vector element type + print( + " return select(r, nextafter(r, ({DST}{N})-INFINITY), convert_{BOOL}{N}(y > x || (y == x && y == ({SRC}{N}){SRC_MAX})));".format( + DST=dst, N=size, BOOL=bool_type[dst], SRC=src, SRC_MAX=limit_max[src] + ) + ) + else: + # For float sources, use original logic + print( + " return select(r, nextafter(r, ({DST}{N})-INFINITY), convert_{BOOL}{N}(y > x));".format( + DST=dst, N=size, BOOL=bool_type[dst] + ) ) - ) # Footer print("}") diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index dc69ae823e95..4a1fbaa486b1 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -2048,6 +2048,7 @@ static RISCVFPRndMode::RoundingMode matchRoundingOp(unsigned Opc) { switch (Opc) { case ISD::FROUNDEVEN: case ISD::VP_FROUNDEVEN: + case ISD::FRINT: return RISCVFPRndMode::RNE; case ISD::FTRUNC: case ISD::VP_FROUNDTOZERO: @@ -2061,8 +2062,6 @@ static RISCVFPRndMode::RoundingMode matchRoundingOp(unsigned Opc) { case ISD::FROUND: case ISD::VP_FROUND: return RISCVFPRndMode::RMM; - case ISD::FRINT: - return RISCVFPRndMode::DYN; } return RISCVFPRndMode::Invalid; @@ -11262,21 +11261,61 @@ static MachineBasicBlock *emitFROUND(MachineInstr &MI, MachineBasicBlock *MBB, // Convert to integer. Register F2IReg = MRI.createVirtualRegister(isDivergent ? &RISCV::VGPRRegClass : &RISCV::GPRRegClass); - MIB = BuildMI(CvtMBB, DL, TII.get(F2IOpc), F2IReg).addReg(SrcReg); - if(!isDivergent) + + Register OldFRMReg; + if (isDivergent) { + // For vector version, set FRM once for both conversions + // Save current FRM value + OldFRMReg = MRI.createVirtualRegister(&RISCV::GPRRegClass); + BuildMI(CvtMBB, DL, TII.get(RISCV::CSRRS), OldFRMReg) + .addImm(0x002) // FRM CSR address + .addReg(RISCV::X0); + + // Set new rounding mode + BuildMI(CvtMBB, DL, TII.get(RISCV::CSRRWI)) + .addReg(RISCV::X0) + .addImm(0x002) // FRM CSR address + .addImm(FRM); + } + + // First conversion: float to int + if (isDivergent) { + // Vector version, FRM already set + MIB = BuildMI(CvtMBB, DL, TII.get(F2IOpc), F2IReg).addReg(SrcReg); + } else { + // Scalar version, add rounding mode directly + MIB = BuildMI(CvtMBB, DL, TII.get(F2IOpc), F2IReg).addReg(SrcReg); MIB.addImm(FRM); + } + if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept)) MIB->setFlag(MachineInstr::MIFlag::NoFPExcept); // Convert back to FP. Register I2FReg = MRI.createVirtualRegister(isDivergent ? &RISCV::VGPRRegClass : &RISCV::GPRRegClass); - MIB = BuildMI(CvtMBB, DL, TII.get(I2FOpc), I2FReg).addReg(F2IReg); - if(!isDivergent) + + // Second conversion: int to float + if (isDivergent) { + // Vector version, FRM still effective + MIB = BuildMI(CvtMBB, DL, TII.get(I2FOpc), I2FReg).addReg(F2IReg); + } else { + // Scalar version, add rounding mode directly + MIB = BuildMI(CvtMBB, DL, TII.get(I2FOpc), I2FReg).addReg(F2IReg); MIB.addImm(FRM); + } + if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept)) MIB->setFlag(MachineInstr::MIFlag::NoFPExcept); + // For vector version, restore original FRM + if (isDivergent) { + BuildMI(CvtMBB, DL, TII.get(RISCV::CSRRW)) + .addReg(RISCV::X0) + .addImm(0x002) // FRM CSR address + .addReg(OldFRMReg); + } + // Restore the sign bit. Register CvtReg = MRI.createVirtualRegister(RC); BuildMI(CvtMBB, DL, TII.get(FSGNJOpc), CvtReg).addReg(I2FReg).addReg(SrcReg); diff --git a/llvm/lib/Target/RISCV/VentusInstrInfoV.td b/llvm/lib/Target/RISCV/VentusInstrInfoV.td index 37a41e540e29..467394fa4faa 100644 --- a/llvm/lib/Target/RISCV/VentusInstrInfoV.td +++ b/llvm/lib/Target/RISCV/VentusInstrInfoV.td @@ -1140,10 +1140,10 @@ let Uses = [FRM] in { defm VFCVT_XU_F_V : VCVTI_FV_VS2<"vfcvt.xu.f.v", 0b010010, 0b00000>; defm VFCVT_X_F_V : VCVTI_FV_VS2<"vfcvt.x.f.v", 0b010010, 0b00001>; } + // Follow the way by RISCVInstrInfoF -// TODO: later support -// defm VFCVT_RTZ_XU_F_V : VCVTI_FV_VS2_FRM<"vfcvt.rtz.xu.f.v", 0b010010, 0b00110>; -// defm VFCVT_RTZ_X_F_V : VCVTI_FV_VS2_FRM<"vfcvt.rtz.x.f.v", 0b010010, 0b00111>; +defm VFCVT_RTZ_XU_F_V : VCVTI_FV_VS2_FRM<"vfcvt.rtz.xu.f.v", 0b010010, 0b00110>; +defm VFCVT_RTZ_X_F_V : VCVTI_FV_VS2_FRM<"vfcvt.rtz.x.f.v", 0b010010, 0b00111>; let Uses = [FRM] in { defm VFCVT_F_XU_V : VCVTF_IV_VS2<"vfcvt.f.xu.v", 0b010010, 0b00010>; defm VFCVT_F_X_V : VCVTF_IV_VS2<"vfcvt.f.x.v", 0b010010, 0b00011>; @@ -1342,10 +1342,10 @@ defm : PatFloatSetCC<[VGPR, GPRF32], [SETOGE, SETGE], VMFGE_VF>; // (VFCVT_RTZ_X_F_V (f32 VGPR:$rs1), $frm)>; // def : Pat<(i32 (DivergentBinFrag (f32 VGPR:$rs1), timm:$frm)), // (VFCVT_RTZ_XU_F_V (f32 VGPR:$rs1), $frm)>; -def : PatFXConvert, - [XLenVT, f32], VFCVT_X_F_V>; -def : PatFXConvert, - [XLenVT, f32], VFCVT_XU_F_V>; +def : Pat<(i32 (DivergentUnaryFrag (f32 VGPR:$rs1))), + (VFCVT_RTZ_X_F_V VGPR:$rs1, 0b001)>; +def : Pat<(i32 (DivergentUnaryFrag (f32 VGPR:$rs1))), + (VFCVT_RTZ_XU_F_V VGPR:$rs1, 0b001)>; def : PatFXConvert, [f32, XLenVT], VFCVT_F_X_V>; def : PatFXConvert, diff --git a/llvm/test/CodeGen/RISCV/VentusGPGPU/float.ll b/llvm/test/CodeGen/RISCV/VentusGPGPU/float.ll index 812584749c7c..dfde2431e008 100644 --- a/llvm/test/CodeGen/RISCV/VentusGPGPU/float.ll +++ b/llvm/test/CodeGen/RISCV/VentusGPGPU/float.ll @@ -277,7 +277,7 @@ entry: define dso_local i32 @fcvt_x_f(float noundef %a) local_unnamed_addr { ; VENTUS-LABEL: fcvt_x_f: ; VENTUS: # %bb.0: # %entry -; VENTUS-NEXT: vfcvt.x.f.v v0, v0 +; VENTUS-NEXT: vfcvt.rtz.x.f.v v0, v0 ; VENTUS-NEXT: ret entry: %conv = fptosi float %a to i32 @@ -288,7 +288,7 @@ entry: define dso_local i32 @fcvtu_xu_f(float noundef %a) local_unnamed_addr { ; VENTUS-LABEL: fcvtu_xu_f: ; VENTUS: # %bb.0: # %entry -; VENTUS-NEXT: vfcvt.xu.f.v v0, v0 +; VENTUS-NEXT: vfcvt.rtz.xu.f.v v0, v0 ; VENTUS-NEXT: ret entry: %conv = fptoui float %a to i32 diff --git a/llvm/test/CodeGen/RISCV/VentusGPGPU/fround.ll b/llvm/test/CodeGen/RISCV/VentusGPGPU/fround.ll new file mode 100644 index 000000000000..38f1d5ea6857 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/VentusGPGPU/fround.ll @@ -0,0 +1,48 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -mcpu=ventus-gpgpu < %s \ +; RUN: | FileCheck -check-prefix=VENTUS %s + +define float @test_fceil(float noundef %a) #0 { +; VENTUS-LABEL: test_fceil: +; VENTUS: frrm t0 +; VENTUS-NEXT: fsrmi 3 +; VENTUS-NEXT: vfcvt.x.f.v v1, v0 +; VENTUS-NEXT: vfcvt.f.x.v v1, v1 +; VENTUS-NEXT: fsrm t0 +entry: + %call = call float @llvm.ceil.f32(float noundef %a) + ret float %call +} + +define float @test_ffloor(float noundef %a) #0 { +; VENTUS-LABEL: test_ffloor: +; VENTUS: frrm t0 +; VENTUS-NEXT: fsrmi 2 +; VENTUS-NEXT: vfcvt.x.f.v v1, v0 +; VENTUS-NEXT: vfcvt.f.x.v v1, v1 +; VENTUS-NEXT: fsrm t0 +entry: + %call = call float @llvm.floor.f32(float noundef %a) + ret float %call +} + +define float @test_frint(float noundef %a) #0 { +; VENTUS-LABEL: test_frint: +; VENTUS: frrm t0 +; VENTUS-NEXT: fsrmi 0 +; VENTUS-NEXT: vfcvt.x.f.v v1, v0 +; VENTUS-NEXT: vfcvt.f.x.v v1, v1 +; VENTUS-NEXT: fsrm t0 +entry: + %call = call float @llvm.rint.f32(float noundef %a) + ret float %call +} + + +declare float @llvm.ceil.f32(float) #1 +declare float @llvm.floor.f32(float) #1 +declare float @llvm.rint.f32(float) #1 + + +attributes #0 = { alwaysinline convergent mustprogress nofree norecurse nosync nounwind willreturn memory(none) vscale_range(1,2048) "disable-tail-calls"="true" "frame-pointer"="all" "min-legal-vector-width"="0" "no-builtins" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="ventus-gpgpu" "target-features"="+32bit,+a,+m,+relax,+zdinx,+zfinx,+zhinx,+zve32f,+zve32x,+zvl32b,-64bit,-save-restore" } +attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }