[PowerPC] refactor rewriteLoadStores for reusing; nfc

This is split from https://reviews.llvm.org/D108750.
Refactor rewriteLoadStores() so that we can reuse the outlined
functions.

Reviewed By: jsji

Differential Revision: https://reviews.llvm.org/D110314
This commit is contained in:
Chen Zheng 2021-09-23 05:48:46 +00:00
parent 7ee133d3fc
commit 1bf05fbc98
1 changed files with 238 additions and 169 deletions

View File

@ -125,10 +125,10 @@ STATISTIC(UpdFormChainRewritten, "Num of update form chain rewritten");
namespace {
struct BucketElement {
BucketElement(const SCEVConstant *O, Instruction *I) : Offset(O), Instr(I) {}
BucketElement(const SCEV *O, Instruction *I) : Offset(O), Instr(I) {}
BucketElement(Instruction *I) : Offset(nullptr), Instr(I) {}
const SCEVConstant *Offset;
const SCEV *Offset;
Instruction *Instr;
};
@ -234,6 +234,19 @@ namespace {
bool rewriteLoadStores(Loop *L, Bucket &BucketChain,
SmallSet<BasicBlock *, 16> &BBChanged,
InstrForm Form);
/// Rewrite for the base load/store of a chain.
std::pair<Instruction *, Instruction *>
rewriteForBase(Loop *L, const SCEVAddRecExpr *BasePtrSCEV,
Instruction *BaseMemI, bool CanPreInc, InstrForm Form,
SCEVExpander &SCEVE, SmallPtrSet<Value *, 16> &DeletedPtrs);
/// Rewrite for the other load/stores of a chain according to the new \p
/// Base.
Instruction *
rewriteForBucketElement(std::pair<Instruction *, Instruction *> Base,
const BucketElement &Element, Value *OffToBase,
SmallPtrSet<Value *, 16> &DeletedPtrs);
};
} // end anonymous namespace
@ -321,6 +334,193 @@ bool PPCLoopInstrFormPrep::runOnFunction(Function &F) {
return MadeChange;
}
// Rewrite the new base according to BasePtrSCEV.
// bb.loop.preheader:
// %newstart = ...
// bb.loop.body:
// %phinode = phi [ %newstart, %bb.loop.preheader ], [ %add, %bb.loop.body ]
// ...
// %add = getelementptr %phinode, %inc
//
// First returned instruciton is %phinode (or a type cast to %phinode), caller
// needs this value to rewrite other load/stores in the same chain.
// Second returned instruction is %add, caller needs this value to rewrite other
// load/stores in the same chain.
std::pair<Instruction *, Instruction *>
PPCLoopInstrFormPrep::rewriteForBase(Loop *L, const SCEVAddRecExpr *BasePtrSCEV,
Instruction *BaseMemI, bool CanPreInc,
InstrForm Form, SCEVExpander &SCEVE,
SmallPtrSet<Value *, 16> &DeletedPtrs) {
LLVM_DEBUG(dbgs() << "PIP: Transforming: " << *BasePtrSCEV << "\n");
assert(BasePtrSCEV->getLoop() == L && "AddRec for the wrong loop?");
Value *BasePtr = getPointerOperandAndType(BaseMemI);
assert(BasePtr && "No pointer operand");
Type *I8Ty = Type::getInt8Ty(BaseMemI->getParent()->getContext());
Type *I8PtrTy =
Type::getInt8PtrTy(BaseMemI->getParent()->getContext(),
BasePtr->getType()->getPointerAddressSpace());
bool IsConstantInc = false;
const SCEV *BasePtrIncSCEV = BasePtrSCEV->getStepRecurrence(*SE);
Value *IncNode = getNodeForInc(L, BaseMemI, BasePtrIncSCEV);
const SCEVConstant *BasePtrIncConstantSCEV =
dyn_cast<SCEVConstant>(BasePtrIncSCEV);
if (BasePtrIncConstantSCEV)
IsConstantInc = true;
// No valid representation for the increment.
if (!IncNode) {
LLVM_DEBUG(dbgs() << "Loop Increasement can not be represented!\n");
return std::make_pair(nullptr, nullptr);
}
const SCEV *BasePtrStartSCEV = nullptr;
if (CanPreInc) {
assert(SE->isLoopInvariant(BasePtrIncSCEV, L) &&
"Increment is not loop invariant!\n");
BasePtrStartSCEV = SE->getMinusSCEV(BasePtrSCEV->getStart(),
IsConstantInc ? BasePtrIncConstantSCEV
: BasePtrIncSCEV);
} else
BasePtrStartSCEV = BasePtrSCEV->getStart();
if (alreadyPrepared(L, BaseMemI, BasePtrStartSCEV, BasePtrIncSCEV, Form)) {
LLVM_DEBUG(dbgs() << "Instruction form is already prepared!\n");
return std::make_pair(nullptr, nullptr);
}
LLVM_DEBUG(dbgs() << "PIP: New start is: " << *BasePtrStartSCEV << "\n");
BasicBlock *Header = L->getHeader();
unsigned HeaderLoopPredCount = pred_size(Header);
BasicBlock *LoopPredecessor = L->getLoopPredecessor();
PHINode *NewPHI = PHINode::Create(I8PtrTy, HeaderLoopPredCount,
getInstrName(BaseMemI, PHINodeNameSuffix),
Header->getFirstNonPHI());
Value *BasePtrStart = SCEVE.expandCodeFor(BasePtrStartSCEV, I8PtrTy,
LoopPredecessor->getTerminator());
// Note that LoopPredecessor might occur in the predecessor list multiple
// times, and we need to add it the right number of times.
for (auto PI : predecessors(Header)) {
if (PI != LoopPredecessor)
continue;
NewPHI->addIncoming(BasePtrStart, LoopPredecessor);
}
Instruction *PtrInc = nullptr;
Instruction *NewBasePtr = nullptr;
if (CanPreInc) {
Instruction *InsPoint = &*Header->getFirstInsertionPt();
PtrInc = GetElementPtrInst::Create(
I8Ty, NewPHI, IncNode, getInstrName(BaseMemI, GEPNodeIncNameSuffix),
InsPoint);
cast<GetElementPtrInst>(PtrInc)->setIsInBounds(IsPtrInBounds(BasePtr));
for (auto PI : predecessors(Header)) {
if (PI == LoopPredecessor)
continue;
NewPHI->addIncoming(PtrInc, PI);
}
if (PtrInc->getType() != BasePtr->getType())
NewBasePtr =
new BitCastInst(PtrInc, BasePtr->getType(),
getInstrName(PtrInc, CastNodeNameSuffix), InsPoint);
else
NewBasePtr = PtrInc;
} else {
// Note that LoopPredecessor might occur in the predecessor list multiple
// times, and we need to make sure no more incoming value for them in PHI.
for (auto PI : predecessors(Header)) {
if (PI == LoopPredecessor)
continue;
// For the latch predecessor, we need to insert a GEP just before the
// terminator to increase the address.
BasicBlock *BB = PI;
Instruction *InsPoint = BB->getTerminator();
PtrInc = GetElementPtrInst::Create(
I8Ty, NewPHI, IncNode, getInstrName(BaseMemI, GEPNodeIncNameSuffix),
InsPoint);
cast<GetElementPtrInst>(PtrInc)->setIsInBounds(IsPtrInBounds(BasePtr));
NewPHI->addIncoming(PtrInc, PI);
}
PtrInc = NewPHI;
if (NewPHI->getType() != BasePtr->getType())
NewBasePtr = new BitCastInst(NewPHI, BasePtr->getType(),
getInstrName(NewPHI, CastNodeNameSuffix),
&*Header->getFirstInsertionPt());
else
NewBasePtr = NewPHI;
}
BasePtr->replaceAllUsesWith(NewBasePtr);
DeletedPtrs.insert(BasePtr);
return std::make_pair(NewBasePtr, PtrInc);
}
Instruction *PPCLoopInstrFormPrep::rewriteForBucketElement(
std::pair<Instruction *, Instruction *> Base, const BucketElement &Element,
Value *OffToBase, SmallPtrSet<Value *, 16> &DeletedPtrs) {
Instruction *NewBasePtr = Base.first;
Instruction *PtrInc = Base.second;
assert((NewBasePtr && PtrInc) && "base does not exist!\n");
Type *I8Ty = Type::getInt8Ty(PtrInc->getParent()->getContext());
Value *Ptr = getPointerOperandAndType(Element.Instr);
assert(Ptr && "No pointer operand");
Instruction *RealNewPtr;
if (!Element.Offset ||
(isa<SCEVConstant>(Element.Offset) &&
cast<SCEVConstant>(Element.Offset)->getValue()->isZero())) {
RealNewPtr = NewBasePtr;
} else {
Instruction *PtrIP = dyn_cast<Instruction>(Ptr);
if (PtrIP && isa<Instruction>(NewBasePtr) &&
cast<Instruction>(NewBasePtr)->getParent() == PtrIP->getParent())
PtrIP = nullptr;
else if (PtrIP && isa<PHINode>(PtrIP))
PtrIP = &*PtrIP->getParent()->getFirstInsertionPt();
else if (!PtrIP)
PtrIP = Element.Instr;
assert(OffToBase && "There should be an offset for non base element!\n");
GetElementPtrInst *NewPtr = GetElementPtrInst::Create(
I8Ty, PtrInc, OffToBase,
getInstrName(Element.Instr, GEPNodeOffNameSuffix), PtrIP);
if (!PtrIP)
NewPtr->insertAfter(cast<Instruction>(PtrInc));
NewPtr->setIsInBounds(IsPtrInBounds(Ptr));
RealNewPtr = NewPtr;
}
Instruction *ReplNewPtr;
if (Ptr->getType() != RealNewPtr->getType()) {
ReplNewPtr = new BitCastInst(RealNewPtr, Ptr->getType(),
getInstrName(Ptr, CastNodeNameSuffix));
ReplNewPtr->insertAfter(RealNewPtr);
} else
ReplNewPtr = RealNewPtr;
Ptr->replaceAllUsesWith(ReplNewPtr);
DeletedPtrs.insert(Ptr);
return ReplNewPtr;
}
void PPCLoopInstrFormPrep::addOneCandidate(Instruction *MemI, const SCEV *LSCEV,
SmallVector<Bucket, 16> &Buckets,
unsigned MaxCandidateNum) {
@ -390,8 +590,9 @@ bool PPCLoopInstrFormPrep::prepareBaseForDispFormChain(Bucket &BucketChain,
if (!BucketChain.Elements[j].Offset)
RemainderOffsetInfo[0] = std::make_pair(0, 1);
else {
unsigned Remainder =
BucketChain.Elements[j].Offset->getAPInt().urem(Form);
unsigned Remainder = cast<SCEVConstant>(BucketChain.Elements[j].Offset)
->getAPInt()
.urem(Form);
if (RemainderOffsetInfo.find(Remainder) == RemainderOffsetInfo.end())
RemainderOffsetInfo[Remainder] = std::make_pair(j, 1);
else
@ -473,7 +674,7 @@ bool PPCLoopInstrFormPrep::prepareBaseForUpdateFormChain(Bucket &BucketChain) {
// If our chosen element has no offset from the base pointer, there's
// nothing to do.
if (!BucketChain.Elements[j].Offset ||
BucketChain.Elements[j].Offset->isZero())
cast<SCEVConstant>(BucketChain.Elements[j].Offset)->isZero())
break;
const SCEV *Offset = BucketChain.Elements[j].Offset;
@ -491,157 +692,46 @@ bool PPCLoopInstrFormPrep::prepareBaseForUpdateFormChain(Bucket &BucketChain) {
return true;
}
bool PPCLoopInstrFormPrep::rewriteLoadStores(Loop *L, Bucket &BucketChain,
SmallSet<BasicBlock *, 16> &BBChanged,
InstrForm Form) {
bool PPCLoopInstrFormPrep::rewriteLoadStores(
Loop *L, Bucket &BucketChain, SmallSet<BasicBlock *, 16> &BBChanged,
InstrForm Form) {
bool MadeChange = false;
const SCEVAddRecExpr *BasePtrSCEV =
cast<SCEVAddRecExpr>(BucketChain.BaseSCEV);
if (!BasePtrSCEV->isAffine())
return MadeChange;
LLVM_DEBUG(dbgs() << "PIP: Transforming: " << *BasePtrSCEV << "\n");
assert(BasePtrSCEV->getLoop() == L && "AddRec for the wrong loop?");
// The instruction corresponding to the Bucket's BaseSCEV must be the first
// in the vector of elements.
Instruction *MemI = BucketChain.Elements.begin()->Instr;
Value *BasePtr = getPointerOperandAndType(MemI);
assert(BasePtr && "No pointer operand");
Type *I8Ty = Type::getInt8Ty(MemI->getParent()->getContext());
Type *I8PtrTy = Type::getInt8PtrTy(MemI->getParent()->getContext(),
BasePtr->getType()->getPointerAddressSpace());
if (!SE->isLoopInvariant(BasePtrSCEV->getStart(), L))
if (!isSafeToExpand(BasePtrSCEV->getStart(), *SE))
return MadeChange;
bool IsConstantInc = false;
const SCEV *BasePtrIncSCEV = BasePtrSCEV->getStepRecurrence(*SE);
Value *IncNode = getNodeForInc(L, MemI, BasePtrIncSCEV);
SmallPtrSet<Value *, 16> DeletedPtrs;
const SCEVConstant *BasePtrIncConstantSCEV =
dyn_cast<SCEVConstant>(BasePtrIncSCEV);
if (BasePtrIncConstantSCEV)
IsConstantInc = true;
// No valid representation for the increment.
if (!IncNode) {
LLVM_DEBUG(dbgs() << "Loop Increasement can not be represented!\n");
return MadeChange;
}
BasicBlock *Header = L->getHeader();
SCEVExpander SCEVE(*SE, Header->getModule()->getDataLayout(), "pistart");
// For some DS form load/store instructions, it can also be an update form,
// if the stride is constant and is a multipler of 4. Use update form if
// prefer it.
bool CanPreInc =
(Form == UpdateForm ||
((Form == DSForm) && IsConstantInc &&
!BasePtrIncConstantSCEV->getAPInt().urem(4) && PreferUpdateForm));
const SCEV *BasePtrStartSCEV = nullptr;
if (CanPreInc) {
assert(SE->isLoopInvariant(BasePtrIncSCEV, L) &&
"Increment is not loop invariant!\n");
BasePtrStartSCEV = SE->getMinusSCEV(BasePtrSCEV->getStart(),
IsConstantInc ? BasePtrIncConstantSCEV
: BasePtrIncSCEV);
} else
BasePtrStartSCEV = BasePtrSCEV->getStart();
bool CanPreInc = (Form == UpdateForm ||
((Form == DSForm) &&
isa<SCEVConstant>(BasePtrSCEV->getStepRecurrence(*SE)) &&
!cast<SCEVConstant>(BasePtrSCEV->getStepRecurrence(*SE))
->getAPInt()
.urem(4) &&
PreferUpdateForm));
if (!isSafeToExpand(BasePtrStartSCEV, *SE))
std::pair<Instruction *, Instruction *> Base =
rewriteForBase(L, BasePtrSCEV, BucketChain.Elements.begin()->Instr,
CanPreInc, Form, SCEVE, DeletedPtrs);
if (!Base.first || !Base.second)
return MadeChange;
if (alreadyPrepared(L, MemI, BasePtrStartSCEV, BasePtrIncSCEV, Form)) {
LLVM_DEBUG(dbgs() << "Instruction form is already prepared!\n");
return MadeChange;
}
LLVM_DEBUG(dbgs() << "PIP: New start is: " << *BasePtrStartSCEV << "\n");
BasicBlock *Header = L->getHeader();
unsigned HeaderLoopPredCount = pred_size(Header);
BasicBlock *LoopPredecessor = L->getLoopPredecessor();
PHINode *NewPHI =
PHINode::Create(I8PtrTy, HeaderLoopPredCount,
getInstrName(MemI, PHINodeNameSuffix),
Header->getFirstNonPHI());
SCEVExpander SCEVE(*SE, Header->getModule()->getDataLayout(), "pistart");
Value *BasePtrStart = SCEVE.expandCodeFor(BasePtrStartSCEV, I8PtrTy,
LoopPredecessor->getTerminator());
// Note that LoopPredecessor might occur in the predecessor list multiple
// times, and we need to add it the right number of times.
for (auto PI : predecessors(Header)) {
if (PI != LoopPredecessor)
continue;
NewPHI->addIncoming(BasePtrStart, LoopPredecessor);
}
Instruction *PtrInc = nullptr;
Instruction *NewBasePtr = nullptr;
if (CanPreInc) {
Instruction *InsPoint = &*Header->getFirstInsertionPt();
PtrInc = GetElementPtrInst::Create(I8Ty, NewPHI, IncNode,
getInstrName(MemI, GEPNodeIncNameSuffix),
InsPoint);
cast<GetElementPtrInst>(PtrInc)->setIsInBounds(IsPtrInBounds(BasePtr));
for (auto PI : predecessors(Header)) {
if (PI == LoopPredecessor)
continue;
NewPHI->addIncoming(PtrInc, PI);
}
if (PtrInc->getType() != BasePtr->getType())
NewBasePtr = new BitCastInst(
PtrInc, BasePtr->getType(),
getInstrName(PtrInc, CastNodeNameSuffix), InsPoint);
else
NewBasePtr = PtrInc;
} else {
// Note that LoopPredecessor might occur in the predecessor list multiple
// times, and we need to make sure no more incoming value for them in PHI.
for (auto PI : predecessors(Header)) {
if (PI == LoopPredecessor)
continue;
// For the latch predecessor, we need to insert a GEP just before the
// terminator to increase the address.
BasicBlock *BB = PI;
Instruction *InsPoint = BB->getTerminator();
PtrInc = GetElementPtrInst::Create(
I8Ty, NewPHI, IncNode, getInstrName(MemI, GEPNodeIncNameSuffix),
InsPoint);
cast<GetElementPtrInst>(PtrInc)->setIsInBounds(IsPtrInBounds(BasePtr));
NewPHI->addIncoming(PtrInc, PI);
}
PtrInc = NewPHI;
if (NewPHI->getType() != BasePtr->getType())
NewBasePtr =
new BitCastInst(NewPHI, BasePtr->getType(),
getInstrName(NewPHI, CastNodeNameSuffix),
&*Header->getFirstInsertionPt());
else
NewBasePtr = NewPHI;
}
// Clear the rewriter cache, because values that are in the rewriter's cache
// can be deleted below, causing the AssertingVH in the cache to trigger.
SCEVE.clear();
if (Instruction *IDel = dyn_cast<Instruction>(BasePtr))
BBChanged.insert(IDel->getParent());
BasePtr->replaceAllUsesWith(NewBasePtr);
RecursivelyDeleteTriviallyDeadInstructions(BasePtr);
// Keep track of the replacement pointer values we've inserted so that we
// don't generate more pointer values than necessary.
SmallPtrSet<Value *, 16> NewPtrs;
NewPtrs.insert(NewBasePtr);
NewPtrs.insert(Base.first);
for (auto I = std::next(BucketChain.Elements.begin()),
IE = BucketChain.Elements.end(); I != IE; ++I) {
@ -650,43 +740,22 @@ bool PPCLoopInstrFormPrep::rewriteLoadStores(Loop *L, Bucket &BucketChain,
if (NewPtrs.count(Ptr))
continue;
Instruction *RealNewPtr;
if (!I->Offset || I->Offset->getValue()->isZero()) {
RealNewPtr = NewBasePtr;
} else {
Instruction *PtrIP = dyn_cast<Instruction>(Ptr);
if (PtrIP && isa<Instruction>(NewBasePtr) &&
cast<Instruction>(NewBasePtr)->getParent() == PtrIP->getParent())
PtrIP = nullptr;
else if (PtrIP && isa<PHINode>(PtrIP))
PtrIP = &*PtrIP->getParent()->getFirstInsertionPt();
else if (!PtrIP)
PtrIP = I->Instr;
Instruction *NewPtr = rewriteForBucketElement(
Base, *I,
I->Offset ? cast<SCEVConstant>(I->Offset)->getValue() : nullptr,
DeletedPtrs);
assert(NewPtr && "wrong rewrite!\n");
NewPtrs.insert(NewPtr);
}
GetElementPtrInst *NewPtr = GetElementPtrInst::Create(
I8Ty, PtrInc, I->Offset->getValue(),
getInstrName(I->Instr, GEPNodeOffNameSuffix), PtrIP);
if (!PtrIP)
NewPtr->insertAfter(cast<Instruction>(PtrInc));
NewPtr->setIsInBounds(IsPtrInBounds(Ptr));
RealNewPtr = NewPtr;
}
// Clear the rewriter cache, because values that are in the rewriter's cache
// can be deleted below, causing the AssertingVH in the cache to trigger.
SCEVE.clear();
for (auto *Ptr : DeletedPtrs) {
if (Instruction *IDel = dyn_cast<Instruction>(Ptr))
BBChanged.insert(IDel->getParent());
Instruction *ReplNewPtr;
if (Ptr->getType() != RealNewPtr->getType()) {
ReplNewPtr = new BitCastInst(RealNewPtr, Ptr->getType(),
getInstrName(Ptr, CastNodeNameSuffix));
ReplNewPtr->insertAfter(RealNewPtr);
} else
ReplNewPtr = RealNewPtr;
Ptr->replaceAllUsesWith(ReplNewPtr);
RecursivelyDeleteTriviallyDeadInstructions(Ptr);
NewPtrs.insert(RealNewPtr);
}
MadeChange = true;