1b1c73532SDimitry Andric #include "llvm/Transforms/Utils/LoopConstrainer.h"
2b1c73532SDimitry Andric #include "llvm/Analysis/LoopInfo.h"
3b1c73532SDimitry Andric #include "llvm/Analysis/ScalarEvolution.h"
4b1c73532SDimitry Andric #include "llvm/Analysis/ScalarEvolutionExpressions.h"
5b1c73532SDimitry Andric #include "llvm/IR/Dominators.h"
6b1c73532SDimitry Andric #include "llvm/Transforms/Utils/Cloning.h"
7b1c73532SDimitry Andric #include "llvm/Transforms/Utils/LoopSimplify.h"
8b1c73532SDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h"
9b1c73532SDimitry Andric #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
10b1c73532SDimitry Andric
11b1c73532SDimitry Andric using namespace llvm;
12b1c73532SDimitry Andric
13b1c73532SDimitry Andric static const char *ClonedLoopTag = "loop_constrainer.loop.clone";
14b1c73532SDimitry Andric
15b1c73532SDimitry Andric #define DEBUG_TYPE "loop-constrainer"
16b1c73532SDimitry Andric
17b1c73532SDimitry Andric /// Given a loop with an deccreasing induction variable, is it possible to
18b1c73532SDimitry Andric /// safely calculate the bounds of a new loop using the given Predicate.
isSafeDecreasingBound(const SCEV * Start,const SCEV * BoundSCEV,const SCEV * Step,ICmpInst::Predicate Pred,unsigned LatchBrExitIdx,Loop * L,ScalarEvolution & SE)19b1c73532SDimitry Andric static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
20b1c73532SDimitry Andric const SCEV *Step, ICmpInst::Predicate Pred,
21b1c73532SDimitry Andric unsigned LatchBrExitIdx, Loop *L,
22b1c73532SDimitry Andric ScalarEvolution &SE) {
23b1c73532SDimitry Andric if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
24b1c73532SDimitry Andric Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
25b1c73532SDimitry Andric return false;
26b1c73532SDimitry Andric
27b1c73532SDimitry Andric if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
28b1c73532SDimitry Andric return false;
29b1c73532SDimitry Andric
30b1c73532SDimitry Andric assert(SE.isKnownNegative(Step) && "expecting negative step");
31b1c73532SDimitry Andric
32b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n");
33b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
34b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
35b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
36b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
37b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
38b1c73532SDimitry Andric
39b1c73532SDimitry Andric bool IsSigned = ICmpInst::isSigned(Pred);
40b1c73532SDimitry Andric // The predicate that we need to check that the induction variable lies
41b1c73532SDimitry Andric // within bounds.
42b1c73532SDimitry Andric ICmpInst::Predicate BoundPred =
43b1c73532SDimitry Andric IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT;
44b1c73532SDimitry Andric
45ac9a064cSDimitry Andric auto StartLG = SE.applyLoopGuards(Start, L);
46ac9a064cSDimitry Andric auto BoundLG = SE.applyLoopGuards(BoundSCEV, L);
47ac9a064cSDimitry Andric
48b1c73532SDimitry Andric if (LatchBrExitIdx == 1)
49ac9a064cSDimitry Andric return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, BoundLG);
50b1c73532SDimitry Andric
51b1c73532SDimitry Andric assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1");
52b1c73532SDimitry Andric
53b1c73532SDimitry Andric const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType()));
54b1c73532SDimitry Andric unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
55b1c73532SDimitry Andric APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth)
56b1c73532SDimitry Andric : APInt::getMinValue(BitWidth);
57b1c73532SDimitry Andric const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne);
58b1c73532SDimitry Andric
59b1c73532SDimitry Andric const SCEV *MinusOne =
60ac9a064cSDimitry Andric SE.getMinusSCEV(BoundLG, SE.getOne(BoundLG->getType()));
61b1c73532SDimitry Andric
62ac9a064cSDimitry Andric return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, MinusOne) &&
63ac9a064cSDimitry Andric SE.isLoopEntryGuardedByCond(L, BoundPred, BoundLG, Limit);
64b1c73532SDimitry Andric }
65b1c73532SDimitry Andric
66b1c73532SDimitry Andric /// Given a loop with an increasing induction variable, is it possible to
67b1c73532SDimitry Andric /// safely calculate the bounds of a new loop using the given Predicate.
isSafeIncreasingBound(const SCEV * Start,const SCEV * BoundSCEV,const SCEV * Step,ICmpInst::Predicate Pred,unsigned LatchBrExitIdx,Loop * L,ScalarEvolution & SE)68b1c73532SDimitry Andric static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
69b1c73532SDimitry Andric const SCEV *Step, ICmpInst::Predicate Pred,
70b1c73532SDimitry Andric unsigned LatchBrExitIdx, Loop *L,
71b1c73532SDimitry Andric ScalarEvolution &SE) {
72b1c73532SDimitry Andric if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
73b1c73532SDimitry Andric Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
74b1c73532SDimitry Andric return false;
75b1c73532SDimitry Andric
76b1c73532SDimitry Andric if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
77b1c73532SDimitry Andric return false;
78b1c73532SDimitry Andric
79b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n");
80b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
81b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
82b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
83b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
84b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
85b1c73532SDimitry Andric
86b1c73532SDimitry Andric bool IsSigned = ICmpInst::isSigned(Pred);
87b1c73532SDimitry Andric // The predicate that we need to check that the induction variable lies
88b1c73532SDimitry Andric // within bounds.
89b1c73532SDimitry Andric ICmpInst::Predicate BoundPred =
90b1c73532SDimitry Andric IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
91b1c73532SDimitry Andric
92ac9a064cSDimitry Andric auto StartLG = SE.applyLoopGuards(Start, L);
93ac9a064cSDimitry Andric auto BoundLG = SE.applyLoopGuards(BoundSCEV, L);
94ac9a064cSDimitry Andric
95b1c73532SDimitry Andric if (LatchBrExitIdx == 1)
96ac9a064cSDimitry Andric return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, BoundLG);
97b1c73532SDimitry Andric
98b1c73532SDimitry Andric assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1");
99b1c73532SDimitry Andric
100b1c73532SDimitry Andric const SCEV *StepMinusOne = SE.getMinusSCEV(Step, SE.getOne(Step->getType()));
101b1c73532SDimitry Andric unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
102b1c73532SDimitry Andric APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth)
103b1c73532SDimitry Andric : APInt::getMaxValue(BitWidth);
104b1c73532SDimitry Andric const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne);
105b1c73532SDimitry Andric
106ac9a064cSDimitry Andric return (SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG,
107ac9a064cSDimitry Andric SE.getAddExpr(BoundLG, Step)) &&
108ac9a064cSDimitry Andric SE.isLoopEntryGuardedByCond(L, BoundPred, BoundLG, Limit));
109b1c73532SDimitry Andric }
110b1c73532SDimitry Andric
111b1c73532SDimitry Andric /// Returns estimate for max latch taken count of the loop of the narrowest
112b1c73532SDimitry Andric /// available type. If the latch block has such estimate, it is returned.
113b1c73532SDimitry Andric /// Otherwise, we use max exit count of whole loop (that is potentially of wider
114b1c73532SDimitry Andric /// type than latch check itself), which is still better than no estimate.
getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution & SE,const Loop & L)115b1c73532SDimitry Andric static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE,
116b1c73532SDimitry Andric const Loop &L) {
117b1c73532SDimitry Andric const SCEV *FromBlock =
118b1c73532SDimitry Andric SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum);
119b1c73532SDimitry Andric if (isa<SCEVCouldNotCompute>(FromBlock))
120b1c73532SDimitry Andric return SE.getSymbolicMaxBackedgeTakenCount(&L);
121b1c73532SDimitry Andric return FromBlock;
122b1c73532SDimitry Andric }
123b1c73532SDimitry Andric
124b1c73532SDimitry Andric std::optional<LoopStructure>
parseLoopStructure(ScalarEvolution & SE,Loop & L,bool AllowUnsignedLatchCond,const char * & FailureReason)125b1c73532SDimitry Andric LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
126b1c73532SDimitry Andric bool AllowUnsignedLatchCond,
127b1c73532SDimitry Andric const char *&FailureReason) {
128b1c73532SDimitry Andric if (!L.isLoopSimplifyForm()) {
129b1c73532SDimitry Andric FailureReason = "loop not in LoopSimplify form";
130b1c73532SDimitry Andric return std::nullopt;
131b1c73532SDimitry Andric }
132b1c73532SDimitry Andric
133b1c73532SDimitry Andric BasicBlock *Latch = L.getLoopLatch();
134b1c73532SDimitry Andric assert(Latch && "Simplified loops only have one latch!");
135b1c73532SDimitry Andric
136b1c73532SDimitry Andric if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) {
137b1c73532SDimitry Andric FailureReason = "loop has already been cloned";
138b1c73532SDimitry Andric return std::nullopt;
139b1c73532SDimitry Andric }
140b1c73532SDimitry Andric
141b1c73532SDimitry Andric if (!L.isLoopExiting(Latch)) {
142b1c73532SDimitry Andric FailureReason = "no loop latch";
143b1c73532SDimitry Andric return std::nullopt;
144b1c73532SDimitry Andric }
145b1c73532SDimitry Andric
146b1c73532SDimitry Andric BasicBlock *Header = L.getHeader();
147b1c73532SDimitry Andric BasicBlock *Preheader = L.getLoopPreheader();
148b1c73532SDimitry Andric if (!Preheader) {
149b1c73532SDimitry Andric FailureReason = "no preheader";
150b1c73532SDimitry Andric return std::nullopt;
151b1c73532SDimitry Andric }
152b1c73532SDimitry Andric
153b1c73532SDimitry Andric BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator());
154b1c73532SDimitry Andric if (!LatchBr || LatchBr->isUnconditional()) {
155b1c73532SDimitry Andric FailureReason = "latch terminator not conditional branch";
156b1c73532SDimitry Andric return std::nullopt;
157b1c73532SDimitry Andric }
158b1c73532SDimitry Andric
159b1c73532SDimitry Andric unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
160b1c73532SDimitry Andric
161b1c73532SDimitry Andric ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
162b1c73532SDimitry Andric if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
163b1c73532SDimitry Andric FailureReason = "latch terminator branch not conditional on integral icmp";
164b1c73532SDimitry Andric return std::nullopt;
165b1c73532SDimitry Andric }
166b1c73532SDimitry Andric
167b1c73532SDimitry Andric const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L);
168b1c73532SDimitry Andric if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) {
169b1c73532SDimitry Andric FailureReason = "could not compute latch count";
170b1c73532SDimitry Andric return std::nullopt;
171b1c73532SDimitry Andric }
172b1c73532SDimitry Andric assert(SE.getLoopDisposition(MaxBETakenCount, &L) ==
173b1c73532SDimitry Andric ScalarEvolution::LoopInvariant &&
174b1c73532SDimitry Andric "loop variant exit count doesn't make sense!");
175b1c73532SDimitry Andric
176b1c73532SDimitry Andric ICmpInst::Predicate Pred = ICI->getPredicate();
177b1c73532SDimitry Andric Value *LeftValue = ICI->getOperand(0);
178b1c73532SDimitry Andric const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
179b1c73532SDimitry Andric IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
180b1c73532SDimitry Andric
181b1c73532SDimitry Andric Value *RightValue = ICI->getOperand(1);
182b1c73532SDimitry Andric const SCEV *RightSCEV = SE.getSCEV(RightValue);
183b1c73532SDimitry Andric
184b1c73532SDimitry Andric // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
185b1c73532SDimitry Andric if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
186b1c73532SDimitry Andric if (isa<SCEVAddRecExpr>(RightSCEV)) {
187b1c73532SDimitry Andric std::swap(LeftSCEV, RightSCEV);
188b1c73532SDimitry Andric std::swap(LeftValue, RightValue);
189b1c73532SDimitry Andric Pred = ICmpInst::getSwappedPredicate(Pred);
190b1c73532SDimitry Andric } else {
191b1c73532SDimitry Andric FailureReason = "no add recurrences in the icmp";
192b1c73532SDimitry Andric return std::nullopt;
193b1c73532SDimitry Andric }
194b1c73532SDimitry Andric }
195b1c73532SDimitry Andric
196b1c73532SDimitry Andric auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
197b1c73532SDimitry Andric if (AR->getNoWrapFlags(SCEV::FlagNSW))
198b1c73532SDimitry Andric return true;
199b1c73532SDimitry Andric
200b1c73532SDimitry Andric IntegerType *Ty = cast<IntegerType>(AR->getType());
201b1c73532SDimitry Andric IntegerType *WideTy =
202b1c73532SDimitry Andric IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
203b1c73532SDimitry Andric
204b1c73532SDimitry Andric const SCEVAddRecExpr *ExtendAfterOp =
205b1c73532SDimitry Andric dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
206b1c73532SDimitry Andric if (ExtendAfterOp) {
207b1c73532SDimitry Andric const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
208b1c73532SDimitry Andric const SCEV *ExtendedStep =
209b1c73532SDimitry Andric SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
210b1c73532SDimitry Andric
211b1c73532SDimitry Andric bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
212b1c73532SDimitry Andric ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
213b1c73532SDimitry Andric
214b1c73532SDimitry Andric if (NoSignedWrap)
215b1c73532SDimitry Andric return true;
216b1c73532SDimitry Andric }
217b1c73532SDimitry Andric
218b1c73532SDimitry Andric // We may have proved this when computing the sign extension above.
219b1c73532SDimitry Andric return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
220b1c73532SDimitry Andric };
221b1c73532SDimitry Andric
222b1c73532SDimitry Andric // `ICI` is interpreted as taking the backedge if the *next* value of the
223b1c73532SDimitry Andric // induction variable satisfies some constraint.
224b1c73532SDimitry Andric
225b1c73532SDimitry Andric const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV);
226b1c73532SDimitry Andric if (IndVarBase->getLoop() != &L) {
227b1c73532SDimitry Andric FailureReason = "LHS in cmp is not an AddRec for this loop";
228b1c73532SDimitry Andric return std::nullopt;
229b1c73532SDimitry Andric }
230b1c73532SDimitry Andric if (!IndVarBase->isAffine()) {
231b1c73532SDimitry Andric FailureReason = "LHS in icmp not induction variable";
232b1c73532SDimitry Andric return std::nullopt;
233b1c73532SDimitry Andric }
234b1c73532SDimitry Andric const SCEV *StepRec = IndVarBase->getStepRecurrence(SE);
235b1c73532SDimitry Andric if (!isa<SCEVConstant>(StepRec)) {
236b1c73532SDimitry Andric FailureReason = "LHS in icmp not induction variable";
237b1c73532SDimitry Andric return std::nullopt;
238b1c73532SDimitry Andric }
239b1c73532SDimitry Andric ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue();
240b1c73532SDimitry Andric
241b1c73532SDimitry Andric if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) {
242b1c73532SDimitry Andric FailureReason = "LHS in icmp needs nsw for equality predicates";
243b1c73532SDimitry Andric return std::nullopt;
244b1c73532SDimitry Andric }
245b1c73532SDimitry Andric
246b1c73532SDimitry Andric assert(!StepCI->isZero() && "Zero step?");
247b1c73532SDimitry Andric bool IsIncreasing = !StepCI->isNegative();
248b1c73532SDimitry Andric bool IsSignedPredicate;
249b1c73532SDimitry Andric const SCEV *StartNext = IndVarBase->getStart();
250b1c73532SDimitry Andric const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE));
251b1c73532SDimitry Andric const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
252b1c73532SDimitry Andric const SCEV *Step = SE.getSCEV(StepCI);
253b1c73532SDimitry Andric
254b1c73532SDimitry Andric const SCEV *FixedRightSCEV = nullptr;
255b1c73532SDimitry Andric
256b1c73532SDimitry Andric // If RightValue resides within loop (but still being loop invariant),
257b1c73532SDimitry Andric // regenerate it as preheader.
258b1c73532SDimitry Andric if (auto *I = dyn_cast<Instruction>(RightValue))
259b1c73532SDimitry Andric if (L.contains(I->getParent()))
260b1c73532SDimitry Andric FixedRightSCEV = RightSCEV;
261b1c73532SDimitry Andric
262b1c73532SDimitry Andric if (IsIncreasing) {
263b1c73532SDimitry Andric bool DecreasedRightValueByOne = false;
264b1c73532SDimitry Andric if (StepCI->isOne()) {
265b1c73532SDimitry Andric // Try to turn eq/ne predicates to those we can work with.
266b1c73532SDimitry Andric if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
267b1c73532SDimitry Andric // while (++i != len) { while (++i < len) {
268b1c73532SDimitry Andric // ... ---> ...
269b1c73532SDimitry Andric // } }
270b1c73532SDimitry Andric // If both parts are known non-negative, it is profitable to use
271b1c73532SDimitry Andric // unsigned comparison in increasing loop. This allows us to make the
272b1c73532SDimitry Andric // comparison check against "RightSCEV + 1" more optimistic.
273b1c73532SDimitry Andric if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) &&
274b1c73532SDimitry Andric isKnownNonNegativeInLoop(RightSCEV, &L, SE))
275b1c73532SDimitry Andric Pred = ICmpInst::ICMP_ULT;
276b1c73532SDimitry Andric else
277b1c73532SDimitry Andric Pred = ICmpInst::ICMP_SLT;
278b1c73532SDimitry Andric else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
279b1c73532SDimitry Andric // while (true) { while (true) {
280b1c73532SDimitry Andric // if (++i == len) ---> if (++i > len - 1)
281b1c73532SDimitry Andric // break; break;
282b1c73532SDimitry Andric // ... ...
283b1c73532SDimitry Andric // } }
284b1c73532SDimitry Andric if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
285b1c73532SDimitry Andric cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ false)) {
286b1c73532SDimitry Andric Pred = ICmpInst::ICMP_UGT;
287b1c73532SDimitry Andric RightSCEV =
288b1c73532SDimitry Andric SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
289b1c73532SDimitry Andric DecreasedRightValueByOne = true;
290b1c73532SDimitry Andric } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ true)) {
291b1c73532SDimitry Andric Pred = ICmpInst::ICMP_SGT;
292b1c73532SDimitry Andric RightSCEV =
293b1c73532SDimitry Andric SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
294b1c73532SDimitry Andric DecreasedRightValueByOne = true;
295b1c73532SDimitry Andric }
296b1c73532SDimitry Andric }
297b1c73532SDimitry Andric }
298b1c73532SDimitry Andric
299b1c73532SDimitry Andric bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
300b1c73532SDimitry Andric bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
301b1c73532SDimitry Andric bool FoundExpectedPred =
302b1c73532SDimitry Andric (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0);
303b1c73532SDimitry Andric
304b1c73532SDimitry Andric if (!FoundExpectedPred) {
305b1c73532SDimitry Andric FailureReason = "expected icmp slt semantically, found something else";
306b1c73532SDimitry Andric return std::nullopt;
307b1c73532SDimitry Andric }
308b1c73532SDimitry Andric
309b1c73532SDimitry Andric IsSignedPredicate = ICmpInst::isSigned(Pred);
310b1c73532SDimitry Andric if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
311b1c73532SDimitry Andric FailureReason = "unsigned latch conditions are explicitly prohibited";
312b1c73532SDimitry Andric return std::nullopt;
313b1c73532SDimitry Andric }
314b1c73532SDimitry Andric
315b1c73532SDimitry Andric if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred,
316b1c73532SDimitry Andric LatchBrExitIdx, &L, SE)) {
317b1c73532SDimitry Andric FailureReason = "Unsafe loop bounds";
318b1c73532SDimitry Andric return std::nullopt;
319b1c73532SDimitry Andric }
320b1c73532SDimitry Andric if (LatchBrExitIdx == 0) {
321b1c73532SDimitry Andric // We need to increase the right value unless we have already decreased
322b1c73532SDimitry Andric // it virtually when we replaced EQ with SGT.
323b1c73532SDimitry Andric if (!DecreasedRightValueByOne)
324b1c73532SDimitry Andric FixedRightSCEV =
325b1c73532SDimitry Andric SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
326b1c73532SDimitry Andric } else {
327b1c73532SDimitry Andric assert(!DecreasedRightValueByOne &&
328b1c73532SDimitry Andric "Right value can be decreased only for LatchBrExitIdx == 0!");
329b1c73532SDimitry Andric }
330b1c73532SDimitry Andric } else {
331b1c73532SDimitry Andric bool IncreasedRightValueByOne = false;
332b1c73532SDimitry Andric if (StepCI->isMinusOne()) {
333b1c73532SDimitry Andric // Try to turn eq/ne predicates to those we can work with.
334b1c73532SDimitry Andric if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
335b1c73532SDimitry Andric // while (--i != len) { while (--i > len) {
336b1c73532SDimitry Andric // ... ---> ...
337b1c73532SDimitry Andric // } }
338b1c73532SDimitry Andric // We intentionally don't turn the predicate into UGT even if we know
339b1c73532SDimitry Andric // that both operands are non-negative, because it will only pessimize
340b1c73532SDimitry Andric // our check against "RightSCEV - 1".
341b1c73532SDimitry Andric Pred = ICmpInst::ICMP_SGT;
342b1c73532SDimitry Andric else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
343b1c73532SDimitry Andric // while (true) { while (true) {
344b1c73532SDimitry Andric // if (--i == len) ---> if (--i < len + 1)
345b1c73532SDimitry Andric // break; break;
346b1c73532SDimitry Andric // ... ...
347b1c73532SDimitry Andric // } }
348b1c73532SDimitry Andric if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
349b1c73532SDimitry Andric cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) {
350b1c73532SDimitry Andric Pred = ICmpInst::ICMP_ULT;
351b1c73532SDimitry Andric RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
352b1c73532SDimitry Andric IncreasedRightValueByOne = true;
353b1c73532SDimitry Andric } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) {
354b1c73532SDimitry Andric Pred = ICmpInst::ICMP_SLT;
355b1c73532SDimitry Andric RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
356b1c73532SDimitry Andric IncreasedRightValueByOne = true;
357b1c73532SDimitry Andric }
358b1c73532SDimitry Andric }
359b1c73532SDimitry Andric }
360b1c73532SDimitry Andric
361b1c73532SDimitry Andric bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
362b1c73532SDimitry Andric bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
363b1c73532SDimitry Andric
364b1c73532SDimitry Andric bool FoundExpectedPred =
365b1c73532SDimitry Andric (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0);
366b1c73532SDimitry Andric
367b1c73532SDimitry Andric if (!FoundExpectedPred) {
368b1c73532SDimitry Andric FailureReason = "expected icmp sgt semantically, found something else";
369b1c73532SDimitry Andric return std::nullopt;
370b1c73532SDimitry Andric }
371b1c73532SDimitry Andric
372b1c73532SDimitry Andric IsSignedPredicate =
373b1c73532SDimitry Andric Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT;
374b1c73532SDimitry Andric
375b1c73532SDimitry Andric if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
376b1c73532SDimitry Andric FailureReason = "unsigned latch conditions are explicitly prohibited";
377b1c73532SDimitry Andric return std::nullopt;
378b1c73532SDimitry Andric }
379b1c73532SDimitry Andric
380b1c73532SDimitry Andric if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred,
381b1c73532SDimitry Andric LatchBrExitIdx, &L, SE)) {
382b1c73532SDimitry Andric FailureReason = "Unsafe bounds";
383b1c73532SDimitry Andric return std::nullopt;
384b1c73532SDimitry Andric }
385b1c73532SDimitry Andric
386b1c73532SDimitry Andric if (LatchBrExitIdx == 0) {
387b1c73532SDimitry Andric // We need to decrease the right value unless we have already increased
388b1c73532SDimitry Andric // it virtually when we replaced EQ with SLT.
389b1c73532SDimitry Andric if (!IncreasedRightValueByOne)
390b1c73532SDimitry Andric FixedRightSCEV =
391b1c73532SDimitry Andric SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
392b1c73532SDimitry Andric } else {
393b1c73532SDimitry Andric assert(!IncreasedRightValueByOne &&
394b1c73532SDimitry Andric "Right value can be increased only for LatchBrExitIdx == 0!");
395b1c73532SDimitry Andric }
396b1c73532SDimitry Andric }
397b1c73532SDimitry Andric BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
398b1c73532SDimitry Andric
399b1c73532SDimitry Andric assert(!L.contains(LatchExit) && "expected an exit block!");
400ac9a064cSDimitry Andric const DataLayout &DL = Preheader->getDataLayout();
401b1c73532SDimitry Andric SCEVExpander Expander(SE, DL, "loop-constrainer");
402b1c73532SDimitry Andric Instruction *Ins = Preheader->getTerminator();
403b1c73532SDimitry Andric
404b1c73532SDimitry Andric if (FixedRightSCEV)
405b1c73532SDimitry Andric RightValue =
406b1c73532SDimitry Andric Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins);
407b1c73532SDimitry Andric
408b1c73532SDimitry Andric Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins);
409b1c73532SDimitry Andric IndVarStartV->setName("indvar.start");
410b1c73532SDimitry Andric
411b1c73532SDimitry Andric LoopStructure Result;
412b1c73532SDimitry Andric
413b1c73532SDimitry Andric Result.Tag = "main";
414b1c73532SDimitry Andric Result.Header = Header;
415b1c73532SDimitry Andric Result.Latch = Latch;
416b1c73532SDimitry Andric Result.LatchBr = LatchBr;
417b1c73532SDimitry Andric Result.LatchExit = LatchExit;
418b1c73532SDimitry Andric Result.LatchBrExitIdx = LatchBrExitIdx;
419b1c73532SDimitry Andric Result.IndVarStart = IndVarStartV;
420b1c73532SDimitry Andric Result.IndVarStep = StepCI;
421b1c73532SDimitry Andric Result.IndVarBase = LeftValue;
422b1c73532SDimitry Andric Result.IndVarIncreasing = IsIncreasing;
423b1c73532SDimitry Andric Result.LoopExitAt = RightValue;
424b1c73532SDimitry Andric Result.IsSignedPredicate = IsSignedPredicate;
425b1c73532SDimitry Andric Result.ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType());
426b1c73532SDimitry Andric
427b1c73532SDimitry Andric FailureReason = nullptr;
428b1c73532SDimitry Andric
429b1c73532SDimitry Andric return Result;
430b1c73532SDimitry Andric }
431b1c73532SDimitry Andric
432b1c73532SDimitry Andric // Add metadata to the loop L to disable loop optimizations. Callers need to
433b1c73532SDimitry Andric // confirm that optimizing loop L is not beneficial.
DisableAllLoopOptsOnLoop(Loop & L)434b1c73532SDimitry Andric static void DisableAllLoopOptsOnLoop(Loop &L) {
435b1c73532SDimitry Andric // We do not care about any existing loopID related metadata for L, since we
436b1c73532SDimitry Andric // are setting all loop metadata to false.
437b1c73532SDimitry Andric LLVMContext &Context = L.getHeader()->getContext();
438b1c73532SDimitry Andric // Reserve first location for self reference to the LoopID metadata node.
439b1c73532SDimitry Andric MDNode *Dummy = MDNode::get(Context, {});
440b1c73532SDimitry Andric MDNode *DisableUnroll = MDNode::get(
441b1c73532SDimitry Andric Context, {MDString::get(Context, "llvm.loop.unroll.disable")});
442b1c73532SDimitry Andric Metadata *FalseVal =
443b1c73532SDimitry Andric ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0));
444b1c73532SDimitry Andric MDNode *DisableVectorize = MDNode::get(
445b1c73532SDimitry Andric Context,
446b1c73532SDimitry Andric {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal});
447b1c73532SDimitry Andric MDNode *DisableLICMVersioning = MDNode::get(
448b1c73532SDimitry Andric Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")});
449b1c73532SDimitry Andric MDNode *DisableDistribution = MDNode::get(
450b1c73532SDimitry Andric Context,
451b1c73532SDimitry Andric {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal});
452b1c73532SDimitry Andric MDNode *NewLoopID =
453b1c73532SDimitry Andric MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize,
454b1c73532SDimitry Andric DisableLICMVersioning, DisableDistribution});
455b1c73532SDimitry Andric // Set operand 0 to refer to the loop id itself.
456b1c73532SDimitry Andric NewLoopID->replaceOperandWith(0, NewLoopID);
457b1c73532SDimitry Andric L.setLoopID(NewLoopID);
458b1c73532SDimitry Andric }
459b1c73532SDimitry Andric
LoopConstrainer(Loop & L,LoopInfo & LI,function_ref<void (Loop *,bool)> LPMAddNewLoop,const LoopStructure & LS,ScalarEvolution & SE,DominatorTree & DT,Type * T,SubRanges SR)460b1c73532SDimitry Andric LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI,
461b1c73532SDimitry Andric function_ref<void(Loop *, bool)> LPMAddNewLoop,
462b1c73532SDimitry Andric const LoopStructure &LS, ScalarEvolution &SE,
463b1c73532SDimitry Andric DominatorTree &DT, Type *T, SubRanges SR)
464b1c73532SDimitry Andric : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE),
465b1c73532SDimitry Andric DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T),
466b1c73532SDimitry Andric MainLoopStructure(LS), SR(SR) {}
467b1c73532SDimitry Andric
cloneLoop(LoopConstrainer::ClonedLoop & Result,const char * Tag) const468b1c73532SDimitry Andric void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
469b1c73532SDimitry Andric const char *Tag) const {
470b1c73532SDimitry Andric for (BasicBlock *BB : OriginalLoop.getBlocks()) {
471b1c73532SDimitry Andric BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);
472b1c73532SDimitry Andric Result.Blocks.push_back(Clone);
473b1c73532SDimitry Andric Result.Map[BB] = Clone;
474b1c73532SDimitry Andric }
475b1c73532SDimitry Andric
476b1c73532SDimitry Andric auto GetClonedValue = [&Result](Value *V) {
477b1c73532SDimitry Andric assert(V && "null values not in domain!");
478b1c73532SDimitry Andric auto It = Result.Map.find(V);
479b1c73532SDimitry Andric if (It == Result.Map.end())
480b1c73532SDimitry Andric return V;
481b1c73532SDimitry Andric return static_cast<Value *>(It->second);
482b1c73532SDimitry Andric };
483b1c73532SDimitry Andric
484b1c73532SDimitry Andric auto *ClonedLatch =
485b1c73532SDimitry Andric cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch()));
486b1c73532SDimitry Andric ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag,
487b1c73532SDimitry Andric MDNode::get(Ctx, {}));
488b1c73532SDimitry Andric
489b1c73532SDimitry Andric Result.Structure = MainLoopStructure.map(GetClonedValue);
490b1c73532SDimitry Andric Result.Structure.Tag = Tag;
491b1c73532SDimitry Andric
492b1c73532SDimitry Andric for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
493b1c73532SDimitry Andric BasicBlock *ClonedBB = Result.Blocks[i];
494b1c73532SDimitry Andric BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
495b1c73532SDimitry Andric
496b1c73532SDimitry Andric assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
497b1c73532SDimitry Andric
498b1c73532SDimitry Andric for (Instruction &I : *ClonedBB)
499b1c73532SDimitry Andric RemapInstruction(&I, Result.Map,
500b1c73532SDimitry Andric RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
501b1c73532SDimitry Andric
502b1c73532SDimitry Andric // Exit blocks will now have one more predecessor and their PHI nodes need
503b1c73532SDimitry Andric // to be edited to reflect that. No phi nodes need to be introduced because
504b1c73532SDimitry Andric // the loop is in LCSSA.
505b1c73532SDimitry Andric
506b1c73532SDimitry Andric for (auto *SBB : successors(OriginalBB)) {
507b1c73532SDimitry Andric if (OriginalLoop.contains(SBB))
508b1c73532SDimitry Andric continue; // not an exit block
509b1c73532SDimitry Andric
510b1c73532SDimitry Andric for (PHINode &PN : SBB->phis()) {
511b1c73532SDimitry Andric Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB);
512b1c73532SDimitry Andric PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB);
513b1c73532SDimitry Andric SE.forgetValue(&PN);
514b1c73532SDimitry Andric }
515b1c73532SDimitry Andric }
516b1c73532SDimitry Andric }
517b1c73532SDimitry Andric }
518b1c73532SDimitry Andric
changeIterationSpaceEnd(const LoopStructure & LS,BasicBlock * Preheader,Value * ExitSubloopAt,BasicBlock * ContinuationBlock) const519b1c73532SDimitry Andric LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
520b1c73532SDimitry Andric const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
521b1c73532SDimitry Andric BasicBlock *ContinuationBlock) const {
522b1c73532SDimitry Andric // We start with a loop with a single latch:
523b1c73532SDimitry Andric //
524b1c73532SDimitry Andric // +--------------------+
525b1c73532SDimitry Andric // | |
526b1c73532SDimitry Andric // | preheader |
527b1c73532SDimitry Andric // | |
528b1c73532SDimitry Andric // +--------+-----------+
529b1c73532SDimitry Andric // | ----------------\
530b1c73532SDimitry Andric // | / |
531b1c73532SDimitry Andric // +--------v----v------+ |
532b1c73532SDimitry Andric // | | |
533b1c73532SDimitry Andric // | header | |
534b1c73532SDimitry Andric // | | |
535b1c73532SDimitry Andric // +--------------------+ |
536b1c73532SDimitry Andric // |
537b1c73532SDimitry Andric // ..... |
538b1c73532SDimitry Andric // |
539b1c73532SDimitry Andric // +--------------------+ |
540b1c73532SDimitry Andric // | | |
541b1c73532SDimitry Andric // | latch >----------/
542b1c73532SDimitry Andric // | |
543b1c73532SDimitry Andric // +-------v------------+
544b1c73532SDimitry Andric // |
545b1c73532SDimitry Andric // |
546b1c73532SDimitry Andric // | +--------------------+
547b1c73532SDimitry Andric // | | |
548b1c73532SDimitry Andric // +---> original exit |
549b1c73532SDimitry Andric // | |
550b1c73532SDimitry Andric // +--------------------+
551b1c73532SDimitry Andric //
552b1c73532SDimitry Andric // We change the control flow to look like
553b1c73532SDimitry Andric //
554b1c73532SDimitry Andric //
555b1c73532SDimitry Andric // +--------------------+
556b1c73532SDimitry Andric // | |
557b1c73532SDimitry Andric // | preheader >-------------------------+
558b1c73532SDimitry Andric // | | |
559b1c73532SDimitry Andric // +--------v-----------+ |
560b1c73532SDimitry Andric // | /-------------+ |
561b1c73532SDimitry Andric // | / | |
562b1c73532SDimitry Andric // +--------v--v--------+ | |
563b1c73532SDimitry Andric // | | | |
564b1c73532SDimitry Andric // | header | | +--------+ |
565b1c73532SDimitry Andric // | | | | | |
566b1c73532SDimitry Andric // +--------------------+ | | +-----v-----v-----------+
567b1c73532SDimitry Andric // | | | |
568b1c73532SDimitry Andric // | | | .pseudo.exit |
569b1c73532SDimitry Andric // | | | |
570b1c73532SDimitry Andric // | | +-----------v-----------+
571b1c73532SDimitry Andric // | | |
572b1c73532SDimitry Andric // ..... | | |
573b1c73532SDimitry Andric // | | +--------v-------------+
574b1c73532SDimitry Andric // +--------------------+ | | | |
575b1c73532SDimitry Andric // | | | | | ContinuationBlock |
576b1c73532SDimitry Andric // | latch >------+ | | |
577b1c73532SDimitry Andric // | | | +----------------------+
578b1c73532SDimitry Andric // +---------v----------+ |
579b1c73532SDimitry Andric // | |
580b1c73532SDimitry Andric // | |
581b1c73532SDimitry Andric // | +---------------^-----+
582b1c73532SDimitry Andric // | | |
583b1c73532SDimitry Andric // +-----> .exit.selector |
584b1c73532SDimitry Andric // | |
585b1c73532SDimitry Andric // +----------v----------+
586b1c73532SDimitry Andric // |
587b1c73532SDimitry Andric // +--------------------+ |
588b1c73532SDimitry Andric // | | |
589b1c73532SDimitry Andric // | original exit <----+
590b1c73532SDimitry Andric // | |
591b1c73532SDimitry Andric // +--------------------+
592b1c73532SDimitry Andric
593b1c73532SDimitry Andric RewrittenRangeInfo RRI;
594b1c73532SDimitry Andric
595b1c73532SDimitry Andric BasicBlock *BBInsertLocation = LS.Latch->getNextNode();
596b1c73532SDimitry Andric RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",
597b1c73532SDimitry Andric &F, BBInsertLocation);
598b1c73532SDimitry Andric RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,
599b1c73532SDimitry Andric BBInsertLocation);
600b1c73532SDimitry Andric
601b1c73532SDimitry Andric BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator());
602b1c73532SDimitry Andric bool Increasing = LS.IndVarIncreasing;
603b1c73532SDimitry Andric bool IsSignedPredicate = LS.IsSignedPredicate;
604b1c73532SDimitry Andric
605b1c73532SDimitry Andric IRBuilder<> B(PreheaderJump);
606b1c73532SDimitry Andric auto NoopOrExt = [&](Value *V) {
607b1c73532SDimitry Andric if (V->getType() == RangeTy)
608b1c73532SDimitry Andric return V;
609b1c73532SDimitry Andric return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName())
610b1c73532SDimitry Andric : B.CreateZExt(V, RangeTy, "wide." + V->getName());
611b1c73532SDimitry Andric };
612b1c73532SDimitry Andric
613b1c73532SDimitry Andric // EnterLoopCond - is it okay to start executing this `LS'?
614b1c73532SDimitry Andric Value *EnterLoopCond = nullptr;
615b1c73532SDimitry Andric auto Pred =
616b1c73532SDimitry Andric Increasing
617b1c73532SDimitry Andric ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)
618b1c73532SDimitry Andric : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT);
619b1c73532SDimitry Andric Value *IndVarStart = NoopOrExt(LS.IndVarStart);
620b1c73532SDimitry Andric EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt);
621b1c73532SDimitry Andric
622b1c73532SDimitry Andric B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
623b1c73532SDimitry Andric PreheaderJump->eraseFromParent();
624b1c73532SDimitry Andric
625b1c73532SDimitry Andric LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
626b1c73532SDimitry Andric B.SetInsertPoint(LS.LatchBr);
627b1c73532SDimitry Andric Value *IndVarBase = NoopOrExt(LS.IndVarBase);
628b1c73532SDimitry Andric Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt);
629b1c73532SDimitry Andric
630b1c73532SDimitry Andric Value *CondForBranch = LS.LatchBrExitIdx == 1
631b1c73532SDimitry Andric ? TakeBackedgeLoopCond
632b1c73532SDimitry Andric : B.CreateNot(TakeBackedgeLoopCond);
633b1c73532SDimitry Andric
634b1c73532SDimitry Andric LS.LatchBr->setCondition(CondForBranch);
635b1c73532SDimitry Andric
636b1c73532SDimitry Andric B.SetInsertPoint(RRI.ExitSelector);
637b1c73532SDimitry Andric
638b1c73532SDimitry Andric // IterationsLeft - are there any more iterations left, given the original
639b1c73532SDimitry Andric // upper bound on the induction variable? If not, we branch to the "real"
640b1c73532SDimitry Andric // exit.
641b1c73532SDimitry Andric Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);
642b1c73532SDimitry Andric Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt);
643b1c73532SDimitry Andric B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
644b1c73532SDimitry Andric
645b1c73532SDimitry Andric BranchInst *BranchToContinuation =
646b1c73532SDimitry Andric BranchInst::Create(ContinuationBlock, RRI.PseudoExit);
647b1c73532SDimitry Andric
648b1c73532SDimitry Andric // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
649b1c73532SDimitry Andric // each of the PHI nodes in the loop header. This feeds into the initial
650b1c73532SDimitry Andric // value of the same PHI nodes if/when we continue execution.
651b1c73532SDimitry Andric for (PHINode &PN : LS.Header->phis()) {
652b1c73532SDimitry Andric PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy",
653ac9a064cSDimitry Andric BranchToContinuation->getIterator());
654b1c73532SDimitry Andric
655b1c73532SDimitry Andric NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader);
656b1c73532SDimitry Andric NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch),
657b1c73532SDimitry Andric RRI.ExitSelector);
658b1c73532SDimitry Andric RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
659b1c73532SDimitry Andric }
660b1c73532SDimitry Andric
661b1c73532SDimitry Andric RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end",
662ac9a064cSDimitry Andric BranchToContinuation->getIterator());
663b1c73532SDimitry Andric RRI.IndVarEnd->addIncoming(IndVarStart, Preheader);
664b1c73532SDimitry Andric RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector);
665b1c73532SDimitry Andric
666b1c73532SDimitry Andric // The latch exit now has a branch from `RRI.ExitSelector' instead of
667b1c73532SDimitry Andric // `LS.Latch'. The PHI nodes need to be updated to reflect that.
668b1c73532SDimitry Andric LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector);
669b1c73532SDimitry Andric
670b1c73532SDimitry Andric return RRI;
671b1c73532SDimitry Andric }
672b1c73532SDimitry Andric
rewriteIncomingValuesForPHIs(LoopStructure & LS,BasicBlock * ContinuationBlock,const LoopConstrainer::RewrittenRangeInfo & RRI) const673b1c73532SDimitry Andric void LoopConstrainer::rewriteIncomingValuesForPHIs(
674b1c73532SDimitry Andric LoopStructure &LS, BasicBlock *ContinuationBlock,
675b1c73532SDimitry Andric const LoopConstrainer::RewrittenRangeInfo &RRI) const {
676b1c73532SDimitry Andric unsigned PHIIndex = 0;
677b1c73532SDimitry Andric for (PHINode &PN : LS.Header->phis())
678b1c73532SDimitry Andric PN.setIncomingValueForBlock(ContinuationBlock,
679b1c73532SDimitry Andric RRI.PHIValuesAtPseudoExit[PHIIndex++]);
680b1c73532SDimitry Andric
681b1c73532SDimitry Andric LS.IndVarStart = RRI.IndVarEnd;
682b1c73532SDimitry Andric }
683b1c73532SDimitry Andric
createPreheader(const LoopStructure & LS,BasicBlock * OldPreheader,const char * Tag) const684b1c73532SDimitry Andric BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
685b1c73532SDimitry Andric BasicBlock *OldPreheader,
686b1c73532SDimitry Andric const char *Tag) const {
687b1c73532SDimitry Andric BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
688b1c73532SDimitry Andric BranchInst::Create(LS.Header, Preheader);
689b1c73532SDimitry Andric
690b1c73532SDimitry Andric LS.Header->replacePhiUsesWith(OldPreheader, Preheader);
691b1c73532SDimitry Andric
692b1c73532SDimitry Andric return Preheader;
693b1c73532SDimitry Andric }
694b1c73532SDimitry Andric
addToParentLoopIfNeeded(ArrayRef<BasicBlock * > BBs)695b1c73532SDimitry Andric void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
696b1c73532SDimitry Andric Loop *ParentLoop = OriginalLoop.getParentLoop();
697b1c73532SDimitry Andric if (!ParentLoop)
698b1c73532SDimitry Andric return;
699b1c73532SDimitry Andric
700b1c73532SDimitry Andric for (BasicBlock *BB : BBs)
701b1c73532SDimitry Andric ParentLoop->addBasicBlockToLoop(BB, LI);
702b1c73532SDimitry Andric }
703b1c73532SDimitry Andric
createClonedLoopStructure(Loop * Original,Loop * Parent,ValueToValueMapTy & VM,bool IsSubloop)704b1c73532SDimitry Andric Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,
705b1c73532SDimitry Andric ValueToValueMapTy &VM,
706b1c73532SDimitry Andric bool IsSubloop) {
707b1c73532SDimitry Andric Loop &New = *LI.AllocateLoop();
708b1c73532SDimitry Andric if (Parent)
709b1c73532SDimitry Andric Parent->addChildLoop(&New);
710b1c73532SDimitry Andric else
711b1c73532SDimitry Andric LI.addTopLevelLoop(&New);
712b1c73532SDimitry Andric LPMAddNewLoop(&New, IsSubloop);
713b1c73532SDimitry Andric
714b1c73532SDimitry Andric // Add all of the blocks in Original to the new loop.
715b1c73532SDimitry Andric for (auto *BB : Original->blocks())
716b1c73532SDimitry Andric if (LI.getLoopFor(BB) == Original)
717b1c73532SDimitry Andric New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI);
718b1c73532SDimitry Andric
719b1c73532SDimitry Andric // Add all of the subloops to the new loop.
720b1c73532SDimitry Andric for (Loop *SubLoop : *Original)
721b1c73532SDimitry Andric createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true);
722b1c73532SDimitry Andric
723b1c73532SDimitry Andric return &New;
724b1c73532SDimitry Andric }
725b1c73532SDimitry Andric
run()726b1c73532SDimitry Andric bool LoopConstrainer::run() {
727b1c73532SDimitry Andric BasicBlock *Preheader = OriginalLoop.getLoopPreheader();
728b1c73532SDimitry Andric assert(Preheader != nullptr && "precondition!");
729b1c73532SDimitry Andric
730b1c73532SDimitry Andric OriginalPreheader = Preheader;
731b1c73532SDimitry Andric MainLoopPreheader = Preheader;
732b1c73532SDimitry Andric bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;
733b1c73532SDimitry Andric bool Increasing = MainLoopStructure.IndVarIncreasing;
734b1c73532SDimitry Andric IntegerType *IVTy = cast<IntegerType>(RangeTy);
735b1c73532SDimitry Andric
736ac9a064cSDimitry Andric SCEVExpander Expander(SE, F.getDataLayout(), "loop-constrainer");
737b1c73532SDimitry Andric Instruction *InsertPt = OriginalPreheader->getTerminator();
738b1c73532SDimitry Andric
739b1c73532SDimitry Andric // It would have been better to make `PreLoop' and `PostLoop'
740b1c73532SDimitry Andric // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
741b1c73532SDimitry Andric // constructor.
742b1c73532SDimitry Andric ClonedLoop PreLoop, PostLoop;
743b1c73532SDimitry Andric bool NeedsPreLoop =
744b1c73532SDimitry Andric Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value();
745b1c73532SDimitry Andric bool NeedsPostLoop =
746b1c73532SDimitry Andric Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value();
747b1c73532SDimitry Andric
748b1c73532SDimitry Andric Value *ExitPreLoopAt = nullptr;
749b1c73532SDimitry Andric Value *ExitMainLoopAt = nullptr;
750b1c73532SDimitry Andric const SCEVConstant *MinusOneS =
751b1c73532SDimitry Andric cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
752b1c73532SDimitry Andric
753b1c73532SDimitry Andric if (NeedsPreLoop) {
754b1c73532SDimitry Andric const SCEV *ExitPreLoopAtSCEV = nullptr;
755b1c73532SDimitry Andric
756b1c73532SDimitry Andric if (Increasing)
757b1c73532SDimitry Andric ExitPreLoopAtSCEV = *SR.LowLimit;
758b1c73532SDimitry Andric else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE,
759b1c73532SDimitry Andric IsSignedPredicate))
760b1c73532SDimitry Andric ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
761b1c73532SDimitry Andric else {
762b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
763b1c73532SDimitry Andric << "preloop exit limit. HighLimit = "
764b1c73532SDimitry Andric << *(*SR.HighLimit) << "\n");
765b1c73532SDimitry Andric return false;
766b1c73532SDimitry Andric }
767b1c73532SDimitry Andric
768b1c73532SDimitry Andric if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) {
769b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
770b1c73532SDimitry Andric << " preloop exit limit " << *ExitPreLoopAtSCEV
771b1c73532SDimitry Andric << " at block " << InsertPt->getParent()->getName()
772b1c73532SDimitry Andric << "\n");
773b1c73532SDimitry Andric return false;
774b1c73532SDimitry Andric }
775b1c73532SDimitry Andric
776b1c73532SDimitry Andric ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
777b1c73532SDimitry Andric ExitPreLoopAt->setName("exit.preloop.at");
778b1c73532SDimitry Andric }
779b1c73532SDimitry Andric
780b1c73532SDimitry Andric if (NeedsPostLoop) {
781b1c73532SDimitry Andric const SCEV *ExitMainLoopAtSCEV = nullptr;
782b1c73532SDimitry Andric
783b1c73532SDimitry Andric if (Increasing)
784b1c73532SDimitry Andric ExitMainLoopAtSCEV = *SR.HighLimit;
785b1c73532SDimitry Andric else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE,
786b1c73532SDimitry Andric IsSignedPredicate))
787b1c73532SDimitry Andric ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
788b1c73532SDimitry Andric else {
789b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
790b1c73532SDimitry Andric << "mainloop exit limit. LowLimit = "
791b1c73532SDimitry Andric << *(*SR.LowLimit) << "\n");
792b1c73532SDimitry Andric return false;
793b1c73532SDimitry Andric }
794b1c73532SDimitry Andric
795b1c73532SDimitry Andric if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) {
796b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
797b1c73532SDimitry Andric << " main loop exit limit " << *ExitMainLoopAtSCEV
798b1c73532SDimitry Andric << " at block " << InsertPt->getParent()->getName()
799b1c73532SDimitry Andric << "\n");
800b1c73532SDimitry Andric return false;
801b1c73532SDimitry Andric }
802b1c73532SDimitry Andric
803b1c73532SDimitry Andric ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
804b1c73532SDimitry Andric ExitMainLoopAt->setName("exit.mainloop.at");
805b1c73532SDimitry Andric }
806b1c73532SDimitry Andric
807b1c73532SDimitry Andric // We clone these ahead of time so that we don't have to deal with changing
808b1c73532SDimitry Andric // and temporarily invalid IR as we transform the loops.
809b1c73532SDimitry Andric if (NeedsPreLoop)
810b1c73532SDimitry Andric cloneLoop(PreLoop, "preloop");
811b1c73532SDimitry Andric if (NeedsPostLoop)
812b1c73532SDimitry Andric cloneLoop(PostLoop, "postloop");
813b1c73532SDimitry Andric
814b1c73532SDimitry Andric RewrittenRangeInfo PreLoopRRI;
815b1c73532SDimitry Andric
816b1c73532SDimitry Andric if (NeedsPreLoop) {
817b1c73532SDimitry Andric Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,
818b1c73532SDimitry Andric PreLoop.Structure.Header);
819b1c73532SDimitry Andric
820b1c73532SDimitry Andric MainLoopPreheader =
821b1c73532SDimitry Andric createPreheader(MainLoopStructure, Preheader, "mainloop");
822b1c73532SDimitry Andric PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
823b1c73532SDimitry Andric ExitPreLoopAt, MainLoopPreheader);
824b1c73532SDimitry Andric rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
825b1c73532SDimitry Andric PreLoopRRI);
826b1c73532SDimitry Andric }
827b1c73532SDimitry Andric
828b1c73532SDimitry Andric BasicBlock *PostLoopPreheader = nullptr;
829b1c73532SDimitry Andric RewrittenRangeInfo PostLoopRRI;
830b1c73532SDimitry Andric
831b1c73532SDimitry Andric if (NeedsPostLoop) {
832b1c73532SDimitry Andric PostLoopPreheader =
833b1c73532SDimitry Andric createPreheader(PostLoop.Structure, Preheader, "postloop");
834b1c73532SDimitry Andric PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
835b1c73532SDimitry Andric ExitMainLoopAt, PostLoopPreheader);
836b1c73532SDimitry Andric rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
837b1c73532SDimitry Andric PostLoopRRI);
838b1c73532SDimitry Andric }
839b1c73532SDimitry Andric
840b1c73532SDimitry Andric BasicBlock *NewMainLoopPreheader =
841b1c73532SDimitry Andric MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
842b1c73532SDimitry Andric BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit,
843b1c73532SDimitry Andric PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit,
844b1c73532SDimitry Andric PostLoopRRI.ExitSelector, NewMainLoopPreheader};
845b1c73532SDimitry Andric
846b1c73532SDimitry Andric // Some of the above may be nullptr, filter them out before passing to
847b1c73532SDimitry Andric // addToParentLoopIfNeeded.
848b1c73532SDimitry Andric auto NewBlocksEnd =
849b1c73532SDimitry Andric std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);
850b1c73532SDimitry Andric
851b1c73532SDimitry Andric addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd));
852b1c73532SDimitry Andric
853b1c73532SDimitry Andric DT.recalculate(F);
854b1c73532SDimitry Andric
855b1c73532SDimitry Andric // We need to first add all the pre and post loop blocks into the loop
856b1c73532SDimitry Andric // structures (as part of createClonedLoopStructure), and then update the
857b1c73532SDimitry Andric // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating
858b1c73532SDimitry Andric // LI when LoopSimplifyForm is generated.
859b1c73532SDimitry Andric Loop *PreL = nullptr, *PostL = nullptr;
860b1c73532SDimitry Andric if (!PreLoop.Blocks.empty()) {
861b1c73532SDimitry Andric PreL = createClonedLoopStructure(&OriginalLoop,
862b1c73532SDimitry Andric OriginalLoop.getParentLoop(), PreLoop.Map,
863b1c73532SDimitry Andric /* IsSubLoop */ false);
864b1c73532SDimitry Andric }
865b1c73532SDimitry Andric
866b1c73532SDimitry Andric if (!PostLoop.Blocks.empty()) {
867b1c73532SDimitry Andric PostL =
868b1c73532SDimitry Andric createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(),
869b1c73532SDimitry Andric PostLoop.Map, /* IsSubLoop */ false);
870b1c73532SDimitry Andric }
871b1c73532SDimitry Andric
872b1c73532SDimitry Andric // This function canonicalizes the loop into Loop-Simplify and LCSSA forms.
873b1c73532SDimitry Andric auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) {
874b1c73532SDimitry Andric formLCSSARecursively(*L, DT, &LI, &SE);
875b1c73532SDimitry Andric simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true);
876b1c73532SDimitry Andric // Pre/post loops are slow paths, we do not need to perform any loop
877b1c73532SDimitry Andric // optimizations on them.
878b1c73532SDimitry Andric if (!IsOriginalLoop)
879b1c73532SDimitry Andric DisableAllLoopOptsOnLoop(*L);
880b1c73532SDimitry Andric };
881b1c73532SDimitry Andric if (PreL)
882b1c73532SDimitry Andric CanonicalizeLoop(PreL, false);
883b1c73532SDimitry Andric if (PostL)
884b1c73532SDimitry Andric CanonicalizeLoop(PostL, false);
885b1c73532SDimitry Andric CanonicalizeLoop(&OriginalLoop, true);
886b1c73532SDimitry Andric
887b1c73532SDimitry Andric /// At this point:
888b1c73532SDimitry Andric /// - We've broken a "main loop" out of the loop in a way that the "main loop"
889b1c73532SDimitry Andric /// runs with the induction variable in a subset of [Begin, End).
890b1c73532SDimitry Andric /// - There is no overflow when computing "main loop" exit limit.
891b1c73532SDimitry Andric /// - Max latch taken count of the loop is limited.
892b1c73532SDimitry Andric /// It guarantees that induction variable will not overflow iterating in the
893b1c73532SDimitry Andric /// "main loop".
894b1c73532SDimitry Andric if (isa<OverflowingBinaryOperator>(MainLoopStructure.IndVarBase))
895b1c73532SDimitry Andric if (IsSignedPredicate)
896b1c73532SDimitry Andric cast<BinaryOperator>(MainLoopStructure.IndVarBase)
897b1c73532SDimitry Andric ->setHasNoSignedWrap(true);
898b1c73532SDimitry Andric /// TODO: support unsigned predicate.
899b1c73532SDimitry Andric /// To add NUW flag we need to prove that both operands of BO are
900b1c73532SDimitry Andric /// non-negative. E.g:
901b1c73532SDimitry Andric /// ...
902b1c73532SDimitry Andric /// %iv.next = add nsw i32 %iv, -1
903b1c73532SDimitry Andric /// %cmp = icmp ult i32 %iv.next, %n
904b1c73532SDimitry Andric /// br i1 %cmp, label %loopexit, label %loop
905b1c73532SDimitry Andric ///
906b1c73532SDimitry Andric /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will
907b1c73532SDimitry Andric /// overflow, therefore NUW flag is not legal here.
908b1c73532SDimitry Andric
909b1c73532SDimitry Andric return true;
910b1c73532SDimitry Andric }
911