[llvm][Codegen] Make `getVectorTypeBreakdownMVT` work with scalable types.

Reviewers: efriedma, andwar, sdesmalen

Reviewed By: efriedma

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D77434
This commit is contained in:
Francesco Petrogalli 2020-04-10 00:10:52 +01:00
parent 9f87d951fc
commit c846d2682b
2 changed files with 19 additions and 15 deletions

View File

@ -37,6 +37,7 @@ public:
return { Min * RHS, Scalable }; return { Min * RHS, Scalable };
} }
ElementCount operator/(unsigned RHS) { ElementCount operator/(unsigned RHS) {
assert(Min % RHS == 0 && "Min is not a multiple of RHS.");
return { Min / RHS, Scalable }; return { Min / RHS, Scalable };
} }

View File

@ -944,42 +944,45 @@ static unsigned getVectorTypeBreakdownMVT(MVT VT, MVT &IntermediateVT,
MVT &RegisterVT, MVT &RegisterVT,
TargetLoweringBase *TLI) { TargetLoweringBase *TLI) {
// Figure out the right, legal destination reg to copy into. // Figure out the right, legal destination reg to copy into.
unsigned NumElts = VT.getVectorNumElements(); ElementCount EC = VT.getVectorElementCount();
MVT EltTy = VT.getVectorElementType(); MVT EltTy = VT.getVectorElementType();
unsigned NumVectorRegs = 1; unsigned NumVectorRegs = 1;
// FIXME: We don't support non-power-of-2-sized vectors for now. Ideally we // FIXME: We don't support non-power-of-2-sized vectors for now.
// could break down into LHS/RHS like LegalizeDAG does. // Ideally we could break down into LHS/RHS like LegalizeDAG does.
if (!isPowerOf2_32(NumElts)) { if (!isPowerOf2_32(EC.Min)) {
NumVectorRegs = NumElts; // Split EC to unit size (scalable property is preserved).
NumElts = 1; NumVectorRegs = EC.Min;
EC = EC / NumVectorRegs;
} }
// Divide the input until we get to a supported size. This will always // Divide the input until we get to a supported size. This will
// end with a scalar if the target doesn't support vectors. // always end up with an EC that represent a scalar or a scalable
while (NumElts > 1 && !TLI->isTypeLegal(MVT::getVectorVT(EltTy, NumElts))) { // scalar.
NumElts >>= 1; while (EC.Min > 1 && !TLI->isTypeLegal(MVT::getVectorVT(EltTy, EC))) {
EC.Min >>= 1;
NumVectorRegs <<= 1; NumVectorRegs <<= 1;
} }
NumIntermediates = NumVectorRegs; NumIntermediates = NumVectorRegs;
MVT NewVT = MVT::getVectorVT(EltTy, NumElts); MVT NewVT = MVT::getVectorVT(EltTy, EC);
if (!TLI->isTypeLegal(NewVT)) if (!TLI->isTypeLegal(NewVT))
NewVT = EltTy; NewVT = EltTy;
IntermediateVT = NewVT; IntermediateVT = NewVT;
unsigned NewVTSize = NewVT.getSizeInBits(); unsigned LaneSizeInBits = NewVT.getScalarSizeInBits().getFixedSize();
// Convert sizes such as i33 to i64. // Convert sizes such as i33 to i64.
if (!isPowerOf2_32(NewVTSize)) if (!isPowerOf2_32(LaneSizeInBits))
NewVTSize = NextPowerOf2(NewVTSize); LaneSizeInBits = NextPowerOf2(LaneSizeInBits);
MVT DestVT = TLI->getRegisterType(NewVT); MVT DestVT = TLI->getRegisterType(NewVT);
RegisterVT = DestVT; RegisterVT = DestVT;
if (EVT(DestVT).bitsLT(NewVT)) // Value is expanded, e.g. i64 -> i16. if (EVT(DestVT).bitsLT(NewVT)) // Value is expanded, e.g. i64 -> i16.
return NumVectorRegs*(NewVTSize/DestVT.getSizeInBits()); return NumVectorRegs *
(LaneSizeInBits / DestVT.getScalarSizeInBits().getFixedSize());
// Otherwise, promotion or legal types use the same number of registers as // Otherwise, promotion or legal types use the same number of registers as
// the vector decimated to the appropriate level. // the vector decimated to the appropriate level.