[X86][SSE] combineTargetShuffle - rearrange shuffle(hop,hop) matching to delay shuffle mask manipulation. NFC.

Check that we're shuffling hadd/pack ops first before altering shuffle masks.

First step towards adding extra functionality, plus it avoids costly shuffle mask manipulation if not necessary.
This commit is contained in:
Simon Pilgrim 2020-08-10 13:06:36 +01:00
parent 5f104a8099
commit e6dc2c8ce7
1 changed files with 23 additions and 24 deletions

View File

@ -35918,36 +35918,35 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
// Combine binary shuffle of 2 similar 'Horizontal' instructions into a // Combine binary shuffle of 2 similar 'Horizontal' instructions into a
// single instruction. Attempt to match a v2X64 repeating shuffle pattern that // single instruction. Attempt to match a v2X64 repeating shuffle pattern that
// represents the LHS/RHS inputs for the lower/upper halves. // represents the LHS/RHS inputs for the lower/upper halves.
SmallVector<int, 16> TargetMask128; if (!TargetMask.empty() && 0 < TargetOps.size() && TargetOps.size() <= 2) {
if (!TargetMask.empty() && 0 < TargetOps.size() && TargetOps.size() <= 2 && SDValue BC0 = peekThroughBitcasts(TargetOps.front());
isRepeatedTargetShuffleMask(128, VT, TargetMask, TargetMask128)) { SDValue BC1 = peekThroughBitcasts(TargetOps.back());
SmallVector<int, 16> WidenedMask128; EVT VT0 = BC0.getValueType();
if (scaleShuffleElements(TargetMask128, 2, WidenedMask128)) { EVT VT1 = BC1.getValueType();
assert(isUndefOrZeroOrInRange(WidenedMask128, 0, 4) && "Illegal shuffle"); unsigned Opcode0 = BC0.getOpcode();
SDValue BC0 = peekThroughBitcasts(TargetOps.front()); unsigned Opcode1 = BC1.getOpcode();
SDValue BC1 = peekThroughBitcasts(TargetOps.back()); bool isHoriz = (Opcode0 == X86ISD::FHADD || Opcode0 == X86ISD::HADD ||
EVT VT0 = BC0.getValueType(); Opcode0 == X86ISD::FHSUB || Opcode0 == X86ISD::HSUB);
EVT VT1 = BC1.getValueType(); bool isPack = (Opcode0 == X86ISD::PACKSS || Opcode0 == X86ISD::PACKUS);
unsigned Opcode0 = BC0.getOpcode(); if (Opcode0 == Opcode1 && VT0 == VT1 && (isHoriz || isPack)) {
unsigned Opcode1 = BC1.getOpcode(); SmallVector<int, 16> TargetMask128, WideMask128;
bool isHoriz = (Opcode0 == X86ISD::FHADD || Opcode0 == X86ISD::HADD || if (isRepeatedTargetShuffleMask(128, VT, TargetMask, TargetMask128) &&
Opcode0 == X86ISD::FHSUB || Opcode0 == X86ISD::HSUB); scaleShuffleElements(TargetMask128, 2, WideMask128)) {
if (Opcode0 == Opcode1 && VT0 == VT1 && assert(isUndefOrZeroOrInRange(WideMask128, 0, 4) && "Illegal shuffle");
(isHoriz || Opcode0 == X86ISD::PACKSS || Opcode0 == X86ISD::PACKUS)) {
bool SingleOp = (TargetOps.size() == 1); bool SingleOp = (TargetOps.size() == 1);
if (!isHoriz || shouldUseHorizontalOp(SingleOp, DAG, Subtarget)) { if (!isHoriz || shouldUseHorizontalOp(SingleOp, DAG, Subtarget)) {
SDValue Lo = isInRange(WidenedMask128[0], 0, 2) ? BC0 : BC1; SDValue Lo = isInRange(WideMask128[0], 0, 2) ? BC0 : BC1;
SDValue Hi = isInRange(WidenedMask128[1], 0, 2) ? BC0 : BC1; SDValue Hi = isInRange(WideMask128[1], 0, 2) ? BC0 : BC1;
Lo = Lo.getOperand(WidenedMask128[0] & 1); Lo = Lo.getOperand(WideMask128[0] & 1);
Hi = Hi.getOperand(WidenedMask128[1] & 1); Hi = Hi.getOperand(WideMask128[1] & 1);
if (SingleOp) { if (SingleOp) {
MVT SrcVT = BC0.getOperand(0).getSimpleValueType(); MVT SrcVT = BC0.getOperand(0).getSimpleValueType();
SDValue Undef = DAG.getUNDEF(SrcVT); SDValue Undef = DAG.getUNDEF(SrcVT);
SDValue Zero = getZeroVector(SrcVT, Subtarget, DAG, DL); SDValue Zero = getZeroVector(SrcVT, Subtarget, DAG, DL);
Lo = (WidenedMask128[0] == SM_SentinelZero ? Zero : Lo); Lo = (WideMask128[0] == SM_SentinelZero ? Zero : Lo);
Hi = (WidenedMask128[1] == SM_SentinelZero ? Zero : Hi); Hi = (WideMask128[1] == SM_SentinelZero ? Zero : Hi);
Lo = (WidenedMask128[0] == SM_SentinelUndef ? Undef : Lo); Lo = (WideMask128[0] == SM_SentinelUndef ? Undef : Lo);
Hi = (WidenedMask128[1] == SM_SentinelUndef ? Undef : Hi); Hi = (WideMask128[1] == SM_SentinelUndef ? Undef : Hi);
} }
SDValue Horiz = DAG.getNode(Opcode0, DL, VT0, Lo, Hi); SDValue Horiz = DAG.getNode(Opcode0, DL, VT0, Lo, Hi);
return DAG.getBitcast(VT, Horiz); return DAG.getBitcast(VT, Horiz);