diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index baf6d06da7d9..8a53ca96e1fe 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -24960,26 +24960,31 @@ static bool matchUnaryPermuteVectorShuffle(MVT MaskVT, ArrayRef Mask, // shuffle instructions. // TODO: Investigate sharing more of this with shuffle lowering. static bool matchBinaryVectorShuffle(MVT MaskVT, ArrayRef Mask, + SDValue &V1, SDValue &V2, unsigned &Shuffle, MVT &ShuffleVT) { bool FloatDomain = MaskVT.isFloatingPoint(); if (MaskVT.is128BitVector()) { if (isTargetShuffleEquivalent(Mask, {0, 0}) && FloatDomain) { + V2 = V1; Shuffle = X86ISD::MOVLHPS; ShuffleVT = MVT::v4f32; return true; } if (isTargetShuffleEquivalent(Mask, {1, 1}) && FloatDomain) { + V2 = V1; Shuffle = X86ISD::MOVHLPS; ShuffleVT = MVT::v4f32; return true; } if (isTargetShuffleEquivalent(Mask, {0, 0, 1, 1}) && FloatDomain) { + V2 = V1; Shuffle = X86ISD::UNPCKL; ShuffleVT = MVT::v4f32; return true; } if (isTargetShuffleEquivalent(Mask, {2, 2, 3, 3}) && FloatDomain) { + V2 = V1; Shuffle = X86ISD::UNPCKH; ShuffleVT = MVT::v4f32; return true; @@ -24987,6 +24992,7 @@ static bool matchBinaryVectorShuffle(MVT MaskVT, ArrayRef Mask, if (isTargetShuffleEquivalent(Mask, {0, 0, 1, 1, 2, 2, 3, 3}) || isTargetShuffleEquivalent( Mask, {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7})) { + V2 = V1; Shuffle = X86ISD::UNPCKL; ShuffleVT = Mask.size() == 8 ? MVT::v8i16 : MVT::v16i8; return true; @@ -24994,6 +25000,7 @@ static bool matchBinaryVectorShuffle(MVT MaskVT, ArrayRef Mask, if (isTargetShuffleEquivalent(Mask, {4, 4, 5, 5, 6, 6, 7, 7}) || isTargetShuffleEquivalent(Mask, {8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15})) { + V2 = V1; Shuffle = X86ISD::UNPCKH; ShuffleVT = Mask.size() == 8 ? MVT::v8i16 : MVT::v16i8; return true; @@ -25201,19 +25208,20 @@ static bool combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, /*AddTo*/ true); return true; } + } - // TODO - this should support binary shuffles. - if (matchBinaryVectorShuffle(MaskVT, Mask, Shuffle, ShuffleVT)) { - if (Depth == 1 && Root.getOpcode() == Shuffle) - return false; // Nothing to do! - Res = DAG.getBitcast(ShuffleVT, V1); - DCI.AddToWorklist(Res.getNode()); - Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, Res); - DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; - } + if (matchBinaryVectorShuffle(MaskVT, Mask, V1, V2, Shuffle, ShuffleVT)) { + if (Depth == 1 && Root.getOpcode() == Shuffle) + return false; // Nothing to do! + V1 = DAG.getBitcast(ShuffleVT, V1); + DCI.AddToWorklist(V1.getNode()); + V2 = DAG.getBitcast(ShuffleVT, V2); + DCI.AddToWorklist(V2.getNode()); + Res = DAG.getNode(Shuffle, DL, ShuffleVT, V1, V2); + DCI.AddToWorklist(Res.getNode()); + DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), + /*AddTo*/ true); + return true; } if (matchBinaryPermuteVectorShuffle(MaskVT, Mask, V1, V2, DL, DAG, Subtarget,