1cfca06d7SDimitry Andric //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2cfca06d7SDimitry Andric //
3cfca06d7SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4cfca06d7SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5cfca06d7SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6cfca06d7SDimitry Andric //
7cfca06d7SDimitry Andric //===----------------------------------------------------------------------===//
8cfca06d7SDimitry Andric //
9cfca06d7SDimitry Andric // This pass optimizes scalar/vector interactions using target cost models. The
10cfca06d7SDimitry Andric // transforms implemented here may not fit in traditional loop-based or SLP
11cfca06d7SDimitry Andric // vectorization passes.
12cfca06d7SDimitry Andric //
13cfca06d7SDimitry Andric //===----------------------------------------------------------------------===//
14cfca06d7SDimitry Andric
15cfca06d7SDimitry Andric #include "llvm/Transforms/Vectorize/VectorCombine.h"
16b1c73532SDimitry Andric #include "llvm/ADT/DenseMap.h"
17ac9a064cSDimitry Andric #include "llvm/ADT/STLExtras.h"
18b1c73532SDimitry Andric #include "llvm/ADT/ScopeExit.h"
19cfca06d7SDimitry Andric #include "llvm/ADT/Statistic.h"
20344a3780SDimitry Andric #include "llvm/Analysis/AssumptionCache.h"
21cfca06d7SDimitry Andric #include "llvm/Analysis/BasicAliasAnalysis.h"
22cfca06d7SDimitry Andric #include "llvm/Analysis/GlobalsModRef.h"
23b60736ecSDimitry Andric #include "llvm/Analysis/Loads.h"
24cfca06d7SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
25cfca06d7SDimitry Andric #include "llvm/Analysis/ValueTracking.h"
26cfca06d7SDimitry Andric #include "llvm/Analysis/VectorUtils.h"
27cfca06d7SDimitry Andric #include "llvm/IR/Dominators.h"
28cfca06d7SDimitry Andric #include "llvm/IR/Function.h"
29cfca06d7SDimitry Andric #include "llvm/IR/IRBuilder.h"
30cfca06d7SDimitry Andric #include "llvm/IR/PatternMatch.h"
31cfca06d7SDimitry Andric #include "llvm/Support/CommandLine.h"
32cfca06d7SDimitry Andric #include "llvm/Transforms/Utils/Local.h"
33ac9a064cSDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h"
34e3b55780SDimitry Andric #include <numeric>
35b1c73532SDimitry Andric #include <queue>
36cfca06d7SDimitry Andric
37c0981da4SDimitry Andric #define DEBUG_TYPE "vector-combine"
38c0981da4SDimitry Andric #include "llvm/Transforms/Utils/InstructionWorklist.h"
39c0981da4SDimitry Andric
40cfca06d7SDimitry Andric using namespace llvm;
41cfca06d7SDimitry Andric using namespace llvm::PatternMatch;
42cfca06d7SDimitry Andric
43b60736ecSDimitry Andric STATISTIC(NumVecLoad, "Number of vector loads formed");
44cfca06d7SDimitry Andric STATISTIC(NumVecCmp, "Number of vector compares formed");
45cfca06d7SDimitry Andric STATISTIC(NumVecBO, "Number of vector binops formed");
46cfca06d7SDimitry Andric STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
47cfca06d7SDimitry Andric STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
48cfca06d7SDimitry Andric STATISTIC(NumScalarBO, "Number of scalar binops formed");
49cfca06d7SDimitry Andric STATISTIC(NumScalarCmp, "Number of scalar compares formed");
50cfca06d7SDimitry Andric
51cfca06d7SDimitry Andric static cl::opt<bool> DisableVectorCombine(
52cfca06d7SDimitry Andric "disable-vector-combine", cl::init(false), cl::Hidden,
53cfca06d7SDimitry Andric cl::desc("Disable all vector combine transforms"));
54cfca06d7SDimitry Andric
55cfca06d7SDimitry Andric static cl::opt<bool> DisableBinopExtractShuffle(
56cfca06d7SDimitry Andric "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
57cfca06d7SDimitry Andric cl::desc("Disable binop extract to shuffle transforms"));
58cfca06d7SDimitry Andric
59344a3780SDimitry Andric static cl::opt<unsigned> MaxInstrsToScan(
60344a3780SDimitry Andric "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden,
61344a3780SDimitry Andric cl::desc("Max number of instructions to scan for vector combining."));
62344a3780SDimitry Andric
63cfca06d7SDimitry Andric static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
64cfca06d7SDimitry Andric
65cfca06d7SDimitry Andric namespace {
66cfca06d7SDimitry Andric class VectorCombine {
67cfca06d7SDimitry Andric public:
VectorCombine(Function & F,const TargetTransformInfo & TTI,const DominatorTree & DT,AAResults & AA,AssumptionCache & AC,const DataLayout * DL,bool TryEarlyFoldsOnly)68cfca06d7SDimitry Andric VectorCombine(Function &F, const TargetTransformInfo &TTI,
69c0981da4SDimitry Andric const DominatorTree &DT, AAResults &AA, AssumptionCache &AC,
70ac9a064cSDimitry Andric const DataLayout *DL, bool TryEarlyFoldsOnly)
71ac9a064cSDimitry Andric : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC), DL(DL),
72e3b55780SDimitry Andric TryEarlyFoldsOnly(TryEarlyFoldsOnly) {}
73cfca06d7SDimitry Andric
74cfca06d7SDimitry Andric bool run();
75cfca06d7SDimitry Andric
76cfca06d7SDimitry Andric private:
77cfca06d7SDimitry Andric Function &F;
78cfca06d7SDimitry Andric IRBuilder<> Builder;
79cfca06d7SDimitry Andric const TargetTransformInfo &TTI;
80cfca06d7SDimitry Andric const DominatorTree &DT;
81344a3780SDimitry Andric AAResults &AA;
82344a3780SDimitry Andric AssumptionCache &AC;
83ac9a064cSDimitry Andric const DataLayout *DL;
84cfca06d7SDimitry Andric
85e3b55780SDimitry Andric /// If true, only perform beneficial early IR transforms. Do not introduce new
86c0981da4SDimitry Andric /// vector operations.
87e3b55780SDimitry Andric bool TryEarlyFoldsOnly;
88c0981da4SDimitry Andric
89c0981da4SDimitry Andric InstructionWorklist Worklist;
90c0981da4SDimitry Andric
91e3b55780SDimitry Andric // TODO: Direct calls from the top-level "run" loop use a plain "Instruction"
92e3b55780SDimitry Andric // parameter. That should be updated to specific sub-classes because the
93e3b55780SDimitry Andric // run loop was changed to dispatch on opcode.
94b60736ecSDimitry Andric bool vectorizeLoadInsert(Instruction &I);
95e3b55780SDimitry Andric bool widenSubvectorLoad(Instruction &I);
96cfca06d7SDimitry Andric ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
97cfca06d7SDimitry Andric ExtractElementInst *Ext1,
98cfca06d7SDimitry Andric unsigned PreferredExtractIndex) const;
99cfca06d7SDimitry Andric bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
100c0981da4SDimitry Andric const Instruction &I,
101cfca06d7SDimitry Andric ExtractElementInst *&ConvertToShuffle,
102cfca06d7SDimitry Andric unsigned PreferredExtractIndex);
103cfca06d7SDimitry Andric void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
104cfca06d7SDimitry Andric Instruction &I);
105cfca06d7SDimitry Andric void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
106cfca06d7SDimitry Andric Instruction &I);
107cfca06d7SDimitry Andric bool foldExtractExtract(Instruction &I);
108e3b55780SDimitry Andric bool foldInsExtFNeg(Instruction &I);
109b1c73532SDimitry Andric bool foldBitcastShuffle(Instruction &I);
110cfca06d7SDimitry Andric bool scalarizeBinopOrCmp(Instruction &I);
111b1c73532SDimitry Andric bool scalarizeVPIntrinsic(Instruction &I);
112cfca06d7SDimitry Andric bool foldExtractedCmps(Instruction &I);
113344a3780SDimitry Andric bool foldSingleElementStore(Instruction &I);
114344a3780SDimitry Andric bool scalarizeLoadExtract(Instruction &I);
115c0981da4SDimitry Andric bool foldShuffleOfBinops(Instruction &I);
116ac9a064cSDimitry Andric bool foldShuffleOfCastops(Instruction &I);
117ac9a064cSDimitry Andric bool foldShuffleOfShuffles(Instruction &I);
118ac9a064cSDimitry Andric bool foldShuffleToIdentity(Instruction &I);
119145449b1SDimitry Andric bool foldShuffleFromReductions(Instruction &I);
120ac9a064cSDimitry Andric bool foldCastFromReductions(Instruction &I);
121145449b1SDimitry Andric bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
122cfca06d7SDimitry Andric
replaceValue(Value & Old,Value & New)123c0981da4SDimitry Andric void replaceValue(Value &Old, Value &New) {
124cfca06d7SDimitry Andric Old.replaceAllUsesWith(&New);
125c0981da4SDimitry Andric if (auto *NewI = dyn_cast<Instruction>(&New)) {
126145449b1SDimitry Andric New.takeName(&Old);
127c0981da4SDimitry Andric Worklist.pushUsersToWorkList(*NewI);
128c0981da4SDimitry Andric Worklist.pushValue(NewI);
129cfca06d7SDimitry Andric }
130c0981da4SDimitry Andric Worklist.pushValue(&Old);
131c0981da4SDimitry Andric }
132c0981da4SDimitry Andric
eraseInstruction(Instruction & I)133c0981da4SDimitry Andric void eraseInstruction(Instruction &I) {
134c0981da4SDimitry Andric for (Value *Op : I.operands())
135c0981da4SDimitry Andric Worklist.pushValue(Op);
136c0981da4SDimitry Andric Worklist.remove(&I);
137c0981da4SDimitry Andric I.eraseFromParent();
138c0981da4SDimitry Andric }
139c0981da4SDimitry Andric };
140c0981da4SDimitry Andric } // namespace
141cfca06d7SDimitry Andric
142ac9a064cSDimitry Andric /// Return the source operand of a potentially bitcasted value. If there is no
143ac9a064cSDimitry Andric /// bitcast, return the input value itself.
peekThroughBitcasts(Value * V)144ac9a064cSDimitry Andric static Value *peekThroughBitcasts(Value *V) {
145ac9a064cSDimitry Andric while (auto *BitCast = dyn_cast<BitCastInst>(V))
146ac9a064cSDimitry Andric V = BitCast->getOperand(0);
147ac9a064cSDimitry Andric return V;
148ac9a064cSDimitry Andric }
149ac9a064cSDimitry Andric
canWidenLoad(LoadInst * Load,const TargetTransformInfo & TTI)150e3b55780SDimitry Andric static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) {
151e3b55780SDimitry Andric // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan.
152e3b55780SDimitry Andric // The widened load may load data from dirty regions or create data races
153e3b55780SDimitry Andric // non-existent in the source.
154e3b55780SDimitry Andric if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
155e3b55780SDimitry Andric Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) ||
156e3b55780SDimitry Andric mustSuppressSpeculation(*Load))
157e3b55780SDimitry Andric return false;
158e3b55780SDimitry Andric
159e3b55780SDimitry Andric // We are potentially transforming byte-sized (8-bit) memory accesses, so make
160e3b55780SDimitry Andric // sure we have all of our type-based constraints in place for this target.
161e3b55780SDimitry Andric Type *ScalarTy = Load->getType()->getScalarType();
162e3b55780SDimitry Andric uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
163e3b55780SDimitry Andric unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
164e3b55780SDimitry Andric if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
165e3b55780SDimitry Andric ScalarSize % 8 != 0)
166e3b55780SDimitry Andric return false;
167e3b55780SDimitry Andric
168e3b55780SDimitry Andric return true;
169e3b55780SDimitry Andric }
170e3b55780SDimitry Andric
vectorizeLoadInsert(Instruction & I)171b60736ecSDimitry Andric bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
172b60736ecSDimitry Andric // Match insert into fixed vector of scalar value.
173b60736ecSDimitry Andric // TODO: Handle non-zero insert index.
174b60736ecSDimitry Andric Value *Scalar;
175e3b55780SDimitry Andric if (!match(&I, m_InsertElt(m_Undef(), m_Value(Scalar), m_ZeroInt())) ||
176b60736ecSDimitry Andric !Scalar->hasOneUse())
177b60736ecSDimitry Andric return false;
178b60736ecSDimitry Andric
179b60736ecSDimitry Andric // Optionally match an extract from another vector.
180b60736ecSDimitry Andric Value *X;
181b60736ecSDimitry Andric bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt()));
182b60736ecSDimitry Andric if (!HasExtract)
183b60736ecSDimitry Andric X = Scalar;
184b60736ecSDimitry Andric
185b60736ecSDimitry Andric auto *Load = dyn_cast<LoadInst>(X);
186e3b55780SDimitry Andric if (!canWidenLoad(Load, TTI))
187b60736ecSDimitry Andric return false;
188b60736ecSDimitry Andric
189b60736ecSDimitry Andric Type *ScalarTy = Scalar->getType();
190b60736ecSDimitry Andric uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
191b60736ecSDimitry Andric unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
192b60736ecSDimitry Andric
193b60736ecSDimitry Andric // Check safety of replacing the scalar load with a larger vector load.
194b60736ecSDimitry Andric // We use minimal alignment (maximum flexibility) because we only care about
195b60736ecSDimitry Andric // the dereferenceable region. When calculating cost and creating a new op,
196b60736ecSDimitry Andric // we may use a larger value based on alignment attributes.
197e3b55780SDimitry Andric Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
198e3b55780SDimitry Andric assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
199e3b55780SDimitry Andric
200b60736ecSDimitry Andric unsigned MinVecNumElts = MinVectorSize / ScalarSize;
201b60736ecSDimitry Andric auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false);
202b60736ecSDimitry Andric unsigned OffsetEltIndex = 0;
203b60736ecSDimitry Andric Align Alignment = Load->getAlign();
204ac9a064cSDimitry Andric if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC,
205e3b55780SDimitry Andric &DT)) {
206b60736ecSDimitry Andric // It is not safe to load directly from the pointer, but we can still peek
207b60736ecSDimitry Andric // through gep offsets and check if it safe to load from a base address with
208b60736ecSDimitry Andric // updated alignment. If it is, we can shuffle the element(s) into place
209b60736ecSDimitry Andric // after loading.
210ac9a064cSDimitry Andric unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(SrcPtr->getType());
211b60736ecSDimitry Andric APInt Offset(OffsetBitWidth, 0);
212ac9a064cSDimitry Andric SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(*DL, Offset);
213b60736ecSDimitry Andric
214b60736ecSDimitry Andric // We want to shuffle the result down from a high element of a vector, so
215b60736ecSDimitry Andric // the offset must be positive.
216b60736ecSDimitry Andric if (Offset.isNegative())
217b60736ecSDimitry Andric return false;
218b60736ecSDimitry Andric
219b60736ecSDimitry Andric // The offset must be a multiple of the scalar element to shuffle cleanly
220b60736ecSDimitry Andric // in the element's size.
221b60736ecSDimitry Andric uint64_t ScalarSizeInBytes = ScalarSize / 8;
222b60736ecSDimitry Andric if (Offset.urem(ScalarSizeInBytes) != 0)
223b60736ecSDimitry Andric return false;
224b60736ecSDimitry Andric
225b60736ecSDimitry Andric // If we load MinVecNumElts, will our target element still be loaded?
226b60736ecSDimitry Andric OffsetEltIndex = Offset.udiv(ScalarSizeInBytes).getZExtValue();
227b60736ecSDimitry Andric if (OffsetEltIndex >= MinVecNumElts)
228b60736ecSDimitry Andric return false;
229b60736ecSDimitry Andric
230ac9a064cSDimitry Andric if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC,
231e3b55780SDimitry Andric &DT))
232b60736ecSDimitry Andric return false;
233b60736ecSDimitry Andric
234b60736ecSDimitry Andric // Update alignment with offset value. Note that the offset could be negated
235b60736ecSDimitry Andric // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
236b60736ecSDimitry Andric // negation does not change the result of the alignment calculation.
237b60736ecSDimitry Andric Alignment = commonAlignment(Alignment, Offset.getZExtValue());
238b60736ecSDimitry Andric }
239b60736ecSDimitry Andric
240b60736ecSDimitry Andric // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
241b60736ecSDimitry Andric // Use the greater of the alignment on the load or its source pointer.
242ac9a064cSDimitry Andric Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
243b60736ecSDimitry Andric Type *LoadTy = Load->getType();
244e3b55780SDimitry Andric unsigned AS = Load->getPointerAddressSpace();
245b60736ecSDimitry Andric InstructionCost OldCost =
246b60736ecSDimitry Andric TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS);
247b60736ecSDimitry Andric APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0);
248e3b55780SDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
249e3b55780SDimitry Andric OldCost +=
250e3b55780SDimitry Andric TTI.getScalarizationOverhead(MinVecTy, DemandedElts,
251e3b55780SDimitry Andric /* Insert */ true, HasExtract, CostKind);
252b60736ecSDimitry Andric
253b60736ecSDimitry Andric // New pattern: load VecPtr
254b60736ecSDimitry Andric InstructionCost NewCost =
255b60736ecSDimitry Andric TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS);
256b60736ecSDimitry Andric // Optionally, we are shuffling the loaded vector element(s) into place.
257344a3780SDimitry Andric // For the mask set everything but element 0 to undef to prevent poison from
258344a3780SDimitry Andric // propagating from the extra loaded memory. This will also optionally
259344a3780SDimitry Andric // shrink/grow the vector from the loaded size to the output size.
260344a3780SDimitry Andric // We assume this operation has no cost in codegen if there was no offset.
261344a3780SDimitry Andric // Note that we could use freeze to avoid poison problems, but then we might
262344a3780SDimitry Andric // still need a shuffle to change the vector size.
263e3b55780SDimitry Andric auto *Ty = cast<FixedVectorType>(I.getType());
264344a3780SDimitry Andric unsigned OutputNumElts = Ty->getNumElements();
2657fa27ce4SDimitry Andric SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
266344a3780SDimitry Andric assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
267344a3780SDimitry Andric Mask[0] = OffsetEltIndex;
268b60736ecSDimitry Andric if (OffsetEltIndex)
269344a3780SDimitry Andric NewCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, MinVecTy, Mask);
270b60736ecSDimitry Andric
271b60736ecSDimitry Andric // We can aggressively convert to the vector form because the backend can
272b60736ecSDimitry Andric // invert this transform if it does not result in a performance win.
273b60736ecSDimitry Andric if (OldCost < NewCost || !NewCost.isValid())
274b60736ecSDimitry Andric return false;
275b60736ecSDimitry Andric
276b60736ecSDimitry Andric // It is safe and potentially profitable to load a vector directly:
277b60736ecSDimitry Andric // inselt undef, load Scalar, 0 --> load VecPtr
278b60736ecSDimitry Andric IRBuilder<> Builder(Load);
279b1c73532SDimitry Andric Value *CastedPtr =
280b1c73532SDimitry Andric Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
281b60736ecSDimitry Andric Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment);
282b60736ecSDimitry Andric VecLd = Builder.CreateShuffleVector(VecLd, Mask);
283b60736ecSDimitry Andric
284b60736ecSDimitry Andric replaceValue(I, *VecLd);
285b60736ecSDimitry Andric ++NumVecLoad;
286b60736ecSDimitry Andric return true;
287b60736ecSDimitry Andric }
288b60736ecSDimitry Andric
289e3b55780SDimitry Andric /// If we are loading a vector and then inserting it into a larger vector with
290e3b55780SDimitry Andric /// undefined elements, try to load the larger vector and eliminate the insert.
291e3b55780SDimitry Andric /// This removes a shuffle in IR and may allow combining of other loaded values.
widenSubvectorLoad(Instruction & I)292e3b55780SDimitry Andric bool VectorCombine::widenSubvectorLoad(Instruction &I) {
293e3b55780SDimitry Andric // Match subvector insert of fixed vector.
294e3b55780SDimitry Andric auto *Shuf = cast<ShuffleVectorInst>(&I);
295e3b55780SDimitry Andric if (!Shuf->isIdentityWithPadding())
296e3b55780SDimitry Andric return false;
297e3b55780SDimitry Andric
298e3b55780SDimitry Andric // Allow a non-canonical shuffle mask that is choosing elements from op1.
299e3b55780SDimitry Andric unsigned NumOpElts =
300e3b55780SDimitry Andric cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements();
301e3b55780SDimitry Andric unsigned OpIndex = any_of(Shuf->getShuffleMask(), [&NumOpElts](int M) {
302e3b55780SDimitry Andric return M >= (int)(NumOpElts);
303e3b55780SDimitry Andric });
304e3b55780SDimitry Andric
305e3b55780SDimitry Andric auto *Load = dyn_cast<LoadInst>(Shuf->getOperand(OpIndex));
306e3b55780SDimitry Andric if (!canWidenLoad(Load, TTI))
307e3b55780SDimitry Andric return false;
308e3b55780SDimitry Andric
309e3b55780SDimitry Andric // We use minimal alignment (maximum flexibility) because we only care about
310e3b55780SDimitry Andric // the dereferenceable region. When calculating cost and creating a new op,
311e3b55780SDimitry Andric // we may use a larger value based on alignment attributes.
312e3b55780SDimitry Andric auto *Ty = cast<FixedVectorType>(I.getType());
313e3b55780SDimitry Andric Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
314e3b55780SDimitry Andric assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
315e3b55780SDimitry Andric Align Alignment = Load->getAlign();
316ac9a064cSDimitry Andric if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), *DL, Load, &AC, &DT))
317e3b55780SDimitry Andric return false;
318e3b55780SDimitry Andric
319ac9a064cSDimitry Andric Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
320e3b55780SDimitry Andric Type *LoadTy = Load->getType();
321e3b55780SDimitry Andric unsigned AS = Load->getPointerAddressSpace();
322e3b55780SDimitry Andric
323e3b55780SDimitry Andric // Original pattern: insert_subvector (load PtrOp)
324e3b55780SDimitry Andric // This conservatively assumes that the cost of a subvector insert into an
325e3b55780SDimitry Andric // undef value is 0. We could add that cost if the cost model accurately
326e3b55780SDimitry Andric // reflects the real cost of that operation.
327e3b55780SDimitry Andric InstructionCost OldCost =
328e3b55780SDimitry Andric TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS);
329e3b55780SDimitry Andric
330e3b55780SDimitry Andric // New pattern: load PtrOp
331e3b55780SDimitry Andric InstructionCost NewCost =
332e3b55780SDimitry Andric TTI.getMemoryOpCost(Instruction::Load, Ty, Alignment, AS);
333e3b55780SDimitry Andric
334e3b55780SDimitry Andric // We can aggressively convert to the vector form because the backend can
335e3b55780SDimitry Andric // invert this transform if it does not result in a performance win.
336e3b55780SDimitry Andric if (OldCost < NewCost || !NewCost.isValid())
337e3b55780SDimitry Andric return false;
338e3b55780SDimitry Andric
339e3b55780SDimitry Andric IRBuilder<> Builder(Load);
340e3b55780SDimitry Andric Value *CastedPtr =
341b1c73532SDimitry Andric Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
342e3b55780SDimitry Andric Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment);
343e3b55780SDimitry Andric replaceValue(I, *VecLd);
344e3b55780SDimitry Andric ++NumVecLoad;
345e3b55780SDimitry Andric return true;
346e3b55780SDimitry Andric }
347e3b55780SDimitry Andric
348cfca06d7SDimitry Andric /// Determine which, if any, of the inputs should be replaced by a shuffle
349cfca06d7SDimitry Andric /// followed by extract from a different index.
getShuffleExtract(ExtractElementInst * Ext0,ExtractElementInst * Ext1,unsigned PreferredExtractIndex=InvalidIndex) const350cfca06d7SDimitry Andric ExtractElementInst *VectorCombine::getShuffleExtract(
351cfca06d7SDimitry Andric ExtractElementInst *Ext0, ExtractElementInst *Ext1,
352cfca06d7SDimitry Andric unsigned PreferredExtractIndex = InvalidIndex) const {
353145449b1SDimitry Andric auto *Index0C = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
354145449b1SDimitry Andric auto *Index1C = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
355145449b1SDimitry Andric assert(Index0C && Index1C && "Expected constant extract indexes");
356cfca06d7SDimitry Andric
357145449b1SDimitry Andric unsigned Index0 = Index0C->getZExtValue();
358145449b1SDimitry Andric unsigned Index1 = Index1C->getZExtValue();
359cfca06d7SDimitry Andric
360cfca06d7SDimitry Andric // If the extract indexes are identical, no shuffle is needed.
361cfca06d7SDimitry Andric if (Index0 == Index1)
362cfca06d7SDimitry Andric return nullptr;
363cfca06d7SDimitry Andric
364cfca06d7SDimitry Andric Type *VecTy = Ext0->getVectorOperand()->getType();
365e3b55780SDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
366cfca06d7SDimitry Andric assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
367b60736ecSDimitry Andric InstructionCost Cost0 =
368e3b55780SDimitry Andric TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
369b60736ecSDimitry Andric InstructionCost Cost1 =
370e3b55780SDimitry Andric TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
371b60736ecSDimitry Andric
372b60736ecSDimitry Andric // If both costs are invalid no shuffle is needed
373b60736ecSDimitry Andric if (!Cost0.isValid() && !Cost1.isValid())
374b60736ecSDimitry Andric return nullptr;
375cfca06d7SDimitry Andric
376cfca06d7SDimitry Andric // We are extracting from 2 different indexes, so one operand must be shuffled
377cfca06d7SDimitry Andric // before performing a vector operation and/or extract. The more expensive
378cfca06d7SDimitry Andric // extract will be replaced by a shuffle.
379cfca06d7SDimitry Andric if (Cost0 > Cost1)
380cfca06d7SDimitry Andric return Ext0;
381cfca06d7SDimitry Andric if (Cost1 > Cost0)
382cfca06d7SDimitry Andric return Ext1;
383cfca06d7SDimitry Andric
384cfca06d7SDimitry Andric // If the costs are equal and there is a preferred extract index, shuffle the
385cfca06d7SDimitry Andric // opposite operand.
386cfca06d7SDimitry Andric if (PreferredExtractIndex == Index0)
387cfca06d7SDimitry Andric return Ext1;
388cfca06d7SDimitry Andric if (PreferredExtractIndex == Index1)
389cfca06d7SDimitry Andric return Ext0;
390cfca06d7SDimitry Andric
391cfca06d7SDimitry Andric // Otherwise, replace the extract with the higher index.
392cfca06d7SDimitry Andric return Index0 > Index1 ? Ext0 : Ext1;
393cfca06d7SDimitry Andric }
394cfca06d7SDimitry Andric
395cfca06d7SDimitry Andric /// Compare the relative costs of 2 extracts followed by scalar operation vs.
396cfca06d7SDimitry Andric /// vector operation(s) followed by extract. Return true if the existing
397cfca06d7SDimitry Andric /// instructions are cheaper than a vector alternative. Otherwise, return false
398cfca06d7SDimitry Andric /// and if one of the extracts should be transformed to a shufflevector, set
399cfca06d7SDimitry Andric /// \p ConvertToShuffle to that extract instruction.
isExtractExtractCheap(ExtractElementInst * Ext0,ExtractElementInst * Ext1,const Instruction & I,ExtractElementInst * & ConvertToShuffle,unsigned PreferredExtractIndex)400cfca06d7SDimitry Andric bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
401cfca06d7SDimitry Andric ExtractElementInst *Ext1,
402c0981da4SDimitry Andric const Instruction &I,
403cfca06d7SDimitry Andric ExtractElementInst *&ConvertToShuffle,
404cfca06d7SDimitry Andric unsigned PreferredExtractIndex) {
405145449b1SDimitry Andric auto *Ext0IndexC = dyn_cast<ConstantInt>(Ext0->getOperand(1));
406145449b1SDimitry Andric auto *Ext1IndexC = dyn_cast<ConstantInt>(Ext1->getOperand(1));
407145449b1SDimitry Andric assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes");
408145449b1SDimitry Andric
409c0981da4SDimitry Andric unsigned Opcode = I.getOpcode();
410cfca06d7SDimitry Andric Type *ScalarTy = Ext0->getType();
411cfca06d7SDimitry Andric auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType());
412b60736ecSDimitry Andric InstructionCost ScalarOpCost, VectorOpCost;
413cfca06d7SDimitry Andric
414cfca06d7SDimitry Andric // Get cost estimates for scalar and vector versions of the operation.
415cfca06d7SDimitry Andric bool IsBinOp = Instruction::isBinaryOp(Opcode);
416cfca06d7SDimitry Andric if (IsBinOp) {
417cfca06d7SDimitry Andric ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
418cfca06d7SDimitry Andric VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
419cfca06d7SDimitry Andric } else {
420cfca06d7SDimitry Andric assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
421cfca06d7SDimitry Andric "Expected a compare");
422c0981da4SDimitry Andric CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
423c0981da4SDimitry Andric ScalarOpCost = TTI.getCmpSelInstrCost(
424c0981da4SDimitry Andric Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred);
425c0981da4SDimitry Andric VectorOpCost = TTI.getCmpSelInstrCost(
426c0981da4SDimitry Andric Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred);
427cfca06d7SDimitry Andric }
428cfca06d7SDimitry Andric
429cfca06d7SDimitry Andric // Get cost estimates for the extract elements. These costs will factor into
430cfca06d7SDimitry Andric // both sequences.
431145449b1SDimitry Andric unsigned Ext0Index = Ext0IndexC->getZExtValue();
432145449b1SDimitry Andric unsigned Ext1Index = Ext1IndexC->getZExtValue();
433e3b55780SDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
434cfca06d7SDimitry Andric
435b60736ecSDimitry Andric InstructionCost Extract0Cost =
436e3b55780SDimitry Andric TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Ext0Index);
437b60736ecSDimitry Andric InstructionCost Extract1Cost =
438e3b55780SDimitry Andric TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Ext1Index);
439cfca06d7SDimitry Andric
440cfca06d7SDimitry Andric // A more expensive extract will always be replaced by a splat shuffle.
441cfca06d7SDimitry Andric // For example, if Ext0 is more expensive:
442cfca06d7SDimitry Andric // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
443cfca06d7SDimitry Andric // extelt (opcode (splat V0, Ext0), V1), Ext1
444cfca06d7SDimitry Andric // TODO: Evaluate whether that always results in lowest cost. Alternatively,
445cfca06d7SDimitry Andric // check the cost of creating a broadcast shuffle and shuffling both
446cfca06d7SDimitry Andric // operands to element 0.
447b60736ecSDimitry Andric InstructionCost CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
448cfca06d7SDimitry Andric
449cfca06d7SDimitry Andric // Extra uses of the extracts mean that we include those costs in the
450cfca06d7SDimitry Andric // vector total because those instructions will not be eliminated.
451b60736ecSDimitry Andric InstructionCost OldCost, NewCost;
452cfca06d7SDimitry Andric if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) {
453cfca06d7SDimitry Andric // Handle a special case. If the 2 extracts are identical, adjust the
454cfca06d7SDimitry Andric // formulas to account for that. The extra use charge allows for either the
455cfca06d7SDimitry Andric // CSE'd pattern or an unoptimized form with identical values:
456cfca06d7SDimitry Andric // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
457cfca06d7SDimitry Andric bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
458cfca06d7SDimitry Andric : !Ext0->hasOneUse() || !Ext1->hasOneUse();
459cfca06d7SDimitry Andric OldCost = CheapExtractCost + ScalarOpCost;
460cfca06d7SDimitry Andric NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
461cfca06d7SDimitry Andric } else {
462cfca06d7SDimitry Andric // Handle the general case. Each extract is actually a different value:
463cfca06d7SDimitry Andric // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
464cfca06d7SDimitry Andric OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
465cfca06d7SDimitry Andric NewCost = VectorOpCost + CheapExtractCost +
466cfca06d7SDimitry Andric !Ext0->hasOneUse() * Extract0Cost +
467cfca06d7SDimitry Andric !Ext1->hasOneUse() * Extract1Cost;
468cfca06d7SDimitry Andric }
469cfca06d7SDimitry Andric
470cfca06d7SDimitry Andric ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
471cfca06d7SDimitry Andric if (ConvertToShuffle) {
472cfca06d7SDimitry Andric if (IsBinOp && DisableBinopExtractShuffle)
473cfca06d7SDimitry Andric return true;
474cfca06d7SDimitry Andric
475cfca06d7SDimitry Andric // If we are extracting from 2 different indexes, then one operand must be
476cfca06d7SDimitry Andric // shuffled before performing the vector operation. The shuffle mask is
4777fa27ce4SDimitry Andric // poison except for 1 lane that is being translated to the remaining
478cfca06d7SDimitry Andric // extraction lane. Therefore, it is a splat shuffle. Ex:
4797fa27ce4SDimitry Andric // ShufMask = { poison, poison, 0, poison }
480cfca06d7SDimitry Andric // TODO: The cost model has an option for a "broadcast" shuffle
481cfca06d7SDimitry Andric // (splat-from-element-0), but no option for a more general splat.
482cfca06d7SDimitry Andric NewCost +=
483cfca06d7SDimitry Andric TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy);
484cfca06d7SDimitry Andric }
485cfca06d7SDimitry Andric
486cfca06d7SDimitry Andric // Aggressively form a vector op if the cost is equal because the transform
487cfca06d7SDimitry Andric // may enable further optimization.
488cfca06d7SDimitry Andric // Codegen can reverse this transform (scalarize) if it was not profitable.
489cfca06d7SDimitry Andric return OldCost < NewCost;
490cfca06d7SDimitry Andric }
491cfca06d7SDimitry Andric
492cfca06d7SDimitry Andric /// Create a shuffle that translates (shifts) 1 element from the input vector
493cfca06d7SDimitry Andric /// to a new element location.
createShiftShuffle(Value * Vec,unsigned OldIndex,unsigned NewIndex,IRBuilder<> & Builder)494cfca06d7SDimitry Andric static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
495cfca06d7SDimitry Andric unsigned NewIndex, IRBuilder<> &Builder) {
4967fa27ce4SDimitry Andric // The shuffle mask is poison except for 1 lane that is being translated
497cfca06d7SDimitry Andric // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
4987fa27ce4SDimitry Andric // ShufMask = { 2, poison, poison, poison }
499cfca06d7SDimitry Andric auto *VecTy = cast<FixedVectorType>(Vec->getType());
5007fa27ce4SDimitry Andric SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
501cfca06d7SDimitry Andric ShufMask[NewIndex] = OldIndex;
502b60736ecSDimitry Andric return Builder.CreateShuffleVector(Vec, ShufMask, "shift");
503cfca06d7SDimitry Andric }
504cfca06d7SDimitry Andric
505cfca06d7SDimitry Andric /// Given an extract element instruction with constant index operand, shuffle
506cfca06d7SDimitry Andric /// the source vector (shift the scalar element) to a NewIndex for extraction.
507cfca06d7SDimitry Andric /// Return null if the input can be constant folded, so that we are not creating
508cfca06d7SDimitry Andric /// unnecessary instructions.
translateExtract(ExtractElementInst * ExtElt,unsigned NewIndex,IRBuilder<> & Builder)509cfca06d7SDimitry Andric static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
510cfca06d7SDimitry Andric unsigned NewIndex,
511cfca06d7SDimitry Andric IRBuilder<> &Builder) {
5121f917f69SDimitry Andric // Shufflevectors can only be created for fixed-width vectors.
5131f917f69SDimitry Andric if (!isa<FixedVectorType>(ExtElt->getOperand(0)->getType()))
5141f917f69SDimitry Andric return nullptr;
5151f917f69SDimitry Andric
516cfca06d7SDimitry Andric // If the extract can be constant-folded, this code is unsimplified. Defer
517cfca06d7SDimitry Andric // to other passes to handle that.
518cfca06d7SDimitry Andric Value *X = ExtElt->getVectorOperand();
519cfca06d7SDimitry Andric Value *C = ExtElt->getIndexOperand();
520cfca06d7SDimitry Andric assert(isa<ConstantInt>(C) && "Expected a constant index operand");
521cfca06d7SDimitry Andric if (isa<Constant>(X))
522cfca06d7SDimitry Andric return nullptr;
523cfca06d7SDimitry Andric
524cfca06d7SDimitry Andric Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
525cfca06d7SDimitry Andric NewIndex, Builder);
526cfca06d7SDimitry Andric return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex));
527cfca06d7SDimitry Andric }
528cfca06d7SDimitry Andric
529cfca06d7SDimitry Andric /// Try to reduce extract element costs by converting scalar compares to vector
530cfca06d7SDimitry Andric /// compares followed by extract.
531cfca06d7SDimitry Andric /// cmp (ext0 V0, C), (ext1 V1, C)
foldExtExtCmp(ExtractElementInst * Ext0,ExtractElementInst * Ext1,Instruction & I)532cfca06d7SDimitry Andric void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0,
533cfca06d7SDimitry Andric ExtractElementInst *Ext1, Instruction &I) {
534cfca06d7SDimitry Andric assert(isa<CmpInst>(&I) && "Expected a compare");
535cfca06d7SDimitry Andric assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
536cfca06d7SDimitry Andric cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
537cfca06d7SDimitry Andric "Expected matching constant extract indexes");
538cfca06d7SDimitry Andric
539cfca06d7SDimitry Andric // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
540cfca06d7SDimitry Andric ++NumVecCmp;
541cfca06d7SDimitry Andric CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
542cfca06d7SDimitry Andric Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
543cfca06d7SDimitry Andric Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
544cfca06d7SDimitry Andric Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand());
545cfca06d7SDimitry Andric replaceValue(I, *NewExt);
546cfca06d7SDimitry Andric }
547cfca06d7SDimitry Andric
548cfca06d7SDimitry Andric /// Try to reduce extract element costs by converting scalar binops to vector
549cfca06d7SDimitry Andric /// binops followed by extract.
550cfca06d7SDimitry Andric /// bo (ext0 V0, C), (ext1 V1, C)
foldExtExtBinop(ExtractElementInst * Ext0,ExtractElementInst * Ext1,Instruction & I)551cfca06d7SDimitry Andric void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0,
552cfca06d7SDimitry Andric ExtractElementInst *Ext1, Instruction &I) {
553cfca06d7SDimitry Andric assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
554cfca06d7SDimitry Andric assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
555cfca06d7SDimitry Andric cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
556cfca06d7SDimitry Andric "Expected matching constant extract indexes");
557cfca06d7SDimitry Andric
558cfca06d7SDimitry Andric // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
559cfca06d7SDimitry Andric ++NumVecBO;
560cfca06d7SDimitry Andric Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
561cfca06d7SDimitry Andric Value *VecBO =
562cfca06d7SDimitry Andric Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
563cfca06d7SDimitry Andric
564cfca06d7SDimitry Andric // All IR flags are safe to back-propagate because any potential poison
565cfca06d7SDimitry Andric // created in unused vector elements is discarded by the extract.
566cfca06d7SDimitry Andric if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
567cfca06d7SDimitry Andric VecBOInst->copyIRFlags(&I);
568cfca06d7SDimitry Andric
569cfca06d7SDimitry Andric Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand());
570cfca06d7SDimitry Andric replaceValue(I, *NewExt);
571cfca06d7SDimitry Andric }
572cfca06d7SDimitry Andric
573cfca06d7SDimitry Andric /// Match an instruction with extracted vector operands.
foldExtractExtract(Instruction & I)574cfca06d7SDimitry Andric bool VectorCombine::foldExtractExtract(Instruction &I) {
575cfca06d7SDimitry Andric // It is not safe to transform things like div, urem, etc. because we may
576cfca06d7SDimitry Andric // create undefined behavior when executing those on unknown vector elements.
577cfca06d7SDimitry Andric if (!isSafeToSpeculativelyExecute(&I))
578cfca06d7SDimitry Andric return false;
579cfca06d7SDimitry Andric
580cfca06d7SDimitry Andric Instruction *I0, *I1;
581cfca06d7SDimitry Andric CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
582cfca06d7SDimitry Andric if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
583cfca06d7SDimitry Andric !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1))))
584cfca06d7SDimitry Andric return false;
585cfca06d7SDimitry Andric
586cfca06d7SDimitry Andric Value *V0, *V1;
587cfca06d7SDimitry Andric uint64_t C0, C1;
588cfca06d7SDimitry Andric if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
589cfca06d7SDimitry Andric !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) ||
590cfca06d7SDimitry Andric V0->getType() != V1->getType())
591cfca06d7SDimitry Andric return false;
592cfca06d7SDimitry Andric
593cfca06d7SDimitry Andric // If the scalar value 'I' is going to be re-inserted into a vector, then try
594cfca06d7SDimitry Andric // to create an extract to that same element. The extract/insert can be
595cfca06d7SDimitry Andric // reduced to a "select shuffle".
596cfca06d7SDimitry Andric // TODO: If we add a larger pattern match that starts from an insert, this
597cfca06d7SDimitry Andric // probably becomes unnecessary.
598cfca06d7SDimitry Andric auto *Ext0 = cast<ExtractElementInst>(I0);
599cfca06d7SDimitry Andric auto *Ext1 = cast<ExtractElementInst>(I1);
600cfca06d7SDimitry Andric uint64_t InsertIndex = InvalidIndex;
601cfca06d7SDimitry Andric if (I.hasOneUse())
602cfca06d7SDimitry Andric match(I.user_back(),
603cfca06d7SDimitry Andric m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
604cfca06d7SDimitry Andric
605cfca06d7SDimitry Andric ExtractElementInst *ExtractToChange;
606c0981da4SDimitry Andric if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex))
607cfca06d7SDimitry Andric return false;
608cfca06d7SDimitry Andric
609cfca06d7SDimitry Andric if (ExtractToChange) {
610cfca06d7SDimitry Andric unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
611cfca06d7SDimitry Andric ExtractElementInst *NewExtract =
612cfca06d7SDimitry Andric translateExtract(ExtractToChange, CheapExtractIdx, Builder);
613cfca06d7SDimitry Andric if (!NewExtract)
614cfca06d7SDimitry Andric return false;
615cfca06d7SDimitry Andric if (ExtractToChange == Ext0)
616cfca06d7SDimitry Andric Ext0 = NewExtract;
617cfca06d7SDimitry Andric else
618cfca06d7SDimitry Andric Ext1 = NewExtract;
619cfca06d7SDimitry Andric }
620cfca06d7SDimitry Andric
621cfca06d7SDimitry Andric if (Pred != CmpInst::BAD_ICMP_PREDICATE)
622cfca06d7SDimitry Andric foldExtExtCmp(Ext0, Ext1, I);
623cfca06d7SDimitry Andric else
624cfca06d7SDimitry Andric foldExtExtBinop(Ext0, Ext1, I);
625cfca06d7SDimitry Andric
626c0981da4SDimitry Andric Worklist.push(Ext0);
627c0981da4SDimitry Andric Worklist.push(Ext1);
628cfca06d7SDimitry Andric return true;
629cfca06d7SDimitry Andric }
630cfca06d7SDimitry Andric
631e3b55780SDimitry Andric /// Try to replace an extract + scalar fneg + insert with a vector fneg +
632e3b55780SDimitry Andric /// shuffle.
foldInsExtFNeg(Instruction & I)633e3b55780SDimitry Andric bool VectorCombine::foldInsExtFNeg(Instruction &I) {
634e3b55780SDimitry Andric // Match an insert (op (extract)) pattern.
635e3b55780SDimitry Andric Value *DestVec;
636e3b55780SDimitry Andric uint64_t Index;
637e3b55780SDimitry Andric Instruction *FNeg;
638e3b55780SDimitry Andric if (!match(&I, m_InsertElt(m_Value(DestVec), m_OneUse(m_Instruction(FNeg)),
639e3b55780SDimitry Andric m_ConstantInt(Index))))
640e3b55780SDimitry Andric return false;
641e3b55780SDimitry Andric
642e3b55780SDimitry Andric // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
643e3b55780SDimitry Andric Value *SrcVec;
644e3b55780SDimitry Andric Instruction *Extract;
645e3b55780SDimitry Andric if (!match(FNeg, m_FNeg(m_CombineAnd(
646e3b55780SDimitry Andric m_Instruction(Extract),
647e3b55780SDimitry Andric m_ExtractElt(m_Value(SrcVec), m_SpecificInt(Index))))))
648e3b55780SDimitry Andric return false;
649e3b55780SDimitry Andric
650e3b55780SDimitry Andric // TODO: We could handle this with a length-changing shuffle.
651e3b55780SDimitry Andric auto *VecTy = cast<FixedVectorType>(I.getType());
652e3b55780SDimitry Andric if (SrcVec->getType() != VecTy)
653e3b55780SDimitry Andric return false;
654e3b55780SDimitry Andric
655e3b55780SDimitry Andric // Ignore bogus insert/extract index.
656e3b55780SDimitry Andric unsigned NumElts = VecTy->getNumElements();
657e3b55780SDimitry Andric if (Index >= NumElts)
658e3b55780SDimitry Andric return false;
659e3b55780SDimitry Andric
660e3b55780SDimitry Andric // We are inserting the negated element into the same lane that we extracted
661e3b55780SDimitry Andric // from. This is equivalent to a select-shuffle that chooses all but the
662e3b55780SDimitry Andric // negated element from the destination vector.
663e3b55780SDimitry Andric SmallVector<int> Mask(NumElts);
664e3b55780SDimitry Andric std::iota(Mask.begin(), Mask.end(), 0);
665e3b55780SDimitry Andric Mask[Index] = Index + NumElts;
666e3b55780SDimitry Andric
667e3b55780SDimitry Andric Type *ScalarTy = VecTy->getScalarType();
668e3b55780SDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
669e3b55780SDimitry Andric InstructionCost OldCost =
670e3b55780SDimitry Andric TTI.getArithmeticInstrCost(Instruction::FNeg, ScalarTy) +
671e3b55780SDimitry Andric TTI.getVectorInstrCost(I, VecTy, CostKind, Index);
672e3b55780SDimitry Andric
673e3b55780SDimitry Andric // If the extract has one use, it will be eliminated, so count it in the
674e3b55780SDimitry Andric // original cost. If it has more than one use, ignore the cost because it will
675e3b55780SDimitry Andric // be the same before/after.
676e3b55780SDimitry Andric if (Extract->hasOneUse())
677e3b55780SDimitry Andric OldCost += TTI.getVectorInstrCost(*Extract, VecTy, CostKind, Index);
678e3b55780SDimitry Andric
679e3b55780SDimitry Andric InstructionCost NewCost =
680e3b55780SDimitry Andric TTI.getArithmeticInstrCost(Instruction::FNeg, VecTy) +
681e3b55780SDimitry Andric TTI.getShuffleCost(TargetTransformInfo::SK_Select, VecTy, Mask);
682e3b55780SDimitry Andric
683e3b55780SDimitry Andric if (NewCost > OldCost)
684e3b55780SDimitry Andric return false;
685e3b55780SDimitry Andric
686e3b55780SDimitry Andric // insertelt DestVec, (fneg (extractelt SrcVec, Index)), Index -->
687e3b55780SDimitry Andric // shuffle DestVec, (fneg SrcVec), Mask
688e3b55780SDimitry Andric Value *VecFNeg = Builder.CreateFNegFMF(SrcVec, FNeg);
689e3b55780SDimitry Andric Value *Shuf = Builder.CreateShuffleVector(DestVec, VecFNeg, Mask);
690e3b55780SDimitry Andric replaceValue(I, *Shuf);
691e3b55780SDimitry Andric return true;
692e3b55780SDimitry Andric }
693e3b55780SDimitry Andric
694cfca06d7SDimitry Andric /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
695cfca06d7SDimitry Andric /// destination type followed by shuffle. This can enable further transforms by
696cfca06d7SDimitry Andric /// moving bitcasts or shuffles together.
foldBitcastShuffle(Instruction & I)697b1c73532SDimitry Andric bool VectorCombine::foldBitcastShuffle(Instruction &I) {
698ac9a064cSDimitry Andric Value *V0, *V1;
699cfca06d7SDimitry Andric ArrayRef<int> Mask;
700ac9a064cSDimitry Andric if (!match(&I, m_BitCast(m_OneUse(
701ac9a064cSDimitry Andric m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask))))))
702cfca06d7SDimitry Andric return false;
703cfca06d7SDimitry Andric
704b60736ecSDimitry Andric // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
705b60736ecSDimitry Andric // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
706b60736ecSDimitry Andric // mask for scalable type is a splat or not.
707b1c73532SDimitry Andric // 2) Disallow non-vector casts.
708cfca06d7SDimitry Andric // TODO: We could allow any shuffle.
709b1c73532SDimitry Andric auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
710ac9a064cSDimitry Andric auto *SrcTy = dyn_cast<FixedVectorType>(V0->getType());
711b1c73532SDimitry Andric if (!DestTy || !SrcTy)
712cfca06d7SDimitry Andric return false;
713cfca06d7SDimitry Andric
714b1c73532SDimitry Andric unsigned DestEltSize = DestTy->getScalarSizeInBits();
715b1c73532SDimitry Andric unsigned SrcEltSize = SrcTy->getScalarSizeInBits();
716b1c73532SDimitry Andric if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
717b1c73532SDimitry Andric return false;
718b1c73532SDimitry Andric
719ac9a064cSDimitry Andric bool IsUnary = isa<UndefValue>(V1);
720ac9a064cSDimitry Andric
721ac9a064cSDimitry Andric // For binary shuffles, only fold bitcast(shuffle(X,Y))
722ac9a064cSDimitry Andric // if it won't increase the number of bitcasts.
723ac9a064cSDimitry Andric if (!IsUnary) {
724ac9a064cSDimitry Andric auto *BCTy0 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V0)->getType());
725ac9a064cSDimitry Andric auto *BCTy1 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V1)->getType());
726ac9a064cSDimitry Andric if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) &&
727ac9a064cSDimitry Andric !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType()))
728ac9a064cSDimitry Andric return false;
729ac9a064cSDimitry Andric }
730ac9a064cSDimitry Andric
731cfca06d7SDimitry Andric SmallVector<int, 16> NewMask;
732b1c73532SDimitry Andric if (DestEltSize <= SrcEltSize) {
733cfca06d7SDimitry Andric // The bitcast is from wide to narrow/equal elements. The shuffle mask can
734cfca06d7SDimitry Andric // always be expanded to the equivalent form choosing narrower elements.
735b1c73532SDimitry Andric assert(SrcEltSize % DestEltSize == 0 && "Unexpected shuffle mask");
736b1c73532SDimitry Andric unsigned ScaleFactor = SrcEltSize / DestEltSize;
737cfca06d7SDimitry Andric narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
738cfca06d7SDimitry Andric } else {
739cfca06d7SDimitry Andric // The bitcast is from narrow elements to wide elements. The shuffle mask
740cfca06d7SDimitry Andric // must choose consecutive elements to allow casting first.
741b1c73532SDimitry Andric assert(DestEltSize % SrcEltSize == 0 && "Unexpected shuffle mask");
742b1c73532SDimitry Andric unsigned ScaleFactor = DestEltSize / SrcEltSize;
743cfca06d7SDimitry Andric if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
744cfca06d7SDimitry Andric return false;
745cfca06d7SDimitry Andric }
746344a3780SDimitry Andric
747b1c73532SDimitry Andric // Bitcast the shuffle src - keep its original width but using the destination
748b1c73532SDimitry Andric // scalar type.
749b1c73532SDimitry Andric unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
750ac9a064cSDimitry Andric auto *NewShuffleTy =
751ac9a064cSDimitry Andric FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
752ac9a064cSDimitry Andric auto *OldShuffleTy =
753ac9a064cSDimitry Andric FixedVectorType::get(SrcTy->getScalarType(), Mask.size());
754ac9a064cSDimitry Andric unsigned NumOps = IsUnary ? 1 : 2;
755b1c73532SDimitry Andric
756ac9a064cSDimitry Andric // The new shuffle must not cost more than the old shuffle.
757ac9a064cSDimitry Andric TargetTransformInfo::TargetCostKind CK =
758ac9a064cSDimitry Andric TargetTransformInfo::TCK_RecipThroughput;
759ac9a064cSDimitry Andric TargetTransformInfo::ShuffleKind SK =
760ac9a064cSDimitry Andric IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
761ac9a064cSDimitry Andric : TargetTransformInfo::SK_PermuteTwoSrc;
762ac9a064cSDimitry Andric
763ac9a064cSDimitry Andric InstructionCost DestCost =
764ac9a064cSDimitry Andric TTI.getShuffleCost(SK, NewShuffleTy, NewMask, CK) +
765ac9a064cSDimitry Andric (NumOps * TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy,
766ac9a064cSDimitry Andric TargetTransformInfo::CastContextHint::None,
767ac9a064cSDimitry Andric CK));
768344a3780SDimitry Andric InstructionCost SrcCost =
769ac9a064cSDimitry Andric TTI.getShuffleCost(SK, SrcTy, Mask, CK) +
770ac9a064cSDimitry Andric TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy,
771ac9a064cSDimitry Andric TargetTransformInfo::CastContextHint::None, CK);
772344a3780SDimitry Andric if (DestCost > SrcCost || !DestCost.isValid())
773344a3780SDimitry Andric return false;
774344a3780SDimitry Andric
775ac9a064cSDimitry Andric // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC'
776cfca06d7SDimitry Andric ++NumShufOfBitcast;
777ac9a064cSDimitry Andric Value *CastV0 = Builder.CreateBitCast(peekThroughBitcasts(V0), NewShuffleTy);
778ac9a064cSDimitry Andric Value *CastV1 = Builder.CreateBitCast(peekThroughBitcasts(V1), NewShuffleTy);
779ac9a064cSDimitry Andric Value *Shuf = Builder.CreateShuffleVector(CastV0, CastV1, NewMask);
780cfca06d7SDimitry Andric replaceValue(I, *Shuf);
781cfca06d7SDimitry Andric return true;
782cfca06d7SDimitry Andric }
783cfca06d7SDimitry Andric
784b1c73532SDimitry Andric /// VP Intrinsics whose vector operands are both splat values may be simplified
785b1c73532SDimitry Andric /// into the scalar version of the operation and the result splatted. This
786b1c73532SDimitry Andric /// can lead to scalarization down the line.
scalarizeVPIntrinsic(Instruction & I)787b1c73532SDimitry Andric bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
788b1c73532SDimitry Andric if (!isa<VPIntrinsic>(I))
789b1c73532SDimitry Andric return false;
790b1c73532SDimitry Andric VPIntrinsic &VPI = cast<VPIntrinsic>(I);
791b1c73532SDimitry Andric Value *Op0 = VPI.getArgOperand(0);
792b1c73532SDimitry Andric Value *Op1 = VPI.getArgOperand(1);
793b1c73532SDimitry Andric
794b1c73532SDimitry Andric if (!isSplatValue(Op0) || !isSplatValue(Op1))
795b1c73532SDimitry Andric return false;
796b1c73532SDimitry Andric
797b1c73532SDimitry Andric // Check getSplatValue early in this function, to avoid doing unnecessary
798b1c73532SDimitry Andric // work.
799b1c73532SDimitry Andric Value *ScalarOp0 = getSplatValue(Op0);
800b1c73532SDimitry Andric Value *ScalarOp1 = getSplatValue(Op1);
801b1c73532SDimitry Andric if (!ScalarOp0 || !ScalarOp1)
802b1c73532SDimitry Andric return false;
803b1c73532SDimitry Andric
804b1c73532SDimitry Andric // For the binary VP intrinsics supported here, the result on disabled lanes
805b1c73532SDimitry Andric // is a poison value. For now, only do this simplification if all lanes
806b1c73532SDimitry Andric // are active.
807b1c73532SDimitry Andric // TODO: Relax the condition that all lanes are active by using insertelement
808b1c73532SDimitry Andric // on inactive lanes.
809b1c73532SDimitry Andric auto IsAllTrueMask = [](Value *MaskVal) {
810b1c73532SDimitry Andric if (Value *SplattedVal = getSplatValue(MaskVal))
811b1c73532SDimitry Andric if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
812b1c73532SDimitry Andric return ConstValue->isAllOnesValue();
813b1c73532SDimitry Andric return false;
814b1c73532SDimitry Andric };
815b1c73532SDimitry Andric if (!IsAllTrueMask(VPI.getArgOperand(2)))
816b1c73532SDimitry Andric return false;
817b1c73532SDimitry Andric
818b1c73532SDimitry Andric // Check to make sure we support scalarization of the intrinsic
819b1c73532SDimitry Andric Intrinsic::ID IntrID = VPI.getIntrinsicID();
820b1c73532SDimitry Andric if (!VPBinOpIntrinsic::isVPBinOp(IntrID))
821b1c73532SDimitry Andric return false;
822b1c73532SDimitry Andric
823b1c73532SDimitry Andric // Calculate cost of splatting both operands into vectors and the vector
824b1c73532SDimitry Andric // intrinsic
825b1c73532SDimitry Andric VectorType *VecTy = cast<VectorType>(VPI.getType());
826b1c73532SDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
827ac9a064cSDimitry Andric SmallVector<int> Mask;
828ac9a064cSDimitry Andric if (auto *FVTy = dyn_cast<FixedVectorType>(VecTy))
829ac9a064cSDimitry Andric Mask.resize(FVTy->getNumElements(), 0);
830b1c73532SDimitry Andric InstructionCost SplatCost =
831b1c73532SDimitry Andric TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) +
832ac9a064cSDimitry Andric TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, Mask);
833b1c73532SDimitry Andric
834b1c73532SDimitry Andric // Calculate the cost of the VP Intrinsic
835b1c73532SDimitry Andric SmallVector<Type *, 4> Args;
836b1c73532SDimitry Andric for (Value *V : VPI.args())
837b1c73532SDimitry Andric Args.push_back(V->getType());
838b1c73532SDimitry Andric IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
839b1c73532SDimitry Andric InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
840b1c73532SDimitry Andric InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
841b1c73532SDimitry Andric
842b1c73532SDimitry Andric // Determine scalar opcode
843b1c73532SDimitry Andric std::optional<unsigned> FunctionalOpcode =
844b1c73532SDimitry Andric VPI.getFunctionalOpcode();
845b1c73532SDimitry Andric std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
846b1c73532SDimitry Andric if (!FunctionalOpcode) {
847b1c73532SDimitry Andric ScalarIntrID = VPI.getFunctionalIntrinsicID();
848b1c73532SDimitry Andric if (!ScalarIntrID)
849b1c73532SDimitry Andric return false;
850b1c73532SDimitry Andric }
851b1c73532SDimitry Andric
852b1c73532SDimitry Andric // Calculate cost of scalarizing
853b1c73532SDimitry Andric InstructionCost ScalarOpCost = 0;
854b1c73532SDimitry Andric if (ScalarIntrID) {
855b1c73532SDimitry Andric IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
856b1c73532SDimitry Andric ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
857b1c73532SDimitry Andric } else {
858b1c73532SDimitry Andric ScalarOpCost =
859b1c73532SDimitry Andric TTI.getArithmeticInstrCost(*FunctionalOpcode, VecTy->getScalarType());
860b1c73532SDimitry Andric }
861b1c73532SDimitry Andric
862b1c73532SDimitry Andric // The existing splats may be kept around if other instructions use them.
863b1c73532SDimitry Andric InstructionCost CostToKeepSplats =
864b1c73532SDimitry Andric (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
865b1c73532SDimitry Andric InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
866b1c73532SDimitry Andric
867b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
868b1c73532SDimitry Andric << "\n");
869b1c73532SDimitry Andric LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
870b1c73532SDimitry Andric << ", Cost of scalarizing:" << NewCost << "\n");
871b1c73532SDimitry Andric
872b1c73532SDimitry Andric // We want to scalarize unless the vector variant actually has lower cost.
873b1c73532SDimitry Andric if (OldCost < NewCost || !NewCost.isValid())
874b1c73532SDimitry Andric return false;
875b1c73532SDimitry Andric
876b1c73532SDimitry Andric // Scalarize the intrinsic
877b1c73532SDimitry Andric ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount();
878b1c73532SDimitry Andric Value *EVL = VPI.getArgOperand(3);
879b1c73532SDimitry Andric
880b1c73532SDimitry Andric // If the VP op might introduce UB or poison, we can scalarize it provided
881b1c73532SDimitry Andric // that we know the EVL > 0: If the EVL is zero, then the original VP op
882b1c73532SDimitry Andric // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by
883b1c73532SDimitry Andric // scalarizing it.
884b1c73532SDimitry Andric bool SafeToSpeculate;
885b1c73532SDimitry Andric if (ScalarIntrID)
886b1c73532SDimitry Andric SafeToSpeculate = Intrinsic::getAttributes(I.getContext(), *ScalarIntrID)
887b1c73532SDimitry Andric .hasFnAttr(Attribute::AttrKind::Speculatable);
888b1c73532SDimitry Andric else
889b1c73532SDimitry Andric SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode(
890b1c73532SDimitry Andric *FunctionalOpcode, &VPI, nullptr, &AC, &DT);
891ac9a064cSDimitry Andric if (!SafeToSpeculate &&
892ac9a064cSDimitry Andric !isKnownNonZero(EVL, SimplifyQuery(*DL, &DT, &AC, &VPI)))
893b1c73532SDimitry Andric return false;
894b1c73532SDimitry Andric
895b1c73532SDimitry Andric Value *ScalarVal =
896b1c73532SDimitry Andric ScalarIntrID
897b1c73532SDimitry Andric ? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID,
898b1c73532SDimitry Andric {ScalarOp0, ScalarOp1})
899b1c73532SDimitry Andric : Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode),
900b1c73532SDimitry Andric ScalarOp0, ScalarOp1);
901b1c73532SDimitry Andric
902b1c73532SDimitry Andric replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal));
903b1c73532SDimitry Andric return true;
904b1c73532SDimitry Andric }
905b1c73532SDimitry Andric
906cfca06d7SDimitry Andric /// Match a vector binop or compare instruction with at least one inserted
907cfca06d7SDimitry Andric /// scalar operand and convert to scalar binop/cmp followed by insertelement.
scalarizeBinopOrCmp(Instruction & I)908cfca06d7SDimitry Andric bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
909cfca06d7SDimitry Andric CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
910cfca06d7SDimitry Andric Value *Ins0, *Ins1;
911cfca06d7SDimitry Andric if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
912cfca06d7SDimitry Andric !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1))))
913cfca06d7SDimitry Andric return false;
914cfca06d7SDimitry Andric
915cfca06d7SDimitry Andric // Do not convert the vector condition of a vector select into a scalar
916cfca06d7SDimitry Andric // condition. That may cause problems for codegen because of differences in
917cfca06d7SDimitry Andric // boolean formats and register-file transfers.
918cfca06d7SDimitry Andric // TODO: Can we account for that in the cost model?
919cfca06d7SDimitry Andric bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE;
920cfca06d7SDimitry Andric if (IsCmp)
921cfca06d7SDimitry Andric for (User *U : I.users())
922cfca06d7SDimitry Andric if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
923cfca06d7SDimitry Andric return false;
924cfca06d7SDimitry Andric
925cfca06d7SDimitry Andric // Match against one or both scalar values being inserted into constant
926cfca06d7SDimitry Andric // vectors:
927cfca06d7SDimitry Andric // vec_op VecC0, (inselt VecC1, V1, Index)
928cfca06d7SDimitry Andric // vec_op (inselt VecC0, V0, Index), VecC1
929cfca06d7SDimitry Andric // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
930cfca06d7SDimitry Andric // TODO: Deal with mismatched index constants and variable indexes?
931cfca06d7SDimitry Andric Constant *VecC0 = nullptr, *VecC1 = nullptr;
932cfca06d7SDimitry Andric Value *V0 = nullptr, *V1 = nullptr;
933cfca06d7SDimitry Andric uint64_t Index0 = 0, Index1 = 0;
934cfca06d7SDimitry Andric if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
935cfca06d7SDimitry Andric m_ConstantInt(Index0))) &&
936cfca06d7SDimitry Andric !match(Ins0, m_Constant(VecC0)))
937cfca06d7SDimitry Andric return false;
938cfca06d7SDimitry Andric if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
939cfca06d7SDimitry Andric m_ConstantInt(Index1))) &&
940cfca06d7SDimitry Andric !match(Ins1, m_Constant(VecC1)))
941cfca06d7SDimitry Andric return false;
942cfca06d7SDimitry Andric
943cfca06d7SDimitry Andric bool IsConst0 = !V0;
944cfca06d7SDimitry Andric bool IsConst1 = !V1;
945cfca06d7SDimitry Andric if (IsConst0 && IsConst1)
946cfca06d7SDimitry Andric return false;
947cfca06d7SDimitry Andric if (!IsConst0 && !IsConst1 && Index0 != Index1)
948cfca06d7SDimitry Andric return false;
949cfca06d7SDimitry Andric
950cfca06d7SDimitry Andric // Bail for single insertion if it is a load.
951cfca06d7SDimitry Andric // TODO: Handle this once getVectorInstrCost can cost for load/stores.
952cfca06d7SDimitry Andric auto *I0 = dyn_cast_or_null<Instruction>(V0);
953cfca06d7SDimitry Andric auto *I1 = dyn_cast_or_null<Instruction>(V1);
954cfca06d7SDimitry Andric if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
955cfca06d7SDimitry Andric (IsConst1 && I0 && I0->mayReadFromMemory()))
956cfca06d7SDimitry Andric return false;
957cfca06d7SDimitry Andric
958cfca06d7SDimitry Andric uint64_t Index = IsConst0 ? Index1 : Index0;
959cfca06d7SDimitry Andric Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
960cfca06d7SDimitry Andric Type *VecTy = I.getType();
961cfca06d7SDimitry Andric assert(VecTy->isVectorTy() &&
962cfca06d7SDimitry Andric (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
963cfca06d7SDimitry Andric (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
964cfca06d7SDimitry Andric ScalarTy->isPointerTy()) &&
965cfca06d7SDimitry Andric "Unexpected types for insert element into binop or cmp");
966cfca06d7SDimitry Andric
967cfca06d7SDimitry Andric unsigned Opcode = I.getOpcode();
968b60736ecSDimitry Andric InstructionCost ScalarOpCost, VectorOpCost;
969cfca06d7SDimitry Andric if (IsCmp) {
970c0981da4SDimitry Andric CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
971c0981da4SDimitry Andric ScalarOpCost = TTI.getCmpSelInstrCost(
972c0981da4SDimitry Andric Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred);
973c0981da4SDimitry Andric VectorOpCost = TTI.getCmpSelInstrCost(
974c0981da4SDimitry Andric Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred);
975cfca06d7SDimitry Andric } else {
976cfca06d7SDimitry Andric ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
977cfca06d7SDimitry Andric VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
978cfca06d7SDimitry Andric }
979cfca06d7SDimitry Andric
980cfca06d7SDimitry Andric // Get cost estimate for the insert element. This cost will factor into
981cfca06d7SDimitry Andric // both sequences.
982e3b55780SDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
983e3b55780SDimitry Andric InstructionCost InsertCost = TTI.getVectorInstrCost(
984e3b55780SDimitry Andric Instruction::InsertElement, VecTy, CostKind, Index);
985b60736ecSDimitry Andric InstructionCost OldCost =
986b60736ecSDimitry Andric (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost;
987b60736ecSDimitry Andric InstructionCost NewCost = ScalarOpCost + InsertCost +
988cfca06d7SDimitry Andric (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
989cfca06d7SDimitry Andric (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
990cfca06d7SDimitry Andric
991cfca06d7SDimitry Andric // We want to scalarize unless the vector variant actually has lower cost.
992b60736ecSDimitry Andric if (OldCost < NewCost || !NewCost.isValid())
993cfca06d7SDimitry Andric return false;
994cfca06d7SDimitry Andric
995cfca06d7SDimitry Andric // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
996cfca06d7SDimitry Andric // inselt NewVecC, (scalar_op V0, V1), Index
997cfca06d7SDimitry Andric if (IsCmp)
998cfca06d7SDimitry Andric ++NumScalarCmp;
999cfca06d7SDimitry Andric else
1000cfca06d7SDimitry Andric ++NumScalarBO;
1001cfca06d7SDimitry Andric
1002cfca06d7SDimitry Andric // For constant cases, extract the scalar element, this should constant fold.
1003cfca06d7SDimitry Andric if (IsConst0)
1004cfca06d7SDimitry Andric V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
1005cfca06d7SDimitry Andric if (IsConst1)
1006cfca06d7SDimitry Andric V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
1007cfca06d7SDimitry Andric
1008cfca06d7SDimitry Andric Value *Scalar =
1009cfca06d7SDimitry Andric IsCmp ? Builder.CreateCmp(Pred, V0, V1)
1010cfca06d7SDimitry Andric : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
1011cfca06d7SDimitry Andric
1012cfca06d7SDimitry Andric Scalar->setName(I.getName() + ".scalar");
1013cfca06d7SDimitry Andric
1014cfca06d7SDimitry Andric // All IR flags are safe to back-propagate. There is no potential for extra
1015cfca06d7SDimitry Andric // poison to be created by the scalar instruction.
1016cfca06d7SDimitry Andric if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
1017cfca06d7SDimitry Andric ScalarInst->copyIRFlags(&I);
1018cfca06d7SDimitry Andric
1019cfca06d7SDimitry Andric // Fold the vector constants in the original vectors into a new base vector.
1020145449b1SDimitry Andric Value *NewVecC =
1021145449b1SDimitry Andric IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1)
1022145449b1SDimitry Andric : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1);
1023cfca06d7SDimitry Andric Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
1024cfca06d7SDimitry Andric replaceValue(I, *Insert);
1025cfca06d7SDimitry Andric return true;
1026cfca06d7SDimitry Andric }
1027cfca06d7SDimitry Andric
1028cfca06d7SDimitry Andric /// Try to combine a scalar binop + 2 scalar compares of extracted elements of
1029cfca06d7SDimitry Andric /// a vector into vector operations followed by extract. Note: The SLP pass
1030cfca06d7SDimitry Andric /// may miss this pattern because of implementation problems.
foldExtractedCmps(Instruction & I)1031cfca06d7SDimitry Andric bool VectorCombine::foldExtractedCmps(Instruction &I) {
1032cfca06d7SDimitry Andric // We are looking for a scalar binop of booleans.
1033cfca06d7SDimitry Andric // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
1034cfca06d7SDimitry Andric if (!I.isBinaryOp() || !I.getType()->isIntegerTy(1))
1035cfca06d7SDimitry Andric return false;
1036cfca06d7SDimitry Andric
1037cfca06d7SDimitry Andric // The compare predicates should match, and each compare should have a
1038cfca06d7SDimitry Andric // constant operand.
1039cfca06d7SDimitry Andric // TODO: Relax the one-use constraints.
1040cfca06d7SDimitry Andric Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
1041cfca06d7SDimitry Andric Instruction *I0, *I1;
1042cfca06d7SDimitry Andric Constant *C0, *C1;
1043cfca06d7SDimitry Andric CmpInst::Predicate P0, P1;
1044cfca06d7SDimitry Andric if (!match(B0, m_OneUse(m_Cmp(P0, m_Instruction(I0), m_Constant(C0)))) ||
1045cfca06d7SDimitry Andric !match(B1, m_OneUse(m_Cmp(P1, m_Instruction(I1), m_Constant(C1)))) ||
1046cfca06d7SDimitry Andric P0 != P1)
1047cfca06d7SDimitry Andric return false;
1048cfca06d7SDimitry Andric
1049cfca06d7SDimitry Andric // The compare operands must be extracts of the same vector with constant
1050cfca06d7SDimitry Andric // extract indexes.
1051cfca06d7SDimitry Andric // TODO: Relax the one-use constraints.
1052cfca06d7SDimitry Andric Value *X;
1053cfca06d7SDimitry Andric uint64_t Index0, Index1;
1054cfca06d7SDimitry Andric if (!match(I0, m_OneUse(m_ExtractElt(m_Value(X), m_ConstantInt(Index0)))) ||
1055cfca06d7SDimitry Andric !match(I1, m_OneUse(m_ExtractElt(m_Specific(X), m_ConstantInt(Index1)))))
1056cfca06d7SDimitry Andric return false;
1057cfca06d7SDimitry Andric
1058cfca06d7SDimitry Andric auto *Ext0 = cast<ExtractElementInst>(I0);
1059cfca06d7SDimitry Andric auto *Ext1 = cast<ExtractElementInst>(I1);
1060cfca06d7SDimitry Andric ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1);
1061cfca06d7SDimitry Andric if (!ConvertToShuf)
1062cfca06d7SDimitry Andric return false;
1063cfca06d7SDimitry Andric
1064cfca06d7SDimitry Andric // The original scalar pattern is:
1065cfca06d7SDimitry Andric // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
1066cfca06d7SDimitry Andric CmpInst::Predicate Pred = P0;
1067cfca06d7SDimitry Andric unsigned CmpOpcode = CmpInst::isFPPredicate(Pred) ? Instruction::FCmp
1068cfca06d7SDimitry Andric : Instruction::ICmp;
1069cfca06d7SDimitry Andric auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
1070cfca06d7SDimitry Andric if (!VecTy)
1071cfca06d7SDimitry Andric return false;
1072cfca06d7SDimitry Andric
1073e3b55780SDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1074b60736ecSDimitry Andric InstructionCost OldCost =
1075e3b55780SDimitry Andric TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
1076e3b55780SDimitry Andric OldCost += TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
1077c0981da4SDimitry Andric OldCost +=
1078c0981da4SDimitry Andric TTI.getCmpSelInstrCost(CmpOpcode, I0->getType(),
1079c0981da4SDimitry Andric CmpInst::makeCmpResultType(I0->getType()), Pred) *
1080c0981da4SDimitry Andric 2;
1081cfca06d7SDimitry Andric OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType());
1082cfca06d7SDimitry Andric
1083cfca06d7SDimitry Andric // The proposed vector pattern is:
1084cfca06d7SDimitry Andric // vcmp = cmp Pred X, VecC
1085cfca06d7SDimitry Andric // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
1086cfca06d7SDimitry Andric int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
1087cfca06d7SDimitry Andric int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
1088cfca06d7SDimitry Andric auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType()));
1089c0981da4SDimitry Andric InstructionCost NewCost = TTI.getCmpSelInstrCost(
1090c0981da4SDimitry Andric CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred);
10917fa27ce4SDimitry Andric SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
1092344a3780SDimitry Andric ShufMask[CheapIndex] = ExpensiveIndex;
1093344a3780SDimitry Andric NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy,
1094344a3780SDimitry Andric ShufMask);
1095cfca06d7SDimitry Andric NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy);
1096e3b55780SDimitry Andric NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex);
1097cfca06d7SDimitry Andric
1098cfca06d7SDimitry Andric // Aggressively form vector ops if the cost is equal because the transform
1099cfca06d7SDimitry Andric // may enable further optimization.
1100cfca06d7SDimitry Andric // Codegen can reverse this transform (scalarize) if it was not profitable.
1101b60736ecSDimitry Andric if (OldCost < NewCost || !NewCost.isValid())
1102cfca06d7SDimitry Andric return false;
1103cfca06d7SDimitry Andric
1104cfca06d7SDimitry Andric // Create a vector constant from the 2 scalar constants.
1105cfca06d7SDimitry Andric SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
11067fa27ce4SDimitry Andric PoisonValue::get(VecTy->getElementType()));
1107cfca06d7SDimitry Andric CmpC[Index0] = C0;
1108cfca06d7SDimitry Andric CmpC[Index1] = C1;
1109cfca06d7SDimitry Andric Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
1110cfca06d7SDimitry Andric
1111cfca06d7SDimitry Andric Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
1112cfca06d7SDimitry Andric Value *VecLogic = Builder.CreateBinOp(cast<BinaryOperator>(I).getOpcode(),
1113cfca06d7SDimitry Andric VCmp, Shuf);
1114cfca06d7SDimitry Andric Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
1115cfca06d7SDimitry Andric replaceValue(I, *NewExt);
1116cfca06d7SDimitry Andric ++NumVecCmpBO;
1117cfca06d7SDimitry Andric return true;
1118cfca06d7SDimitry Andric }
1119cfca06d7SDimitry Andric
1120344a3780SDimitry Andric // Check if memory loc modified between two instrs in the same BB
isMemModifiedBetween(BasicBlock::iterator Begin,BasicBlock::iterator End,const MemoryLocation & Loc,AAResults & AA)1121344a3780SDimitry Andric static bool isMemModifiedBetween(BasicBlock::iterator Begin,
1122344a3780SDimitry Andric BasicBlock::iterator End,
1123344a3780SDimitry Andric const MemoryLocation &Loc, AAResults &AA) {
1124344a3780SDimitry Andric unsigned NumScanned = 0;
1125344a3780SDimitry Andric return std::any_of(Begin, End, [&](const Instruction &Instr) {
1126344a3780SDimitry Andric return isModSet(AA.getModRefInfo(&Instr, Loc)) ||
1127344a3780SDimitry Andric ++NumScanned > MaxInstrsToScan;
1128344a3780SDimitry Andric });
1129344a3780SDimitry Andric }
1130344a3780SDimitry Andric
1131e3b55780SDimitry Andric namespace {
1132c0981da4SDimitry Andric /// Helper class to indicate whether a vector index can be safely scalarized and
1133c0981da4SDimitry Andric /// if a freeze needs to be inserted.
1134c0981da4SDimitry Andric class ScalarizationResult {
1135c0981da4SDimitry Andric enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
1136c0981da4SDimitry Andric
1137c0981da4SDimitry Andric StatusTy Status;
1138c0981da4SDimitry Andric Value *ToFreeze;
1139c0981da4SDimitry Andric
ScalarizationResult(StatusTy Status,Value * ToFreeze=nullptr)1140c0981da4SDimitry Andric ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
1141c0981da4SDimitry Andric : Status(Status), ToFreeze(ToFreeze) {}
1142c0981da4SDimitry Andric
1143c0981da4SDimitry Andric public:
1144c0981da4SDimitry Andric ScalarizationResult(const ScalarizationResult &Other) = default;
~ScalarizationResult()1145c0981da4SDimitry Andric ~ScalarizationResult() {
1146c0981da4SDimitry Andric assert(!ToFreeze && "freeze() not called with ToFreeze being set");
1147c0981da4SDimitry Andric }
1148c0981da4SDimitry Andric
unsafe()1149c0981da4SDimitry Andric static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
safe()1150c0981da4SDimitry Andric static ScalarizationResult safe() { return {StatusTy::Safe}; }
safeWithFreeze(Value * ToFreeze)1151c0981da4SDimitry Andric static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
1152c0981da4SDimitry Andric return {StatusTy::SafeWithFreeze, ToFreeze};
1153c0981da4SDimitry Andric }
1154c0981da4SDimitry Andric
1155c0981da4SDimitry Andric /// Returns true if the index can be scalarize without requiring a freeze.
isSafe() const1156c0981da4SDimitry Andric bool isSafe() const { return Status == StatusTy::Safe; }
1157c0981da4SDimitry Andric /// Returns true if the index cannot be scalarized.
isUnsafe() const1158c0981da4SDimitry Andric bool isUnsafe() const { return Status == StatusTy::Unsafe; }
1159c0981da4SDimitry Andric /// Returns true if the index can be scalarize, but requires inserting a
1160c0981da4SDimitry Andric /// freeze.
isSafeWithFreeze() const1161c0981da4SDimitry Andric bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
1162c0981da4SDimitry Andric
1163c0981da4SDimitry Andric /// Reset the state of Unsafe and clear ToFreze if set.
discard()1164c0981da4SDimitry Andric void discard() {
1165c0981da4SDimitry Andric ToFreeze = nullptr;
1166c0981da4SDimitry Andric Status = StatusTy::Unsafe;
1167c0981da4SDimitry Andric }
1168c0981da4SDimitry Andric
1169c0981da4SDimitry Andric /// Freeze the ToFreeze and update the use in \p User to use it.
freeze(IRBuilder<> & Builder,Instruction & UserI)1170c0981da4SDimitry Andric void freeze(IRBuilder<> &Builder, Instruction &UserI) {
1171c0981da4SDimitry Andric assert(isSafeWithFreeze() &&
1172c0981da4SDimitry Andric "should only be used when freezing is required");
1173c0981da4SDimitry Andric assert(is_contained(ToFreeze->users(), &UserI) &&
1174c0981da4SDimitry Andric "UserI must be a user of ToFreeze");
1175c0981da4SDimitry Andric IRBuilder<>::InsertPointGuard Guard(Builder);
1176c0981da4SDimitry Andric Builder.SetInsertPoint(cast<Instruction>(&UserI));
1177c0981da4SDimitry Andric Value *Frozen =
1178c0981da4SDimitry Andric Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen");
1179c0981da4SDimitry Andric for (Use &U : make_early_inc_range((UserI.operands())))
1180c0981da4SDimitry Andric if (U.get() == ToFreeze)
1181c0981da4SDimitry Andric U.set(Frozen);
1182c0981da4SDimitry Andric
1183c0981da4SDimitry Andric ToFreeze = nullptr;
1184c0981da4SDimitry Andric }
1185c0981da4SDimitry Andric };
1186e3b55780SDimitry Andric } // namespace
1187c0981da4SDimitry Andric
1188344a3780SDimitry Andric /// Check if it is legal to scalarize a memory access to \p VecTy at index \p
1189344a3780SDimitry Andric /// Idx. \p Idx must access a valid vector element.
canScalarizeAccess(VectorType * VecTy,Value * Idx,Instruction * CtxI,AssumptionCache & AC,const DominatorTree & DT)1190b1c73532SDimitry Andric static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx,
1191b1c73532SDimitry Andric Instruction *CtxI,
1192c0981da4SDimitry Andric AssumptionCache &AC,
1193c0981da4SDimitry Andric const DominatorTree &DT) {
1194b1c73532SDimitry Andric // We do checks for both fixed vector types and scalable vector types.
1195b1c73532SDimitry Andric // This is the number of elements of fixed vector types,
1196b1c73532SDimitry Andric // or the minimum number of elements of scalable vector types.
1197b1c73532SDimitry Andric uint64_t NumElements = VecTy->getElementCount().getKnownMinValue();
1198b1c73532SDimitry Andric
1199c0981da4SDimitry Andric if (auto *C = dyn_cast<ConstantInt>(Idx)) {
1200b1c73532SDimitry Andric if (C->getValue().ult(NumElements))
1201c0981da4SDimitry Andric return ScalarizationResult::safe();
1202c0981da4SDimitry Andric return ScalarizationResult::unsafe();
1203c0981da4SDimitry Andric }
1204344a3780SDimitry Andric
1205c0981da4SDimitry Andric unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
1206c0981da4SDimitry Andric APInt Zero(IntWidth, 0);
1207b1c73532SDimitry Andric APInt MaxElts(IntWidth, NumElements);
1208344a3780SDimitry Andric ConstantRange ValidIndices(Zero, MaxElts);
1209c0981da4SDimitry Andric ConstantRange IdxRange(IntWidth, true);
1210c0981da4SDimitry Andric
1211c0981da4SDimitry Andric if (isGuaranteedNotToBePoison(Idx, &AC)) {
12126f8fc217SDimitry Andric if (ValidIndices.contains(computeConstantRange(Idx, /* ForSigned */ false,
12136f8fc217SDimitry Andric true, &AC, CtxI, &DT)))
1214c0981da4SDimitry Andric return ScalarizationResult::safe();
1215c0981da4SDimitry Andric return ScalarizationResult::unsafe();
1216c0981da4SDimitry Andric }
1217c0981da4SDimitry Andric
1218c0981da4SDimitry Andric // If the index may be poison, check if we can insert a freeze before the
1219c0981da4SDimitry Andric // range of the index is restricted.
1220c0981da4SDimitry Andric Value *IdxBase;
1221c0981da4SDimitry Andric ConstantInt *CI;
1222c0981da4SDimitry Andric if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) {
1223c0981da4SDimitry Andric IdxRange = IdxRange.binaryAnd(CI->getValue());
1224c0981da4SDimitry Andric } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) {
1225c0981da4SDimitry Andric IdxRange = IdxRange.urem(CI->getValue());
1226c0981da4SDimitry Andric }
1227c0981da4SDimitry Andric
1228c0981da4SDimitry Andric if (ValidIndices.contains(IdxRange))
1229c0981da4SDimitry Andric return ScalarizationResult::safeWithFreeze(IdxBase);
1230c0981da4SDimitry Andric return ScalarizationResult::unsafe();
1231344a3780SDimitry Andric }
1232344a3780SDimitry Andric
1233344a3780SDimitry Andric /// The memory operation on a vector of \p ScalarType had alignment of
1234344a3780SDimitry Andric /// \p VectorAlignment. Compute the maximal, but conservatively correct,
1235344a3780SDimitry Andric /// alignment that will be valid for the memory operation on a single scalar
1236344a3780SDimitry Andric /// element of the same type with index \p Idx.
computeAlignmentAfterScalarization(Align VectorAlignment,Type * ScalarType,Value * Idx,const DataLayout & DL)1237344a3780SDimitry Andric static Align computeAlignmentAfterScalarization(Align VectorAlignment,
1238344a3780SDimitry Andric Type *ScalarType, Value *Idx,
1239344a3780SDimitry Andric const DataLayout &DL) {
1240344a3780SDimitry Andric if (auto *C = dyn_cast<ConstantInt>(Idx))
1241344a3780SDimitry Andric return commonAlignment(VectorAlignment,
1242344a3780SDimitry Andric C->getZExtValue() * DL.getTypeStoreSize(ScalarType));
1243344a3780SDimitry Andric return commonAlignment(VectorAlignment, DL.getTypeStoreSize(ScalarType));
1244344a3780SDimitry Andric }
1245344a3780SDimitry Andric
1246344a3780SDimitry Andric // Combine patterns like:
1247344a3780SDimitry Andric // %0 = load <4 x i32>, <4 x i32>* %a
1248344a3780SDimitry Andric // %1 = insertelement <4 x i32> %0, i32 %b, i32 1
1249344a3780SDimitry Andric // store <4 x i32> %1, <4 x i32>* %a
1250344a3780SDimitry Andric // to:
1251344a3780SDimitry Andric // %0 = bitcast <4 x i32>* %a to i32*
1252344a3780SDimitry Andric // %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
1253344a3780SDimitry Andric // store i32 %b, i32* %1
foldSingleElementStore(Instruction & I)1254344a3780SDimitry Andric bool VectorCombine::foldSingleElementStore(Instruction &I) {
1255e3b55780SDimitry Andric auto *SI = cast<StoreInst>(&I);
1256b1c73532SDimitry Andric if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType()))
1257344a3780SDimitry Andric return false;
1258344a3780SDimitry Andric
1259344a3780SDimitry Andric // TODO: Combine more complicated patterns (multiple insert) by referencing
1260344a3780SDimitry Andric // TargetTransformInfo.
1261344a3780SDimitry Andric Instruction *Source;
1262344a3780SDimitry Andric Value *NewElement;
1263344a3780SDimitry Andric Value *Idx;
1264344a3780SDimitry Andric if (!match(SI->getValueOperand(),
1265344a3780SDimitry Andric m_InsertElt(m_Instruction(Source), m_Value(NewElement),
1266344a3780SDimitry Andric m_Value(Idx))))
1267344a3780SDimitry Andric return false;
1268344a3780SDimitry Andric
1269344a3780SDimitry Andric if (auto *Load = dyn_cast<LoadInst>(Source)) {
1270b1c73532SDimitry Andric auto VecTy = cast<VectorType>(SI->getValueOperand()->getType());
1271344a3780SDimitry Andric Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
1272344a3780SDimitry Andric // Don't optimize for atomic/volatile load or store. Ensure memory is not
1273344a3780SDimitry Andric // modified between, vector type matches store size, and index is inbounds.
1274344a3780SDimitry Andric if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
1275ac9a064cSDimitry Andric !DL->typeSizeEqualsStoreSize(Load->getType()->getScalarType()) ||
1276c0981da4SDimitry Andric SrcAddr != SI->getPointerOperand()->stripPointerCasts())
1277c0981da4SDimitry Andric return false;
1278c0981da4SDimitry Andric
1279c0981da4SDimitry Andric auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT);
1280c0981da4SDimitry Andric if (ScalarizableIdx.isUnsafe() ||
1281344a3780SDimitry Andric isMemModifiedBetween(Load->getIterator(), SI->getIterator(),
1282344a3780SDimitry Andric MemoryLocation::get(SI), AA))
1283344a3780SDimitry Andric return false;
1284344a3780SDimitry Andric
1285c0981da4SDimitry Andric if (ScalarizableIdx.isSafeWithFreeze())
1286c0981da4SDimitry Andric ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx));
1287344a3780SDimitry Andric Value *GEP = Builder.CreateInBoundsGEP(
1288344a3780SDimitry Andric SI->getValueOperand()->getType(), SI->getPointerOperand(),
1289344a3780SDimitry Andric {ConstantInt::get(Idx->getType(), 0), Idx});
1290344a3780SDimitry Andric StoreInst *NSI = Builder.CreateStore(NewElement, GEP);
1291344a3780SDimitry Andric NSI->copyMetadata(*SI);
1292344a3780SDimitry Andric Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1293344a3780SDimitry Andric std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx,
1294ac9a064cSDimitry Andric *DL);
1295344a3780SDimitry Andric NSI->setAlignment(ScalarOpAlignment);
1296344a3780SDimitry Andric replaceValue(I, *NSI);
1297c0981da4SDimitry Andric eraseInstruction(I);
1298344a3780SDimitry Andric return true;
1299344a3780SDimitry Andric }
1300344a3780SDimitry Andric
1301344a3780SDimitry Andric return false;
1302344a3780SDimitry Andric }
1303344a3780SDimitry Andric
1304344a3780SDimitry Andric /// Try to scalarize vector loads feeding extractelement instructions.
scalarizeLoadExtract(Instruction & I)1305344a3780SDimitry Andric bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
1306344a3780SDimitry Andric Value *Ptr;
1307c0981da4SDimitry Andric if (!match(&I, m_Load(m_Value(Ptr))))
1308344a3780SDimitry Andric return false;
1309344a3780SDimitry Andric
1310b1c73532SDimitry Andric auto *VecTy = cast<VectorType>(I.getType());
1311c0981da4SDimitry Andric auto *LI = cast<LoadInst>(&I);
1312ac9a064cSDimitry Andric if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
1313344a3780SDimitry Andric return false;
1314344a3780SDimitry Andric
131577fc4c14SDimitry Andric InstructionCost OriginalCost =
1316b1c73532SDimitry Andric TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
1317344a3780SDimitry Andric LI->getPointerAddressSpace());
1318344a3780SDimitry Andric InstructionCost ScalarizedCost = 0;
1319344a3780SDimitry Andric
1320344a3780SDimitry Andric Instruction *LastCheckedInst = LI;
1321344a3780SDimitry Andric unsigned NumInstChecked = 0;
1322b1c73532SDimitry Andric DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1323b1c73532SDimitry Andric auto FailureGuard = make_scope_exit([&]() {
1324b1c73532SDimitry Andric // If the transform is aborted, discard the ScalarizationResults.
1325b1c73532SDimitry Andric for (auto &Pair : NeedFreeze)
1326b1c73532SDimitry Andric Pair.second.discard();
1327b1c73532SDimitry Andric });
1328b1c73532SDimitry Andric
1329344a3780SDimitry Andric // Check if all users of the load are extracts with no memory modifications
1330344a3780SDimitry Andric // between the load and the extract. Compute the cost of both the original
1331344a3780SDimitry Andric // code and the scalarized version.
1332344a3780SDimitry Andric for (User *U : LI->users()) {
1333344a3780SDimitry Andric auto *UI = dyn_cast<ExtractElementInst>(U);
1334344a3780SDimitry Andric if (!UI || UI->getParent() != LI->getParent())
1335344a3780SDimitry Andric return false;
1336344a3780SDimitry Andric
1337344a3780SDimitry Andric // Check if any instruction between the load and the extract may modify
1338344a3780SDimitry Andric // memory.
1339344a3780SDimitry Andric if (LastCheckedInst->comesBefore(UI)) {
1340344a3780SDimitry Andric for (Instruction &I :
1341344a3780SDimitry Andric make_range(std::next(LI->getIterator()), UI->getIterator())) {
1342344a3780SDimitry Andric // Bail out if we reached the check limit or the instruction may write
1343344a3780SDimitry Andric // to memory.
1344344a3780SDimitry Andric if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
1345344a3780SDimitry Andric return false;
1346344a3780SDimitry Andric NumInstChecked++;
1347344a3780SDimitry Andric }
1348145449b1SDimitry Andric LastCheckedInst = UI;
1349344a3780SDimitry Andric }
1350344a3780SDimitry Andric
1351b1c73532SDimitry Andric auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
1352b1c73532SDimitry Andric if (ScalarIdx.isUnsafe())
1353344a3780SDimitry Andric return false;
1354b1c73532SDimitry Andric if (ScalarIdx.isSafeWithFreeze()) {
1355b1c73532SDimitry Andric NeedFreeze.try_emplace(UI, ScalarIdx);
1356b1c73532SDimitry Andric ScalarIdx.discard();
1357c0981da4SDimitry Andric }
1358344a3780SDimitry Andric
1359344a3780SDimitry Andric auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
1360e3b55780SDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1361344a3780SDimitry Andric OriginalCost +=
1362b1c73532SDimitry Andric TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
1363344a3780SDimitry Andric Index ? Index->getZExtValue() : -1);
1364344a3780SDimitry Andric ScalarizedCost +=
1365b1c73532SDimitry Andric TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
1366344a3780SDimitry Andric Align(1), LI->getPointerAddressSpace());
1367b1c73532SDimitry Andric ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
1368344a3780SDimitry Andric }
1369344a3780SDimitry Andric
1370344a3780SDimitry Andric if (ScalarizedCost >= OriginalCost)
1371344a3780SDimitry Andric return false;
1372344a3780SDimitry Andric
1373344a3780SDimitry Andric // Replace extracts with narrow scalar loads.
1374344a3780SDimitry Andric for (User *U : LI->users()) {
1375344a3780SDimitry Andric auto *EI = cast<ExtractElementInst>(U);
1376344a3780SDimitry Andric Value *Idx = EI->getOperand(1);
1377b1c73532SDimitry Andric
1378b1c73532SDimitry Andric // Insert 'freeze' for poison indexes.
1379b1c73532SDimitry Andric auto It = NeedFreeze.find(EI);
1380b1c73532SDimitry Andric if (It != NeedFreeze.end())
1381b1c73532SDimitry Andric It->second.freeze(Builder, *cast<Instruction>(Idx));
1382b1c73532SDimitry Andric
1383b1c73532SDimitry Andric Builder.SetInsertPoint(EI);
1384344a3780SDimitry Andric Value *GEP =
1385b1c73532SDimitry Andric Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
1386344a3780SDimitry Andric auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
1387b1c73532SDimitry Andric VecTy->getElementType(), GEP, EI->getName() + ".scalar"));
1388344a3780SDimitry Andric
1389344a3780SDimitry Andric Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1390ac9a064cSDimitry Andric LI->getAlign(), VecTy->getElementType(), Idx, *DL);
1391344a3780SDimitry Andric NewLoad->setAlignment(ScalarOpAlignment);
1392344a3780SDimitry Andric
1393344a3780SDimitry Andric replaceValue(*EI, *NewLoad);
1394344a3780SDimitry Andric }
1395344a3780SDimitry Andric
1396b1c73532SDimitry Andric FailureGuard.release();
1397344a3780SDimitry Andric return true;
1398344a3780SDimitry Andric }
1399344a3780SDimitry Andric
1400ac9a064cSDimitry Andric /// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
foldShuffleOfBinops(Instruction & I)1401c0981da4SDimitry Andric bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
1402c0981da4SDimitry Andric BinaryOperator *B0, *B1;
1403ac9a064cSDimitry Andric ArrayRef<int> OldMask;
1404c0981da4SDimitry Andric if (!match(&I, m_Shuffle(m_OneUse(m_BinOp(B0)), m_OneUse(m_BinOp(B1)),
1405ac9a064cSDimitry Andric m_Mask(OldMask))))
1406c0981da4SDimitry Andric return false;
1407c0981da4SDimitry Andric
1408ac9a064cSDimitry Andric // Don't introduce poison into div/rem.
1409ac9a064cSDimitry Andric if (any_of(OldMask, [](int M) { return M == PoisonMaskElem; }) &&
1410ac9a064cSDimitry Andric B0->isIntDivRem())
1411c0981da4SDimitry Andric return false;
1412c0981da4SDimitry Andric
1413ac9a064cSDimitry Andric // TODO: Add support for addlike etc.
1414ac9a064cSDimitry Andric Instruction::BinaryOps Opcode = B0->getOpcode();
1415ac9a064cSDimitry Andric if (Opcode != B1->getOpcode())
1416ac9a064cSDimitry Andric return false;
1417ac9a064cSDimitry Andric
1418ac9a064cSDimitry Andric auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1419ac9a064cSDimitry Andric auto *BinOpTy = dyn_cast<FixedVectorType>(B0->getType());
1420ac9a064cSDimitry Andric if (!ShuffleDstTy || !BinOpTy)
1421ac9a064cSDimitry Andric return false;
1422ac9a064cSDimitry Andric
1423ac9a064cSDimitry Andric unsigned NumSrcElts = BinOpTy->getNumElements();
1424ac9a064cSDimitry Andric
1425c0981da4SDimitry Andric // If we have something like "add X, Y" and "add Z, X", swap ops to match.
1426c0981da4SDimitry Andric Value *X = B0->getOperand(0), *Y = B0->getOperand(1);
1427c0981da4SDimitry Andric Value *Z = B1->getOperand(0), *W = B1->getOperand(1);
1428ac9a064cSDimitry Andric if (BinaryOperator::isCommutative(Opcode) && X != Z && Y != W &&
1429ac9a064cSDimitry Andric (X == W || Y == Z))
1430c0981da4SDimitry Andric std::swap(X, Y);
1431c0981da4SDimitry Andric
1432ac9a064cSDimitry Andric auto ConvertToUnary = [NumSrcElts](int &M) {
1433ac9a064cSDimitry Andric if (M >= (int)NumSrcElts)
1434ac9a064cSDimitry Andric M -= NumSrcElts;
1435ac9a064cSDimitry Andric };
1436ac9a064cSDimitry Andric
1437ac9a064cSDimitry Andric SmallVector<int> NewMask0(OldMask.begin(), OldMask.end());
1438ac9a064cSDimitry Andric TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc;
1439c0981da4SDimitry Andric if (X == Z) {
1440ac9a064cSDimitry Andric llvm::for_each(NewMask0, ConvertToUnary);
1441ac9a064cSDimitry Andric SK0 = TargetTransformInfo::SK_PermuteSingleSrc;
1442ac9a064cSDimitry Andric Z = PoisonValue::get(BinOpTy);
1443c0981da4SDimitry Andric }
1444c0981da4SDimitry Andric
1445ac9a064cSDimitry Andric SmallVector<int> NewMask1(OldMask.begin(), OldMask.end());
1446ac9a064cSDimitry Andric TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc;
1447ac9a064cSDimitry Andric if (Y == W) {
1448ac9a064cSDimitry Andric llvm::for_each(NewMask1, ConvertToUnary);
1449ac9a064cSDimitry Andric SK1 = TargetTransformInfo::SK_PermuteSingleSrc;
1450ac9a064cSDimitry Andric W = PoisonValue::get(BinOpTy);
1451ac9a064cSDimitry Andric }
1452ac9a064cSDimitry Andric
1453ac9a064cSDimitry Andric // Try to replace a binop with a shuffle if the shuffle is not costly.
1454ac9a064cSDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1455ac9a064cSDimitry Andric
1456ac9a064cSDimitry Andric InstructionCost OldCost =
1457ac9a064cSDimitry Andric TTI.getArithmeticInstrCost(B0->getOpcode(), BinOpTy, CostKind) +
1458ac9a064cSDimitry Andric TTI.getArithmeticInstrCost(B1->getOpcode(), BinOpTy, CostKind) +
1459ac9a064cSDimitry Andric TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, BinOpTy,
1460ac9a064cSDimitry Andric OldMask, CostKind, 0, nullptr, {B0, B1}, &I);
1461ac9a064cSDimitry Andric
1462ac9a064cSDimitry Andric InstructionCost NewCost =
1463ac9a064cSDimitry Andric TTI.getShuffleCost(SK0, BinOpTy, NewMask0, CostKind, 0, nullptr, {X, Z}) +
1464ac9a064cSDimitry Andric TTI.getShuffleCost(SK1, BinOpTy, NewMask1, CostKind, 0, nullptr, {Y, W}) +
1465ac9a064cSDimitry Andric TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind);
1466ac9a064cSDimitry Andric
1467ac9a064cSDimitry Andric LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
1468ac9a064cSDimitry Andric << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1469ac9a064cSDimitry Andric << "\n");
1470ac9a064cSDimitry Andric if (NewCost >= OldCost)
1471ac9a064cSDimitry Andric return false;
1472ac9a064cSDimitry Andric
1473ac9a064cSDimitry Andric Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0);
1474ac9a064cSDimitry Andric Value *Shuf1 = Builder.CreateShuffleVector(Y, W, NewMask1);
1475c0981da4SDimitry Andric Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1);
1476ac9a064cSDimitry Andric
1477c0981da4SDimitry Andric // Intersect flags from the old binops.
1478c0981da4SDimitry Andric if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
1479c0981da4SDimitry Andric NewInst->copyIRFlags(B0);
1480c0981da4SDimitry Andric NewInst->andIRFlags(B1);
1481c0981da4SDimitry Andric }
1482ac9a064cSDimitry Andric
1483ac9a064cSDimitry Andric Worklist.pushValue(Shuf0);
1484ac9a064cSDimitry Andric Worklist.pushValue(Shuf1);
1485c0981da4SDimitry Andric replaceValue(I, *NewBO);
1486c0981da4SDimitry Andric return true;
1487c0981da4SDimitry Andric }
1488c0981da4SDimitry Andric
1489ac9a064cSDimitry Andric /// Try to convert "shuffle (castop), (castop)" with a shared castop operand
1490ac9a064cSDimitry Andric /// into "castop (shuffle)".
foldShuffleOfCastops(Instruction & I)1491ac9a064cSDimitry Andric bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
1492ac9a064cSDimitry Andric Value *V0, *V1;
1493ac9a064cSDimitry Andric ArrayRef<int> OldMask;
1494ac9a064cSDimitry Andric if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
1495ac9a064cSDimitry Andric return false;
1496ac9a064cSDimitry Andric
1497ac9a064cSDimitry Andric auto *C0 = dyn_cast<CastInst>(V0);
1498ac9a064cSDimitry Andric auto *C1 = dyn_cast<CastInst>(V1);
1499ac9a064cSDimitry Andric if (!C0 || !C1)
1500ac9a064cSDimitry Andric return false;
1501ac9a064cSDimitry Andric
1502ac9a064cSDimitry Andric Instruction::CastOps Opcode = C0->getOpcode();
1503ac9a064cSDimitry Andric if (C0->getSrcTy() != C1->getSrcTy())
1504ac9a064cSDimitry Andric return false;
1505ac9a064cSDimitry Andric
1506ac9a064cSDimitry Andric // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
1507ac9a064cSDimitry Andric if (Opcode != C1->getOpcode()) {
1508ac9a064cSDimitry Andric if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
1509ac9a064cSDimitry Andric Opcode = Instruction::SExt;
1510ac9a064cSDimitry Andric else
1511ac9a064cSDimitry Andric return false;
1512ac9a064cSDimitry Andric }
1513ac9a064cSDimitry Andric
1514ac9a064cSDimitry Andric auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1515ac9a064cSDimitry Andric auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy());
1516ac9a064cSDimitry Andric auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy());
1517ac9a064cSDimitry Andric if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
1518ac9a064cSDimitry Andric return false;
1519ac9a064cSDimitry Andric
1520ac9a064cSDimitry Andric unsigned NumSrcElts = CastSrcTy->getNumElements();
1521ac9a064cSDimitry Andric unsigned NumDstElts = CastDstTy->getNumElements();
1522ac9a064cSDimitry Andric assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) &&
1523ac9a064cSDimitry Andric "Only bitcasts expected to alter src/dst element counts");
1524ac9a064cSDimitry Andric
1525ac9a064cSDimitry Andric // Check for bitcasting of unscalable vector types.
1526ac9a064cSDimitry Andric // e.g. <32 x i40> -> <40 x i32>
1527ac9a064cSDimitry Andric if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 &&
1528ac9a064cSDimitry Andric (NumDstElts % NumSrcElts) != 0)
1529ac9a064cSDimitry Andric return false;
1530ac9a064cSDimitry Andric
1531ac9a064cSDimitry Andric SmallVector<int, 16> NewMask;
1532ac9a064cSDimitry Andric if (NumSrcElts >= NumDstElts) {
1533ac9a064cSDimitry Andric // The bitcast is from wide to narrow/equal elements. The shuffle mask can
1534ac9a064cSDimitry Andric // always be expanded to the equivalent form choosing narrower elements.
1535ac9a064cSDimitry Andric assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask");
1536ac9a064cSDimitry Andric unsigned ScaleFactor = NumSrcElts / NumDstElts;
1537ac9a064cSDimitry Andric narrowShuffleMaskElts(ScaleFactor, OldMask, NewMask);
1538ac9a064cSDimitry Andric } else {
1539ac9a064cSDimitry Andric // The bitcast is from narrow elements to wide elements. The shuffle mask
1540ac9a064cSDimitry Andric // must choose consecutive elements to allow casting first.
1541ac9a064cSDimitry Andric assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask");
1542ac9a064cSDimitry Andric unsigned ScaleFactor = NumDstElts / NumSrcElts;
1543ac9a064cSDimitry Andric if (!widenShuffleMaskElts(ScaleFactor, OldMask, NewMask))
1544ac9a064cSDimitry Andric return false;
1545ac9a064cSDimitry Andric }
1546ac9a064cSDimitry Andric
1547ac9a064cSDimitry Andric auto *NewShuffleDstTy =
1548ac9a064cSDimitry Andric FixedVectorType::get(CastSrcTy->getScalarType(), NewMask.size());
1549ac9a064cSDimitry Andric
1550ac9a064cSDimitry Andric // Try to replace a castop with a shuffle if the shuffle is not costly.
1551ac9a064cSDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1552ac9a064cSDimitry Andric
1553ac9a064cSDimitry Andric InstructionCost CostC0 =
1554ac9a064cSDimitry Andric TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
1555ac9a064cSDimitry Andric TTI::CastContextHint::None, CostKind);
1556ac9a064cSDimitry Andric InstructionCost CostC1 =
1557ac9a064cSDimitry Andric TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
1558ac9a064cSDimitry Andric TTI::CastContextHint::None, CostKind);
1559ac9a064cSDimitry Andric InstructionCost OldCost = CostC0 + CostC1;
1560ac9a064cSDimitry Andric OldCost +=
1561ac9a064cSDimitry Andric TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, CastDstTy,
1562ac9a064cSDimitry Andric OldMask, CostKind, 0, nullptr, std::nullopt, &I);
1563ac9a064cSDimitry Andric
1564ac9a064cSDimitry Andric InstructionCost NewCost = TTI.getShuffleCost(
1565ac9a064cSDimitry Andric TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, NewMask, CostKind);
1566ac9a064cSDimitry Andric NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
1567ac9a064cSDimitry Andric TTI::CastContextHint::None, CostKind);
1568ac9a064cSDimitry Andric if (!C0->hasOneUse())
1569ac9a064cSDimitry Andric NewCost += CostC0;
1570ac9a064cSDimitry Andric if (!C1->hasOneUse())
1571ac9a064cSDimitry Andric NewCost += CostC1;
1572ac9a064cSDimitry Andric
1573ac9a064cSDimitry Andric LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
1574ac9a064cSDimitry Andric << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1575ac9a064cSDimitry Andric << "\n");
1576ac9a064cSDimitry Andric if (NewCost > OldCost)
1577ac9a064cSDimitry Andric return false;
1578ac9a064cSDimitry Andric
1579ac9a064cSDimitry Andric Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0),
1580ac9a064cSDimitry Andric C1->getOperand(0), NewMask);
1581ac9a064cSDimitry Andric Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
1582ac9a064cSDimitry Andric
1583ac9a064cSDimitry Andric // Intersect flags from the old casts.
1584ac9a064cSDimitry Andric if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
1585ac9a064cSDimitry Andric NewInst->copyIRFlags(C0);
1586ac9a064cSDimitry Andric NewInst->andIRFlags(C1);
1587ac9a064cSDimitry Andric }
1588ac9a064cSDimitry Andric
1589ac9a064cSDimitry Andric Worklist.pushValue(Shuf);
1590ac9a064cSDimitry Andric replaceValue(I, *Cast);
1591ac9a064cSDimitry Andric return true;
1592ac9a064cSDimitry Andric }
1593ac9a064cSDimitry Andric
1594ac9a064cSDimitry Andric /// Try to convert "shuffle (shuffle x, undef), (shuffle y, undef)"
1595ac9a064cSDimitry Andric /// into "shuffle x, y".
foldShuffleOfShuffles(Instruction & I)1596ac9a064cSDimitry Andric bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
1597ac9a064cSDimitry Andric Value *V0, *V1;
1598ac9a064cSDimitry Andric UndefValue *U0, *U1;
1599ac9a064cSDimitry Andric ArrayRef<int> OuterMask, InnerMask0, InnerMask1;
1600ac9a064cSDimitry Andric if (!match(&I, m_Shuffle(m_OneUse(m_Shuffle(m_Value(V0), m_UndefValue(U0),
1601ac9a064cSDimitry Andric m_Mask(InnerMask0))),
1602ac9a064cSDimitry Andric m_OneUse(m_Shuffle(m_Value(V1), m_UndefValue(U1),
1603ac9a064cSDimitry Andric m_Mask(InnerMask1))),
1604ac9a064cSDimitry Andric m_Mask(OuterMask))))
1605ac9a064cSDimitry Andric return false;
1606ac9a064cSDimitry Andric
1607ac9a064cSDimitry Andric auto *ShufI0 = dyn_cast<Instruction>(I.getOperand(0));
1608ac9a064cSDimitry Andric auto *ShufI1 = dyn_cast<Instruction>(I.getOperand(1));
1609ac9a064cSDimitry Andric auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1610ac9a064cSDimitry Andric auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(V0->getType());
1611ac9a064cSDimitry Andric auto *ShuffleImmTy = dyn_cast<FixedVectorType>(I.getOperand(0)->getType());
1612ac9a064cSDimitry Andric if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
1613ac9a064cSDimitry Andric V0->getType() != V1->getType())
1614ac9a064cSDimitry Andric return false;
1615ac9a064cSDimitry Andric
1616ac9a064cSDimitry Andric unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
1617ac9a064cSDimitry Andric unsigned NumImmElts = ShuffleImmTy->getNumElements();
1618ac9a064cSDimitry Andric
1619ac9a064cSDimitry Andric // Bail if either inner masks reference a RHS undef arg.
1620ac9a064cSDimitry Andric if ((!isa<PoisonValue>(U0) &&
1621ac9a064cSDimitry Andric any_of(InnerMask0, [&](int M) { return M >= (int)NumSrcElts; })) ||
1622ac9a064cSDimitry Andric (!isa<PoisonValue>(U1) &&
1623ac9a064cSDimitry Andric any_of(InnerMask1, [&](int M) { return M >= (int)NumSrcElts; })))
1624ac9a064cSDimitry Andric return false;
1625ac9a064cSDimitry Andric
1626ac9a064cSDimitry Andric // Merge shuffles - replace index to the RHS poison arg with PoisonMaskElem,
1627ac9a064cSDimitry Andric SmallVector<int, 16> NewMask(OuterMask.begin(), OuterMask.end());
1628ac9a064cSDimitry Andric for (int &M : NewMask) {
1629ac9a064cSDimitry Andric if (0 <= M && M < (int)NumImmElts) {
1630ac9a064cSDimitry Andric M = (InnerMask0[M] >= (int)NumSrcElts) ? PoisonMaskElem : InnerMask0[M];
1631ac9a064cSDimitry Andric } else if (M >= (int)NumImmElts) {
1632ac9a064cSDimitry Andric if (InnerMask1[M - NumImmElts] >= (int)NumSrcElts)
1633ac9a064cSDimitry Andric M = PoisonMaskElem;
1634ac9a064cSDimitry Andric else
1635ac9a064cSDimitry Andric M = InnerMask1[M - NumImmElts] + (V0 == V1 ? 0 : NumSrcElts);
1636ac9a064cSDimitry Andric }
1637ac9a064cSDimitry Andric }
1638ac9a064cSDimitry Andric
1639ac9a064cSDimitry Andric // Have we folded to an Identity shuffle?
1640ac9a064cSDimitry Andric if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) {
1641ac9a064cSDimitry Andric replaceValue(I, *V0);
1642ac9a064cSDimitry Andric return true;
1643ac9a064cSDimitry Andric }
1644ac9a064cSDimitry Andric
1645ac9a064cSDimitry Andric // Try to merge the shuffles if the new shuffle is not costly.
1646ac9a064cSDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1647ac9a064cSDimitry Andric
1648ac9a064cSDimitry Andric InstructionCost OldCost =
1649ac9a064cSDimitry Andric TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
1650ac9a064cSDimitry Andric InnerMask0, CostKind, 0, nullptr, {V0, U0}, ShufI0) +
1651ac9a064cSDimitry Andric TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
1652ac9a064cSDimitry Andric InnerMask1, CostKind, 0, nullptr, {V1, U1}, ShufI1) +
1653ac9a064cSDimitry Andric TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy,
1654ac9a064cSDimitry Andric OuterMask, CostKind, 0, nullptr, {ShufI0, ShufI1}, &I);
1655ac9a064cSDimitry Andric
1656ac9a064cSDimitry Andric InstructionCost NewCost =
1657ac9a064cSDimitry Andric TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleSrcTy,
1658ac9a064cSDimitry Andric NewMask, CostKind, 0, nullptr, {V0, V1});
1659ac9a064cSDimitry Andric
1660ac9a064cSDimitry Andric LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
1661ac9a064cSDimitry Andric << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1662ac9a064cSDimitry Andric << "\n");
1663ac9a064cSDimitry Andric if (NewCost > OldCost)
1664ac9a064cSDimitry Andric return false;
1665ac9a064cSDimitry Andric
1666ac9a064cSDimitry Andric // Clear unused sources to poison.
1667ac9a064cSDimitry Andric if (none_of(NewMask, [&](int M) { return 0 <= M && M < (int)NumSrcElts; }))
1668ac9a064cSDimitry Andric V0 = PoisonValue::get(ShuffleSrcTy);
1669ac9a064cSDimitry Andric if (none_of(NewMask, [&](int M) { return (int)NumSrcElts <= M; }))
1670ac9a064cSDimitry Andric V1 = PoisonValue::get(ShuffleSrcTy);
1671ac9a064cSDimitry Andric
1672ac9a064cSDimitry Andric Value *Shuf = Builder.CreateShuffleVector(V0, V1, NewMask);
1673ac9a064cSDimitry Andric replaceValue(I, *Shuf);
1674ac9a064cSDimitry Andric return true;
1675ac9a064cSDimitry Andric }
1676ac9a064cSDimitry Andric
1677ac9a064cSDimitry Andric using InstLane = std::pair<Use *, int>;
1678ac9a064cSDimitry Andric
lookThroughShuffles(Use * U,int Lane)1679ac9a064cSDimitry Andric static InstLane lookThroughShuffles(Use *U, int Lane) {
1680ac9a064cSDimitry Andric while (auto *SV = dyn_cast<ShuffleVectorInst>(U->get())) {
1681ac9a064cSDimitry Andric unsigned NumElts =
1682ac9a064cSDimitry Andric cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
1683ac9a064cSDimitry Andric int M = SV->getMaskValue(Lane);
1684ac9a064cSDimitry Andric if (M < 0)
1685ac9a064cSDimitry Andric return {nullptr, PoisonMaskElem};
1686ac9a064cSDimitry Andric if (static_cast<unsigned>(M) < NumElts) {
1687ac9a064cSDimitry Andric U = &SV->getOperandUse(0);
1688ac9a064cSDimitry Andric Lane = M;
1689ac9a064cSDimitry Andric } else {
1690ac9a064cSDimitry Andric U = &SV->getOperandUse(1);
1691ac9a064cSDimitry Andric Lane = M - NumElts;
1692ac9a064cSDimitry Andric }
1693ac9a064cSDimitry Andric }
1694ac9a064cSDimitry Andric return InstLane{U, Lane};
1695ac9a064cSDimitry Andric }
1696ac9a064cSDimitry Andric
1697ac9a064cSDimitry Andric static SmallVector<InstLane>
generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item,int Op)1698ac9a064cSDimitry Andric generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
1699ac9a064cSDimitry Andric SmallVector<InstLane> NItem;
1700ac9a064cSDimitry Andric for (InstLane IL : Item) {
1701ac9a064cSDimitry Andric auto [U, Lane] = IL;
1702ac9a064cSDimitry Andric InstLane OpLane =
1703ac9a064cSDimitry Andric U ? lookThroughShuffles(&cast<Instruction>(U->get())->getOperandUse(Op),
1704ac9a064cSDimitry Andric Lane)
1705ac9a064cSDimitry Andric : InstLane{nullptr, PoisonMaskElem};
1706ac9a064cSDimitry Andric NItem.emplace_back(OpLane);
1707ac9a064cSDimitry Andric }
1708ac9a064cSDimitry Andric return NItem;
1709ac9a064cSDimitry Andric }
1710ac9a064cSDimitry Andric
1711ac9a064cSDimitry Andric /// Detect concat of multiple values into a vector
isFreeConcat(ArrayRef<InstLane> Item,const TargetTransformInfo & TTI)1712ac9a064cSDimitry Andric static bool isFreeConcat(ArrayRef<InstLane> Item,
1713ac9a064cSDimitry Andric const TargetTransformInfo &TTI) {
1714ac9a064cSDimitry Andric auto *Ty = cast<FixedVectorType>(Item.front().first->get()->getType());
1715ac9a064cSDimitry Andric unsigned NumElts = Ty->getNumElements();
1716ac9a064cSDimitry Andric if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0)
1717ac9a064cSDimitry Andric return false;
1718ac9a064cSDimitry Andric
1719ac9a064cSDimitry Andric // Check that the concat is free, usually meaning that the type will be split
1720ac9a064cSDimitry Andric // during legalization.
1721ac9a064cSDimitry Andric SmallVector<int, 16> ConcatMask(NumElts * 2);
1722ac9a064cSDimitry Andric std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
1723ac9a064cSDimitry Andric if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, Ty, ConcatMask,
1724ac9a064cSDimitry Andric TTI::TCK_RecipThroughput) != 0)
1725ac9a064cSDimitry Andric return false;
1726ac9a064cSDimitry Andric
1727ac9a064cSDimitry Andric unsigned NumSlices = Item.size() / NumElts;
1728ac9a064cSDimitry Andric // Currently we generate a tree of shuffles for the concats, which limits us
1729ac9a064cSDimitry Andric // to a power2.
1730ac9a064cSDimitry Andric if (!isPowerOf2_32(NumSlices))
1731ac9a064cSDimitry Andric return false;
1732ac9a064cSDimitry Andric for (unsigned Slice = 0; Slice < NumSlices; ++Slice) {
1733ac9a064cSDimitry Andric Use *SliceV = Item[Slice * NumElts].first;
1734ac9a064cSDimitry Andric if (!SliceV || SliceV->get()->getType() != Ty)
1735ac9a064cSDimitry Andric return false;
1736ac9a064cSDimitry Andric for (unsigned Elt = 0; Elt < NumElts; ++Elt) {
1737ac9a064cSDimitry Andric auto [V, Lane] = Item[Slice * NumElts + Elt];
1738ac9a064cSDimitry Andric if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get())
1739ac9a064cSDimitry Andric return false;
1740ac9a064cSDimitry Andric }
1741ac9a064cSDimitry Andric }
1742ac9a064cSDimitry Andric return true;
1743ac9a064cSDimitry Andric }
1744ac9a064cSDimitry Andric
generateNewInstTree(ArrayRef<InstLane> Item,FixedVectorType * Ty,const SmallPtrSet<Use *,4> & IdentityLeafs,const SmallPtrSet<Use *,4> & SplatLeafs,const SmallPtrSet<Use *,4> & ConcatLeafs,IRBuilder<> & Builder)1745ac9a064cSDimitry Andric static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
1746ac9a064cSDimitry Andric const SmallPtrSet<Use *, 4> &IdentityLeafs,
1747ac9a064cSDimitry Andric const SmallPtrSet<Use *, 4> &SplatLeafs,
1748ac9a064cSDimitry Andric const SmallPtrSet<Use *, 4> &ConcatLeafs,
1749ac9a064cSDimitry Andric IRBuilder<> &Builder) {
1750ac9a064cSDimitry Andric auto [FrontU, FrontLane] = Item.front();
1751ac9a064cSDimitry Andric
1752ac9a064cSDimitry Andric if (IdentityLeafs.contains(FrontU)) {
1753ac9a064cSDimitry Andric return FrontU->get();
1754ac9a064cSDimitry Andric }
1755ac9a064cSDimitry Andric if (SplatLeafs.contains(FrontU)) {
1756ac9a064cSDimitry Andric SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
1757ac9a064cSDimitry Andric return Builder.CreateShuffleVector(FrontU->get(), Mask);
1758ac9a064cSDimitry Andric }
1759ac9a064cSDimitry Andric if (ConcatLeafs.contains(FrontU)) {
1760ac9a064cSDimitry Andric unsigned NumElts =
1761ac9a064cSDimitry Andric cast<FixedVectorType>(FrontU->get()->getType())->getNumElements();
1762ac9a064cSDimitry Andric SmallVector<Value *> Values(Item.size() / NumElts, nullptr);
1763ac9a064cSDimitry Andric for (unsigned S = 0; S < Values.size(); ++S)
1764ac9a064cSDimitry Andric Values[S] = Item[S * NumElts].first->get();
1765ac9a064cSDimitry Andric
1766ac9a064cSDimitry Andric while (Values.size() > 1) {
1767ac9a064cSDimitry Andric NumElts *= 2;
1768ac9a064cSDimitry Andric SmallVector<int, 16> Mask(NumElts, 0);
1769ac9a064cSDimitry Andric std::iota(Mask.begin(), Mask.end(), 0);
1770ac9a064cSDimitry Andric SmallVector<Value *> NewValues(Values.size() / 2, nullptr);
1771ac9a064cSDimitry Andric for (unsigned S = 0; S < NewValues.size(); ++S)
1772ac9a064cSDimitry Andric NewValues[S] =
1773ac9a064cSDimitry Andric Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask);
1774ac9a064cSDimitry Andric Values = NewValues;
1775ac9a064cSDimitry Andric }
1776ac9a064cSDimitry Andric return Values[0];
1777ac9a064cSDimitry Andric }
1778ac9a064cSDimitry Andric
1779ac9a064cSDimitry Andric auto *I = cast<Instruction>(FrontU->get());
1780ac9a064cSDimitry Andric auto *II = dyn_cast<IntrinsicInst>(I);
1781ac9a064cSDimitry Andric unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
1782ac9a064cSDimitry Andric SmallVector<Value *> Ops(NumOps);
1783ac9a064cSDimitry Andric for (unsigned Idx = 0; Idx < NumOps; Idx++) {
1784ac9a064cSDimitry Andric if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx)) {
1785ac9a064cSDimitry Andric Ops[Idx] = II->getOperand(Idx);
1786ac9a064cSDimitry Andric continue;
1787ac9a064cSDimitry Andric }
1788ac9a064cSDimitry Andric Ops[Idx] =
1789ac9a064cSDimitry Andric generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx), Ty,
1790ac9a064cSDimitry Andric IdentityLeafs, SplatLeafs, ConcatLeafs, Builder);
1791ac9a064cSDimitry Andric }
1792ac9a064cSDimitry Andric
1793ac9a064cSDimitry Andric SmallVector<Value *, 8> ValueList;
1794ac9a064cSDimitry Andric for (const auto &Lane : Item)
1795ac9a064cSDimitry Andric if (Lane.first)
1796ac9a064cSDimitry Andric ValueList.push_back(Lane.first->get());
1797ac9a064cSDimitry Andric
1798ac9a064cSDimitry Andric Type *DstTy =
1799ac9a064cSDimitry Andric FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements());
1800ac9a064cSDimitry Andric if (auto *BI = dyn_cast<BinaryOperator>(I)) {
1801ac9a064cSDimitry Andric auto *Value = Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
1802ac9a064cSDimitry Andric Ops[0], Ops[1]);
1803ac9a064cSDimitry Andric propagateIRFlags(Value, ValueList);
1804ac9a064cSDimitry Andric return Value;
1805ac9a064cSDimitry Andric }
1806ac9a064cSDimitry Andric if (auto *CI = dyn_cast<CmpInst>(I)) {
1807ac9a064cSDimitry Andric auto *Value = Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]);
1808ac9a064cSDimitry Andric propagateIRFlags(Value, ValueList);
1809ac9a064cSDimitry Andric return Value;
1810ac9a064cSDimitry Andric }
1811ac9a064cSDimitry Andric if (auto *SI = dyn_cast<SelectInst>(I)) {
1812ac9a064cSDimitry Andric auto *Value = Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI);
1813ac9a064cSDimitry Andric propagateIRFlags(Value, ValueList);
1814ac9a064cSDimitry Andric return Value;
1815ac9a064cSDimitry Andric }
1816ac9a064cSDimitry Andric if (auto *CI = dyn_cast<CastInst>(I)) {
1817ac9a064cSDimitry Andric auto *Value = Builder.CreateCast((Instruction::CastOps)CI->getOpcode(),
1818ac9a064cSDimitry Andric Ops[0], DstTy);
1819ac9a064cSDimitry Andric propagateIRFlags(Value, ValueList);
1820ac9a064cSDimitry Andric return Value;
1821ac9a064cSDimitry Andric }
1822ac9a064cSDimitry Andric if (II) {
1823ac9a064cSDimitry Andric auto *Value = Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
1824ac9a064cSDimitry Andric propagateIRFlags(Value, ValueList);
1825ac9a064cSDimitry Andric return Value;
1826ac9a064cSDimitry Andric }
1827ac9a064cSDimitry Andric assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
1828ac9a064cSDimitry Andric auto *Value =
1829ac9a064cSDimitry Andric Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
1830ac9a064cSDimitry Andric propagateIRFlags(Value, ValueList);
1831ac9a064cSDimitry Andric return Value;
1832ac9a064cSDimitry Andric }
1833ac9a064cSDimitry Andric
1834ac9a064cSDimitry Andric // Starting from a shuffle, look up through operands tracking the shuffled index
1835ac9a064cSDimitry Andric // of each lane. If we can simplify away the shuffles to identities then
1836ac9a064cSDimitry Andric // do so.
foldShuffleToIdentity(Instruction & I)1837ac9a064cSDimitry Andric bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
1838ac9a064cSDimitry Andric auto *Ty = dyn_cast<FixedVectorType>(I.getType());
1839ac9a064cSDimitry Andric if (!Ty || I.use_empty())
1840ac9a064cSDimitry Andric return false;
1841ac9a064cSDimitry Andric
1842ac9a064cSDimitry Andric SmallVector<InstLane> Start(Ty->getNumElements());
1843ac9a064cSDimitry Andric for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
1844ac9a064cSDimitry Andric Start[M] = lookThroughShuffles(&*I.use_begin(), M);
1845ac9a064cSDimitry Andric
1846ac9a064cSDimitry Andric SmallVector<SmallVector<InstLane>> Worklist;
1847ac9a064cSDimitry Andric Worklist.push_back(Start);
1848ac9a064cSDimitry Andric SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs;
1849ac9a064cSDimitry Andric unsigned NumVisited = 0;
1850ac9a064cSDimitry Andric
1851ac9a064cSDimitry Andric while (!Worklist.empty()) {
1852ac9a064cSDimitry Andric if (++NumVisited > MaxInstrsToScan)
1853ac9a064cSDimitry Andric return false;
1854ac9a064cSDimitry Andric
1855ac9a064cSDimitry Andric SmallVector<InstLane> Item = Worklist.pop_back_val();
1856ac9a064cSDimitry Andric auto [FrontU, FrontLane] = Item.front();
1857ac9a064cSDimitry Andric
1858ac9a064cSDimitry Andric // If we found an undef first lane then bail out to keep things simple.
1859ac9a064cSDimitry Andric if (!FrontU)
1860ac9a064cSDimitry Andric return false;
1861ac9a064cSDimitry Andric
1862ac9a064cSDimitry Andric // Helper to peek through bitcasts to the same value.
1863ac9a064cSDimitry Andric auto IsEquiv = [&](Value *X, Value *Y) {
1864ac9a064cSDimitry Andric return X->getType() == Y->getType() &&
1865ac9a064cSDimitry Andric peekThroughBitcasts(X) == peekThroughBitcasts(Y);
1866ac9a064cSDimitry Andric };
1867ac9a064cSDimitry Andric
1868ac9a064cSDimitry Andric // Look for an identity value.
1869ac9a064cSDimitry Andric if (FrontLane == 0 &&
1870ac9a064cSDimitry Andric cast<FixedVectorType>(FrontU->get()->getType())->getNumElements() ==
1871ac9a064cSDimitry Andric Ty->getNumElements() &&
1872ac9a064cSDimitry Andric all_of(drop_begin(enumerate(Item)), [IsEquiv, Item](const auto &E) {
1873ac9a064cSDimitry Andric Value *FrontV = Item.front().first->get();
1874ac9a064cSDimitry Andric return !E.value().first || (IsEquiv(E.value().first->get(), FrontV) &&
1875ac9a064cSDimitry Andric E.value().second == (int)E.index());
1876ac9a064cSDimitry Andric })) {
1877ac9a064cSDimitry Andric IdentityLeafs.insert(FrontU);
1878ac9a064cSDimitry Andric continue;
1879ac9a064cSDimitry Andric }
1880ac9a064cSDimitry Andric // Look for constants, for the moment only supporting constant splats.
1881ac9a064cSDimitry Andric if (auto *C = dyn_cast<Constant>(FrontU);
1882ac9a064cSDimitry Andric C && C->getSplatValue() &&
1883ac9a064cSDimitry Andric all_of(drop_begin(Item), [Item](InstLane &IL) {
1884ac9a064cSDimitry Andric Value *FrontV = Item.front().first->get();
1885ac9a064cSDimitry Andric Use *U = IL.first;
1886ac9a064cSDimitry Andric return !U || U->get() == FrontV;
1887ac9a064cSDimitry Andric })) {
1888ac9a064cSDimitry Andric SplatLeafs.insert(FrontU);
1889ac9a064cSDimitry Andric continue;
1890ac9a064cSDimitry Andric }
1891ac9a064cSDimitry Andric // Look for a splat value.
1892ac9a064cSDimitry Andric if (all_of(drop_begin(Item), [Item](InstLane &IL) {
1893ac9a064cSDimitry Andric auto [FrontU, FrontLane] = Item.front();
1894ac9a064cSDimitry Andric auto [U, Lane] = IL;
1895ac9a064cSDimitry Andric return !U || (U->get() == FrontU->get() && Lane == FrontLane);
1896ac9a064cSDimitry Andric })) {
1897ac9a064cSDimitry Andric SplatLeafs.insert(FrontU);
1898ac9a064cSDimitry Andric continue;
1899ac9a064cSDimitry Andric }
1900ac9a064cSDimitry Andric
1901ac9a064cSDimitry Andric // We need each element to be the same type of value, and check that each
1902ac9a064cSDimitry Andric // element has a single use.
1903e6b73279SDimitry Andric auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) {
1904ac9a064cSDimitry Andric Value *FrontV = Item.front().first->get();
1905ac9a064cSDimitry Andric if (!IL.first)
1906ac9a064cSDimitry Andric return true;
1907ac9a064cSDimitry Andric Value *V = IL.first->get();
1908ac9a064cSDimitry Andric if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse())
1909ac9a064cSDimitry Andric return false;
1910ac9a064cSDimitry Andric if (V->getValueID() != FrontV->getValueID())
1911ac9a064cSDimitry Andric return false;
1912ac9a064cSDimitry Andric if (auto *CI = dyn_cast<CmpInst>(V))
1913ac9a064cSDimitry Andric if (CI->getPredicate() != cast<CmpInst>(FrontV)->getPredicate())
1914ac9a064cSDimitry Andric return false;
1915ac9a064cSDimitry Andric if (auto *CI = dyn_cast<CastInst>(V))
1916ac9a064cSDimitry Andric if (CI->getSrcTy() != cast<CastInst>(FrontV)->getSrcTy())
1917ac9a064cSDimitry Andric return false;
1918ac9a064cSDimitry Andric if (auto *SI = dyn_cast<SelectInst>(V))
1919ac9a064cSDimitry Andric if (!isa<VectorType>(SI->getOperand(0)->getType()) ||
1920ac9a064cSDimitry Andric SI->getOperand(0)->getType() !=
1921ac9a064cSDimitry Andric cast<SelectInst>(FrontV)->getOperand(0)->getType())
1922ac9a064cSDimitry Andric return false;
1923ac9a064cSDimitry Andric if (isa<CallInst>(V) && !isa<IntrinsicInst>(V))
1924ac9a064cSDimitry Andric return false;
1925ac9a064cSDimitry Andric auto *II = dyn_cast<IntrinsicInst>(V);
1926ac9a064cSDimitry Andric return !II || (isa<IntrinsicInst>(FrontV) &&
1927ac9a064cSDimitry Andric II->getIntrinsicID() ==
1928e6b73279SDimitry Andric cast<IntrinsicInst>(FrontV)->getIntrinsicID() &&
1929e6b73279SDimitry Andric !II->hasOperandBundles());
1930e6b73279SDimitry Andric };
1931e6b73279SDimitry Andric if (all_of(drop_begin(Item), CheckLaneIsEquivalentToFirst)) {
1932ac9a064cSDimitry Andric // Check the operator is one that we support.
1933ac9a064cSDimitry Andric if (isa<BinaryOperator, CmpInst>(FrontU)) {
1934ac9a064cSDimitry Andric // We exclude div/rem in case they hit UB from poison lanes.
1935ac9a064cSDimitry Andric if (auto *BO = dyn_cast<BinaryOperator>(FrontU);
1936ac9a064cSDimitry Andric BO && BO->isIntDivRem())
1937ac9a064cSDimitry Andric return false;
1938ac9a064cSDimitry Andric Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
1939ac9a064cSDimitry Andric Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
1940ac9a064cSDimitry Andric continue;
1941ac9a064cSDimitry Andric } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst>(FrontU)) {
1942ac9a064cSDimitry Andric Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
1943ac9a064cSDimitry Andric continue;
1944ac9a064cSDimitry Andric } else if (auto *BitCast = dyn_cast<BitCastInst>(FrontU)) {
1945ac9a064cSDimitry Andric // TODO: Handle vector widening/narrowing bitcasts.
1946ac9a064cSDimitry Andric auto *DstTy = dyn_cast<FixedVectorType>(BitCast->getDestTy());
1947ac9a064cSDimitry Andric auto *SrcTy = dyn_cast<FixedVectorType>(BitCast->getSrcTy());
1948ac9a064cSDimitry Andric if (DstTy && SrcTy &&
1949ac9a064cSDimitry Andric SrcTy->getNumElements() == DstTy->getNumElements()) {
1950ac9a064cSDimitry Andric Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
1951ac9a064cSDimitry Andric continue;
1952ac9a064cSDimitry Andric }
1953ac9a064cSDimitry Andric } else if (isa<SelectInst>(FrontU)) {
1954ac9a064cSDimitry Andric Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
1955ac9a064cSDimitry Andric Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
1956ac9a064cSDimitry Andric Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2));
1957ac9a064cSDimitry Andric continue;
1958ac9a064cSDimitry Andric } else if (auto *II = dyn_cast<IntrinsicInst>(FrontU);
1959e6b73279SDimitry Andric II && isTriviallyVectorizable(II->getIntrinsicID()) &&
1960e6b73279SDimitry Andric !II->hasOperandBundles()) {
1961ac9a064cSDimitry Andric for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
1962ac9a064cSDimitry Andric if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op)) {
1963ac9a064cSDimitry Andric if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
1964ac9a064cSDimitry Andric Value *FrontV = Item.front().first->get();
1965ac9a064cSDimitry Andric Use *U = IL.first;
1966ac9a064cSDimitry Andric return !U || (cast<Instruction>(U->get())->getOperand(Op) ==
1967ac9a064cSDimitry Andric cast<Instruction>(FrontV)->getOperand(Op));
1968ac9a064cSDimitry Andric }))
1969ac9a064cSDimitry Andric return false;
1970ac9a064cSDimitry Andric continue;
1971ac9a064cSDimitry Andric }
1972ac9a064cSDimitry Andric Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op));
1973ac9a064cSDimitry Andric }
1974ac9a064cSDimitry Andric continue;
1975ac9a064cSDimitry Andric }
1976ac9a064cSDimitry Andric }
1977ac9a064cSDimitry Andric
1978ac9a064cSDimitry Andric if (isFreeConcat(Item, TTI)) {
1979ac9a064cSDimitry Andric ConcatLeafs.insert(FrontU);
1980ac9a064cSDimitry Andric continue;
1981ac9a064cSDimitry Andric }
1982ac9a064cSDimitry Andric
1983ac9a064cSDimitry Andric return false;
1984ac9a064cSDimitry Andric }
1985ac9a064cSDimitry Andric
1986ac9a064cSDimitry Andric if (NumVisited <= 1)
1987ac9a064cSDimitry Andric return false;
1988ac9a064cSDimitry Andric
1989ac9a064cSDimitry Andric // If we got this far, we know the shuffles are superfluous and can be
1990ac9a064cSDimitry Andric // removed. Scan through again and generate the new tree of instructions.
1991ac9a064cSDimitry Andric Builder.SetInsertPoint(&I);
1992ac9a064cSDimitry Andric Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs,
1993ac9a064cSDimitry Andric ConcatLeafs, Builder);
1994ac9a064cSDimitry Andric replaceValue(I, *V);
1995ac9a064cSDimitry Andric return true;
1996ac9a064cSDimitry Andric }
1997ac9a064cSDimitry Andric
1998145449b1SDimitry Andric /// Given a commutative reduction, the order of the input lanes does not alter
1999145449b1SDimitry Andric /// the results. We can use this to remove certain shuffles feeding the
2000145449b1SDimitry Andric /// reduction, removing the need to shuffle at all.
foldShuffleFromReductions(Instruction & I)2001145449b1SDimitry Andric bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
2002145449b1SDimitry Andric auto *II = dyn_cast<IntrinsicInst>(&I);
2003145449b1SDimitry Andric if (!II)
2004145449b1SDimitry Andric return false;
2005145449b1SDimitry Andric switch (II->getIntrinsicID()) {
2006145449b1SDimitry Andric case Intrinsic::vector_reduce_add:
2007145449b1SDimitry Andric case Intrinsic::vector_reduce_mul:
2008145449b1SDimitry Andric case Intrinsic::vector_reduce_and:
2009145449b1SDimitry Andric case Intrinsic::vector_reduce_or:
2010145449b1SDimitry Andric case Intrinsic::vector_reduce_xor:
2011145449b1SDimitry Andric case Intrinsic::vector_reduce_smin:
2012145449b1SDimitry Andric case Intrinsic::vector_reduce_smax:
2013145449b1SDimitry Andric case Intrinsic::vector_reduce_umin:
2014145449b1SDimitry Andric case Intrinsic::vector_reduce_umax:
2015145449b1SDimitry Andric break;
2016145449b1SDimitry Andric default:
2017145449b1SDimitry Andric return false;
2018145449b1SDimitry Andric }
2019145449b1SDimitry Andric
2020145449b1SDimitry Andric // Find all the inputs when looking through operations that do not alter the
2021145449b1SDimitry Andric // lane order (binops, for example). Currently we look for a single shuffle,
2022145449b1SDimitry Andric // and can ignore splat values.
2023145449b1SDimitry Andric std::queue<Value *> Worklist;
2024145449b1SDimitry Andric SmallPtrSet<Value *, 4> Visited;
2025145449b1SDimitry Andric ShuffleVectorInst *Shuffle = nullptr;
2026145449b1SDimitry Andric if (auto *Op = dyn_cast<Instruction>(I.getOperand(0)))
2027145449b1SDimitry Andric Worklist.push(Op);
2028145449b1SDimitry Andric
2029145449b1SDimitry Andric while (!Worklist.empty()) {
2030145449b1SDimitry Andric Value *CV = Worklist.front();
2031145449b1SDimitry Andric Worklist.pop();
2032145449b1SDimitry Andric if (Visited.contains(CV))
2033145449b1SDimitry Andric continue;
2034145449b1SDimitry Andric
2035145449b1SDimitry Andric // Splats don't change the order, so can be safely ignored.
2036145449b1SDimitry Andric if (isSplatValue(CV))
2037145449b1SDimitry Andric continue;
2038145449b1SDimitry Andric
2039145449b1SDimitry Andric Visited.insert(CV);
2040145449b1SDimitry Andric
2041145449b1SDimitry Andric if (auto *CI = dyn_cast<Instruction>(CV)) {
2042145449b1SDimitry Andric if (CI->isBinaryOp()) {
2043145449b1SDimitry Andric for (auto *Op : CI->operand_values())
2044145449b1SDimitry Andric Worklist.push(Op);
2045145449b1SDimitry Andric continue;
2046145449b1SDimitry Andric } else if (auto *SV = dyn_cast<ShuffleVectorInst>(CI)) {
2047145449b1SDimitry Andric if (Shuffle && Shuffle != SV)
2048145449b1SDimitry Andric return false;
2049145449b1SDimitry Andric Shuffle = SV;
2050145449b1SDimitry Andric continue;
2051145449b1SDimitry Andric }
2052145449b1SDimitry Andric }
2053145449b1SDimitry Andric
2054145449b1SDimitry Andric // Anything else is currently an unknown node.
2055145449b1SDimitry Andric return false;
2056145449b1SDimitry Andric }
2057145449b1SDimitry Andric
2058145449b1SDimitry Andric if (!Shuffle)
2059145449b1SDimitry Andric return false;
2060145449b1SDimitry Andric
2061145449b1SDimitry Andric // Check all uses of the binary ops and shuffles are also included in the
2062145449b1SDimitry Andric // lane-invariant operations (Visited should be the list of lanewise
2063145449b1SDimitry Andric // instructions, including the shuffle that we found).
2064145449b1SDimitry Andric for (auto *V : Visited)
2065145449b1SDimitry Andric for (auto *U : V->users())
2066145449b1SDimitry Andric if (!Visited.contains(U) && U != &I)
2067145449b1SDimitry Andric return false;
2068145449b1SDimitry Andric
2069145449b1SDimitry Andric FixedVectorType *VecType =
2070145449b1SDimitry Andric dyn_cast<FixedVectorType>(II->getOperand(0)->getType());
2071145449b1SDimitry Andric if (!VecType)
2072145449b1SDimitry Andric return false;
2073145449b1SDimitry Andric FixedVectorType *ShuffleInputType =
2074145449b1SDimitry Andric dyn_cast<FixedVectorType>(Shuffle->getOperand(0)->getType());
2075145449b1SDimitry Andric if (!ShuffleInputType)
2076145449b1SDimitry Andric return false;
2077b1c73532SDimitry Andric unsigned NumInputElts = ShuffleInputType->getNumElements();
2078145449b1SDimitry Andric
2079145449b1SDimitry Andric // Find the mask from sorting the lanes into order. This is most likely to
2080145449b1SDimitry Andric // become a identity or concat mask. Undef elements are pushed to the end.
2081145449b1SDimitry Andric SmallVector<int> ConcatMask;
2082145449b1SDimitry Andric Shuffle->getShuffleMask(ConcatMask);
2083145449b1SDimitry Andric sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
2084b1c73532SDimitry Andric // In the case of a truncating shuffle it's possible for the mask
2085b1c73532SDimitry Andric // to have an index greater than the size of the resulting vector.
2086b1c73532SDimitry Andric // This requires special handling.
2087b1c73532SDimitry Andric bool IsTruncatingShuffle = VecType->getNumElements() < NumInputElts;
2088145449b1SDimitry Andric bool UsesSecondVec =
2089b1c73532SDimitry Andric any_of(ConcatMask, [&](int M) { return M >= (int)NumInputElts; });
2090b1c73532SDimitry Andric
2091b1c73532SDimitry Andric FixedVectorType *VecTyForCost =
2092b1c73532SDimitry Andric (UsesSecondVec && !IsTruncatingShuffle) ? VecType : ShuffleInputType;
2093145449b1SDimitry Andric InstructionCost OldCost = TTI.getShuffleCost(
2094b1c73532SDimitry Andric UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
2095b1c73532SDimitry Andric VecTyForCost, Shuffle->getShuffleMask());
2096145449b1SDimitry Andric InstructionCost NewCost = TTI.getShuffleCost(
2097b1c73532SDimitry Andric UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
2098b1c73532SDimitry Andric VecTyForCost, ConcatMask);
2099145449b1SDimitry Andric
2100145449b1SDimitry Andric LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
2101145449b1SDimitry Andric << "\n");
2102145449b1SDimitry Andric LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost
2103145449b1SDimitry Andric << "\n");
2104145449b1SDimitry Andric if (NewCost < OldCost) {
2105145449b1SDimitry Andric Builder.SetInsertPoint(Shuffle);
2106145449b1SDimitry Andric Value *NewShuffle = Builder.CreateShuffleVector(
2107145449b1SDimitry Andric Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask);
2108145449b1SDimitry Andric LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
2109145449b1SDimitry Andric replaceValue(*Shuffle, *NewShuffle);
2110145449b1SDimitry Andric }
2111145449b1SDimitry Andric
2112145449b1SDimitry Andric // See if we can re-use foldSelectShuffle, getting it to reduce the size of
2113145449b1SDimitry Andric // the shuffle into a nicer order, as it can ignore the order of the shuffles.
2114145449b1SDimitry Andric return foldSelectShuffle(*Shuffle, true);
2115145449b1SDimitry Andric }
2116145449b1SDimitry Andric
2117ac9a064cSDimitry Andric /// Determine if its more efficient to fold:
2118ac9a064cSDimitry Andric /// reduce(trunc(x)) -> trunc(reduce(x)).
2119ac9a064cSDimitry Andric /// reduce(sext(x)) -> sext(reduce(x)).
2120ac9a064cSDimitry Andric /// reduce(zext(x)) -> zext(reduce(x)).
foldCastFromReductions(Instruction & I)2121ac9a064cSDimitry Andric bool VectorCombine::foldCastFromReductions(Instruction &I) {
2122ac9a064cSDimitry Andric auto *II = dyn_cast<IntrinsicInst>(&I);
2123ac9a064cSDimitry Andric if (!II)
2124ac9a064cSDimitry Andric return false;
2125ac9a064cSDimitry Andric
2126ac9a064cSDimitry Andric bool TruncOnly = false;
2127ac9a064cSDimitry Andric Intrinsic::ID IID = II->getIntrinsicID();
2128ac9a064cSDimitry Andric switch (IID) {
2129ac9a064cSDimitry Andric case Intrinsic::vector_reduce_add:
2130ac9a064cSDimitry Andric case Intrinsic::vector_reduce_mul:
2131ac9a064cSDimitry Andric TruncOnly = true;
2132ac9a064cSDimitry Andric break;
2133ac9a064cSDimitry Andric case Intrinsic::vector_reduce_and:
2134ac9a064cSDimitry Andric case Intrinsic::vector_reduce_or:
2135ac9a064cSDimitry Andric case Intrinsic::vector_reduce_xor:
2136ac9a064cSDimitry Andric break;
2137ac9a064cSDimitry Andric default:
2138ac9a064cSDimitry Andric return false;
2139ac9a064cSDimitry Andric }
2140ac9a064cSDimitry Andric
2141ac9a064cSDimitry Andric unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
2142ac9a064cSDimitry Andric Value *ReductionSrc = I.getOperand(0);
2143ac9a064cSDimitry Andric
2144ac9a064cSDimitry Andric Value *Src;
2145ac9a064cSDimitry Andric if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
2146ac9a064cSDimitry Andric (TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
2147ac9a064cSDimitry Andric return false;
2148ac9a064cSDimitry Andric
2149ac9a064cSDimitry Andric auto CastOpc =
2150ac9a064cSDimitry Andric (Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode();
2151ac9a064cSDimitry Andric
2152ac9a064cSDimitry Andric auto *SrcTy = cast<VectorType>(Src->getType());
2153ac9a064cSDimitry Andric auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
2154ac9a064cSDimitry Andric Type *ResultTy = I.getType();
2155ac9a064cSDimitry Andric
2156ac9a064cSDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2157ac9a064cSDimitry Andric InstructionCost OldCost = TTI.getArithmeticReductionCost(
2158ac9a064cSDimitry Andric ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
2159ac9a064cSDimitry Andric OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
2160ac9a064cSDimitry Andric TTI::CastContextHint::None, CostKind,
2161ac9a064cSDimitry Andric cast<CastInst>(ReductionSrc));
2162ac9a064cSDimitry Andric InstructionCost NewCost =
2163ac9a064cSDimitry Andric TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
2164ac9a064cSDimitry Andric CostKind) +
2165ac9a064cSDimitry Andric TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(),
2166ac9a064cSDimitry Andric TTI::CastContextHint::None, CostKind);
2167ac9a064cSDimitry Andric
2168ac9a064cSDimitry Andric if (OldCost <= NewCost || !NewCost.isValid())
2169ac9a064cSDimitry Andric return false;
2170ac9a064cSDimitry Andric
2171ac9a064cSDimitry Andric Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(),
2172ac9a064cSDimitry Andric II->getIntrinsicID(), {Src});
2173ac9a064cSDimitry Andric Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
2174ac9a064cSDimitry Andric replaceValue(I, *NewCast);
2175ac9a064cSDimitry Andric return true;
2176ac9a064cSDimitry Andric }
2177ac9a064cSDimitry Andric
2178145449b1SDimitry Andric /// This method looks for groups of shuffles acting on binops, of the form:
2179145449b1SDimitry Andric /// %x = shuffle ...
2180145449b1SDimitry Andric /// %y = shuffle ...
2181145449b1SDimitry Andric /// %a = binop %x, %y
2182145449b1SDimitry Andric /// %b = binop %x, %y
2183145449b1SDimitry Andric /// shuffle %a, %b, selectmask
2184145449b1SDimitry Andric /// We may, especially if the shuffle is wider than legal, be able to convert
2185145449b1SDimitry Andric /// the shuffle to a form where only parts of a and b need to be computed. On
2186145449b1SDimitry Andric /// architectures with no obvious "select" shuffle, this can reduce the total
2187145449b1SDimitry Andric /// number of operations if the target reports them as cheaper.
foldSelectShuffle(Instruction & I,bool FromReduction)2188145449b1SDimitry Andric bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
2189e3b55780SDimitry Andric auto *SVI = cast<ShuffleVectorInst>(&I);
2190e3b55780SDimitry Andric auto *VT = cast<FixedVectorType>(I.getType());
2191145449b1SDimitry Andric auto *Op0 = dyn_cast<Instruction>(SVI->getOperand(0));
2192145449b1SDimitry Andric auto *Op1 = dyn_cast<Instruction>(SVI->getOperand(1));
2193145449b1SDimitry Andric if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() ||
2194145449b1SDimitry Andric VT != Op0->getType())
2195145449b1SDimitry Andric return false;
2196e3b55780SDimitry Andric
21971f917f69SDimitry Andric auto *SVI0A = dyn_cast<Instruction>(Op0->getOperand(0));
21981f917f69SDimitry Andric auto *SVI0B = dyn_cast<Instruction>(Op0->getOperand(1));
21991f917f69SDimitry Andric auto *SVI1A = dyn_cast<Instruction>(Op1->getOperand(0));
22001f917f69SDimitry Andric auto *SVI1B = dyn_cast<Instruction>(Op1->getOperand(1));
22011f917f69SDimitry Andric SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B});
2202145449b1SDimitry Andric auto checkSVNonOpUses = [&](Instruction *I) {
2203145449b1SDimitry Andric if (!I || I->getOperand(0)->getType() != VT)
2204145449b1SDimitry Andric return true;
22051f917f69SDimitry Andric return any_of(I->users(), [&](User *U) {
22061f917f69SDimitry Andric return U != Op0 && U != Op1 &&
22071f917f69SDimitry Andric !(isa<ShuffleVectorInst>(U) &&
22081f917f69SDimitry Andric (InputShuffles.contains(cast<Instruction>(U)) ||
22091f917f69SDimitry Andric isInstructionTriviallyDead(cast<Instruction>(U))));
22101f917f69SDimitry Andric });
2211145449b1SDimitry Andric };
2212145449b1SDimitry Andric if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) ||
2213145449b1SDimitry Andric checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B))
2214145449b1SDimitry Andric return false;
2215145449b1SDimitry Andric
2216145449b1SDimitry Andric // Collect all the uses that are shuffles that we can transform together. We
2217145449b1SDimitry Andric // may not have a single shuffle, but a group that can all be transformed
2218145449b1SDimitry Andric // together profitably.
2219145449b1SDimitry Andric SmallVector<ShuffleVectorInst *> Shuffles;
2220145449b1SDimitry Andric auto collectShuffles = [&](Instruction *I) {
2221145449b1SDimitry Andric for (auto *U : I->users()) {
2222145449b1SDimitry Andric auto *SV = dyn_cast<ShuffleVectorInst>(U);
2223145449b1SDimitry Andric if (!SV || SV->getType() != VT)
2224145449b1SDimitry Andric return false;
22251f917f69SDimitry Andric if ((SV->getOperand(0) != Op0 && SV->getOperand(0) != Op1) ||
22261f917f69SDimitry Andric (SV->getOperand(1) != Op0 && SV->getOperand(1) != Op1))
22271f917f69SDimitry Andric return false;
2228145449b1SDimitry Andric if (!llvm::is_contained(Shuffles, SV))
2229145449b1SDimitry Andric Shuffles.push_back(SV);
2230145449b1SDimitry Andric }
2231145449b1SDimitry Andric return true;
2232145449b1SDimitry Andric };
2233145449b1SDimitry Andric if (!collectShuffles(Op0) || !collectShuffles(Op1))
2234145449b1SDimitry Andric return false;
2235145449b1SDimitry Andric // From a reduction, we need to be processing a single shuffle, otherwise the
2236145449b1SDimitry Andric // other uses will not be lane-invariant.
2237145449b1SDimitry Andric if (FromReduction && Shuffles.size() > 1)
2238145449b1SDimitry Andric return false;
2239145449b1SDimitry Andric
22401f917f69SDimitry Andric // Add any shuffle uses for the shuffles we have found, to include them in our
22411f917f69SDimitry Andric // cost calculations.
22421f917f69SDimitry Andric if (!FromReduction) {
22431f917f69SDimitry Andric for (ShuffleVectorInst *SV : Shuffles) {
2244e3b55780SDimitry Andric for (auto *U : SV->users()) {
22451f917f69SDimitry Andric ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(U);
22464b4fe385SDimitry Andric if (SSV && isa<UndefValue>(SSV->getOperand(1)) && SSV->getType() == VT)
22471f917f69SDimitry Andric Shuffles.push_back(SSV);
22481f917f69SDimitry Andric }
22491f917f69SDimitry Andric }
22501f917f69SDimitry Andric }
22511f917f69SDimitry Andric
2252145449b1SDimitry Andric // For each of the output shuffles, we try to sort all the first vector
2253145449b1SDimitry Andric // elements to the beginning, followed by the second array elements at the
2254145449b1SDimitry Andric // end. If the binops are legalized to smaller vectors, this may reduce total
2255145449b1SDimitry Andric // number of binops. We compute the ReconstructMask mask needed to convert
2256145449b1SDimitry Andric // back to the original lane order.
22571f917f69SDimitry Andric SmallVector<std::pair<int, int>> V1, V2;
22581f917f69SDimitry Andric SmallVector<SmallVector<int>> OrigReconstructMasks;
2259145449b1SDimitry Andric int MaxV1Elt = 0, MaxV2Elt = 0;
2260145449b1SDimitry Andric unsigned NumElts = VT->getNumElements();
2261145449b1SDimitry Andric for (ShuffleVectorInst *SVN : Shuffles) {
2262145449b1SDimitry Andric SmallVector<int> Mask;
2263145449b1SDimitry Andric SVN->getShuffleMask(Mask);
2264145449b1SDimitry Andric
2265145449b1SDimitry Andric // Check the operands are the same as the original, or reversed (in which
2266145449b1SDimitry Andric // case we need to commute the mask).
2267145449b1SDimitry Andric Value *SVOp0 = SVN->getOperand(0);
2268145449b1SDimitry Andric Value *SVOp1 = SVN->getOperand(1);
22691f917f69SDimitry Andric if (isa<UndefValue>(SVOp1)) {
22701f917f69SDimitry Andric auto *SSV = cast<ShuffleVectorInst>(SVOp0);
22711f917f69SDimitry Andric SVOp0 = SSV->getOperand(0);
22721f917f69SDimitry Andric SVOp1 = SSV->getOperand(1);
22731f917f69SDimitry Andric for (unsigned I = 0, E = Mask.size(); I != E; I++) {
22741f917f69SDimitry Andric if (Mask[I] >= static_cast<int>(SSV->getShuffleMask().size()))
22751f917f69SDimitry Andric return false;
22761f917f69SDimitry Andric Mask[I] = Mask[I] < 0 ? Mask[I] : SSV->getMaskValue(Mask[I]);
22771f917f69SDimitry Andric }
22781f917f69SDimitry Andric }
2279145449b1SDimitry Andric if (SVOp0 == Op1 && SVOp1 == Op0) {
2280145449b1SDimitry Andric std::swap(SVOp0, SVOp1);
2281145449b1SDimitry Andric ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
2282145449b1SDimitry Andric }
2283145449b1SDimitry Andric if (SVOp0 != Op0 || SVOp1 != Op1)
2284145449b1SDimitry Andric return false;
2285145449b1SDimitry Andric
2286145449b1SDimitry Andric // Calculate the reconstruction mask for this shuffle, as the mask needed to
2287145449b1SDimitry Andric // take the packed values from Op0/Op1 and reconstructing to the original
2288145449b1SDimitry Andric // order.
2289145449b1SDimitry Andric SmallVector<int> ReconstructMask;
2290145449b1SDimitry Andric for (unsigned I = 0; I < Mask.size(); I++) {
2291145449b1SDimitry Andric if (Mask[I] < 0) {
2292145449b1SDimitry Andric ReconstructMask.push_back(-1);
2293145449b1SDimitry Andric } else if (Mask[I] < static_cast<int>(NumElts)) {
2294145449b1SDimitry Andric MaxV1Elt = std::max(MaxV1Elt, Mask[I]);
22951f917f69SDimitry Andric auto It = find_if(V1, [&](const std::pair<int, int> &A) {
22961f917f69SDimitry Andric return Mask[I] == A.first;
22971f917f69SDimitry Andric });
2298145449b1SDimitry Andric if (It != V1.end())
2299145449b1SDimitry Andric ReconstructMask.push_back(It - V1.begin());
2300145449b1SDimitry Andric else {
2301145449b1SDimitry Andric ReconstructMask.push_back(V1.size());
23021f917f69SDimitry Andric V1.emplace_back(Mask[I], V1.size());
2303145449b1SDimitry Andric }
2304145449b1SDimitry Andric } else {
2305145449b1SDimitry Andric MaxV2Elt = std::max<int>(MaxV2Elt, Mask[I] - NumElts);
23061f917f69SDimitry Andric auto It = find_if(V2, [&](const std::pair<int, int> &A) {
23071f917f69SDimitry Andric return Mask[I] - static_cast<int>(NumElts) == A.first;
23081f917f69SDimitry Andric });
2309145449b1SDimitry Andric if (It != V2.end())
2310145449b1SDimitry Andric ReconstructMask.push_back(NumElts + It - V2.begin());
2311145449b1SDimitry Andric else {
2312145449b1SDimitry Andric ReconstructMask.push_back(NumElts + V2.size());
23131f917f69SDimitry Andric V2.emplace_back(Mask[I] - NumElts, NumElts + V2.size());
2314145449b1SDimitry Andric }
2315145449b1SDimitry Andric }
2316145449b1SDimitry Andric }
2317145449b1SDimitry Andric
2318145449b1SDimitry Andric // For reductions, we know that the lane ordering out doesn't alter the
2319145449b1SDimitry Andric // result. In-order can help simplify the shuffle away.
2320145449b1SDimitry Andric if (FromReduction)
2321145449b1SDimitry Andric sort(ReconstructMask);
23221f917f69SDimitry Andric OrigReconstructMasks.push_back(std::move(ReconstructMask));
2323145449b1SDimitry Andric }
2324145449b1SDimitry Andric
2325145449b1SDimitry Andric // If the Maximum element used from V1 and V2 are not larger than the new
2326145449b1SDimitry Andric // vectors, the vectors are already packes and performing the optimization
2327145449b1SDimitry Andric // again will likely not help any further. This also prevents us from getting
2328145449b1SDimitry Andric // stuck in a cycle in case the costs do not also rule it out.
2329145449b1SDimitry Andric if (V1.empty() || V2.empty() ||
2330145449b1SDimitry Andric (MaxV1Elt == static_cast<int>(V1.size()) - 1 &&
2331145449b1SDimitry Andric MaxV2Elt == static_cast<int>(V2.size()) - 1))
2332145449b1SDimitry Andric return false;
2333145449b1SDimitry Andric
23341f917f69SDimitry Andric // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a
23351f917f69SDimitry Andric // shuffle of another shuffle, or not a shuffle (that is treated like a
23361f917f69SDimitry Andric // identity shuffle).
23371f917f69SDimitry Andric auto GetBaseMaskValue = [&](Instruction *I, int M) {
23381f917f69SDimitry Andric auto *SV = dyn_cast<ShuffleVectorInst>(I);
23391f917f69SDimitry Andric if (!SV)
23401f917f69SDimitry Andric return M;
23411f917f69SDimitry Andric if (isa<UndefValue>(SV->getOperand(1)))
23421f917f69SDimitry Andric if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
23431f917f69SDimitry Andric if (InputShuffles.contains(SSV))
23441f917f69SDimitry Andric return SSV->getMaskValue(SV->getMaskValue(M));
23451f917f69SDimitry Andric return SV->getMaskValue(M);
23461f917f69SDimitry Andric };
23471f917f69SDimitry Andric
23481f917f69SDimitry Andric // Attempt to sort the inputs my ascending mask values to make simpler input
23491f917f69SDimitry Andric // shuffles and push complex shuffles down to the uses. We sort on the first
23501f917f69SDimitry Andric // of the two input shuffle orders, to try and get at least one input into a
23511f917f69SDimitry Andric // nice order.
23521f917f69SDimitry Andric auto SortBase = [&](Instruction *A, std::pair<int, int> X,
23531f917f69SDimitry Andric std::pair<int, int> Y) {
23541f917f69SDimitry Andric int MXA = GetBaseMaskValue(A, X.first);
23551f917f69SDimitry Andric int MYA = GetBaseMaskValue(A, Y.first);
23561f917f69SDimitry Andric return MXA < MYA;
23571f917f69SDimitry Andric };
23581f917f69SDimitry Andric stable_sort(V1, [&](std::pair<int, int> A, std::pair<int, int> B) {
23591f917f69SDimitry Andric return SortBase(SVI0A, A, B);
23601f917f69SDimitry Andric });
23611f917f69SDimitry Andric stable_sort(V2, [&](std::pair<int, int> A, std::pair<int, int> B) {
23621f917f69SDimitry Andric return SortBase(SVI1A, A, B);
23631f917f69SDimitry Andric });
23641f917f69SDimitry Andric // Calculate our ReconstructMasks from the OrigReconstructMasks and the
23651f917f69SDimitry Andric // modified order of the input shuffles.
23661f917f69SDimitry Andric SmallVector<SmallVector<int>> ReconstructMasks;
23677fa27ce4SDimitry Andric for (const auto &Mask : OrigReconstructMasks) {
23681f917f69SDimitry Andric SmallVector<int> ReconstructMask;
23691f917f69SDimitry Andric for (int M : Mask) {
23701f917f69SDimitry Andric auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
23711f917f69SDimitry Andric auto It = find_if(V, [M](auto A) { return A.second == M; });
23721f917f69SDimitry Andric assert(It != V.end() && "Expected all entries in Mask");
23731f917f69SDimitry Andric return std::distance(V.begin(), It);
23741f917f69SDimitry Andric };
23751f917f69SDimitry Andric if (M < 0)
23761f917f69SDimitry Andric ReconstructMask.push_back(-1);
23771f917f69SDimitry Andric else if (M < static_cast<int>(NumElts)) {
23781f917f69SDimitry Andric ReconstructMask.push_back(FindIndex(V1, M));
23791f917f69SDimitry Andric } else {
23801f917f69SDimitry Andric ReconstructMask.push_back(NumElts + FindIndex(V2, M));
23811f917f69SDimitry Andric }
23821f917f69SDimitry Andric }
23831f917f69SDimitry Andric ReconstructMasks.push_back(std::move(ReconstructMask));
23841f917f69SDimitry Andric }
23851f917f69SDimitry Andric
2386145449b1SDimitry Andric // Calculate the masks needed for the new input shuffles, which get padded
2387145449b1SDimitry Andric // with undef
2388145449b1SDimitry Andric SmallVector<int> V1A, V1B, V2A, V2B;
2389145449b1SDimitry Andric for (unsigned I = 0; I < V1.size(); I++) {
23901f917f69SDimitry Andric V1A.push_back(GetBaseMaskValue(SVI0A, V1[I].first));
23911f917f69SDimitry Andric V1B.push_back(GetBaseMaskValue(SVI0B, V1[I].first));
2392145449b1SDimitry Andric }
2393145449b1SDimitry Andric for (unsigned I = 0; I < V2.size(); I++) {
23941f917f69SDimitry Andric V2A.push_back(GetBaseMaskValue(SVI1A, V2[I].first));
23951f917f69SDimitry Andric V2B.push_back(GetBaseMaskValue(SVI1B, V2[I].first));
2396145449b1SDimitry Andric }
2397145449b1SDimitry Andric while (V1A.size() < NumElts) {
23987fa27ce4SDimitry Andric V1A.push_back(PoisonMaskElem);
23997fa27ce4SDimitry Andric V1B.push_back(PoisonMaskElem);
2400145449b1SDimitry Andric }
2401145449b1SDimitry Andric while (V2A.size() < NumElts) {
24027fa27ce4SDimitry Andric V2A.push_back(PoisonMaskElem);
24037fa27ce4SDimitry Andric V2B.push_back(PoisonMaskElem);
2404145449b1SDimitry Andric }
2405145449b1SDimitry Andric
24061f917f69SDimitry Andric auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
24071f917f69SDimitry Andric auto *SV = dyn_cast<ShuffleVectorInst>(I);
24081f917f69SDimitry Andric if (!SV)
24091f917f69SDimitry Andric return C;
24101f917f69SDimitry Andric return C + TTI.getShuffleCost(isa<UndefValue>(SV->getOperand(1))
24111f917f69SDimitry Andric ? TTI::SK_PermuteSingleSrc
24121f917f69SDimitry Andric : TTI::SK_PermuteTwoSrc,
24131f917f69SDimitry Andric VT, SV->getShuffleMask());
2414145449b1SDimitry Andric };
2415145449b1SDimitry Andric auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) {
2416145449b1SDimitry Andric return C + TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, VT, Mask);
2417145449b1SDimitry Andric };
2418145449b1SDimitry Andric
2419145449b1SDimitry Andric // Get the costs of the shuffles + binops before and after with the new
2420145449b1SDimitry Andric // shuffle masks.
2421145449b1SDimitry Andric InstructionCost CostBefore =
2422145449b1SDimitry Andric TTI.getArithmeticInstrCost(Op0->getOpcode(), VT) +
2423145449b1SDimitry Andric TTI.getArithmeticInstrCost(Op1->getOpcode(), VT);
2424145449b1SDimitry Andric CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(),
2425145449b1SDimitry Andric InstructionCost(0), AddShuffleCost);
2426145449b1SDimitry Andric CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
2427145449b1SDimitry Andric InstructionCost(0), AddShuffleCost);
2428145449b1SDimitry Andric
2429145449b1SDimitry Andric // The new binops will be unused for lanes past the used shuffle lengths.
2430145449b1SDimitry Andric // These types attempt to get the correct cost for that from the target.
2431145449b1SDimitry Andric FixedVectorType *Op0SmallVT =
2432145449b1SDimitry Andric FixedVectorType::get(VT->getScalarType(), V1.size());
2433145449b1SDimitry Andric FixedVectorType *Op1SmallVT =
2434145449b1SDimitry Andric FixedVectorType::get(VT->getScalarType(), V2.size());
2435145449b1SDimitry Andric InstructionCost CostAfter =
2436145449b1SDimitry Andric TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT) +
2437145449b1SDimitry Andric TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT);
2438145449b1SDimitry Andric CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(),
2439145449b1SDimitry Andric InstructionCost(0), AddShuffleMaskCost);
2440145449b1SDimitry Andric std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
2441145449b1SDimitry Andric CostAfter +=
2442145449b1SDimitry Andric std::accumulate(OutputShuffleMasks.begin(), OutputShuffleMasks.end(),
2443145449b1SDimitry Andric InstructionCost(0), AddShuffleMaskCost);
2444145449b1SDimitry Andric
24451f917f69SDimitry Andric LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
24461f917f69SDimitry Andric LLVM_DEBUG(dbgs() << " CostBefore: " << CostBefore
24471f917f69SDimitry Andric << " vs CostAfter: " << CostAfter << "\n");
2448145449b1SDimitry Andric if (CostBefore <= CostAfter)
2449145449b1SDimitry Andric return false;
2450145449b1SDimitry Andric
2451145449b1SDimitry Andric // The cost model has passed, create the new instructions.
24521f917f69SDimitry Andric auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * {
24531f917f69SDimitry Andric auto *SV = dyn_cast<ShuffleVectorInst>(I);
24541f917f69SDimitry Andric if (!SV)
24551f917f69SDimitry Andric return I;
24561f917f69SDimitry Andric if (isa<UndefValue>(SV->getOperand(1)))
24571f917f69SDimitry Andric if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
24581f917f69SDimitry Andric if (InputShuffles.contains(SSV))
24591f917f69SDimitry Andric return SSV->getOperand(Op);
24601f917f69SDimitry Andric return SV->getOperand(Op);
24611f917f69SDimitry Andric };
2462b1c73532SDimitry Andric Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef());
24631f917f69SDimitry Andric Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0),
24641f917f69SDimitry Andric GetShuffleOperand(SVI0A, 1), V1A);
2465b1c73532SDimitry Andric Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef());
24661f917f69SDimitry Andric Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0),
24671f917f69SDimitry Andric GetShuffleOperand(SVI0B, 1), V1B);
2468b1c73532SDimitry Andric Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef());
24691f917f69SDimitry Andric Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0),
24701f917f69SDimitry Andric GetShuffleOperand(SVI1A, 1), V2A);
2471b1c73532SDimitry Andric Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef());
24721f917f69SDimitry Andric Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0),
24731f917f69SDimitry Andric GetShuffleOperand(SVI1B, 1), V2B);
2474145449b1SDimitry Andric Builder.SetInsertPoint(Op0);
2475145449b1SDimitry Andric Value *NOp0 = Builder.CreateBinOp((Instruction::BinaryOps)Op0->getOpcode(),
2476145449b1SDimitry Andric NSV0A, NSV0B);
2477145449b1SDimitry Andric if (auto *I = dyn_cast<Instruction>(NOp0))
2478145449b1SDimitry Andric I->copyIRFlags(Op0, true);
2479145449b1SDimitry Andric Builder.SetInsertPoint(Op1);
2480145449b1SDimitry Andric Value *NOp1 = Builder.CreateBinOp((Instruction::BinaryOps)Op1->getOpcode(),
2481145449b1SDimitry Andric NSV1A, NSV1B);
2482145449b1SDimitry Andric if (auto *I = dyn_cast<Instruction>(NOp1))
2483145449b1SDimitry Andric I->copyIRFlags(Op1, true);
2484145449b1SDimitry Andric
2485145449b1SDimitry Andric for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
2486145449b1SDimitry Andric Builder.SetInsertPoint(Shuffles[S]);
2487145449b1SDimitry Andric Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]);
2488145449b1SDimitry Andric replaceValue(*Shuffles[S], *NSV);
2489145449b1SDimitry Andric }
2490145449b1SDimitry Andric
2491145449b1SDimitry Andric Worklist.pushValue(NSV0A);
2492145449b1SDimitry Andric Worklist.pushValue(NSV0B);
2493145449b1SDimitry Andric Worklist.pushValue(NSV1A);
2494145449b1SDimitry Andric Worklist.pushValue(NSV1B);
2495145449b1SDimitry Andric for (auto *S : Shuffles)
2496145449b1SDimitry Andric Worklist.add(S);
2497145449b1SDimitry Andric return true;
2498145449b1SDimitry Andric }
2499145449b1SDimitry Andric
2500cfca06d7SDimitry Andric /// This is the entry point for all transforms. Pass manager differences are
2501cfca06d7SDimitry Andric /// handled in the callers of this function.
run()2502cfca06d7SDimitry Andric bool VectorCombine::run() {
2503cfca06d7SDimitry Andric if (DisableVectorCombine)
2504cfca06d7SDimitry Andric return false;
2505cfca06d7SDimitry Andric
2506b60736ecSDimitry Andric // Don't attempt vectorization if the target does not support vectors.
2507b60736ecSDimitry Andric if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true)))
2508b60736ecSDimitry Andric return false;
2509b60736ecSDimitry Andric
2510cfca06d7SDimitry Andric bool MadeChange = false;
2511c0981da4SDimitry Andric auto FoldInst = [this, &MadeChange](Instruction &I) {
2512c0981da4SDimitry Andric Builder.SetInsertPoint(&I);
2513e3b55780SDimitry Andric bool IsFixedVectorType = isa<FixedVectorType>(I.getType());
2514e3b55780SDimitry Andric auto Opcode = I.getOpcode();
2515e3b55780SDimitry Andric
2516e3b55780SDimitry Andric // These folds should be beneficial regardless of when this pass is run
2517e3b55780SDimitry Andric // in the optimization pipeline.
2518e3b55780SDimitry Andric // The type checking is for run-time efficiency. We can avoid wasting time
2519e3b55780SDimitry Andric // dispatching to folding functions if there's no chance of matching.
2520e3b55780SDimitry Andric if (IsFixedVectorType) {
2521e3b55780SDimitry Andric switch (Opcode) {
2522e3b55780SDimitry Andric case Instruction::InsertElement:
2523c0981da4SDimitry Andric MadeChange |= vectorizeLoadInsert(I);
2524e3b55780SDimitry Andric break;
2525e3b55780SDimitry Andric case Instruction::ShuffleVector:
2526e3b55780SDimitry Andric MadeChange |= widenSubvectorLoad(I);
2527e3b55780SDimitry Andric break;
2528e3b55780SDimitry Andric default:
2529e3b55780SDimitry Andric break;
2530e3b55780SDimitry Andric }
2531e3b55780SDimitry Andric }
2532e3b55780SDimitry Andric
2533e3b55780SDimitry Andric // This transform works with scalable and fixed vectors
2534e3b55780SDimitry Andric // TODO: Identify and allow other scalable transforms
2535b1c73532SDimitry Andric if (isa<VectorType>(I.getType())) {
2536e3b55780SDimitry Andric MadeChange |= scalarizeBinopOrCmp(I);
2537b1c73532SDimitry Andric MadeChange |= scalarizeLoadExtract(I);
2538b1c73532SDimitry Andric MadeChange |= scalarizeVPIntrinsic(I);
2539b1c73532SDimitry Andric }
2540e3b55780SDimitry Andric
2541e3b55780SDimitry Andric if (Opcode == Instruction::Store)
2542c0981da4SDimitry Andric MadeChange |= foldSingleElementStore(I);
2543e3b55780SDimitry Andric
2544e3b55780SDimitry Andric // If this is an early pipeline invocation of this pass, we are done.
2545e3b55780SDimitry Andric if (TryEarlyFoldsOnly)
2546e3b55780SDimitry Andric return;
2547e3b55780SDimitry Andric
2548e3b55780SDimitry Andric // Otherwise, try folds that improve codegen but may interfere with
2549e3b55780SDimitry Andric // early IR canonicalizations.
2550e3b55780SDimitry Andric // The type checking is for run-time efficiency. We can avoid wasting time
2551e3b55780SDimitry Andric // dispatching to folding functions if there's no chance of matching.
2552e3b55780SDimitry Andric if (IsFixedVectorType) {
2553e3b55780SDimitry Andric switch (Opcode) {
2554e3b55780SDimitry Andric case Instruction::InsertElement:
2555e3b55780SDimitry Andric MadeChange |= foldInsExtFNeg(I);
2556e3b55780SDimitry Andric break;
2557e3b55780SDimitry Andric case Instruction::ShuffleVector:
2558e3b55780SDimitry Andric MadeChange |= foldShuffleOfBinops(I);
2559ac9a064cSDimitry Andric MadeChange |= foldShuffleOfCastops(I);
2560ac9a064cSDimitry Andric MadeChange |= foldShuffleOfShuffles(I);
2561e3b55780SDimitry Andric MadeChange |= foldSelectShuffle(I);
2562ac9a064cSDimitry Andric MadeChange |= foldShuffleToIdentity(I);
2563e3b55780SDimitry Andric break;
2564e3b55780SDimitry Andric case Instruction::BitCast:
2565b1c73532SDimitry Andric MadeChange |= foldBitcastShuffle(I);
2566e3b55780SDimitry Andric break;
2567e3b55780SDimitry Andric }
2568e3b55780SDimitry Andric } else {
2569e3b55780SDimitry Andric switch (Opcode) {
2570e3b55780SDimitry Andric case Instruction::Call:
2571e3b55780SDimitry Andric MadeChange |= foldShuffleFromReductions(I);
2572ac9a064cSDimitry Andric MadeChange |= foldCastFromReductions(I);
2573e3b55780SDimitry Andric break;
2574e3b55780SDimitry Andric case Instruction::ICmp:
2575e3b55780SDimitry Andric case Instruction::FCmp:
2576e3b55780SDimitry Andric MadeChange |= foldExtractExtract(I);
2577e3b55780SDimitry Andric break;
2578e3b55780SDimitry Andric default:
2579e3b55780SDimitry Andric if (Instruction::isBinaryOp(Opcode)) {
2580e3b55780SDimitry Andric MadeChange |= foldExtractExtract(I);
2581e3b55780SDimitry Andric MadeChange |= foldExtractedCmps(I);
2582e3b55780SDimitry Andric }
2583e3b55780SDimitry Andric break;
2584e3b55780SDimitry Andric }
2585e3b55780SDimitry Andric }
2586c0981da4SDimitry Andric };
2587e3b55780SDimitry Andric
2588cfca06d7SDimitry Andric for (BasicBlock &BB : F) {
2589cfca06d7SDimitry Andric // Ignore unreachable basic blocks.
2590cfca06d7SDimitry Andric if (!DT.isReachableFromEntry(&BB))
2591cfca06d7SDimitry Andric continue;
2592344a3780SDimitry Andric // Use early increment range so that we can erase instructions in loop.
2593344a3780SDimitry Andric for (Instruction &I : make_early_inc_range(BB)) {
2594c0981da4SDimitry Andric if (I.isDebugOrPseudoInst())
2595cfca06d7SDimitry Andric continue;
2596c0981da4SDimitry Andric FoldInst(I);
2597cfca06d7SDimitry Andric }
2598cfca06d7SDimitry Andric }
2599cfca06d7SDimitry Andric
2600c0981da4SDimitry Andric while (!Worklist.isEmpty()) {
2601c0981da4SDimitry Andric Instruction *I = Worklist.removeOne();
2602c0981da4SDimitry Andric if (!I)
2603c0981da4SDimitry Andric continue;
2604c0981da4SDimitry Andric
2605c0981da4SDimitry Andric if (isInstructionTriviallyDead(I)) {
2606c0981da4SDimitry Andric eraseInstruction(*I);
2607c0981da4SDimitry Andric continue;
2608c0981da4SDimitry Andric }
2609c0981da4SDimitry Andric
2610c0981da4SDimitry Andric FoldInst(*I);
2611c0981da4SDimitry Andric }
2612cfca06d7SDimitry Andric
2613cfca06d7SDimitry Andric return MadeChange;
2614cfca06d7SDimitry Andric }
2615cfca06d7SDimitry Andric
run(Function & F,FunctionAnalysisManager & FAM)2616cfca06d7SDimitry Andric PreservedAnalyses VectorCombinePass::run(Function &F,
2617cfca06d7SDimitry Andric FunctionAnalysisManager &FAM) {
2618344a3780SDimitry Andric auto &AC = FAM.getResult<AssumptionAnalysis>(F);
2619cfca06d7SDimitry Andric TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
2620cfca06d7SDimitry Andric DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
2621344a3780SDimitry Andric AAResults &AA = FAM.getResult<AAManager>(F);
2622ac9a064cSDimitry Andric const DataLayout *DL = &F.getDataLayout();
2623ac9a064cSDimitry Andric VectorCombine Combiner(F, TTI, DT, AA, AC, DL, TryEarlyFoldsOnly);
2624cfca06d7SDimitry Andric if (!Combiner.run())
2625cfca06d7SDimitry Andric return PreservedAnalyses::all();
2626cfca06d7SDimitry Andric PreservedAnalyses PA;
2627cfca06d7SDimitry Andric PA.preserveSet<CFGAnalyses>();
2628cfca06d7SDimitry Andric return PA;
2629cfca06d7SDimitry Andric }
2630