forked from OSchip/llvm-project
[X86][AVX] lowerShuffleWithVTRUNC - pull out TRUNCATE/VTRUNC creation into helper code. NFCI.
Prep work toward adding v16i16/v32i8 support for lowerShuffleWithVTRUNC and improving lowerShuffleWithVPMOV.
This commit is contained in:
parent
2f5f5febf3
commit
d5621b83a5
|
|
@ -11286,6 +11286,37 @@ static bool matchShuffleAsVTRUNC(MVT &SrcVT, MVT &DstVT, MVT VT,
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper to create TRUNCATE/VTRUNC nodes, optionally with zero/undef upper
|
||||||
|
// element padding to the final DstVT.
|
||||||
|
static SDValue getAVX512TruncNode(const SDLoc &DL, MVT DstVT, SDValue Src,
|
||||||
|
const X86Subtarget &Subtarget,
|
||||||
|
SelectionDAG &DAG, bool ZeroUppers) {
|
||||||
|
MVT SrcVT = Src.getSimpleValueType();
|
||||||
|
unsigned NumDstElts = DstVT.getVectorNumElements();
|
||||||
|
unsigned NumSrcElts = SrcVT.getVectorNumElements();
|
||||||
|
|
||||||
|
// Perform a direct ISD::TRUNCATE if possible.
|
||||||
|
if (NumSrcElts == NumDstElts)
|
||||||
|
return DAG.getNode(ISD::TRUNCATE, DL, DstVT, Src);
|
||||||
|
|
||||||
|
if (NumSrcElts > NumDstElts) {
|
||||||
|
MVT TruncVT = MVT::getVectorVT(DstVT.getScalarType(), NumSrcElts);
|
||||||
|
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Src);
|
||||||
|
return extractSubVector(Trunc, 0, DAG, DL, DstVT.getSizeInBits());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-VLX targets must truncate from a 512-bit type, so we need to
|
||||||
|
// widen, truncate and then possibly extract the original subvector.
|
||||||
|
if (!Subtarget.hasVLX() && !SrcVT.is512BitVector()) {
|
||||||
|
SDValue NewSrc = widenSubVector(Src, ZeroUppers, Subtarget, DAG, DL, 512);
|
||||||
|
return getAVX512TruncNode(DL, DstVT, NewSrc, Subtarget, DAG, ZeroUppers);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to a X86ISD::VTRUNC.
|
||||||
|
// TODO: Handle cases where we go from 512-bit vectors to sub-128-bit vectors.
|
||||||
|
return DAG.getNode(X86ISD::VTRUNC, DL, DstVT, Src);
|
||||||
|
}
|
||||||
|
|
||||||
static bool matchShuffleAsVPMOV(ArrayRef<int> Mask, bool SwappedOps,
|
static bool matchShuffleAsVPMOV(ArrayRef<int> Mask, bool SwappedOps,
|
||||||
int Delta) {
|
int Delta) {
|
||||||
int Size = (int)Mask.size();
|
int Size = (int)Mask.size();
|
||||||
|
|
@ -11388,7 +11419,7 @@ static SDValue lowerShuffleAsVTRUNC(const SDLoc &DL, MVT VT, SDValue V1,
|
||||||
|
|
||||||
unsigned NumElts = VT.getVectorNumElements();
|
unsigned NumElts = VT.getVectorNumElements();
|
||||||
unsigned EltSizeInBits = VT.getScalarSizeInBits();
|
unsigned EltSizeInBits = VT.getScalarSizeInBits();
|
||||||
unsigned MaxScale = 64 / VT.getScalarSizeInBits();
|
unsigned MaxScale = 64 / EltSizeInBits;
|
||||||
for (unsigned Scale = 2; Scale <= MaxScale; Scale += Scale) {
|
for (unsigned Scale = 2; Scale <= MaxScale; Scale += Scale) {
|
||||||
// TODO: Support non-BWI VPMOVWB truncations?
|
// TODO: Support non-BWI VPMOVWB truncations?
|
||||||
unsigned SrcEltBits = EltSizeInBits * Scale;
|
unsigned SrcEltBits = EltSizeInBits * Scale;
|
||||||
|
|
@ -11408,36 +11439,18 @@ static SDValue lowerShuffleAsVTRUNC(const SDLoc &DL, MVT VT, SDValue V1,
|
||||||
if (UpperElts > 0 &&
|
if (UpperElts > 0 &&
|
||||||
!Zeroable.extractBits(UpperElts, NumSrcElts).isAllOnesValue())
|
!Zeroable.extractBits(UpperElts, NumSrcElts).isAllOnesValue())
|
||||||
continue;
|
continue;
|
||||||
|
bool UndefUppers =
|
||||||
|
UpperElts > 0 && isUndefInRange(Mask, NumSrcElts, UpperElts);
|
||||||
|
|
||||||
// As we're using both sources then we need to concat them together
|
// As we're using both sources then we need to concat them together
|
||||||
// and truncate from the 256-bit src.
|
// and truncate from the double-sized src.
|
||||||
MVT ConcatVT = MVT::getVectorVT(VT.getScalarType(), NumElts * 2);
|
MVT ConcatVT = MVT::getVectorVT(VT.getScalarType(), NumElts * 2);
|
||||||
SDValue Src = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT, V1, V2);
|
SDValue Src = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT, V1, V2);
|
||||||
|
|
||||||
MVT SrcSVT = MVT::getIntegerVT(SrcEltBits);
|
MVT SrcSVT = MVT::getIntegerVT(SrcEltBits);
|
||||||
MVT SrcVT = MVT::getVectorVT(SrcSVT, NumSrcElts);
|
MVT SrcVT = MVT::getVectorVT(SrcSVT, NumSrcElts);
|
||||||
Src = DAG.getBitcast(SrcVT, Src);
|
Src = DAG.getBitcast(SrcVT, Src);
|
||||||
|
return getAVX512TruncNode(DL, VT, Src, Subtarget, DAG, !UndefUppers);
|
||||||
if (SrcVT.getVectorNumElements() == NumElts)
|
|
||||||
return DAG.getNode(ISD::TRUNCATE, DL, VT, Src);
|
|
||||||
|
|
||||||
if (!Subtarget.hasVLX()) {
|
|
||||||
// Non-VLX targets must truncate from a 512-bit type, so we need to
|
|
||||||
// widen, truncate and then possibly extract the original 128-bit
|
|
||||||
// vector.
|
|
||||||
bool UndefUppers = isUndefInRange(Mask, NumSrcElts, UpperElts);
|
|
||||||
Src = widenSubVector(Src, !UndefUppers, Subtarget, DAG, DL, 512);
|
|
||||||
unsigned NumWideSrcElts = Src.getValueType().getVectorNumElements();
|
|
||||||
if (NumWideSrcElts >= NumElts) {
|
|
||||||
// Widening means we can now use a regular TRUNCATE.
|
|
||||||
MVT WideVT = MVT::getVectorVT(VT.getScalarType(), NumWideSrcElts);
|
|
||||||
SDValue WideRes = DAG.getNode(ISD::TRUNCATE, DL, WideVT, Src);
|
|
||||||
if (!WideVT.is128BitVector())
|
|
||||||
WideRes = extract128BitVector(WideRes, 0, DAG, DL);
|
|
||||||
return WideRes;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return DAG.getNode(X86ISD::VTRUNC, DL, VT, Src);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return SDValue();
|
return SDValue();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue