[X86] Improve combineVectorShiftImm

Summary:
Fold (shift (shift X, C2), C1) -> (shift X, (C1 + C2)) for logical as
well as arithmetic shifts. This is needed to prevent regressions from
an upcoming funnel shift expansion change.

While we're here, fold (VSRAI -1, C) -> -1 too.

Reviewers: RKSimon, craig.topper

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D77300
This commit is contained in:
Jay Foad 2020-04-02 12:20:35 +01:00
parent e8111502d8
commit bc78baec4c
4 changed files with 926 additions and 917 deletions

View File

@ -41084,26 +41084,37 @@ static SDValue combineVectorShiftImm(SDNode *N, SelectionDAG &DAG,
if (ShiftVal >= NumBitsPerElt) { if (ShiftVal >= NumBitsPerElt) {
if (LogicalShift) if (LogicalShift)
return DAG.getConstant(0, SDLoc(N), VT); return DAG.getConstant(0, SDLoc(N), VT);
else ShiftVal = NumBitsPerElt - 1;
ShiftVal = NumBitsPerElt - 1;
} }
// Shift N0 by zero -> N0. // (shift X, 0) -> X
if (!ShiftVal) if (!ShiftVal)
return N0; return N0;
// Shift zero -> zero. // (shift 0, C) -> 0
if (ISD::isBuildVectorAllZeros(N0.getNode())) if (ISD::isBuildVectorAllZeros(N0.getNode()))
// N0 is all zeros or undef. We guarantee that the bits shifted into the
// result are all zeros, not undef.
return DAG.getConstant(0, SDLoc(N), VT); return DAG.getConstant(0, SDLoc(N), VT);
// Fold (VSRAI (VSRAI X, C1), C2) --> (VSRAI X, (C1 + C2)) with (C1 + C2) // (VSRAI -1, C) -> -1
// clamped to (NumBitsPerElt - 1). if (!LogicalShift && ISD::isBuildVectorAllOnes(N0.getNode()))
if (Opcode == X86ISD::VSRAI && N0.getOpcode() == X86ISD::VSRAI) { // N0 is all ones or undef. We guarantee that the bits shifted into the
// result are all ones, not undef.
return DAG.getConstant(-1, SDLoc(N), VT);
// (shift (shift X, C2), C1) -> (shift X, (C1 + C2))
if (Opcode == N0.getOpcode()) {
unsigned ShiftVal2 = cast<ConstantSDNode>(N0.getOperand(1))->getZExtValue(); unsigned ShiftVal2 = cast<ConstantSDNode>(N0.getOperand(1))->getZExtValue();
unsigned NewShiftVal = ShiftVal + ShiftVal2; unsigned NewShiftVal = ShiftVal + ShiftVal2;
if (NewShiftVal >= NumBitsPerElt) if (NewShiftVal >= NumBitsPerElt) {
// Out of range logical bit shifts are guaranteed to be zero.
// Out of range arithmetic bit shifts splat the sign bit.
if (LogicalShift)
return DAG.getConstant(0, SDLoc(N), VT);
NewShiftVal = NumBitsPerElt - 1; NewShiftVal = NumBitsPerElt - 1;
return DAG.getNode(X86ISD::VSRAI, SDLoc(N), VT, N0.getOperand(0), }
return DAG.getNode(Opcode, SDLoc(N), VT, N0.getOperand(0),
DAG.getTargetConstant(NewShiftVal, SDLoc(N), MVT::i8)); DAG.getTargetConstant(NewShiftVal, SDLoc(N), MVT::i8));
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -146,16 +146,16 @@ define <8 x i64> @vec512_i64_signed_reg_reg(<8 x i64> %a1, <8 x i64> %a2) nounwi
; ALL-NEXT: vpminsq %zmm1, %zmm0, %zmm2 ; ALL-NEXT: vpminsq %zmm1, %zmm0, %zmm2
; ALL-NEXT: vpmaxsq %zmm1, %zmm0, %zmm1 ; ALL-NEXT: vpmaxsq %zmm1, %zmm0, %zmm1
; ALL-NEXT: vpsubq %zmm2, %zmm1, %zmm1 ; ALL-NEXT: vpsubq %zmm2, %zmm1, %zmm1
; ALL-NEXT: vpsrlq $1, %zmm1, %zmm1 ; ALL-NEXT: vpsrlq $1, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $32, %zmm3, %zmm2 ; ALL-NEXT: vpsrlq $33, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm2, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $32, %zmm1, %zmm4
; ALL-NEXT: vpmuludq %zmm3, %zmm4, %zmm4
; ALL-NEXT: vpaddq %zmm4, %zmm2, %zmm2
; ALL-NEXT: vpsllq $32, %zmm2, %zmm2
; ALL-NEXT: vpmuludq %zmm3, %zmm1, %zmm1 ; ALL-NEXT: vpmuludq %zmm3, %zmm1, %zmm1
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0 ; ALL-NEXT: vpsrlq $32, %zmm3, %zmm4
; ALL-NEXT: vpmuludq %zmm4, %zmm2, %zmm4
; ALL-NEXT: vpaddq %zmm1, %zmm4, %zmm1
; ALL-NEXT: vpsllq $32, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm3, %zmm2, %zmm2
; ALL-NEXT: vpaddq %zmm0, %zmm1, %zmm0 ; ALL-NEXT: vpaddq %zmm0, %zmm1, %zmm0
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0
; ALL-NEXT: retq ; ALL-NEXT: retq
%t3 = icmp sgt <8 x i64> %a1, %a2 ; signed %t3 = icmp sgt <8 x i64> %a1, %a2 ; signed
%t4 = select <8 x i1> %t3, <8 x i64> <i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1>, <8 x i64> <i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1> %t4 = select <8 x i1> %t3, <8 x i64> <i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1>, <8 x i64> <i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1>
@ -178,16 +178,16 @@ define <8 x i64> @vec512_i64_unsigned_reg_reg(<8 x i64> %a1, <8 x i64> %a2) noun
; ALL-NEXT: vpminuq %zmm1, %zmm0, %zmm2 ; ALL-NEXT: vpminuq %zmm1, %zmm0, %zmm2
; ALL-NEXT: vpmaxuq %zmm1, %zmm0, %zmm1 ; ALL-NEXT: vpmaxuq %zmm1, %zmm0, %zmm1
; ALL-NEXT: vpsubq %zmm2, %zmm1, %zmm1 ; ALL-NEXT: vpsubq %zmm2, %zmm1, %zmm1
; ALL-NEXT: vpsrlq $1, %zmm1, %zmm1 ; ALL-NEXT: vpsrlq $1, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $32, %zmm3, %zmm2 ; ALL-NEXT: vpsrlq $33, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm2, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $32, %zmm1, %zmm4
; ALL-NEXT: vpmuludq %zmm3, %zmm4, %zmm4
; ALL-NEXT: vpaddq %zmm4, %zmm2, %zmm2
; ALL-NEXT: vpsllq $32, %zmm2, %zmm2
; ALL-NEXT: vpmuludq %zmm3, %zmm1, %zmm1 ; ALL-NEXT: vpmuludq %zmm3, %zmm1, %zmm1
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0 ; ALL-NEXT: vpsrlq $32, %zmm3, %zmm4
; ALL-NEXT: vpmuludq %zmm4, %zmm2, %zmm4
; ALL-NEXT: vpaddq %zmm1, %zmm4, %zmm1
; ALL-NEXT: vpsllq $32, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm3, %zmm2, %zmm2
; ALL-NEXT: vpaddq %zmm0, %zmm1, %zmm0 ; ALL-NEXT: vpaddq %zmm0, %zmm1, %zmm0
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0
; ALL-NEXT: retq ; ALL-NEXT: retq
%t3 = icmp ugt <8 x i64> %a1, %a2 %t3 = icmp ugt <8 x i64> %a1, %a2
%t4 = select <8 x i1> %t3, <8 x i64> <i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1>, <8 x i64> <i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1> %t4 = select <8 x i1> %t3, <8 x i64> <i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1>, <8 x i64> <i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1>
@ -213,16 +213,16 @@ define <8 x i64> @vec512_i64_signed_mem_reg(<8 x i64>* %a1_addr, <8 x i64> %a2)
; ALL-NEXT: vpminsq %zmm0, %zmm1, %zmm2 ; ALL-NEXT: vpminsq %zmm0, %zmm1, %zmm2
; ALL-NEXT: vpmaxsq %zmm0, %zmm1, %zmm0 ; ALL-NEXT: vpmaxsq %zmm0, %zmm1, %zmm0
; ALL-NEXT: vpsubq %zmm2, %zmm0, %zmm0 ; ALL-NEXT: vpsubq %zmm2, %zmm0, %zmm0
; ALL-NEXT: vpsrlq $1, %zmm0, %zmm0 ; ALL-NEXT: vpsrlq $1, %zmm0, %zmm2
; ALL-NEXT: vpsrlq $32, %zmm3, %zmm2 ; ALL-NEXT: vpsrlq $33, %zmm0, %zmm0
; ALL-NEXT: vpmuludq %zmm2, %zmm0, %zmm2
; ALL-NEXT: vpsrlq $32, %zmm0, %zmm4
; ALL-NEXT: vpmuludq %zmm3, %zmm4, %zmm4
; ALL-NEXT: vpaddq %zmm4, %zmm2, %zmm2
; ALL-NEXT: vpsllq $32, %zmm2, %zmm2
; ALL-NEXT: vpmuludq %zmm3, %zmm0, %zmm0 ; ALL-NEXT: vpmuludq %zmm3, %zmm0, %zmm0
; ALL-NEXT: vpaddq %zmm1, %zmm2, %zmm1 ; ALL-NEXT: vpsrlq $32, %zmm3, %zmm4
; ALL-NEXT: vpmuludq %zmm4, %zmm2, %zmm4
; ALL-NEXT: vpaddq %zmm0, %zmm4, %zmm0
; ALL-NEXT: vpsllq $32, %zmm0, %zmm0
; ALL-NEXT: vpmuludq %zmm3, %zmm2, %zmm2
; ALL-NEXT: vpaddq %zmm1, %zmm0, %zmm0 ; ALL-NEXT: vpaddq %zmm1, %zmm0, %zmm0
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0
; ALL-NEXT: retq ; ALL-NEXT: retq
%a1 = load <8 x i64>, <8 x i64>* %a1_addr %a1 = load <8 x i64>, <8 x i64>* %a1_addr
%t3 = icmp sgt <8 x i64> %a1, %a2 ; signed %t3 = icmp sgt <8 x i64> %a1, %a2 ; signed
@ -247,16 +247,16 @@ define <8 x i64> @vec512_i64_signed_reg_mem(<8 x i64> %a1, <8 x i64>* %a2_addr)
; ALL-NEXT: vpminsq %zmm1, %zmm0, %zmm2 ; ALL-NEXT: vpminsq %zmm1, %zmm0, %zmm2
; ALL-NEXT: vpmaxsq %zmm1, %zmm0, %zmm1 ; ALL-NEXT: vpmaxsq %zmm1, %zmm0, %zmm1
; ALL-NEXT: vpsubq %zmm2, %zmm1, %zmm1 ; ALL-NEXT: vpsubq %zmm2, %zmm1, %zmm1
; ALL-NEXT: vpsrlq $1, %zmm1, %zmm1 ; ALL-NEXT: vpsrlq $1, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $32, %zmm3, %zmm2 ; ALL-NEXT: vpsrlq $33, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm2, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $32, %zmm1, %zmm4
; ALL-NEXT: vpmuludq %zmm3, %zmm4, %zmm4
; ALL-NEXT: vpaddq %zmm4, %zmm2, %zmm2
; ALL-NEXT: vpsllq $32, %zmm2, %zmm2
; ALL-NEXT: vpmuludq %zmm3, %zmm1, %zmm1 ; ALL-NEXT: vpmuludq %zmm3, %zmm1, %zmm1
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0 ; ALL-NEXT: vpsrlq $32, %zmm3, %zmm4
; ALL-NEXT: vpmuludq %zmm4, %zmm2, %zmm4
; ALL-NEXT: vpaddq %zmm1, %zmm4, %zmm1
; ALL-NEXT: vpsllq $32, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm3, %zmm2, %zmm2
; ALL-NEXT: vpaddq %zmm0, %zmm1, %zmm0 ; ALL-NEXT: vpaddq %zmm0, %zmm1, %zmm0
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0
; ALL-NEXT: retq ; ALL-NEXT: retq
%a2 = load <8 x i64>, <8 x i64>* %a2_addr %a2 = load <8 x i64>, <8 x i64>* %a2_addr
%t3 = icmp sgt <8 x i64> %a1, %a2 ; signed %t3 = icmp sgt <8 x i64> %a1, %a2 ; signed
@ -282,16 +282,16 @@ define <8 x i64> @vec512_i64_signed_mem_mem(<8 x i64>* %a1_addr, <8 x i64>* %a2_
; ALL-NEXT: vpminsq %zmm1, %zmm0, %zmm2 ; ALL-NEXT: vpminsq %zmm1, %zmm0, %zmm2
; ALL-NEXT: vpmaxsq %zmm1, %zmm0, %zmm1 ; ALL-NEXT: vpmaxsq %zmm1, %zmm0, %zmm1
; ALL-NEXT: vpsubq %zmm2, %zmm1, %zmm1 ; ALL-NEXT: vpsubq %zmm2, %zmm1, %zmm1
; ALL-NEXT: vpsrlq $1, %zmm1, %zmm1 ; ALL-NEXT: vpsrlq $1, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $32, %zmm3, %zmm2 ; ALL-NEXT: vpsrlq $33, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm2, %zmm1, %zmm2
; ALL-NEXT: vpsrlq $32, %zmm1, %zmm4
; ALL-NEXT: vpmuludq %zmm3, %zmm4, %zmm4
; ALL-NEXT: vpaddq %zmm4, %zmm2, %zmm2
; ALL-NEXT: vpsllq $32, %zmm2, %zmm2
; ALL-NEXT: vpmuludq %zmm3, %zmm1, %zmm1 ; ALL-NEXT: vpmuludq %zmm3, %zmm1, %zmm1
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0 ; ALL-NEXT: vpsrlq $32, %zmm3, %zmm4
; ALL-NEXT: vpmuludq %zmm4, %zmm2, %zmm4
; ALL-NEXT: vpaddq %zmm1, %zmm4, %zmm1
; ALL-NEXT: vpsllq $32, %zmm1, %zmm1
; ALL-NEXT: vpmuludq %zmm3, %zmm2, %zmm2
; ALL-NEXT: vpaddq %zmm0, %zmm1, %zmm0 ; ALL-NEXT: vpaddq %zmm0, %zmm1, %zmm0
; ALL-NEXT: vpaddq %zmm0, %zmm2, %zmm0
; ALL-NEXT: retq ; ALL-NEXT: retq
%a1 = load <8 x i64>, <8 x i64>* %a1_addr %a1 = load <8 x i64>, <8 x i64>* %a1_addr
%a2 = load <8 x i64>, <8 x i64>* %a2_addr %a2 = load <8 x i64>, <8 x i64>* %a2_addr