xref: /src/contrib/llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1e3b55780SDimitry Andric //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2e3b55780SDimitry Andric //
3e3b55780SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e3b55780SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5e3b55780SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e3b55780SDimitry Andric //
7e3b55780SDimitry Andric //===----------------------------------------------------------------------===//
8e3b55780SDimitry Andric //
9e3b55780SDimitry Andric // Identification:
10e3b55780SDimitry Andric // This step is responsible for finding the patterns that can be lowered to
11e3b55780SDimitry Andric // complex instructions, and building a graph to represent the complex
12e3b55780SDimitry Andric // structures. Starting from the "Converging Shuffle" (a shuffle that
13e3b55780SDimitry Andric // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14e3b55780SDimitry Andric // operands are evaluated and identified as "Composite Nodes" (collections of
15e3b55780SDimitry Andric // instructions that can potentially be lowered to a single complex
16e3b55780SDimitry Andric // instruction). This is performed by checking the real and imaginary components
17e3b55780SDimitry Andric // and tracking the data flow for each component while following the operand
18e3b55780SDimitry Andric // pairs. Validity of each node is expected to be done upon creation, and any
19e3b55780SDimitry Andric // validation errors should halt traversal and prevent further graph
20e3b55780SDimitry Andric // construction.
217fa27ce4SDimitry Andric // Instead of relying on Shuffle operations, vector interleaving and
227fa27ce4SDimitry Andric // deinterleaving can be represented by vector.interleave2 and
237fa27ce4SDimitry Andric // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
247fa27ce4SDimitry Andric // these intrinsics, whereas, fixed-width vectors are recognized for both
257fa27ce4SDimitry Andric // shufflevector instruction and intrinsics.
26e3b55780SDimitry Andric //
27e3b55780SDimitry Andric // Replacement:
28e3b55780SDimitry Andric // This step traverses the graph built up by identification, delegating to the
29e3b55780SDimitry Andric // target to validate and generate the correct intrinsics, and plumbs them
30e3b55780SDimitry Andric // together connecting each end of the new intrinsics graph to the existing
31e3b55780SDimitry Andric // use-def chain. This step is assumed to finish successfully, as all
32e3b55780SDimitry Andric // information is expected to be correct by this point.
33e3b55780SDimitry Andric //
34e3b55780SDimitry Andric //
35e3b55780SDimitry Andric // Internal data structure:
36e3b55780SDimitry Andric // ComplexDeinterleavingGraph:
37e3b55780SDimitry Andric // Keeps references to all the valid CompositeNodes formed as part of the
38e3b55780SDimitry Andric // transformation, and every Instruction contained within said nodes. It also
39e3b55780SDimitry Andric // holds onto a reference to the root Instruction, and the root node that should
40e3b55780SDimitry Andric // replace it.
41e3b55780SDimitry Andric //
42e3b55780SDimitry Andric // ComplexDeinterleavingCompositeNode:
43e3b55780SDimitry Andric // A CompositeNode represents a single transformation point; each node should
44e3b55780SDimitry Andric // transform into a single complex instruction (ignoring vector splitting, which
45e3b55780SDimitry Andric // would generate more instructions per node). They are identified in a
46e3b55780SDimitry Andric // depth-first manner, traversing and identifying the operands of each
47e3b55780SDimitry Andric // instruction in the order they appear in the IR.
48e3b55780SDimitry Andric // Each node maintains a reference  to its Real and Imaginary instructions,
49e3b55780SDimitry Andric // as well as any additional instructions that make up the identified operation
50e3b55780SDimitry Andric // (Internal instructions should only have uses within their containing node).
51e3b55780SDimitry Andric // A Node also contains the rotation and operation type that it represents.
52e3b55780SDimitry Andric // Operands contains pointers to other CompositeNodes, acting as the edges in
53e3b55780SDimitry Andric // the graph. ReplacementValue is the transformed Value* that has been emitted
54e3b55780SDimitry Andric // to the IR.
55e3b55780SDimitry Andric //
56e3b55780SDimitry Andric // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57e3b55780SDimitry Andric // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58e3b55780SDimitry Andric // should be pre-populated.
59e3b55780SDimitry Andric //
60e3b55780SDimitry Andric //===----------------------------------------------------------------------===//
61e3b55780SDimitry Andric 
62e3b55780SDimitry Andric #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63b1c73532SDimitry Andric #include "llvm/ADT/MapVector.h"
64e3b55780SDimitry Andric #include "llvm/ADT/Statistic.h"
65e3b55780SDimitry Andric #include "llvm/Analysis/TargetLibraryInfo.h"
66e3b55780SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
67e3b55780SDimitry Andric #include "llvm/CodeGen/TargetLowering.h"
68e3b55780SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
69e3b55780SDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h"
70e3b55780SDimitry Andric #include "llvm/IR/IRBuilder.h"
717fa27ce4SDimitry Andric #include "llvm/IR/PatternMatch.h"
72e3b55780SDimitry Andric #include "llvm/InitializePasses.h"
73e3b55780SDimitry Andric #include "llvm/Target/TargetMachine.h"
74e3b55780SDimitry Andric #include "llvm/Transforms/Utils/Local.h"
75e3b55780SDimitry Andric #include <algorithm>
76e3b55780SDimitry Andric 
77e3b55780SDimitry Andric using namespace llvm;
78e3b55780SDimitry Andric using namespace PatternMatch;
79e3b55780SDimitry Andric 
80e3b55780SDimitry Andric #define DEBUG_TYPE "complex-deinterleaving"
81e3b55780SDimitry Andric 
82e3b55780SDimitry Andric STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
83e3b55780SDimitry Andric 
84e3b55780SDimitry Andric static cl::opt<bool> ComplexDeinterleavingEnabled(
85e3b55780SDimitry Andric     "enable-complex-deinterleaving",
86e3b55780SDimitry Andric     cl::desc("Enable generation of complex instructions"), cl::init(true),
87e3b55780SDimitry Andric     cl::Hidden);
88e3b55780SDimitry Andric 
89e3b55780SDimitry Andric /// Checks the given mask, and determines whether said mask is interleaving.
90e3b55780SDimitry Andric ///
91e3b55780SDimitry Andric /// To be interleaving, a mask must alternate between `i` and `i + (Length /
92e3b55780SDimitry Andric /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
93e3b55780SDimitry Andric /// 4x vector interleaving mask would be <0, 2, 1, 3>).
94e3b55780SDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask);
95e3b55780SDimitry Andric 
96e3b55780SDimitry Andric /// Checks the given mask, and determines whether said mask is deinterleaving.
97e3b55780SDimitry Andric ///
98e3b55780SDimitry Andric /// To be deinterleaving, a mask must increment in steps of 2, and either start
99e3b55780SDimitry Andric /// with 0 or 1.
100e3b55780SDimitry Andric /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
101e3b55780SDimitry Andric /// <1, 3, 5, 7>).
102e3b55780SDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask);
103e3b55780SDimitry Andric 
1047fa27ce4SDimitry Andric /// Returns true if the operation is a negation of V, and it works for both
1057fa27ce4SDimitry Andric /// integers and floats.
1067fa27ce4SDimitry Andric static bool isNeg(Value *V);
1077fa27ce4SDimitry Andric 
1087fa27ce4SDimitry Andric /// Returns the operand for negation operation.
1097fa27ce4SDimitry Andric static Value *getNegOperand(Value *V);
1107fa27ce4SDimitry Andric 
111e3b55780SDimitry Andric namespace {
112e3b55780SDimitry Andric 
113e3b55780SDimitry Andric class ComplexDeinterleavingLegacyPass : public FunctionPass {
114e3b55780SDimitry Andric public:
115e3b55780SDimitry Andric   static char ID;
116e3b55780SDimitry Andric 
ComplexDeinterleavingLegacyPass(const TargetMachine * TM=nullptr)117e3b55780SDimitry Andric   ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
118e3b55780SDimitry Andric       : FunctionPass(ID), TM(TM) {
119e3b55780SDimitry Andric     initializeComplexDeinterleavingLegacyPassPass(
120e3b55780SDimitry Andric         *PassRegistry::getPassRegistry());
121e3b55780SDimitry Andric   }
122e3b55780SDimitry Andric 
getPassName() const123e3b55780SDimitry Andric   StringRef getPassName() const override {
124e3b55780SDimitry Andric     return "Complex Deinterleaving Pass";
125e3b55780SDimitry Andric   }
126e3b55780SDimitry Andric 
127e3b55780SDimitry Andric   bool runOnFunction(Function &F) override;
getAnalysisUsage(AnalysisUsage & AU) const128e3b55780SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
129e3b55780SDimitry Andric     AU.addRequired<TargetLibraryInfoWrapperPass>();
130e3b55780SDimitry Andric     AU.setPreservesCFG();
131e3b55780SDimitry Andric   }
132e3b55780SDimitry Andric 
133e3b55780SDimitry Andric private:
134e3b55780SDimitry Andric   const TargetMachine *TM;
135e3b55780SDimitry Andric };
136e3b55780SDimitry Andric 
137e3b55780SDimitry Andric class ComplexDeinterleavingGraph;
138e3b55780SDimitry Andric struct ComplexDeinterleavingCompositeNode {
139e3b55780SDimitry Andric 
ComplexDeinterleavingCompositeNode__anon446059c90111::ComplexDeinterleavingCompositeNode140e3b55780SDimitry Andric   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
1417fa27ce4SDimitry Andric                                      Value *R, Value *I)
142e3b55780SDimitry Andric       : Operation(Op), Real(R), Imag(I) {}
143e3b55780SDimitry Andric 
144e3b55780SDimitry Andric private:
145e3b55780SDimitry Andric   friend class ComplexDeinterleavingGraph;
146e3b55780SDimitry Andric   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
147e3b55780SDimitry Andric   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
148e3b55780SDimitry Andric 
149e3b55780SDimitry Andric public:
150e3b55780SDimitry Andric   ComplexDeinterleavingOperation Operation;
1517fa27ce4SDimitry Andric   Value *Real;
1527fa27ce4SDimitry Andric   Value *Imag;
153e3b55780SDimitry Andric 
1547fa27ce4SDimitry Andric   // This two members are required exclusively for generating
1557fa27ce4SDimitry Andric   // ComplexDeinterleavingOperation::Symmetric operations.
1567fa27ce4SDimitry Andric   unsigned Opcode;
1577fa27ce4SDimitry Andric   std::optional<FastMathFlags> Flags;
1587fa27ce4SDimitry Andric 
1597fa27ce4SDimitry Andric   ComplexDeinterleavingRotation Rotation =
1607fa27ce4SDimitry Andric       ComplexDeinterleavingRotation::Rotation_0;
161e3b55780SDimitry Andric   SmallVector<RawNodePtr> Operands;
162e3b55780SDimitry Andric   Value *ReplacementNode = nullptr;
163e3b55780SDimitry Andric 
addOperand__anon446059c90111::ComplexDeinterleavingCompositeNode164e3b55780SDimitry Andric   void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
165e3b55780SDimitry Andric 
dump__anon446059c90111::ComplexDeinterleavingCompositeNode166e3b55780SDimitry Andric   void dump() { dump(dbgs()); }
dump__anon446059c90111::ComplexDeinterleavingCompositeNode167e3b55780SDimitry Andric   void dump(raw_ostream &OS) {
168e3b55780SDimitry Andric     auto PrintValue = [&](Value *V) {
169e3b55780SDimitry Andric       if (V) {
170e3b55780SDimitry Andric         OS << "\"";
171e3b55780SDimitry Andric         V->print(OS, true);
172e3b55780SDimitry Andric         OS << "\"\n";
173e3b55780SDimitry Andric       } else
174e3b55780SDimitry Andric         OS << "nullptr\n";
175e3b55780SDimitry Andric     };
176e3b55780SDimitry Andric     auto PrintNodeRef = [&](RawNodePtr Ptr) {
177e3b55780SDimitry Andric       if (Ptr)
178e3b55780SDimitry Andric         OS << Ptr << "\n";
179e3b55780SDimitry Andric       else
180e3b55780SDimitry Andric         OS << "nullptr\n";
181e3b55780SDimitry Andric     };
182e3b55780SDimitry Andric 
183e3b55780SDimitry Andric     OS << "- CompositeNode: " << this << "\n";
184e3b55780SDimitry Andric     OS << "  Real: ";
185e3b55780SDimitry Andric     PrintValue(Real);
186e3b55780SDimitry Andric     OS << "  Imag: ";
187e3b55780SDimitry Andric     PrintValue(Imag);
188e3b55780SDimitry Andric     OS << "  ReplacementNode: ";
189e3b55780SDimitry Andric     PrintValue(ReplacementNode);
190e3b55780SDimitry Andric     OS << "  Operation: " << (int)Operation << "\n";
191e3b55780SDimitry Andric     OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
192e3b55780SDimitry Andric     OS << "  Operands: \n";
193e3b55780SDimitry Andric     for (const auto &Op : Operands) {
194e3b55780SDimitry Andric       OS << "    - ";
195e3b55780SDimitry Andric       PrintNodeRef(Op);
196e3b55780SDimitry Andric     }
197e3b55780SDimitry Andric   }
198e3b55780SDimitry Andric };
199e3b55780SDimitry Andric 
200e3b55780SDimitry Andric class ComplexDeinterleavingGraph {
201e3b55780SDimitry Andric public:
2027fa27ce4SDimitry Andric   struct Product {
2037fa27ce4SDimitry Andric     Value *Multiplier;
2047fa27ce4SDimitry Andric     Value *Multiplicand;
2057fa27ce4SDimitry Andric     bool IsPositive;
2067fa27ce4SDimitry Andric   };
2077fa27ce4SDimitry Andric 
2087fa27ce4SDimitry Andric   using Addend = std::pair<Value *, bool>;
209e3b55780SDimitry Andric   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
210e3b55780SDimitry Andric   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
2117fa27ce4SDimitry Andric 
2127fa27ce4SDimitry Andric   // Helper struct for holding info about potential partial multiplication
2137fa27ce4SDimitry Andric   // candidates
2147fa27ce4SDimitry Andric   struct PartialMulCandidate {
2157fa27ce4SDimitry Andric     Value *Common;
2167fa27ce4SDimitry Andric     NodePtr Node;
2177fa27ce4SDimitry Andric     unsigned RealIdx;
2187fa27ce4SDimitry Andric     unsigned ImagIdx;
2197fa27ce4SDimitry Andric     bool IsNodeInverted;
2207fa27ce4SDimitry Andric   };
2217fa27ce4SDimitry Andric 
ComplexDeinterleavingGraph(const TargetLowering * TL,const TargetLibraryInfo * TLI)2227fa27ce4SDimitry Andric   explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
2237fa27ce4SDimitry Andric                                       const TargetLibraryInfo *TLI)
2247fa27ce4SDimitry Andric       : TL(TL), TLI(TLI) {}
225e3b55780SDimitry Andric 
226e3b55780SDimitry Andric private:
2277fa27ce4SDimitry Andric   const TargetLowering *TL = nullptr;
2287fa27ce4SDimitry Andric   const TargetLibraryInfo *TLI = nullptr;
229e3b55780SDimitry Andric   SmallVector<NodePtr> CompositeNodes;
230b1c73532SDimitry Andric   DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
2317fa27ce4SDimitry Andric 
2327fa27ce4SDimitry Andric   SmallPtrSet<Instruction *, 16> FinalInstructions;
2337fa27ce4SDimitry Andric 
2347fa27ce4SDimitry Andric   /// Root instructions are instructions from which complex computation starts
2357fa27ce4SDimitry Andric   std::map<Instruction *, NodePtr> RootToNode;
2367fa27ce4SDimitry Andric 
2377fa27ce4SDimitry Andric   /// Topologically sorted root instructions
2387fa27ce4SDimitry Andric   SmallVector<Instruction *, 1> OrderedRoots;
2397fa27ce4SDimitry Andric 
2407fa27ce4SDimitry Andric   /// When examining a basic block for complex deinterleaving, if it is a simple
2417fa27ce4SDimitry Andric   /// one-block loop, then the only incoming block is 'Incoming' and the
2427fa27ce4SDimitry Andric   /// 'BackEdge' block is the block itself."
2437fa27ce4SDimitry Andric   BasicBlock *BackEdge = nullptr;
2447fa27ce4SDimitry Andric   BasicBlock *Incoming = nullptr;
2457fa27ce4SDimitry Andric 
2467fa27ce4SDimitry Andric   /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
2477fa27ce4SDimitry Andric   /// %OutsideUser as it is shown in the IR:
2487fa27ce4SDimitry Andric   ///
2497fa27ce4SDimitry Andric   /// vector.body:
2507fa27ce4SDimitry Andric   ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
2517fa27ce4SDimitry Andric   ///                                [ %ReductionOp, %vector.body ]
2527fa27ce4SDimitry Andric   ///   ...
2537fa27ce4SDimitry Andric   ///   %ReductionOp = fadd i64 ...
2547fa27ce4SDimitry Andric   ///   ...
2557fa27ce4SDimitry Andric   ///   br i1 %condition, label %vector.body, %middle.block
2567fa27ce4SDimitry Andric   ///
2577fa27ce4SDimitry Andric   /// middle.block:
2587fa27ce4SDimitry Andric   ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
2597fa27ce4SDimitry Andric   ///
2607fa27ce4SDimitry Andric   /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
2617fa27ce4SDimitry Andric   /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
262b1c73532SDimitry Andric   MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
2637fa27ce4SDimitry Andric 
2647fa27ce4SDimitry Andric   /// In the process of detecting a reduction, we consider a pair of
2657fa27ce4SDimitry Andric   /// %ReductionOP, which we refer to as real and imag (or vice versa), and
2667fa27ce4SDimitry Andric   /// traverse the use-tree to detect complex operations. As this is a reduction
2677fa27ce4SDimitry Andric   /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
2687fa27ce4SDimitry Andric   /// to the %ReductionOPs that we suspect to be complex.
2697fa27ce4SDimitry Andric   /// RealPHI and ImagPHI are used by the identifyPHINode method.
2707fa27ce4SDimitry Andric   PHINode *RealPHI = nullptr;
2717fa27ce4SDimitry Andric   PHINode *ImagPHI = nullptr;
2727fa27ce4SDimitry Andric 
2737fa27ce4SDimitry Andric   /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
2747fa27ce4SDimitry Andric   /// detection.
2757fa27ce4SDimitry Andric   bool PHIsFound = false;
2767fa27ce4SDimitry Andric 
2777fa27ce4SDimitry Andric   /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
2787fa27ce4SDimitry Andric   /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
2797fa27ce4SDimitry Andric   /// This mapping is populated during
2807fa27ce4SDimitry Andric   /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
2817fa27ce4SDimitry Andric   /// used in the ComplexDeinterleavingOperation::ReductionOperation node
2827fa27ce4SDimitry Andric   /// replacement process.
2837fa27ce4SDimitry Andric   std::map<PHINode *, PHINode *> OldToNewPHI;
284e3b55780SDimitry Andric 
prepareCompositeNode(ComplexDeinterleavingOperation Operation,Value * R,Value * I)285e3b55780SDimitry Andric   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
2867fa27ce4SDimitry Andric                                Value *R, Value *I) {
2877fa27ce4SDimitry Andric     assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
2887fa27ce4SDimitry Andric              Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
2897fa27ce4SDimitry Andric             (R && I)) &&
2907fa27ce4SDimitry Andric            "Reduction related nodes must have Real and Imaginary parts");
291e3b55780SDimitry Andric     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
292e3b55780SDimitry Andric                                                                 I);
293e3b55780SDimitry Andric   }
294e3b55780SDimitry Andric 
submitCompositeNode(NodePtr Node)295e3b55780SDimitry Andric   NodePtr submitCompositeNode(NodePtr Node) {
296e3b55780SDimitry Andric     CompositeNodes.push_back(Node);
297b1c73532SDimitry Andric     if (Node->Real && Node->Imag)
298b1c73532SDimitry Andric       CachedResult[{Node->Real, Node->Imag}] = Node;
299e3b55780SDimitry Andric     return Node;
300e3b55780SDimitry Andric   }
301e3b55780SDimitry Andric 
302e3b55780SDimitry Andric   /// Identifies a complex partial multiply pattern and its rotation, based on
303e3b55780SDimitry Andric   /// the following patterns
304e3b55780SDimitry Andric   ///
305e3b55780SDimitry Andric   ///  0:  r: cr + ar * br
306e3b55780SDimitry Andric   ///      i: ci + ar * bi
307e3b55780SDimitry Andric   /// 90:  r: cr - ai * bi
308e3b55780SDimitry Andric   ///      i: ci + ai * br
309e3b55780SDimitry Andric   /// 180: r: cr - ar * br
310e3b55780SDimitry Andric   ///      i: ci - ar * bi
311e3b55780SDimitry Andric   /// 270: r: cr + ai * bi
312e3b55780SDimitry Andric   ///      i: ci - ai * br
313e3b55780SDimitry Andric   NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
314e3b55780SDimitry Andric 
315e3b55780SDimitry Andric   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
316e3b55780SDimitry Andric   /// is partially known from identifyPartialMul, filling in the other half of
317e3b55780SDimitry Andric   /// the complex pair.
3187fa27ce4SDimitry Andric   NodePtr
3197fa27ce4SDimitry Andric   identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
3207fa27ce4SDimitry Andric                               std::pair<Value *, Value *> &CommonOperandI);
321e3b55780SDimitry Andric 
322e3b55780SDimitry Andric   /// Identifies a complex add pattern and its rotation, based on the following
323e3b55780SDimitry Andric   /// patterns.
324e3b55780SDimitry Andric   ///
325e3b55780SDimitry Andric   /// 90:  r: ar - bi
326e3b55780SDimitry Andric   ///      i: ai + br
327e3b55780SDimitry Andric   /// 270: r: ar + bi
328e3b55780SDimitry Andric   ///      i: ai - br
329e3b55780SDimitry Andric   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
3307fa27ce4SDimitry Andric   NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
331e3b55780SDimitry Andric 
3327fa27ce4SDimitry Andric   NodePtr identifyNode(Value *R, Value *I);
333e3b55780SDimitry Andric 
3347fa27ce4SDimitry Andric   /// Determine if a sum of complex numbers can be formed from \p RealAddends
3357fa27ce4SDimitry Andric   /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
3367fa27ce4SDimitry Andric   /// Return nullptr if it is not possible to construct a complex number.
3377fa27ce4SDimitry Andric   /// \p Flags are needed to generate symmetric Add and Sub operations.
3387fa27ce4SDimitry Andric   NodePtr identifyAdditions(std::list<Addend> &RealAddends,
3397fa27ce4SDimitry Andric                             std::list<Addend> &ImagAddends,
3407fa27ce4SDimitry Andric                             std::optional<FastMathFlags> Flags,
3417fa27ce4SDimitry Andric                             NodePtr Accumulator);
3427fa27ce4SDimitry Andric 
3437fa27ce4SDimitry Andric   /// Extract one addend that have both real and imaginary parts positive.
3447fa27ce4SDimitry Andric   NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
3457fa27ce4SDimitry Andric                                 std::list<Addend> &ImagAddends);
3467fa27ce4SDimitry Andric 
3477fa27ce4SDimitry Andric   /// Determine if sum of multiplications of complex numbers can be formed from
3487fa27ce4SDimitry Andric   /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
3497fa27ce4SDimitry Andric   /// to it. Return nullptr if it is not possible to construct a complex number.
3507fa27ce4SDimitry Andric   NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
3517fa27ce4SDimitry Andric                                   std::vector<Product> &ImagMuls,
3527fa27ce4SDimitry Andric                                   NodePtr Accumulator);
3537fa27ce4SDimitry Andric 
3547fa27ce4SDimitry Andric   /// Go through pairs of multiplication (one Real and one Imag) and find all
3557fa27ce4SDimitry Andric   /// possible candidates for partial multiplication and put them into \p
3567fa27ce4SDimitry Andric   /// Candidates. Returns true if all Product has pair with common operand
3577fa27ce4SDimitry Andric   bool collectPartialMuls(const std::vector<Product> &RealMuls,
3587fa27ce4SDimitry Andric                           const std::vector<Product> &ImagMuls,
3597fa27ce4SDimitry Andric                           std::vector<PartialMulCandidate> &Candidates);
3607fa27ce4SDimitry Andric 
3617fa27ce4SDimitry Andric   /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
3627fa27ce4SDimitry Andric   /// the order of complex computation operations may be significantly altered,
3637fa27ce4SDimitry Andric   /// and the real and imaginary parts may not be executed in parallel. This
3647fa27ce4SDimitry Andric   /// function takes this into consideration and employs a more general approach
3657fa27ce4SDimitry Andric   /// to identify complex computations. Initially, it gathers all the addends
3667fa27ce4SDimitry Andric   /// and multiplicands and then constructs a complex expression from them.
3677fa27ce4SDimitry Andric   NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
3687fa27ce4SDimitry Andric 
3697fa27ce4SDimitry Andric   NodePtr identifyRoot(Instruction *I);
3707fa27ce4SDimitry Andric 
3717fa27ce4SDimitry Andric   /// Identifies the Deinterleave operation applied to a vector containing
3727fa27ce4SDimitry Andric   /// complex numbers. There are two ways to represent the Deinterleave
3737fa27ce4SDimitry Andric   /// operation:
3747fa27ce4SDimitry Andric   /// * Using two shufflevectors with even indices for /pReal instruction and
3757fa27ce4SDimitry Andric   /// odd indices for /pImag instructions (only for fixed-width vectors)
3767fa27ce4SDimitry Andric   /// * Using two extractvalue instructions applied to `vector.deinterleave2`
3777fa27ce4SDimitry Andric   /// intrinsic (for both fixed and scalable vectors)
3787fa27ce4SDimitry Andric   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
3797fa27ce4SDimitry Andric 
3807fa27ce4SDimitry Andric   /// identifying the operation that represents a complex number repeated in a
3817fa27ce4SDimitry Andric   /// Splat vector. There are two possible types of splats: ConstantExpr with
3827fa27ce4SDimitry Andric   /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
3837fa27ce4SDimitry Andric   /// initialization mask with all values set to zero.
3847fa27ce4SDimitry Andric   NodePtr identifySplat(Value *Real, Value *Imag);
3857fa27ce4SDimitry Andric 
3867fa27ce4SDimitry Andric   NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
3877fa27ce4SDimitry Andric 
3887fa27ce4SDimitry Andric   /// Identifies SelectInsts in a loop that has reduction with predication masks
3897fa27ce4SDimitry Andric   /// and/or predicated tail folding
3907fa27ce4SDimitry Andric   NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
3917fa27ce4SDimitry Andric 
3927fa27ce4SDimitry Andric   Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
3937fa27ce4SDimitry Andric 
3947fa27ce4SDimitry Andric   /// Complete IR modifications after producing new reduction operation:
3957fa27ce4SDimitry Andric   /// * Populate the PHINode generated for
3967fa27ce4SDimitry Andric   /// ComplexDeinterleavingOperation::ReductionPHI
3977fa27ce4SDimitry Andric   /// * Deinterleave the final value outside of the loop and repurpose original
3987fa27ce4SDimitry Andric   /// reduction users
3997fa27ce4SDimitry Andric   void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
400e3b55780SDimitry Andric 
401e3b55780SDimitry Andric public:
dump()402e3b55780SDimitry Andric   void dump() { dump(dbgs()); }
dump(raw_ostream & OS)403e3b55780SDimitry Andric   void dump(raw_ostream &OS) {
404e3b55780SDimitry Andric     for (const auto &Node : CompositeNodes)
405e3b55780SDimitry Andric       Node->dump(OS);
406e3b55780SDimitry Andric   }
407e3b55780SDimitry Andric 
408e3b55780SDimitry Andric   /// Returns false if the deinterleaving operation should be cancelled for the
409e3b55780SDimitry Andric   /// current graph.
410e3b55780SDimitry Andric   bool identifyNodes(Instruction *RootI);
411e3b55780SDimitry Andric 
4127fa27ce4SDimitry Andric   /// In case \pB is one-block loop, this function seeks potential reductions
4137fa27ce4SDimitry Andric   /// and populates ReductionInfo. Returns true if any reductions were
4147fa27ce4SDimitry Andric   /// identified.
4157fa27ce4SDimitry Andric   bool collectPotentialReductions(BasicBlock *B);
4167fa27ce4SDimitry Andric 
4177fa27ce4SDimitry Andric   void identifyReductionNodes();
4187fa27ce4SDimitry Andric 
4197fa27ce4SDimitry Andric   /// Check that every instruction, from the roots to the leaves, has internal
4207fa27ce4SDimitry Andric   /// uses.
4217fa27ce4SDimitry Andric   bool checkNodes();
4227fa27ce4SDimitry Andric 
423e3b55780SDimitry Andric   /// Perform the actual replacement of the underlying instruction graph.
424e3b55780SDimitry Andric   void replaceNodes();
425e3b55780SDimitry Andric };
426e3b55780SDimitry Andric 
427e3b55780SDimitry Andric class ComplexDeinterleaving {
428e3b55780SDimitry Andric public:
ComplexDeinterleaving(const TargetLowering * tl,const TargetLibraryInfo * tli)429e3b55780SDimitry Andric   ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
430e3b55780SDimitry Andric       : TL(tl), TLI(tli) {}
431e3b55780SDimitry Andric   bool runOnFunction(Function &F);
432e3b55780SDimitry Andric 
433e3b55780SDimitry Andric private:
434e3b55780SDimitry Andric   bool evaluateBasicBlock(BasicBlock *B);
435e3b55780SDimitry Andric 
436e3b55780SDimitry Andric   const TargetLowering *TL = nullptr;
437e3b55780SDimitry Andric   const TargetLibraryInfo *TLI = nullptr;
438e3b55780SDimitry Andric };
439e3b55780SDimitry Andric 
440e3b55780SDimitry Andric } // namespace
441e3b55780SDimitry Andric 
442e3b55780SDimitry Andric char ComplexDeinterleavingLegacyPass::ID = 0;
443e3b55780SDimitry Andric 
444e3b55780SDimitry Andric INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
445e3b55780SDimitry Andric                       "Complex Deinterleaving", false, false)
446e3b55780SDimitry Andric INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
447e3b55780SDimitry Andric                     "Complex Deinterleaving", false, false)
448e3b55780SDimitry Andric 
run(Function & F,FunctionAnalysisManager & AM)449e3b55780SDimitry Andric PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
450e3b55780SDimitry Andric                                                  FunctionAnalysisManager &AM) {
451e3b55780SDimitry Andric   const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
452e3b55780SDimitry Andric   auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
453e3b55780SDimitry Andric   if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
454e3b55780SDimitry Andric     return PreservedAnalyses::all();
455e3b55780SDimitry Andric 
456e3b55780SDimitry Andric   PreservedAnalyses PA;
457e3b55780SDimitry Andric   PA.preserve<FunctionAnalysisManagerModuleProxy>();
458e3b55780SDimitry Andric   return PA;
459e3b55780SDimitry Andric }
460e3b55780SDimitry Andric 
createComplexDeinterleavingPass(const TargetMachine * TM)461e3b55780SDimitry Andric FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
462e3b55780SDimitry Andric   return new ComplexDeinterleavingLegacyPass(TM);
463e3b55780SDimitry Andric }
464e3b55780SDimitry Andric 
runOnFunction(Function & F)465e3b55780SDimitry Andric bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
466e3b55780SDimitry Andric   const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
467e3b55780SDimitry Andric   auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
468e3b55780SDimitry Andric   return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
469e3b55780SDimitry Andric }
470e3b55780SDimitry Andric 
runOnFunction(Function & F)471e3b55780SDimitry Andric bool ComplexDeinterleaving::runOnFunction(Function &F) {
472e3b55780SDimitry Andric   if (!ComplexDeinterleavingEnabled) {
473e3b55780SDimitry Andric     LLVM_DEBUG(
474e3b55780SDimitry Andric         dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
475e3b55780SDimitry Andric     return false;
476e3b55780SDimitry Andric   }
477e3b55780SDimitry Andric 
478e3b55780SDimitry Andric   if (!TL->isComplexDeinterleavingSupported()) {
479e3b55780SDimitry Andric     LLVM_DEBUG(
480e3b55780SDimitry Andric         dbgs() << "Complex deinterleaving has been disabled, target does "
481e3b55780SDimitry Andric                   "not support lowering of complex number operations.\n");
482e3b55780SDimitry Andric     return false;
483e3b55780SDimitry Andric   }
484e3b55780SDimitry Andric 
485e3b55780SDimitry Andric   bool Changed = false;
486e3b55780SDimitry Andric   for (auto &B : F)
487e3b55780SDimitry Andric     Changed |= evaluateBasicBlock(&B);
488e3b55780SDimitry Andric 
489e3b55780SDimitry Andric   return Changed;
490e3b55780SDimitry Andric }
491e3b55780SDimitry Andric 
isInterleavingMask(ArrayRef<int> Mask)492e3b55780SDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask) {
493e3b55780SDimitry Andric   // If the size is not even, it's not an interleaving mask
494e3b55780SDimitry Andric   if ((Mask.size() & 1))
495e3b55780SDimitry Andric     return false;
496e3b55780SDimitry Andric 
497e3b55780SDimitry Andric   int HalfNumElements = Mask.size() / 2;
498e3b55780SDimitry Andric   for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
499e3b55780SDimitry Andric     int MaskIdx = Idx * 2;
500e3b55780SDimitry Andric     if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
501e3b55780SDimitry Andric       return false;
502e3b55780SDimitry Andric   }
503e3b55780SDimitry Andric 
504e3b55780SDimitry Andric   return true;
505e3b55780SDimitry Andric }
506e3b55780SDimitry Andric 
isDeinterleavingMask(ArrayRef<int> Mask)507e3b55780SDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask) {
508e3b55780SDimitry Andric   int Offset = Mask[0];
509e3b55780SDimitry Andric   int HalfNumElements = Mask.size() / 2;
510e3b55780SDimitry Andric 
511e3b55780SDimitry Andric   for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
512e3b55780SDimitry Andric     if (Mask[Idx] != (Idx * 2) + Offset)
513e3b55780SDimitry Andric       return false;
514e3b55780SDimitry Andric   }
515e3b55780SDimitry Andric 
516e3b55780SDimitry Andric   return true;
517e3b55780SDimitry Andric }
518e3b55780SDimitry Andric 
isNeg(Value * V)5197fa27ce4SDimitry Andric bool isNeg(Value *V) {
5207fa27ce4SDimitry Andric   return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
5217fa27ce4SDimitry Andric }
5227fa27ce4SDimitry Andric 
getNegOperand(Value * V)5237fa27ce4SDimitry Andric Value *getNegOperand(Value *V) {
5247fa27ce4SDimitry Andric   assert(isNeg(V));
5257fa27ce4SDimitry Andric   auto *I = cast<Instruction>(V);
5267fa27ce4SDimitry Andric   if (I->getOpcode() == Instruction::FNeg)
5277fa27ce4SDimitry Andric     return I->getOperand(0);
5287fa27ce4SDimitry Andric 
5297fa27ce4SDimitry Andric   return I->getOperand(1);
5307fa27ce4SDimitry Andric }
5317fa27ce4SDimitry Andric 
evaluateBasicBlock(BasicBlock * B)532e3b55780SDimitry Andric bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
5337fa27ce4SDimitry Andric   ComplexDeinterleavingGraph Graph(TL, TLI);
5347fa27ce4SDimitry Andric   if (Graph.collectPotentialReductions(B))
5357fa27ce4SDimitry Andric     Graph.identifyReductionNodes();
536e3b55780SDimitry Andric 
5377fa27ce4SDimitry Andric   for (auto &I : *B)
5387fa27ce4SDimitry Andric     Graph.identifyNodes(&I);
539e3b55780SDimitry Andric 
5407fa27ce4SDimitry Andric   if (Graph.checkNodes()) {
541e3b55780SDimitry Andric     Graph.replaceNodes();
5427fa27ce4SDimitry Andric     return true;
543e3b55780SDimitry Andric   }
544e3b55780SDimitry Andric 
5457fa27ce4SDimitry Andric   return false;
546e3b55780SDimitry Andric }
547e3b55780SDimitry Andric 
548e3b55780SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyNodeWithImplicitAdd(Instruction * Real,Instruction * Imag,std::pair<Value *,Value * > & PartialMatch)549e3b55780SDimitry Andric ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
550e3b55780SDimitry Andric     Instruction *Real, Instruction *Imag,
5517fa27ce4SDimitry Andric     std::pair<Value *, Value *> &PartialMatch) {
552e3b55780SDimitry Andric   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
553e3b55780SDimitry Andric                     << "\n");
554e3b55780SDimitry Andric 
555e3b55780SDimitry Andric   if (!Real->hasOneUse() || !Imag->hasOneUse()) {
556e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
557e3b55780SDimitry Andric     return nullptr;
558e3b55780SDimitry Andric   }
559e3b55780SDimitry Andric 
5607fa27ce4SDimitry Andric   if ((Real->getOpcode() != Instruction::FMul &&
5617fa27ce4SDimitry Andric        Real->getOpcode() != Instruction::Mul) ||
5627fa27ce4SDimitry Andric       (Imag->getOpcode() != Instruction::FMul &&
5637fa27ce4SDimitry Andric        Imag->getOpcode() != Instruction::Mul)) {
5647fa27ce4SDimitry Andric     LLVM_DEBUG(
5657fa27ce4SDimitry Andric         dbgs() << "  - Real or imaginary instruction is not fmul or mul\n");
566e3b55780SDimitry Andric     return nullptr;
567e3b55780SDimitry Andric   }
568e3b55780SDimitry Andric 
5697fa27ce4SDimitry Andric   Value *R0 = Real->getOperand(0);
5707fa27ce4SDimitry Andric   Value *R1 = Real->getOperand(1);
5717fa27ce4SDimitry Andric   Value *I0 = Imag->getOperand(0);
5727fa27ce4SDimitry Andric   Value *I1 = Imag->getOperand(1);
573e3b55780SDimitry Andric 
574e3b55780SDimitry Andric   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
575e3b55780SDimitry Andric   // rotations and use the operand.
576e3b55780SDimitry Andric   unsigned Negs = 0;
5777fa27ce4SDimitry Andric   Value *Op;
5787fa27ce4SDimitry Andric   if (match(R0, m_Neg(m_Value(Op)))) {
579e3b55780SDimitry Andric     Negs |= 1;
5807fa27ce4SDimitry Andric     R0 = Op;
5817fa27ce4SDimitry Andric   } else if (match(R1, m_Neg(m_Value(Op)))) {
5827fa27ce4SDimitry Andric     Negs |= 1;
5837fa27ce4SDimitry Andric     R1 = Op;
584e3b55780SDimitry Andric   }
5857fa27ce4SDimitry Andric 
5867fa27ce4SDimitry Andric   if (isNeg(I0)) {
587e3b55780SDimitry Andric     Negs |= 2;
588e3b55780SDimitry Andric     Negs ^= 1;
5897fa27ce4SDimitry Andric     I0 = Op;
5907fa27ce4SDimitry Andric   } else if (match(I1, m_Neg(m_Value(Op)))) {
5917fa27ce4SDimitry Andric     Negs |= 2;
5927fa27ce4SDimitry Andric     Negs ^= 1;
5937fa27ce4SDimitry Andric     I1 = Op;
594e3b55780SDimitry Andric   }
595e3b55780SDimitry Andric 
596e3b55780SDimitry Andric   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
597e3b55780SDimitry Andric 
5987fa27ce4SDimitry Andric   Value *CommonOperand;
5997fa27ce4SDimitry Andric   Value *UncommonRealOp;
6007fa27ce4SDimitry Andric   Value *UncommonImagOp;
601e3b55780SDimitry Andric 
602e3b55780SDimitry Andric   if (R0 == I0 || R0 == I1) {
603e3b55780SDimitry Andric     CommonOperand = R0;
604e3b55780SDimitry Andric     UncommonRealOp = R1;
605e3b55780SDimitry Andric   } else if (R1 == I0 || R1 == I1) {
606e3b55780SDimitry Andric     CommonOperand = R1;
607e3b55780SDimitry Andric     UncommonRealOp = R0;
608e3b55780SDimitry Andric   } else {
609e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
610e3b55780SDimitry Andric     return nullptr;
611e3b55780SDimitry Andric   }
612e3b55780SDimitry Andric 
613e3b55780SDimitry Andric   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
614e3b55780SDimitry Andric   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
615e3b55780SDimitry Andric       Rotation == ComplexDeinterleavingRotation::Rotation_270)
616e3b55780SDimitry Andric     std::swap(UncommonRealOp, UncommonImagOp);
617e3b55780SDimitry Andric 
618e3b55780SDimitry Andric   // Between identifyPartialMul and here we need to have found a complete valid
619e3b55780SDimitry Andric   // pair from the CommonOperand of each part.
620e3b55780SDimitry Andric   if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
621e3b55780SDimitry Andric       Rotation == ComplexDeinterleavingRotation::Rotation_180)
622e3b55780SDimitry Andric     PartialMatch.first = CommonOperand;
623e3b55780SDimitry Andric   else
624e3b55780SDimitry Andric     PartialMatch.second = CommonOperand;
625e3b55780SDimitry Andric 
626e3b55780SDimitry Andric   if (!PartialMatch.first || !PartialMatch.second) {
627e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
628e3b55780SDimitry Andric     return nullptr;
629e3b55780SDimitry Andric   }
630e3b55780SDimitry Andric 
631e3b55780SDimitry Andric   NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
632e3b55780SDimitry Andric   if (!CommonNode) {
633e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
634e3b55780SDimitry Andric     return nullptr;
635e3b55780SDimitry Andric   }
636e3b55780SDimitry Andric 
637e3b55780SDimitry Andric   NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
638e3b55780SDimitry Andric   if (!UncommonNode) {
639e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
640e3b55780SDimitry Andric     return nullptr;
641e3b55780SDimitry Andric   }
642e3b55780SDimitry Andric 
643e3b55780SDimitry Andric   NodePtr Node = prepareCompositeNode(
644e3b55780SDimitry Andric       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
645e3b55780SDimitry Andric   Node->Rotation = Rotation;
646e3b55780SDimitry Andric   Node->addOperand(CommonNode);
647e3b55780SDimitry Andric   Node->addOperand(UncommonNode);
648e3b55780SDimitry Andric   return submitCompositeNode(Node);
649e3b55780SDimitry Andric }
650e3b55780SDimitry Andric 
651e3b55780SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyPartialMul(Instruction * Real,Instruction * Imag)652e3b55780SDimitry Andric ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
653e3b55780SDimitry Andric                                                Instruction *Imag) {
654e3b55780SDimitry Andric   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
655e3b55780SDimitry Andric                     << "\n");
656e3b55780SDimitry Andric   // Determine rotation
6577fa27ce4SDimitry Andric   auto IsAdd = [](unsigned Op) {
6587fa27ce4SDimitry Andric     return Op == Instruction::FAdd || Op == Instruction::Add;
6597fa27ce4SDimitry Andric   };
6607fa27ce4SDimitry Andric   auto IsSub = [](unsigned Op) {
6617fa27ce4SDimitry Andric     return Op == Instruction::FSub || Op == Instruction::Sub;
6627fa27ce4SDimitry Andric   };
663e3b55780SDimitry Andric   ComplexDeinterleavingRotation Rotation;
6647fa27ce4SDimitry Andric   if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
665e3b55780SDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_0;
6667fa27ce4SDimitry Andric   else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
667e3b55780SDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_90;
6687fa27ce4SDimitry Andric   else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
669e3b55780SDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_180;
6707fa27ce4SDimitry Andric   else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
671e3b55780SDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_270;
672e3b55780SDimitry Andric   else {
673e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
674e3b55780SDimitry Andric     return nullptr;
675e3b55780SDimitry Andric   }
676e3b55780SDimitry Andric 
6777fa27ce4SDimitry Andric   if (isa<FPMathOperator>(Real) &&
6787fa27ce4SDimitry Andric       (!Real->getFastMathFlags().allowContract() ||
6797fa27ce4SDimitry Andric        !Imag->getFastMathFlags().allowContract())) {
680e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
681e3b55780SDimitry Andric     return nullptr;
682e3b55780SDimitry Andric   }
683e3b55780SDimitry Andric 
684e3b55780SDimitry Andric   Value *CR = Real->getOperand(0);
685e3b55780SDimitry Andric   Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
686e3b55780SDimitry Andric   if (!RealMulI)
687e3b55780SDimitry Andric     return nullptr;
688e3b55780SDimitry Andric   Value *CI = Imag->getOperand(0);
689e3b55780SDimitry Andric   Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
690e3b55780SDimitry Andric   if (!ImagMulI)
691e3b55780SDimitry Andric     return nullptr;
692e3b55780SDimitry Andric 
693e3b55780SDimitry Andric   if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
694e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
695e3b55780SDimitry Andric     return nullptr;
696e3b55780SDimitry Andric   }
697e3b55780SDimitry Andric 
6987fa27ce4SDimitry Andric   Value *R0 = RealMulI->getOperand(0);
6997fa27ce4SDimitry Andric   Value *R1 = RealMulI->getOperand(1);
7007fa27ce4SDimitry Andric   Value *I0 = ImagMulI->getOperand(0);
7017fa27ce4SDimitry Andric   Value *I1 = ImagMulI->getOperand(1);
702e3b55780SDimitry Andric 
7037fa27ce4SDimitry Andric   Value *CommonOperand;
7047fa27ce4SDimitry Andric   Value *UncommonRealOp;
7057fa27ce4SDimitry Andric   Value *UncommonImagOp;
706e3b55780SDimitry Andric 
707e3b55780SDimitry Andric   if (R0 == I0 || R0 == I1) {
708e3b55780SDimitry Andric     CommonOperand = R0;
709e3b55780SDimitry Andric     UncommonRealOp = R1;
710e3b55780SDimitry Andric   } else if (R1 == I0 || R1 == I1) {
711e3b55780SDimitry Andric     CommonOperand = R1;
712e3b55780SDimitry Andric     UncommonRealOp = R0;
713e3b55780SDimitry Andric   } else {
714e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
715e3b55780SDimitry Andric     return nullptr;
716e3b55780SDimitry Andric   }
717e3b55780SDimitry Andric 
718e3b55780SDimitry Andric   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
719e3b55780SDimitry Andric   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
720e3b55780SDimitry Andric       Rotation == ComplexDeinterleavingRotation::Rotation_270)
721e3b55780SDimitry Andric     std::swap(UncommonRealOp, UncommonImagOp);
722e3b55780SDimitry Andric 
7237fa27ce4SDimitry Andric   std::pair<Value *, Value *> PartialMatch(
724e3b55780SDimitry Andric       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
725e3b55780SDimitry Andric        Rotation == ComplexDeinterleavingRotation::Rotation_180)
726e3b55780SDimitry Andric           ? CommonOperand
727e3b55780SDimitry Andric           : nullptr,
728e3b55780SDimitry Andric       (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
729e3b55780SDimitry Andric        Rotation == ComplexDeinterleavingRotation::Rotation_270)
730e3b55780SDimitry Andric           ? CommonOperand
731e3b55780SDimitry Andric           : nullptr);
7327fa27ce4SDimitry Andric 
7337fa27ce4SDimitry Andric   auto *CRInst = dyn_cast<Instruction>(CR);
7347fa27ce4SDimitry Andric   auto *CIInst = dyn_cast<Instruction>(CI);
7357fa27ce4SDimitry Andric 
7367fa27ce4SDimitry Andric   if (!CRInst || !CIInst) {
7377fa27ce4SDimitry Andric     LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
7387fa27ce4SDimitry Andric     return nullptr;
7397fa27ce4SDimitry Andric   }
7407fa27ce4SDimitry Andric 
7417fa27ce4SDimitry Andric   NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
742e3b55780SDimitry Andric   if (!CNode) {
743e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
744e3b55780SDimitry Andric     return nullptr;
745e3b55780SDimitry Andric   }
746e3b55780SDimitry Andric 
747e3b55780SDimitry Andric   NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
748e3b55780SDimitry Andric   if (!UncommonRes) {
749e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
750e3b55780SDimitry Andric     return nullptr;
751e3b55780SDimitry Andric   }
752e3b55780SDimitry Andric 
753e3b55780SDimitry Andric   assert(PartialMatch.first && PartialMatch.second);
754e3b55780SDimitry Andric   NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
755e3b55780SDimitry Andric   if (!CommonRes) {
756e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
757e3b55780SDimitry Andric     return nullptr;
758e3b55780SDimitry Andric   }
759e3b55780SDimitry Andric 
760e3b55780SDimitry Andric   NodePtr Node = prepareCompositeNode(
761e3b55780SDimitry Andric       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
762e3b55780SDimitry Andric   Node->Rotation = Rotation;
763e3b55780SDimitry Andric   Node->addOperand(CommonRes);
764e3b55780SDimitry Andric   Node->addOperand(UncommonRes);
765e3b55780SDimitry Andric   Node->addOperand(CNode);
766e3b55780SDimitry Andric   return submitCompositeNode(Node);
767e3b55780SDimitry Andric }
768e3b55780SDimitry Andric 
769e3b55780SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyAdd(Instruction * Real,Instruction * Imag)770e3b55780SDimitry Andric ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
771e3b55780SDimitry Andric   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
772e3b55780SDimitry Andric 
773e3b55780SDimitry Andric   // Determine rotation
774e3b55780SDimitry Andric   ComplexDeinterleavingRotation Rotation;
775e3b55780SDimitry Andric   if ((Real->getOpcode() == Instruction::FSub &&
776e3b55780SDimitry Andric        Imag->getOpcode() == Instruction::FAdd) ||
777e3b55780SDimitry Andric       (Real->getOpcode() == Instruction::Sub &&
778e3b55780SDimitry Andric        Imag->getOpcode() == Instruction::Add))
779e3b55780SDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_90;
780e3b55780SDimitry Andric   else if ((Real->getOpcode() == Instruction::FAdd &&
781e3b55780SDimitry Andric             Imag->getOpcode() == Instruction::FSub) ||
782e3b55780SDimitry Andric            (Real->getOpcode() == Instruction::Add &&
783e3b55780SDimitry Andric             Imag->getOpcode() == Instruction::Sub))
784e3b55780SDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_270;
785e3b55780SDimitry Andric   else {
786e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
787e3b55780SDimitry Andric     return nullptr;
788e3b55780SDimitry Andric   }
789e3b55780SDimitry Andric 
790e3b55780SDimitry Andric   auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
791e3b55780SDimitry Andric   auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
792e3b55780SDimitry Andric   auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
793e3b55780SDimitry Andric   auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
794e3b55780SDimitry Andric 
795e3b55780SDimitry Andric   if (!AR || !AI || !BR || !BI) {
796e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
797e3b55780SDimitry Andric     return nullptr;
798e3b55780SDimitry Andric   }
799e3b55780SDimitry Andric 
800e3b55780SDimitry Andric   NodePtr ResA = identifyNode(AR, AI);
801e3b55780SDimitry Andric   if (!ResA) {
802e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
803e3b55780SDimitry Andric     return nullptr;
804e3b55780SDimitry Andric   }
805e3b55780SDimitry Andric   NodePtr ResB = identifyNode(BR, BI);
806e3b55780SDimitry Andric   if (!ResB) {
807e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
808e3b55780SDimitry Andric     return nullptr;
809e3b55780SDimitry Andric   }
810e3b55780SDimitry Andric 
811e3b55780SDimitry Andric   NodePtr Node =
812e3b55780SDimitry Andric       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
813e3b55780SDimitry Andric   Node->Rotation = Rotation;
814e3b55780SDimitry Andric   Node->addOperand(ResA);
815e3b55780SDimitry Andric   Node->addOperand(ResB);
816e3b55780SDimitry Andric   return submitCompositeNode(Node);
817e3b55780SDimitry Andric }
818e3b55780SDimitry Andric 
isInstructionPairAdd(Instruction * A,Instruction * B)819e3b55780SDimitry Andric static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
820e3b55780SDimitry Andric   unsigned OpcA = A->getOpcode();
821e3b55780SDimitry Andric   unsigned OpcB = B->getOpcode();
822e3b55780SDimitry Andric 
823e3b55780SDimitry Andric   return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
824e3b55780SDimitry Andric          (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
825e3b55780SDimitry Andric          (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
826e3b55780SDimitry Andric          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
827e3b55780SDimitry Andric }
828e3b55780SDimitry Andric 
isInstructionPairMul(Instruction * A,Instruction * B)829e3b55780SDimitry Andric static bool isInstructionPairMul(Instruction *A, Instruction *B) {
830e3b55780SDimitry Andric   auto Pattern =
831e3b55780SDimitry Andric       m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
832e3b55780SDimitry Andric 
833e3b55780SDimitry Andric   return match(A, Pattern) && match(B, Pattern);
834e3b55780SDimitry Andric }
835e3b55780SDimitry Andric 
isInstructionPotentiallySymmetric(Instruction * I)8367fa27ce4SDimitry Andric static bool isInstructionPotentiallySymmetric(Instruction *I) {
8377fa27ce4SDimitry Andric   switch (I->getOpcode()) {
8387fa27ce4SDimitry Andric   case Instruction::FAdd:
8397fa27ce4SDimitry Andric   case Instruction::FSub:
8407fa27ce4SDimitry Andric   case Instruction::FMul:
8417fa27ce4SDimitry Andric   case Instruction::FNeg:
8427fa27ce4SDimitry Andric   case Instruction::Add:
8437fa27ce4SDimitry Andric   case Instruction::Sub:
8447fa27ce4SDimitry Andric   case Instruction::Mul:
8457fa27ce4SDimitry Andric     return true;
8467fa27ce4SDimitry Andric   default:
8477fa27ce4SDimitry Andric     return false;
8487fa27ce4SDimitry Andric   }
8497fa27ce4SDimitry Andric }
8507fa27ce4SDimitry Andric 
851e3b55780SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifySymmetricOperation(Instruction * Real,Instruction * Imag)8527fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
8537fa27ce4SDimitry Andric                                                        Instruction *Imag) {
8547fa27ce4SDimitry Andric   if (Real->getOpcode() != Imag->getOpcode())
8557fa27ce4SDimitry Andric     return nullptr;
8567fa27ce4SDimitry Andric 
8577fa27ce4SDimitry Andric   if (!isInstructionPotentiallySymmetric(Real) ||
8587fa27ce4SDimitry Andric       !isInstructionPotentiallySymmetric(Imag))
8597fa27ce4SDimitry Andric     return nullptr;
8607fa27ce4SDimitry Andric 
8617fa27ce4SDimitry Andric   auto *R0 = Real->getOperand(0);
8627fa27ce4SDimitry Andric   auto *I0 = Imag->getOperand(0);
8637fa27ce4SDimitry Andric 
8647fa27ce4SDimitry Andric   NodePtr Op0 = identifyNode(R0, I0);
8657fa27ce4SDimitry Andric   NodePtr Op1 = nullptr;
8667fa27ce4SDimitry Andric   if (Op0 == nullptr)
8677fa27ce4SDimitry Andric     return nullptr;
8687fa27ce4SDimitry Andric 
8697fa27ce4SDimitry Andric   if (Real->isBinaryOp()) {
8707fa27ce4SDimitry Andric     auto *R1 = Real->getOperand(1);
8717fa27ce4SDimitry Andric     auto *I1 = Imag->getOperand(1);
8727fa27ce4SDimitry Andric     Op1 = identifyNode(R1, I1);
8737fa27ce4SDimitry Andric     if (Op1 == nullptr)
8747fa27ce4SDimitry Andric       return nullptr;
8757fa27ce4SDimitry Andric   }
8767fa27ce4SDimitry Andric 
8777fa27ce4SDimitry Andric   if (isa<FPMathOperator>(Real) &&
8787fa27ce4SDimitry Andric       Real->getFastMathFlags() != Imag->getFastMathFlags())
8797fa27ce4SDimitry Andric     return nullptr;
8807fa27ce4SDimitry Andric 
8817fa27ce4SDimitry Andric   auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
8827fa27ce4SDimitry Andric                                    Real, Imag);
8837fa27ce4SDimitry Andric   Node->Opcode = Real->getOpcode();
8847fa27ce4SDimitry Andric   if (isa<FPMathOperator>(Real))
8857fa27ce4SDimitry Andric     Node->Flags = Real->getFastMathFlags();
8867fa27ce4SDimitry Andric 
8877fa27ce4SDimitry Andric   Node->addOperand(Op0);
8887fa27ce4SDimitry Andric   if (Real->isBinaryOp())
8897fa27ce4SDimitry Andric     Node->addOperand(Op1);
8907fa27ce4SDimitry Andric 
8917fa27ce4SDimitry Andric   return submitCompositeNode(Node);
8927fa27ce4SDimitry Andric }
8937fa27ce4SDimitry Andric 
8947fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyNode(Value * R,Value * I)8957fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
8967fa27ce4SDimitry Andric   LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
8977fa27ce4SDimitry Andric   assert(R->getType() == I->getType() &&
8987fa27ce4SDimitry Andric          "Real and imaginary parts should not have different types");
899b1c73532SDimitry Andric 
900b1c73532SDimitry Andric   auto It = CachedResult.find({R, I});
901b1c73532SDimitry Andric   if (It != CachedResult.end()) {
902e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
903b1c73532SDimitry Andric     return It->second;
904e3b55780SDimitry Andric   }
905e3b55780SDimitry Andric 
9067fa27ce4SDimitry Andric   if (NodePtr CN = identifySplat(R, I))
9077fa27ce4SDimitry Andric     return CN;
9087fa27ce4SDimitry Andric 
9097fa27ce4SDimitry Andric   auto *Real = dyn_cast<Instruction>(R);
9107fa27ce4SDimitry Andric   auto *Imag = dyn_cast<Instruction>(I);
9117fa27ce4SDimitry Andric   if (!Real || !Imag)
9127fa27ce4SDimitry Andric     return nullptr;
9137fa27ce4SDimitry Andric 
9147fa27ce4SDimitry Andric   if (NodePtr CN = identifyDeinterleave(Real, Imag))
9157fa27ce4SDimitry Andric     return CN;
9167fa27ce4SDimitry Andric 
9177fa27ce4SDimitry Andric   if (NodePtr CN = identifyPHINode(Real, Imag))
9187fa27ce4SDimitry Andric     return CN;
9197fa27ce4SDimitry Andric 
9207fa27ce4SDimitry Andric   if (NodePtr CN = identifySelectNode(Real, Imag))
9217fa27ce4SDimitry Andric     return CN;
9227fa27ce4SDimitry Andric 
9237fa27ce4SDimitry Andric   auto *VTy = cast<VectorType>(Real->getType());
9247fa27ce4SDimitry Andric   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
9257fa27ce4SDimitry Andric 
9267fa27ce4SDimitry Andric   bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
9277fa27ce4SDimitry Andric       ComplexDeinterleavingOperation::CMulPartial, NewVTy);
9287fa27ce4SDimitry Andric   bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
9297fa27ce4SDimitry Andric       ComplexDeinterleavingOperation::CAdd, NewVTy);
9307fa27ce4SDimitry Andric 
9317fa27ce4SDimitry Andric   if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
9327fa27ce4SDimitry Andric     if (NodePtr CN = identifyPartialMul(Real, Imag))
9337fa27ce4SDimitry Andric       return CN;
9347fa27ce4SDimitry Andric   }
9357fa27ce4SDimitry Andric 
9367fa27ce4SDimitry Andric   if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
9377fa27ce4SDimitry Andric     if (NodePtr CN = identifyAdd(Real, Imag))
9387fa27ce4SDimitry Andric       return CN;
9397fa27ce4SDimitry Andric   }
9407fa27ce4SDimitry Andric 
9417fa27ce4SDimitry Andric   if (HasCMulSupport && HasCAddSupport) {
9427fa27ce4SDimitry Andric     if (NodePtr CN = identifyReassocNodes(Real, Imag))
9437fa27ce4SDimitry Andric       return CN;
9447fa27ce4SDimitry Andric   }
9457fa27ce4SDimitry Andric 
9467fa27ce4SDimitry Andric   if (NodePtr CN = identifySymmetricOperation(Real, Imag))
9477fa27ce4SDimitry Andric     return CN;
9487fa27ce4SDimitry Andric 
9497fa27ce4SDimitry Andric   LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
950b1c73532SDimitry Andric   CachedResult[{R, I}] = nullptr;
9517fa27ce4SDimitry Andric   return nullptr;
9527fa27ce4SDimitry Andric }
9537fa27ce4SDimitry Andric 
9547fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyReassocNodes(Instruction * Real,Instruction * Imag)9557fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
9567fa27ce4SDimitry Andric                                                  Instruction *Imag) {
9577fa27ce4SDimitry Andric   auto IsOperationSupported = [](unsigned Opcode) -> bool {
9587fa27ce4SDimitry Andric     return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
9597fa27ce4SDimitry Andric            Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
9607fa27ce4SDimitry Andric            Opcode == Instruction::Sub;
9617fa27ce4SDimitry Andric   };
9627fa27ce4SDimitry Andric 
9637fa27ce4SDimitry Andric   if (!IsOperationSupported(Real->getOpcode()) ||
9647fa27ce4SDimitry Andric       !IsOperationSupported(Imag->getOpcode()))
9657fa27ce4SDimitry Andric     return nullptr;
9667fa27ce4SDimitry Andric 
9677fa27ce4SDimitry Andric   std::optional<FastMathFlags> Flags;
9687fa27ce4SDimitry Andric   if (isa<FPMathOperator>(Real)) {
9697fa27ce4SDimitry Andric     if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
9707fa27ce4SDimitry Andric       LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
9717fa27ce4SDimitry Andric                            "not identical\n");
9727fa27ce4SDimitry Andric       return nullptr;
9737fa27ce4SDimitry Andric     }
9747fa27ce4SDimitry Andric 
9757fa27ce4SDimitry Andric     Flags = Real->getFastMathFlags();
9767fa27ce4SDimitry Andric     if (!Flags->allowReassoc()) {
9777fa27ce4SDimitry Andric       LLVM_DEBUG(
9787fa27ce4SDimitry Andric           dbgs()
9797fa27ce4SDimitry Andric           << "the 'Reassoc' attribute is missing in the FastMath flags\n");
9807fa27ce4SDimitry Andric       return nullptr;
9817fa27ce4SDimitry Andric     }
9827fa27ce4SDimitry Andric   }
9837fa27ce4SDimitry Andric 
9847fa27ce4SDimitry Andric   // Collect multiplications and addend instructions from the given instruction
9857fa27ce4SDimitry Andric   // while traversing it operands. Additionally, verify that all instructions
9867fa27ce4SDimitry Andric   // have the same fast math flags.
9877fa27ce4SDimitry Andric   auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
9887fa27ce4SDimitry Andric                           std::list<Addend> &Addends) -> bool {
9897fa27ce4SDimitry Andric     SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
9907fa27ce4SDimitry Andric     SmallPtrSet<Value *, 8> Visited;
9917fa27ce4SDimitry Andric     while (!Worklist.empty()) {
9927fa27ce4SDimitry Andric       auto [V, IsPositive] = Worklist.back();
9937fa27ce4SDimitry Andric       Worklist.pop_back();
9947fa27ce4SDimitry Andric       if (!Visited.insert(V).second)
9957fa27ce4SDimitry Andric         continue;
9967fa27ce4SDimitry Andric 
9977fa27ce4SDimitry Andric       Instruction *I = dyn_cast<Instruction>(V);
9987fa27ce4SDimitry Andric       if (!I) {
9997fa27ce4SDimitry Andric         Addends.emplace_back(V, IsPositive);
10007fa27ce4SDimitry Andric         continue;
10017fa27ce4SDimitry Andric       }
10027fa27ce4SDimitry Andric 
10037fa27ce4SDimitry Andric       // If an instruction has more than one user, it indicates that it either
10047fa27ce4SDimitry Andric       // has an external user, which will be later checked by the checkNodes
10057fa27ce4SDimitry Andric       // function, or it is a subexpression utilized by multiple expressions. In
10067fa27ce4SDimitry Andric       // the latter case, we will attempt to separately identify the complex
10077fa27ce4SDimitry Andric       // operation from here in order to create a shared
10087fa27ce4SDimitry Andric       // ComplexDeinterleavingCompositeNode.
10097fa27ce4SDimitry Andric       if (I != Insn && I->getNumUses() > 1) {
10107fa27ce4SDimitry Andric         LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
10117fa27ce4SDimitry Andric         Addends.emplace_back(I, IsPositive);
10127fa27ce4SDimitry Andric         continue;
10137fa27ce4SDimitry Andric       }
10147fa27ce4SDimitry Andric       switch (I->getOpcode()) {
10157fa27ce4SDimitry Andric       case Instruction::FAdd:
10167fa27ce4SDimitry Andric       case Instruction::Add:
10177fa27ce4SDimitry Andric         Worklist.emplace_back(I->getOperand(1), IsPositive);
10187fa27ce4SDimitry Andric         Worklist.emplace_back(I->getOperand(0), IsPositive);
10197fa27ce4SDimitry Andric         break;
10207fa27ce4SDimitry Andric       case Instruction::FSub:
10217fa27ce4SDimitry Andric         Worklist.emplace_back(I->getOperand(1), !IsPositive);
10227fa27ce4SDimitry Andric         Worklist.emplace_back(I->getOperand(0), IsPositive);
10237fa27ce4SDimitry Andric         break;
10247fa27ce4SDimitry Andric       case Instruction::Sub:
10257fa27ce4SDimitry Andric         if (isNeg(I)) {
10267fa27ce4SDimitry Andric           Worklist.emplace_back(getNegOperand(I), !IsPositive);
10277fa27ce4SDimitry Andric         } else {
10287fa27ce4SDimitry Andric           Worklist.emplace_back(I->getOperand(1), !IsPositive);
10297fa27ce4SDimitry Andric           Worklist.emplace_back(I->getOperand(0), IsPositive);
10307fa27ce4SDimitry Andric         }
10317fa27ce4SDimitry Andric         break;
10327fa27ce4SDimitry Andric       case Instruction::FMul:
10337fa27ce4SDimitry Andric       case Instruction::Mul: {
10347fa27ce4SDimitry Andric         Value *A, *B;
10357fa27ce4SDimitry Andric         if (isNeg(I->getOperand(0))) {
10367fa27ce4SDimitry Andric           A = getNegOperand(I->getOperand(0));
10377fa27ce4SDimitry Andric           IsPositive = !IsPositive;
10387fa27ce4SDimitry Andric         } else {
10397fa27ce4SDimitry Andric           A = I->getOperand(0);
10407fa27ce4SDimitry Andric         }
10417fa27ce4SDimitry Andric 
10427fa27ce4SDimitry Andric         if (isNeg(I->getOperand(1))) {
10437fa27ce4SDimitry Andric           B = getNegOperand(I->getOperand(1));
10447fa27ce4SDimitry Andric           IsPositive = !IsPositive;
10457fa27ce4SDimitry Andric         } else {
10467fa27ce4SDimitry Andric           B = I->getOperand(1);
10477fa27ce4SDimitry Andric         }
10487fa27ce4SDimitry Andric         Muls.push_back(Product{A, B, IsPositive});
10497fa27ce4SDimitry Andric         break;
10507fa27ce4SDimitry Andric       }
10517fa27ce4SDimitry Andric       case Instruction::FNeg:
10527fa27ce4SDimitry Andric         Worklist.emplace_back(I->getOperand(0), !IsPositive);
10537fa27ce4SDimitry Andric         break;
10547fa27ce4SDimitry Andric       default:
10557fa27ce4SDimitry Andric         Addends.emplace_back(I, IsPositive);
10567fa27ce4SDimitry Andric         continue;
10577fa27ce4SDimitry Andric       }
10587fa27ce4SDimitry Andric 
10597fa27ce4SDimitry Andric       if (Flags && I->getFastMathFlags() != *Flags) {
10607fa27ce4SDimitry Andric         LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
10617fa27ce4SDimitry Andric                              "inconsistent with the root instructions' flags: "
10627fa27ce4SDimitry Andric                           << *I << "\n");
10637fa27ce4SDimitry Andric         return false;
10647fa27ce4SDimitry Andric       }
10657fa27ce4SDimitry Andric     }
10667fa27ce4SDimitry Andric     return true;
10677fa27ce4SDimitry Andric   };
10687fa27ce4SDimitry Andric 
10697fa27ce4SDimitry Andric   std::vector<Product> RealMuls, ImagMuls;
10707fa27ce4SDimitry Andric   std::list<Addend> RealAddends, ImagAddends;
10717fa27ce4SDimitry Andric   if (!Collect(Real, RealMuls, RealAddends) ||
10727fa27ce4SDimitry Andric       !Collect(Imag, ImagMuls, ImagAddends))
10737fa27ce4SDimitry Andric     return nullptr;
10747fa27ce4SDimitry Andric 
10757fa27ce4SDimitry Andric   if (RealAddends.size() != ImagAddends.size())
10767fa27ce4SDimitry Andric     return nullptr;
10777fa27ce4SDimitry Andric 
10787fa27ce4SDimitry Andric   NodePtr FinalNode;
10797fa27ce4SDimitry Andric   if (!RealMuls.empty() || !ImagMuls.empty()) {
10807fa27ce4SDimitry Andric     // If there are multiplicands, extract positive addend and use it as an
10817fa27ce4SDimitry Andric     // accumulator
10827fa27ce4SDimitry Andric     FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
10837fa27ce4SDimitry Andric     FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
10847fa27ce4SDimitry Andric     if (!FinalNode)
10857fa27ce4SDimitry Andric       return nullptr;
10867fa27ce4SDimitry Andric   }
10877fa27ce4SDimitry Andric 
10887fa27ce4SDimitry Andric   // Identify and process remaining additions
10897fa27ce4SDimitry Andric   if (!RealAddends.empty() || !ImagAddends.empty()) {
10907fa27ce4SDimitry Andric     FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
10917fa27ce4SDimitry Andric     if (!FinalNode)
10927fa27ce4SDimitry Andric       return nullptr;
10937fa27ce4SDimitry Andric   }
10947fa27ce4SDimitry Andric   assert(FinalNode && "FinalNode can not be nullptr here");
10957fa27ce4SDimitry Andric   // Set the Real and Imag fields of the final node and submit it
10967fa27ce4SDimitry Andric   FinalNode->Real = Real;
10977fa27ce4SDimitry Andric   FinalNode->Imag = Imag;
10987fa27ce4SDimitry Andric   submitCompositeNode(FinalNode);
10997fa27ce4SDimitry Andric   return FinalNode;
11007fa27ce4SDimitry Andric }
11017fa27ce4SDimitry Andric 
collectPartialMuls(const std::vector<Product> & RealMuls,const std::vector<Product> & ImagMuls,std::vector<PartialMulCandidate> & PartialMulCandidates)11027fa27ce4SDimitry Andric bool ComplexDeinterleavingGraph::collectPartialMuls(
11037fa27ce4SDimitry Andric     const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
11047fa27ce4SDimitry Andric     std::vector<PartialMulCandidate> &PartialMulCandidates) {
11057fa27ce4SDimitry Andric   // Helper function to extract a common operand from two products
11067fa27ce4SDimitry Andric   auto FindCommonInstruction = [](const Product &Real,
11077fa27ce4SDimitry Andric                                   const Product &Imag) -> Value * {
11087fa27ce4SDimitry Andric     if (Real.Multiplicand == Imag.Multiplicand ||
11097fa27ce4SDimitry Andric         Real.Multiplicand == Imag.Multiplier)
11107fa27ce4SDimitry Andric       return Real.Multiplicand;
11117fa27ce4SDimitry Andric 
11127fa27ce4SDimitry Andric     if (Real.Multiplier == Imag.Multiplicand ||
11137fa27ce4SDimitry Andric         Real.Multiplier == Imag.Multiplier)
11147fa27ce4SDimitry Andric       return Real.Multiplier;
11157fa27ce4SDimitry Andric 
11167fa27ce4SDimitry Andric     return nullptr;
11177fa27ce4SDimitry Andric   };
11187fa27ce4SDimitry Andric 
11197fa27ce4SDimitry Andric   // Iterating over real and imaginary multiplications to find common operands
11207fa27ce4SDimitry Andric   // If a common operand is found, a partial multiplication candidate is created
11217fa27ce4SDimitry Andric   // and added to the candidates vector The function returns false if no common
11227fa27ce4SDimitry Andric   // operands are found for any product
11237fa27ce4SDimitry Andric   for (unsigned i = 0; i < RealMuls.size(); ++i) {
11247fa27ce4SDimitry Andric     bool FoundCommon = false;
11257fa27ce4SDimitry Andric     for (unsigned j = 0; j < ImagMuls.size(); ++j) {
11267fa27ce4SDimitry Andric       auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
11277fa27ce4SDimitry Andric       if (!Common)
11287fa27ce4SDimitry Andric         continue;
11297fa27ce4SDimitry Andric 
11307fa27ce4SDimitry Andric       auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
11317fa27ce4SDimitry Andric                                                    : RealMuls[i].Multiplicand;
11327fa27ce4SDimitry Andric       auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
11337fa27ce4SDimitry Andric                                                    : ImagMuls[j].Multiplicand;
11347fa27ce4SDimitry Andric 
11357fa27ce4SDimitry Andric       auto Node = identifyNode(A, B);
11367fa27ce4SDimitry Andric       if (Node) {
11377fa27ce4SDimitry Andric         FoundCommon = true;
11387fa27ce4SDimitry Andric         PartialMulCandidates.push_back({Common, Node, i, j, false});
11397fa27ce4SDimitry Andric       }
11407fa27ce4SDimitry Andric 
11417fa27ce4SDimitry Andric       Node = identifyNode(B, A);
11427fa27ce4SDimitry Andric       if (Node) {
11437fa27ce4SDimitry Andric         FoundCommon = true;
11447fa27ce4SDimitry Andric         PartialMulCandidates.push_back({Common, Node, i, j, true});
11457fa27ce4SDimitry Andric       }
11467fa27ce4SDimitry Andric     }
11477fa27ce4SDimitry Andric     if (!FoundCommon)
11487fa27ce4SDimitry Andric       return false;
11497fa27ce4SDimitry Andric   }
11507fa27ce4SDimitry Andric   return true;
11517fa27ce4SDimitry Andric }
11527fa27ce4SDimitry Andric 
11537fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyMultiplications(std::vector<Product> & RealMuls,std::vector<Product> & ImagMuls,NodePtr Accumulator=nullptr)11547fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifyMultiplications(
11557fa27ce4SDimitry Andric     std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
11567fa27ce4SDimitry Andric     NodePtr Accumulator = nullptr) {
11577fa27ce4SDimitry Andric   if (RealMuls.size() != ImagMuls.size())
11587fa27ce4SDimitry Andric     return nullptr;
11597fa27ce4SDimitry Andric 
11607fa27ce4SDimitry Andric   std::vector<PartialMulCandidate> Info;
11617fa27ce4SDimitry Andric   if (!collectPartialMuls(RealMuls, ImagMuls, Info))
11627fa27ce4SDimitry Andric     return nullptr;
11637fa27ce4SDimitry Andric 
11647fa27ce4SDimitry Andric   // Map to store common instruction to node pointers
11657fa27ce4SDimitry Andric   std::map<Value *, NodePtr> CommonToNode;
11667fa27ce4SDimitry Andric   std::vector<bool> Processed(Info.size(), false);
11677fa27ce4SDimitry Andric   for (unsigned I = 0; I < Info.size(); ++I) {
11687fa27ce4SDimitry Andric     if (Processed[I])
11697fa27ce4SDimitry Andric       continue;
11707fa27ce4SDimitry Andric 
11717fa27ce4SDimitry Andric     PartialMulCandidate &InfoA = Info[I];
11727fa27ce4SDimitry Andric     for (unsigned J = I + 1; J < Info.size(); ++J) {
11737fa27ce4SDimitry Andric       if (Processed[J])
11747fa27ce4SDimitry Andric         continue;
11757fa27ce4SDimitry Andric 
11767fa27ce4SDimitry Andric       PartialMulCandidate &InfoB = Info[J];
11777fa27ce4SDimitry Andric       auto *InfoReal = &InfoA;
11787fa27ce4SDimitry Andric       auto *InfoImag = &InfoB;
11797fa27ce4SDimitry Andric 
11807fa27ce4SDimitry Andric       auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
11817fa27ce4SDimitry Andric       if (!NodeFromCommon) {
11827fa27ce4SDimitry Andric         std::swap(InfoReal, InfoImag);
11837fa27ce4SDimitry Andric         NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
11847fa27ce4SDimitry Andric       }
11857fa27ce4SDimitry Andric       if (!NodeFromCommon)
11867fa27ce4SDimitry Andric         continue;
11877fa27ce4SDimitry Andric 
11887fa27ce4SDimitry Andric       CommonToNode[InfoReal->Common] = NodeFromCommon;
11897fa27ce4SDimitry Andric       CommonToNode[InfoImag->Common] = NodeFromCommon;
11907fa27ce4SDimitry Andric       Processed[I] = true;
11917fa27ce4SDimitry Andric       Processed[J] = true;
11927fa27ce4SDimitry Andric     }
11937fa27ce4SDimitry Andric   }
11947fa27ce4SDimitry Andric 
11957fa27ce4SDimitry Andric   std::vector<bool> ProcessedReal(RealMuls.size(), false);
11967fa27ce4SDimitry Andric   std::vector<bool> ProcessedImag(ImagMuls.size(), false);
11977fa27ce4SDimitry Andric   NodePtr Result = Accumulator;
11987fa27ce4SDimitry Andric   for (auto &PMI : Info) {
11997fa27ce4SDimitry Andric     if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
12007fa27ce4SDimitry Andric       continue;
12017fa27ce4SDimitry Andric 
12027fa27ce4SDimitry Andric     auto It = CommonToNode.find(PMI.Common);
12037fa27ce4SDimitry Andric     // TODO: Process independent complex multiplications. Cases like this:
12047fa27ce4SDimitry Andric     //  A.real() * B where both A and B are complex numbers.
12057fa27ce4SDimitry Andric     if (It == CommonToNode.end()) {
12067fa27ce4SDimitry Andric       LLVM_DEBUG({
12077fa27ce4SDimitry Andric         dbgs() << "Unprocessed independent partial multiplication:\n";
12087fa27ce4SDimitry Andric         for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
12097fa27ce4SDimitry Andric           dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
12107fa27ce4SDimitry Andric                            << " multiplied by " << *Mul->Multiplicand << "\n";
12117fa27ce4SDimitry Andric       });
12127fa27ce4SDimitry Andric       return nullptr;
12137fa27ce4SDimitry Andric     }
12147fa27ce4SDimitry Andric 
12157fa27ce4SDimitry Andric     auto &RealMul = RealMuls[PMI.RealIdx];
12167fa27ce4SDimitry Andric     auto &ImagMul = ImagMuls[PMI.ImagIdx];
12177fa27ce4SDimitry Andric 
12187fa27ce4SDimitry Andric     auto NodeA = It->second;
12197fa27ce4SDimitry Andric     auto NodeB = PMI.Node;
12207fa27ce4SDimitry Andric     auto IsMultiplicandReal = PMI.Common == NodeA->Real;
12217fa27ce4SDimitry Andric     // The following table illustrates the relationship between multiplications
12227fa27ce4SDimitry Andric     // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
12237fa27ce4SDimitry Andric     // can see:
12247fa27ce4SDimitry Andric     //
12257fa27ce4SDimitry Andric     // Rotation |   Real |   Imag |
12267fa27ce4SDimitry Andric     // ---------+--------+--------+
12277fa27ce4SDimitry Andric     //        0 |  x * u |  x * v |
12287fa27ce4SDimitry Andric     //       90 | -y * v |  y * u |
12297fa27ce4SDimitry Andric     //      180 | -x * u | -x * v |
12307fa27ce4SDimitry Andric     //      270 |  y * v | -y * u |
12317fa27ce4SDimitry Andric     //
12327fa27ce4SDimitry Andric     // Check if the candidate can indeed be represented by partial
12337fa27ce4SDimitry Andric     // multiplication
12347fa27ce4SDimitry Andric     // TODO: Add support for multiplication by complex one
12357fa27ce4SDimitry Andric     if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
12367fa27ce4SDimitry Andric         (!IsMultiplicandReal && !PMI.IsNodeInverted))
12377fa27ce4SDimitry Andric       continue;
12387fa27ce4SDimitry Andric 
12397fa27ce4SDimitry Andric     // Determine the rotation based on the multiplications
12407fa27ce4SDimitry Andric     ComplexDeinterleavingRotation Rotation;
12417fa27ce4SDimitry Andric     if (IsMultiplicandReal) {
12427fa27ce4SDimitry Andric       // Detect 0 and 180 degrees rotation
12437fa27ce4SDimitry Andric       if (RealMul.IsPositive && ImagMul.IsPositive)
12447fa27ce4SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
12457fa27ce4SDimitry Andric       else if (!RealMul.IsPositive && !ImagMul.IsPositive)
12467fa27ce4SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
12477fa27ce4SDimitry Andric       else
12487fa27ce4SDimitry Andric         continue;
12497fa27ce4SDimitry Andric 
12507fa27ce4SDimitry Andric     } else {
12517fa27ce4SDimitry Andric       // Detect 90 and 270 degrees rotation
12527fa27ce4SDimitry Andric       if (!RealMul.IsPositive && ImagMul.IsPositive)
12537fa27ce4SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
12547fa27ce4SDimitry Andric       else if (RealMul.IsPositive && !ImagMul.IsPositive)
12557fa27ce4SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
12567fa27ce4SDimitry Andric       else
12577fa27ce4SDimitry Andric         continue;
12587fa27ce4SDimitry Andric     }
12597fa27ce4SDimitry Andric 
12607fa27ce4SDimitry Andric     LLVM_DEBUG({
12617fa27ce4SDimitry Andric       dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
12627fa27ce4SDimitry Andric       dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
12637fa27ce4SDimitry Andric       dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
12647fa27ce4SDimitry Andric       dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
12657fa27ce4SDimitry Andric       dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
12667fa27ce4SDimitry Andric       dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
12677fa27ce4SDimitry Andric     });
12687fa27ce4SDimitry Andric 
12697fa27ce4SDimitry Andric     NodePtr NodeMul = prepareCompositeNode(
12707fa27ce4SDimitry Andric         ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
12717fa27ce4SDimitry Andric     NodeMul->Rotation = Rotation;
12727fa27ce4SDimitry Andric     NodeMul->addOperand(NodeA);
12737fa27ce4SDimitry Andric     NodeMul->addOperand(NodeB);
12747fa27ce4SDimitry Andric     if (Result)
12757fa27ce4SDimitry Andric       NodeMul->addOperand(Result);
12767fa27ce4SDimitry Andric     submitCompositeNode(NodeMul);
12777fa27ce4SDimitry Andric     Result = NodeMul;
12787fa27ce4SDimitry Andric     ProcessedReal[PMI.RealIdx] = true;
12797fa27ce4SDimitry Andric     ProcessedImag[PMI.ImagIdx] = true;
12807fa27ce4SDimitry Andric   }
12817fa27ce4SDimitry Andric 
12827fa27ce4SDimitry Andric   // Ensure all products have been processed, if not return nullptr.
12837fa27ce4SDimitry Andric   if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
12847fa27ce4SDimitry Andric       !all_of(ProcessedImag, [](bool V) { return V; })) {
12857fa27ce4SDimitry Andric 
12867fa27ce4SDimitry Andric     // Dump debug information about which partial multiplications are not
12877fa27ce4SDimitry Andric     // processed.
12887fa27ce4SDimitry Andric     LLVM_DEBUG({
12897fa27ce4SDimitry Andric       dbgs() << "Unprocessed products (Real):\n";
12907fa27ce4SDimitry Andric       for (size_t i = 0; i < ProcessedReal.size(); ++i) {
12917fa27ce4SDimitry Andric         if (!ProcessedReal[i])
12927fa27ce4SDimitry Andric           dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
12937fa27ce4SDimitry Andric                            << *RealMuls[i].Multiplier << " multiplied by "
12947fa27ce4SDimitry Andric                            << *RealMuls[i].Multiplicand << "\n";
12957fa27ce4SDimitry Andric       }
12967fa27ce4SDimitry Andric       dbgs() << "Unprocessed products (Imag):\n";
12977fa27ce4SDimitry Andric       for (size_t i = 0; i < ProcessedImag.size(); ++i) {
12987fa27ce4SDimitry Andric         if (!ProcessedImag[i])
12997fa27ce4SDimitry Andric           dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
13007fa27ce4SDimitry Andric                            << *ImagMuls[i].Multiplier << " multiplied by "
13017fa27ce4SDimitry Andric                            << *ImagMuls[i].Multiplicand << "\n";
13027fa27ce4SDimitry Andric       }
13037fa27ce4SDimitry Andric     });
13047fa27ce4SDimitry Andric     return nullptr;
13057fa27ce4SDimitry Andric   }
13067fa27ce4SDimitry Andric 
13077fa27ce4SDimitry Andric   return Result;
13087fa27ce4SDimitry Andric }
13097fa27ce4SDimitry Andric 
13107fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyAdditions(std::list<Addend> & RealAddends,std::list<Addend> & ImagAddends,std::optional<FastMathFlags> Flags,NodePtr Accumulator=nullptr)13117fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifyAdditions(
13127fa27ce4SDimitry Andric     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
13137fa27ce4SDimitry Andric     std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
13147fa27ce4SDimitry Andric   if (RealAddends.size() != ImagAddends.size())
13157fa27ce4SDimitry Andric     return nullptr;
13167fa27ce4SDimitry Andric 
13177fa27ce4SDimitry Andric   NodePtr Result;
13187fa27ce4SDimitry Andric   // If we have accumulator use it as first addend
13197fa27ce4SDimitry Andric   if (Accumulator)
13207fa27ce4SDimitry Andric     Result = Accumulator;
13217fa27ce4SDimitry Andric   // Otherwise find an element with both positive real and imaginary parts.
13227fa27ce4SDimitry Andric   else
13237fa27ce4SDimitry Andric     Result = extractPositiveAddend(RealAddends, ImagAddends);
13247fa27ce4SDimitry Andric 
13257fa27ce4SDimitry Andric   if (!Result)
13267fa27ce4SDimitry Andric     return nullptr;
13277fa27ce4SDimitry Andric 
13287fa27ce4SDimitry Andric   while (!RealAddends.empty()) {
13297fa27ce4SDimitry Andric     auto ItR = RealAddends.begin();
13307fa27ce4SDimitry Andric     auto [R, IsPositiveR] = *ItR;
13317fa27ce4SDimitry Andric 
13327fa27ce4SDimitry Andric     bool FoundImag = false;
13337fa27ce4SDimitry Andric     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
13347fa27ce4SDimitry Andric       auto [I, IsPositiveI] = *ItI;
13357fa27ce4SDimitry Andric       ComplexDeinterleavingRotation Rotation;
13367fa27ce4SDimitry Andric       if (IsPositiveR && IsPositiveI)
13377fa27ce4SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_0;
13387fa27ce4SDimitry Andric       else if (!IsPositiveR && IsPositiveI)
13397fa27ce4SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_90;
13407fa27ce4SDimitry Andric       else if (!IsPositiveR && !IsPositiveI)
13417fa27ce4SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_180;
13427fa27ce4SDimitry Andric       else
13437fa27ce4SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_270;
13447fa27ce4SDimitry Andric 
13457fa27ce4SDimitry Andric       NodePtr AddNode;
13467fa27ce4SDimitry Andric       if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
13477fa27ce4SDimitry Andric           Rotation == ComplexDeinterleavingRotation::Rotation_180) {
13487fa27ce4SDimitry Andric         AddNode = identifyNode(R, I);
13497fa27ce4SDimitry Andric       } else {
13507fa27ce4SDimitry Andric         AddNode = identifyNode(I, R);
13517fa27ce4SDimitry Andric       }
13527fa27ce4SDimitry Andric       if (AddNode) {
13537fa27ce4SDimitry Andric         LLVM_DEBUG({
13547fa27ce4SDimitry Andric           dbgs() << "Identified addition:\n";
13557fa27ce4SDimitry Andric           dbgs().indent(4) << "X: " << *R << "\n";
13567fa27ce4SDimitry Andric           dbgs().indent(4) << "Y: " << *I << "\n";
13577fa27ce4SDimitry Andric           dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
13587fa27ce4SDimitry Andric         });
13597fa27ce4SDimitry Andric 
13607fa27ce4SDimitry Andric         NodePtr TmpNode;
13617fa27ce4SDimitry Andric         if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
13627fa27ce4SDimitry Andric           TmpNode = prepareCompositeNode(
13637fa27ce4SDimitry Andric               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
13647fa27ce4SDimitry Andric           if (Flags) {
13657fa27ce4SDimitry Andric             TmpNode->Opcode = Instruction::FAdd;
13667fa27ce4SDimitry Andric             TmpNode->Flags = *Flags;
13677fa27ce4SDimitry Andric           } else {
13687fa27ce4SDimitry Andric             TmpNode->Opcode = Instruction::Add;
13697fa27ce4SDimitry Andric           }
13707fa27ce4SDimitry Andric         } else if (Rotation ==
13717fa27ce4SDimitry Andric                    llvm::ComplexDeinterleavingRotation::Rotation_180) {
13727fa27ce4SDimitry Andric           TmpNode = prepareCompositeNode(
13737fa27ce4SDimitry Andric               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
13747fa27ce4SDimitry Andric           if (Flags) {
13757fa27ce4SDimitry Andric             TmpNode->Opcode = Instruction::FSub;
13767fa27ce4SDimitry Andric             TmpNode->Flags = *Flags;
13777fa27ce4SDimitry Andric           } else {
13787fa27ce4SDimitry Andric             TmpNode->Opcode = Instruction::Sub;
13797fa27ce4SDimitry Andric           }
13807fa27ce4SDimitry Andric         } else {
13817fa27ce4SDimitry Andric           TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
13827fa27ce4SDimitry Andric                                          nullptr, nullptr);
13837fa27ce4SDimitry Andric           TmpNode->Rotation = Rotation;
13847fa27ce4SDimitry Andric         }
13857fa27ce4SDimitry Andric 
13867fa27ce4SDimitry Andric         TmpNode->addOperand(Result);
13877fa27ce4SDimitry Andric         TmpNode->addOperand(AddNode);
13887fa27ce4SDimitry Andric         submitCompositeNode(TmpNode);
13897fa27ce4SDimitry Andric         Result = TmpNode;
13907fa27ce4SDimitry Andric         RealAddends.erase(ItR);
13917fa27ce4SDimitry Andric         ImagAddends.erase(ItI);
13927fa27ce4SDimitry Andric         FoundImag = true;
13937fa27ce4SDimitry Andric         break;
13947fa27ce4SDimitry Andric       }
13957fa27ce4SDimitry Andric     }
13967fa27ce4SDimitry Andric     if (!FoundImag)
13977fa27ce4SDimitry Andric       return nullptr;
13987fa27ce4SDimitry Andric   }
13997fa27ce4SDimitry Andric   return Result;
14007fa27ce4SDimitry Andric }
14017fa27ce4SDimitry Andric 
14027fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
extractPositiveAddend(std::list<Addend> & RealAddends,std::list<Addend> & ImagAddends)14037fa27ce4SDimitry Andric ComplexDeinterleavingGraph::extractPositiveAddend(
14047fa27ce4SDimitry Andric     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
14057fa27ce4SDimitry Andric   for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
14067fa27ce4SDimitry Andric     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
14077fa27ce4SDimitry Andric       auto [R, IsPositiveR] = *ItR;
14087fa27ce4SDimitry Andric       auto [I, IsPositiveI] = *ItI;
14097fa27ce4SDimitry Andric       if (IsPositiveR && IsPositiveI) {
14107fa27ce4SDimitry Andric         auto Result = identifyNode(R, I);
14117fa27ce4SDimitry Andric         if (Result) {
14127fa27ce4SDimitry Andric           RealAddends.erase(ItR);
14137fa27ce4SDimitry Andric           ImagAddends.erase(ItI);
14147fa27ce4SDimitry Andric           return Result;
14157fa27ce4SDimitry Andric         }
14167fa27ce4SDimitry Andric       }
14177fa27ce4SDimitry Andric     }
14187fa27ce4SDimitry Andric   }
14197fa27ce4SDimitry Andric   return nullptr;
14207fa27ce4SDimitry Andric }
14217fa27ce4SDimitry Andric 
identifyNodes(Instruction * RootI)14227fa27ce4SDimitry Andric bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
14237fa27ce4SDimitry Andric   // This potential root instruction might already have been recognized as
14247fa27ce4SDimitry Andric   // reduction. Because RootToNode maps both Real and Imaginary parts to
14257fa27ce4SDimitry Andric   // CompositeNode we should choose only one either Real or Imag instruction to
14267fa27ce4SDimitry Andric   // use as an anchor for generating complex instruction.
14277fa27ce4SDimitry Andric   auto It = RootToNode.find(RootI);
1428b1c73532SDimitry Andric   if (It != RootToNode.end()) {
1429b1c73532SDimitry Andric     auto RootNode = It->second;
1430b1c73532SDimitry Andric     assert(RootNode->Operation ==
1431b1c73532SDimitry Andric            ComplexDeinterleavingOperation::ReductionOperation);
1432b1c73532SDimitry Andric     // Find out which part, Real or Imag, comes later, and only if we come to
1433b1c73532SDimitry Andric     // the latest part, add it to OrderedRoots.
1434b1c73532SDimitry Andric     auto *R = cast<Instruction>(RootNode->Real);
1435b1c73532SDimitry Andric     auto *I = cast<Instruction>(RootNode->Imag);
1436b1c73532SDimitry Andric     auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
1437b1c73532SDimitry Andric     if (ReplacementAnchor != RootI)
1438b1c73532SDimitry Andric       return false;
14397fa27ce4SDimitry Andric     OrderedRoots.push_back(RootI);
14407fa27ce4SDimitry Andric     return true;
14417fa27ce4SDimitry Andric   }
14427fa27ce4SDimitry Andric 
14437fa27ce4SDimitry Andric   auto RootNode = identifyRoot(RootI);
14447fa27ce4SDimitry Andric   if (!RootNode)
14457fa27ce4SDimitry Andric     return false;
14467fa27ce4SDimitry Andric 
14477fa27ce4SDimitry Andric   LLVM_DEBUG({
14487fa27ce4SDimitry Andric     Function *F = RootI->getFunction();
14497fa27ce4SDimitry Andric     BasicBlock *B = RootI->getParent();
14507fa27ce4SDimitry Andric     dbgs() << "Complex deinterleaving graph for " << F->getName()
14517fa27ce4SDimitry Andric            << "::" << B->getName() << ".\n";
14527fa27ce4SDimitry Andric     dump(dbgs());
14537fa27ce4SDimitry Andric     dbgs() << "\n";
14547fa27ce4SDimitry Andric   });
14557fa27ce4SDimitry Andric   RootToNode[RootI] = RootNode;
14567fa27ce4SDimitry Andric   OrderedRoots.push_back(RootI);
14577fa27ce4SDimitry Andric   return true;
14587fa27ce4SDimitry Andric }
14597fa27ce4SDimitry Andric 
collectPotentialReductions(BasicBlock * B)14607fa27ce4SDimitry Andric bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
14617fa27ce4SDimitry Andric   bool FoundPotentialReduction = false;
14627fa27ce4SDimitry Andric 
14637fa27ce4SDimitry Andric   auto *Br = dyn_cast<BranchInst>(B->getTerminator());
14647fa27ce4SDimitry Andric   if (!Br || Br->getNumSuccessors() != 2)
14657fa27ce4SDimitry Andric     return false;
14667fa27ce4SDimitry Andric 
14677fa27ce4SDimitry Andric   // Identify simple one-block loop
14687fa27ce4SDimitry Andric   if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
14697fa27ce4SDimitry Andric     return false;
14707fa27ce4SDimitry Andric 
14717fa27ce4SDimitry Andric   SmallVector<PHINode *> PHIs;
14727fa27ce4SDimitry Andric   for (auto &PHI : B->phis()) {
14737fa27ce4SDimitry Andric     if (PHI.getNumIncomingValues() != 2)
14747fa27ce4SDimitry Andric       continue;
14757fa27ce4SDimitry Andric 
14767fa27ce4SDimitry Andric     if (!PHI.getType()->isVectorTy())
14777fa27ce4SDimitry Andric       continue;
14787fa27ce4SDimitry Andric 
14797fa27ce4SDimitry Andric     auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
14807fa27ce4SDimitry Andric     if (!ReductionOp)
14817fa27ce4SDimitry Andric       continue;
14827fa27ce4SDimitry Andric 
14837fa27ce4SDimitry Andric     // Check if final instruction is reduced outside of current block
14847fa27ce4SDimitry Andric     Instruction *FinalReduction = nullptr;
14857fa27ce4SDimitry Andric     auto NumUsers = 0u;
14867fa27ce4SDimitry Andric     for (auto *U : ReductionOp->users()) {
14877fa27ce4SDimitry Andric       ++NumUsers;
14887fa27ce4SDimitry Andric       if (U == &PHI)
14897fa27ce4SDimitry Andric         continue;
14907fa27ce4SDimitry Andric       FinalReduction = dyn_cast<Instruction>(U);
14917fa27ce4SDimitry Andric     }
14927fa27ce4SDimitry Andric 
14937fa27ce4SDimitry Andric     if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
14947fa27ce4SDimitry Andric         isa<PHINode>(FinalReduction))
14957fa27ce4SDimitry Andric       continue;
14967fa27ce4SDimitry Andric 
14977fa27ce4SDimitry Andric     ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
14987fa27ce4SDimitry Andric     BackEdge = B;
14997fa27ce4SDimitry Andric     auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
15007fa27ce4SDimitry Andric     auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
15017fa27ce4SDimitry Andric     Incoming = PHI.getIncomingBlock(IncomingIdx);
15027fa27ce4SDimitry Andric     FoundPotentialReduction = true;
15037fa27ce4SDimitry Andric 
15047fa27ce4SDimitry Andric     // If the initial value of PHINode is an Instruction, consider it a leaf
15057fa27ce4SDimitry Andric     // value of a complex deinterleaving graph.
15067fa27ce4SDimitry Andric     if (auto *InitPHI =
15077fa27ce4SDimitry Andric             dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
15087fa27ce4SDimitry Andric       FinalInstructions.insert(InitPHI);
15097fa27ce4SDimitry Andric   }
15107fa27ce4SDimitry Andric   return FoundPotentialReduction;
15117fa27ce4SDimitry Andric }
15127fa27ce4SDimitry Andric 
identifyReductionNodes()15137fa27ce4SDimitry Andric void ComplexDeinterleavingGraph::identifyReductionNodes() {
15147fa27ce4SDimitry Andric   SmallVector<bool> Processed(ReductionInfo.size(), false);
15157fa27ce4SDimitry Andric   SmallVector<Instruction *> OperationInstruction;
15167fa27ce4SDimitry Andric   for (auto &P : ReductionInfo)
15177fa27ce4SDimitry Andric     OperationInstruction.push_back(P.first);
15187fa27ce4SDimitry Andric 
15197fa27ce4SDimitry Andric   // Identify a complex computation by evaluating two reduction operations that
15207fa27ce4SDimitry Andric   // potentially could be involved
15217fa27ce4SDimitry Andric   for (size_t i = 0; i < OperationInstruction.size(); ++i) {
15227fa27ce4SDimitry Andric     if (Processed[i])
15237fa27ce4SDimitry Andric       continue;
15247fa27ce4SDimitry Andric     for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
15257fa27ce4SDimitry Andric       if (Processed[j])
15267fa27ce4SDimitry Andric         continue;
15277fa27ce4SDimitry Andric 
15287fa27ce4SDimitry Andric       auto *Real = OperationInstruction[i];
15297fa27ce4SDimitry Andric       auto *Imag = OperationInstruction[j];
15307fa27ce4SDimitry Andric       if (Real->getType() != Imag->getType())
15317fa27ce4SDimitry Andric         continue;
15327fa27ce4SDimitry Andric 
15337fa27ce4SDimitry Andric       RealPHI = ReductionInfo[Real].first;
15347fa27ce4SDimitry Andric       ImagPHI = ReductionInfo[Imag].first;
15357fa27ce4SDimitry Andric       PHIsFound = false;
15367fa27ce4SDimitry Andric       auto Node = identifyNode(Real, Imag);
15377fa27ce4SDimitry Andric       if (!Node) {
15387fa27ce4SDimitry Andric         std::swap(Real, Imag);
15397fa27ce4SDimitry Andric         std::swap(RealPHI, ImagPHI);
15407fa27ce4SDimitry Andric         Node = identifyNode(Real, Imag);
15417fa27ce4SDimitry Andric       }
15427fa27ce4SDimitry Andric 
15437fa27ce4SDimitry Andric       // If a node is identified and reduction PHINode is used in the chain of
15447fa27ce4SDimitry Andric       // operations, mark its operation instructions as used to prevent
15457fa27ce4SDimitry Andric       // re-identification and attach the node to the real part
15467fa27ce4SDimitry Andric       if (Node && PHIsFound) {
15477fa27ce4SDimitry Andric         LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
15487fa27ce4SDimitry Andric                           << *Real << " / " << *Imag << "\n");
15497fa27ce4SDimitry Andric         Processed[i] = true;
15507fa27ce4SDimitry Andric         Processed[j] = true;
15517fa27ce4SDimitry Andric         auto RootNode = prepareCompositeNode(
15527fa27ce4SDimitry Andric             ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
15537fa27ce4SDimitry Andric         RootNode->addOperand(Node);
15547fa27ce4SDimitry Andric         RootToNode[Real] = RootNode;
15557fa27ce4SDimitry Andric         RootToNode[Imag] = RootNode;
15567fa27ce4SDimitry Andric         submitCompositeNode(RootNode);
15577fa27ce4SDimitry Andric         break;
15587fa27ce4SDimitry Andric       }
15597fa27ce4SDimitry Andric     }
15607fa27ce4SDimitry Andric   }
15617fa27ce4SDimitry Andric 
15627fa27ce4SDimitry Andric   RealPHI = nullptr;
15637fa27ce4SDimitry Andric   ImagPHI = nullptr;
15647fa27ce4SDimitry Andric }
15657fa27ce4SDimitry Andric 
checkNodes()15667fa27ce4SDimitry Andric bool ComplexDeinterleavingGraph::checkNodes() {
15677fa27ce4SDimitry Andric   // Collect all instructions from roots to leaves
15687fa27ce4SDimitry Andric   SmallPtrSet<Instruction *, 16> AllInstructions;
15697fa27ce4SDimitry Andric   SmallVector<Instruction *, 8> Worklist;
15707fa27ce4SDimitry Andric   for (auto &Pair : RootToNode)
15717fa27ce4SDimitry Andric     Worklist.push_back(Pair.first);
15727fa27ce4SDimitry Andric 
15737fa27ce4SDimitry Andric   // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
15747fa27ce4SDimitry Andric   // chains
15757fa27ce4SDimitry Andric   while (!Worklist.empty()) {
15767fa27ce4SDimitry Andric     auto *I = Worklist.back();
15777fa27ce4SDimitry Andric     Worklist.pop_back();
15787fa27ce4SDimitry Andric 
15797fa27ce4SDimitry Andric     if (!AllInstructions.insert(I).second)
15807fa27ce4SDimitry Andric       continue;
15817fa27ce4SDimitry Andric 
15827fa27ce4SDimitry Andric     for (Value *Op : I->operands()) {
15837fa27ce4SDimitry Andric       if (auto *OpI = dyn_cast<Instruction>(Op)) {
15847fa27ce4SDimitry Andric         if (!FinalInstructions.count(I))
15857fa27ce4SDimitry Andric           Worklist.emplace_back(OpI);
15867fa27ce4SDimitry Andric       }
15877fa27ce4SDimitry Andric     }
15887fa27ce4SDimitry Andric   }
15897fa27ce4SDimitry Andric 
15907fa27ce4SDimitry Andric   // Find instructions that have users outside of chain
15917fa27ce4SDimitry Andric   SmallVector<Instruction *, 2> OuterInstructions;
15927fa27ce4SDimitry Andric   for (auto *I : AllInstructions) {
15937fa27ce4SDimitry Andric     // Skip root nodes
15947fa27ce4SDimitry Andric     if (RootToNode.count(I))
15957fa27ce4SDimitry Andric       continue;
15967fa27ce4SDimitry Andric 
15977fa27ce4SDimitry Andric     for (User *U : I->users()) {
15987fa27ce4SDimitry Andric       if (AllInstructions.count(cast<Instruction>(U)))
15997fa27ce4SDimitry Andric         continue;
16007fa27ce4SDimitry Andric 
16017fa27ce4SDimitry Andric       // Found an instruction that is not used by XCMLA/XCADD chain
16027fa27ce4SDimitry Andric       Worklist.emplace_back(I);
16037fa27ce4SDimitry Andric       break;
16047fa27ce4SDimitry Andric     }
16057fa27ce4SDimitry Andric   }
16067fa27ce4SDimitry Andric 
16077fa27ce4SDimitry Andric   // If any instructions are found to be used outside, find and remove roots
16087fa27ce4SDimitry Andric   // that somehow connect to those instructions.
16097fa27ce4SDimitry Andric   SmallPtrSet<Instruction *, 16> Visited;
16107fa27ce4SDimitry Andric   while (!Worklist.empty()) {
16117fa27ce4SDimitry Andric     auto *I = Worklist.back();
16127fa27ce4SDimitry Andric     Worklist.pop_back();
16137fa27ce4SDimitry Andric     if (!Visited.insert(I).second)
16147fa27ce4SDimitry Andric       continue;
16157fa27ce4SDimitry Andric 
16167fa27ce4SDimitry Andric     // Found an impacted root node. Removing it from the nodes to be
16177fa27ce4SDimitry Andric     // deinterleaved
16187fa27ce4SDimitry Andric     if (RootToNode.count(I)) {
16197fa27ce4SDimitry Andric       LLVM_DEBUG(dbgs() << "Instruction " << *I
16207fa27ce4SDimitry Andric                         << " could be deinterleaved but its chain of complex "
16217fa27ce4SDimitry Andric                            "operations have an outside user\n");
16227fa27ce4SDimitry Andric       RootToNode.erase(I);
16237fa27ce4SDimitry Andric     }
16247fa27ce4SDimitry Andric 
16257fa27ce4SDimitry Andric     if (!AllInstructions.count(I) || FinalInstructions.count(I))
16267fa27ce4SDimitry Andric       continue;
16277fa27ce4SDimitry Andric 
16287fa27ce4SDimitry Andric     for (User *U : I->users())
16297fa27ce4SDimitry Andric       Worklist.emplace_back(cast<Instruction>(U));
16307fa27ce4SDimitry Andric 
16317fa27ce4SDimitry Andric     for (Value *Op : I->operands()) {
16327fa27ce4SDimitry Andric       if (auto *OpI = dyn_cast<Instruction>(Op))
16337fa27ce4SDimitry Andric         Worklist.emplace_back(OpI);
16347fa27ce4SDimitry Andric     }
16357fa27ce4SDimitry Andric   }
16367fa27ce4SDimitry Andric   return !RootToNode.empty();
16377fa27ce4SDimitry Andric }
16387fa27ce4SDimitry Andric 
16397fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyRoot(Instruction * RootI)16407fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
16417fa27ce4SDimitry Andric   if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1642ac9a064cSDimitry Andric     if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2)
16437fa27ce4SDimitry Andric       return nullptr;
16447fa27ce4SDimitry Andric 
16457fa27ce4SDimitry Andric     auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
16467fa27ce4SDimitry Andric     auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
16477fa27ce4SDimitry Andric     if (!Real || !Imag)
16487fa27ce4SDimitry Andric       return nullptr;
16497fa27ce4SDimitry Andric 
16507fa27ce4SDimitry Andric     return identifyNode(Real, Imag);
16517fa27ce4SDimitry Andric   }
16527fa27ce4SDimitry Andric 
16537fa27ce4SDimitry Andric   auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
16547fa27ce4SDimitry Andric   if (!SVI)
16557fa27ce4SDimitry Andric     return nullptr;
16567fa27ce4SDimitry Andric 
16577fa27ce4SDimitry Andric   // Look for a shufflevector that takes separate vectors of the real and
16587fa27ce4SDimitry Andric   // imaginary components and recombines them into a single vector.
16597fa27ce4SDimitry Andric   if (!isInterleavingMask(SVI->getShuffleMask()))
16607fa27ce4SDimitry Andric     return nullptr;
16617fa27ce4SDimitry Andric 
16627fa27ce4SDimitry Andric   Instruction *Real;
16637fa27ce4SDimitry Andric   Instruction *Imag;
16647fa27ce4SDimitry Andric   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
16657fa27ce4SDimitry Andric     return nullptr;
16667fa27ce4SDimitry Andric 
16677fa27ce4SDimitry Andric   return identifyNode(Real, Imag);
16687fa27ce4SDimitry Andric }
16697fa27ce4SDimitry Andric 
16707fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyDeinterleave(Instruction * Real,Instruction * Imag)16717fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
16727fa27ce4SDimitry Andric                                                  Instruction *Imag) {
16737fa27ce4SDimitry Andric   Instruction *I = nullptr;
16747fa27ce4SDimitry Andric   Value *FinalValue = nullptr;
16757fa27ce4SDimitry Andric   if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
16767fa27ce4SDimitry Andric       match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1677ac9a064cSDimitry Andric       match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>(
16787fa27ce4SDimitry Andric                    m_Value(FinalValue)))) {
16797fa27ce4SDimitry Andric     NodePtr PlaceholderNode = prepareCompositeNode(
16807fa27ce4SDimitry Andric         llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
16817fa27ce4SDimitry Andric     PlaceholderNode->ReplacementNode = FinalValue;
16827fa27ce4SDimitry Andric     FinalInstructions.insert(Real);
16837fa27ce4SDimitry Andric     FinalInstructions.insert(Imag);
16847fa27ce4SDimitry Andric     return submitCompositeNode(PlaceholderNode);
16857fa27ce4SDimitry Andric   }
16867fa27ce4SDimitry Andric 
1687e3b55780SDimitry Andric   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1688e3b55780SDimitry Andric   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
16897fa27ce4SDimitry Andric   if (!RealShuffle || !ImagShuffle) {
16907fa27ce4SDimitry Andric     if (RealShuffle || ImagShuffle)
16917fa27ce4SDimitry Andric       LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
16927fa27ce4SDimitry Andric     return nullptr;
16937fa27ce4SDimitry Andric   }
16947fa27ce4SDimitry Andric 
1695e3b55780SDimitry Andric   Value *RealOp1 = RealShuffle->getOperand(1);
1696e3b55780SDimitry Andric   if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1697e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1698e3b55780SDimitry Andric     return nullptr;
1699e3b55780SDimitry Andric   }
1700e3b55780SDimitry Andric   Value *ImagOp1 = ImagShuffle->getOperand(1);
1701e3b55780SDimitry Andric   if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1702e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1703e3b55780SDimitry Andric     return nullptr;
1704e3b55780SDimitry Andric   }
1705e3b55780SDimitry Andric 
1706e3b55780SDimitry Andric   Value *RealOp0 = RealShuffle->getOperand(0);
1707e3b55780SDimitry Andric   Value *ImagOp0 = ImagShuffle->getOperand(0);
1708e3b55780SDimitry Andric 
1709e3b55780SDimitry Andric   if (RealOp0 != ImagOp0) {
1710e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1711e3b55780SDimitry Andric     return nullptr;
1712e3b55780SDimitry Andric   }
1713e3b55780SDimitry Andric 
1714e3b55780SDimitry Andric   ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1715e3b55780SDimitry Andric   ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1716e3b55780SDimitry Andric   if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1717e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1718e3b55780SDimitry Andric     return nullptr;
1719e3b55780SDimitry Andric   }
1720e3b55780SDimitry Andric 
1721e3b55780SDimitry Andric   if (RealMask[0] != 0 || ImagMask[0] != 1) {
1722e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1723e3b55780SDimitry Andric     return nullptr;
1724e3b55780SDimitry Andric   }
1725e3b55780SDimitry Andric 
1726e3b55780SDimitry Andric   // Type checking, the shuffle type should be a vector type of the same
1727e3b55780SDimitry Andric   // scalar type, but half the size
1728e3b55780SDimitry Andric   auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1729e3b55780SDimitry Andric     Value *Op = Shuffle->getOperand(0);
1730e3b55780SDimitry Andric     auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1731e3b55780SDimitry Andric     auto *OpTy = cast<FixedVectorType>(Op->getType());
1732e3b55780SDimitry Andric 
1733e3b55780SDimitry Andric     if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1734e3b55780SDimitry Andric       return false;
1735e3b55780SDimitry Andric     if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1736e3b55780SDimitry Andric       return false;
1737e3b55780SDimitry Andric 
1738e3b55780SDimitry Andric     return true;
1739e3b55780SDimitry Andric   };
1740e3b55780SDimitry Andric 
1741e3b55780SDimitry Andric   auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1742e3b55780SDimitry Andric     if (!CheckType(Shuffle))
1743e3b55780SDimitry Andric       return false;
1744e3b55780SDimitry Andric 
1745e3b55780SDimitry Andric     ArrayRef<int> Mask = Shuffle->getShuffleMask();
1746e3b55780SDimitry Andric     int Last = *Mask.rbegin();
1747e3b55780SDimitry Andric 
1748e3b55780SDimitry Andric     Value *Op = Shuffle->getOperand(0);
1749e3b55780SDimitry Andric     auto *OpTy = cast<FixedVectorType>(Op->getType());
1750e3b55780SDimitry Andric     int NumElements = OpTy->getNumElements();
1751e3b55780SDimitry Andric 
1752e3b55780SDimitry Andric     // Ensure that the deinterleaving shuffle only pulls from the first
1753e3b55780SDimitry Andric     // shuffle operand.
1754e3b55780SDimitry Andric     return Last < NumElements;
1755e3b55780SDimitry Andric   };
1756e3b55780SDimitry Andric 
1757e3b55780SDimitry Andric   if (RealShuffle->getType() != ImagShuffle->getType()) {
1758e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1759e3b55780SDimitry Andric     return nullptr;
1760e3b55780SDimitry Andric   }
1761e3b55780SDimitry Andric   if (!CheckDeinterleavingShuffle(RealShuffle)) {
1762e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1763e3b55780SDimitry Andric     return nullptr;
1764e3b55780SDimitry Andric   }
1765e3b55780SDimitry Andric   if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1766e3b55780SDimitry Andric     LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1767e3b55780SDimitry Andric     return nullptr;
1768e3b55780SDimitry Andric   }
1769e3b55780SDimitry Andric 
1770e3b55780SDimitry Andric   NodePtr PlaceholderNode =
17717fa27ce4SDimitry Andric       prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1772e3b55780SDimitry Andric                            RealShuffle, ImagShuffle);
1773e3b55780SDimitry Andric   PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
17747fa27ce4SDimitry Andric   FinalInstructions.insert(RealShuffle);
17757fa27ce4SDimitry Andric   FinalInstructions.insert(ImagShuffle);
1776e3b55780SDimitry Andric   return submitCompositeNode(PlaceholderNode);
1777e3b55780SDimitry Andric }
1778e3b55780SDimitry Andric 
17797fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifySplat(Value * R,Value * I)17807fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
17817fa27ce4SDimitry Andric   auto IsSplat = [](Value *V) -> bool {
17827fa27ce4SDimitry Andric     // Fixed-width vector with constants
17837fa27ce4SDimitry Andric     if (isa<ConstantDataVector>(V))
17847fa27ce4SDimitry Andric       return true;
1785e3b55780SDimitry Andric 
17867fa27ce4SDimitry Andric     VectorType *VTy;
17877fa27ce4SDimitry Andric     ArrayRef<int> Mask;
17887fa27ce4SDimitry Andric     // Splats are represented differently depending on whether the repeated
17897fa27ce4SDimitry Andric     // value is a constant or an Instruction
17907fa27ce4SDimitry Andric     if (auto *Const = dyn_cast<ConstantExpr>(V)) {
17917fa27ce4SDimitry Andric       if (Const->getOpcode() != Instruction::ShuffleVector)
1792e3b55780SDimitry Andric         return false;
17937fa27ce4SDimitry Andric       VTy = cast<VectorType>(Const->getType());
17947fa27ce4SDimitry Andric       Mask = Const->getShuffleMask();
17957fa27ce4SDimitry Andric     } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
17967fa27ce4SDimitry Andric       VTy = Shuf->getType();
17977fa27ce4SDimitry Andric       Mask = Shuf->getShuffleMask();
17987fa27ce4SDimitry Andric     } else {
1799e3b55780SDimitry Andric       return false;
1800e3b55780SDimitry Andric     }
18017fa27ce4SDimitry Andric 
18027fa27ce4SDimitry Andric     // When the data type is <1 x Type>, it's not possible to differentiate
18037fa27ce4SDimitry Andric     // between the ComplexDeinterleaving::Deinterleave and
18047fa27ce4SDimitry Andric     // ComplexDeinterleaving::Splat operations.
18057fa27ce4SDimitry Andric     if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
18067fa27ce4SDimitry Andric       return false;
18077fa27ce4SDimitry Andric 
18087fa27ce4SDimitry Andric     return all_equal(Mask) && Mask[0] == 0;
18097fa27ce4SDimitry Andric   };
18107fa27ce4SDimitry Andric 
18117fa27ce4SDimitry Andric   if (!IsSplat(R) || !IsSplat(I))
18127fa27ce4SDimitry Andric     return nullptr;
18137fa27ce4SDimitry Andric 
18147fa27ce4SDimitry Andric   auto *Real = dyn_cast<Instruction>(R);
18157fa27ce4SDimitry Andric   auto *Imag = dyn_cast<Instruction>(I);
18167fa27ce4SDimitry Andric   if ((!Real && Imag) || (Real && !Imag))
18177fa27ce4SDimitry Andric     return nullptr;
18187fa27ce4SDimitry Andric 
18197fa27ce4SDimitry Andric   if (Real && Imag) {
18207fa27ce4SDimitry Andric     // Non-constant splats should be in the same basic block
18217fa27ce4SDimitry Andric     if (Real->getParent() != Imag->getParent())
18227fa27ce4SDimitry Andric       return nullptr;
18237fa27ce4SDimitry Andric 
18247fa27ce4SDimitry Andric     FinalInstructions.insert(Real);
18257fa27ce4SDimitry Andric     FinalInstructions.insert(Imag);
1826e3b55780SDimitry Andric   }
18277fa27ce4SDimitry Andric   NodePtr PlaceholderNode =
18287fa27ce4SDimitry Andric       prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
18297fa27ce4SDimitry Andric   return submitCompositeNode(PlaceholderNode);
1830e3b55780SDimitry Andric }
1831e3b55780SDimitry Andric 
18327fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyPHINode(Instruction * Real,Instruction * Imag)18337fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
18347fa27ce4SDimitry Andric                                             Instruction *Imag) {
18357fa27ce4SDimitry Andric   if (Real != RealPHI || Imag != ImagPHI)
18367fa27ce4SDimitry Andric     return nullptr;
18377fa27ce4SDimitry Andric 
18387fa27ce4SDimitry Andric   PHIsFound = true;
18397fa27ce4SDimitry Andric   NodePtr PlaceholderNode = prepareCompositeNode(
18407fa27ce4SDimitry Andric       ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
18417fa27ce4SDimitry Andric   return submitCompositeNode(PlaceholderNode);
18427fa27ce4SDimitry Andric }
18437fa27ce4SDimitry Andric 
18447fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifySelectNode(Instruction * Real,Instruction * Imag)18457fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
18467fa27ce4SDimitry Andric                                                Instruction *Imag) {
18477fa27ce4SDimitry Andric   auto *SelectReal = dyn_cast<SelectInst>(Real);
18487fa27ce4SDimitry Andric   auto *SelectImag = dyn_cast<SelectInst>(Imag);
18497fa27ce4SDimitry Andric   if (!SelectReal || !SelectImag)
18507fa27ce4SDimitry Andric     return nullptr;
18517fa27ce4SDimitry Andric 
18527fa27ce4SDimitry Andric   Instruction *MaskA, *MaskB;
18537fa27ce4SDimitry Andric   Instruction *AR, *AI, *RA, *BI;
18547fa27ce4SDimitry Andric   if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
18557fa27ce4SDimitry Andric                             m_Instruction(RA))) ||
18567fa27ce4SDimitry Andric       !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
18577fa27ce4SDimitry Andric                             m_Instruction(BI))))
18587fa27ce4SDimitry Andric     return nullptr;
18597fa27ce4SDimitry Andric 
18607fa27ce4SDimitry Andric   if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
18617fa27ce4SDimitry Andric     return nullptr;
18627fa27ce4SDimitry Andric 
18637fa27ce4SDimitry Andric   if (!MaskA->getType()->isVectorTy())
18647fa27ce4SDimitry Andric     return nullptr;
18657fa27ce4SDimitry Andric 
18667fa27ce4SDimitry Andric   auto NodeA = identifyNode(AR, AI);
18677fa27ce4SDimitry Andric   if (!NodeA)
18687fa27ce4SDimitry Andric     return nullptr;
18697fa27ce4SDimitry Andric 
18707fa27ce4SDimitry Andric   auto NodeB = identifyNode(RA, BI);
18717fa27ce4SDimitry Andric   if (!NodeB)
18727fa27ce4SDimitry Andric     return nullptr;
18737fa27ce4SDimitry Andric 
18747fa27ce4SDimitry Andric   NodePtr PlaceholderNode = prepareCompositeNode(
18757fa27ce4SDimitry Andric       ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
18767fa27ce4SDimitry Andric   PlaceholderNode->addOperand(NodeA);
18777fa27ce4SDimitry Andric   PlaceholderNode->addOperand(NodeB);
18787fa27ce4SDimitry Andric   FinalInstructions.insert(MaskA);
18797fa27ce4SDimitry Andric   FinalInstructions.insert(MaskB);
18807fa27ce4SDimitry Andric   return submitCompositeNode(PlaceholderNode);
18817fa27ce4SDimitry Andric }
18827fa27ce4SDimitry Andric 
replaceSymmetricNode(IRBuilderBase & B,unsigned Opcode,std::optional<FastMathFlags> Flags,Value * InputA,Value * InputB)18837fa27ce4SDimitry Andric static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
18847fa27ce4SDimitry Andric                                    std::optional<FastMathFlags> Flags,
18857fa27ce4SDimitry Andric                                    Value *InputA, Value *InputB) {
18867fa27ce4SDimitry Andric   Value *I;
18877fa27ce4SDimitry Andric   switch (Opcode) {
18887fa27ce4SDimitry Andric   case Instruction::FNeg:
18897fa27ce4SDimitry Andric     I = B.CreateFNeg(InputA);
18907fa27ce4SDimitry Andric     break;
18917fa27ce4SDimitry Andric   case Instruction::FAdd:
18927fa27ce4SDimitry Andric     I = B.CreateFAdd(InputA, InputB);
18937fa27ce4SDimitry Andric     break;
18947fa27ce4SDimitry Andric   case Instruction::Add:
18957fa27ce4SDimitry Andric     I = B.CreateAdd(InputA, InputB);
18967fa27ce4SDimitry Andric     break;
18977fa27ce4SDimitry Andric   case Instruction::FSub:
18987fa27ce4SDimitry Andric     I = B.CreateFSub(InputA, InputB);
18997fa27ce4SDimitry Andric     break;
19007fa27ce4SDimitry Andric   case Instruction::Sub:
19017fa27ce4SDimitry Andric     I = B.CreateSub(InputA, InputB);
19027fa27ce4SDimitry Andric     break;
19037fa27ce4SDimitry Andric   case Instruction::FMul:
19047fa27ce4SDimitry Andric     I = B.CreateFMul(InputA, InputB);
19057fa27ce4SDimitry Andric     break;
19067fa27ce4SDimitry Andric   case Instruction::Mul:
19077fa27ce4SDimitry Andric     I = B.CreateMul(InputA, InputB);
19087fa27ce4SDimitry Andric     break;
19097fa27ce4SDimitry Andric   default:
19107fa27ce4SDimitry Andric     llvm_unreachable("Incorrect symmetric opcode");
19117fa27ce4SDimitry Andric   }
19127fa27ce4SDimitry Andric   if (Flags)
19137fa27ce4SDimitry Andric     cast<Instruction>(I)->setFastMathFlags(*Flags);
19147fa27ce4SDimitry Andric   return I;
19157fa27ce4SDimitry Andric }
19167fa27ce4SDimitry Andric 
replaceNode(IRBuilderBase & Builder,RawNodePtr Node)19177fa27ce4SDimitry Andric Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
19187fa27ce4SDimitry Andric                                                RawNodePtr Node) {
1919e3b55780SDimitry Andric   if (Node->ReplacementNode)
1920e3b55780SDimitry Andric     return Node->ReplacementNode;
1921e3b55780SDimitry Andric 
19227fa27ce4SDimitry Andric   auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
19237fa27ce4SDimitry Andric     return Node->Operands.size() > Idx
19247fa27ce4SDimitry Andric                ? replaceNode(Builder, Node->Operands[Idx])
19257fa27ce4SDimitry Andric                : nullptr;
19267fa27ce4SDimitry Andric   };
1927e3b55780SDimitry Andric 
19287fa27ce4SDimitry Andric   Value *ReplacementNode;
19297fa27ce4SDimitry Andric   switch (Node->Operation) {
19307fa27ce4SDimitry Andric   case ComplexDeinterleavingOperation::CAdd:
19317fa27ce4SDimitry Andric   case ComplexDeinterleavingOperation::CMulPartial:
19327fa27ce4SDimitry Andric   case ComplexDeinterleavingOperation::Symmetric: {
19337fa27ce4SDimitry Andric     Value *Input0 = ReplaceOperandIfExist(Node, 0);
19347fa27ce4SDimitry Andric     Value *Input1 = ReplaceOperandIfExist(Node, 1);
19357fa27ce4SDimitry Andric     Value *Accumulator = ReplaceOperandIfExist(Node, 2);
19367fa27ce4SDimitry Andric     assert(!Input1 || (Input0->getType() == Input1->getType() &&
19377fa27ce4SDimitry Andric                        "Node inputs need to be of the same type"));
19387fa27ce4SDimitry Andric     assert(!Accumulator ||
19397fa27ce4SDimitry Andric            (Input0->getType() == Accumulator->getType() &&
19407fa27ce4SDimitry Andric             "Accumulator and input need to be of the same type"));
19417fa27ce4SDimitry Andric     if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
19427fa27ce4SDimitry Andric       ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
19437fa27ce4SDimitry Andric                                              Input0, Input1);
19447fa27ce4SDimitry Andric     else
19457fa27ce4SDimitry Andric       ReplacementNode = TL->createComplexDeinterleavingIR(
19467fa27ce4SDimitry Andric           Builder, Node->Operation, Node->Rotation, Input0, Input1,
19477fa27ce4SDimitry Andric           Accumulator);
19487fa27ce4SDimitry Andric     break;
19497fa27ce4SDimitry Andric   }
19507fa27ce4SDimitry Andric   case ComplexDeinterleavingOperation::Deinterleave:
19517fa27ce4SDimitry Andric     llvm_unreachable("Deinterleave node should already have ReplacementNode");
19527fa27ce4SDimitry Andric     break;
19537fa27ce4SDimitry Andric   case ComplexDeinterleavingOperation::Splat: {
19547fa27ce4SDimitry Andric     auto *NewTy = VectorType::getDoubleElementsVectorType(
19557fa27ce4SDimitry Andric         cast<VectorType>(Node->Real->getType()));
19567fa27ce4SDimitry Andric     auto *R = dyn_cast<Instruction>(Node->Real);
19577fa27ce4SDimitry Andric     auto *I = dyn_cast<Instruction>(Node->Imag);
19587fa27ce4SDimitry Andric     if (R && I) {
19597fa27ce4SDimitry Andric       // Splats that are not constant are interleaved where they are located
19607fa27ce4SDimitry Andric       Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
19617fa27ce4SDimitry Andric       IRBuilder<> IRB(InsertPoint);
1962ac9a064cSDimitry Andric       ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2,
19637fa27ce4SDimitry Andric                                             NewTy, {Node->Real, Node->Imag});
1964ac9a064cSDimitry Andric     } else {
1965ac9a064cSDimitry Andric       ReplacementNode = Builder.CreateIntrinsic(
1966ac9a064cSDimitry Andric           Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag});
19677fa27ce4SDimitry Andric     }
19687fa27ce4SDimitry Andric     break;
19697fa27ce4SDimitry Andric   }
19707fa27ce4SDimitry Andric   case ComplexDeinterleavingOperation::ReductionPHI: {
19717fa27ce4SDimitry Andric     // If Operation is ReductionPHI, a new empty PHINode is created.
19727fa27ce4SDimitry Andric     // It is filled later when the ReductionOperation is processed.
19737fa27ce4SDimitry Andric     auto *VTy = cast<VectorType>(Node->Real->getType());
19747fa27ce4SDimitry Andric     auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1975ac9a064cSDimitry Andric     auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
19767fa27ce4SDimitry Andric     OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
19777fa27ce4SDimitry Andric     ReplacementNode = NewPHI;
19787fa27ce4SDimitry Andric     break;
19797fa27ce4SDimitry Andric   }
19807fa27ce4SDimitry Andric   case ComplexDeinterleavingOperation::ReductionOperation:
19817fa27ce4SDimitry Andric     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
19827fa27ce4SDimitry Andric     processReductionOperation(ReplacementNode, Node);
19837fa27ce4SDimitry Andric     break;
19847fa27ce4SDimitry Andric   case ComplexDeinterleavingOperation::ReductionSelect: {
19857fa27ce4SDimitry Andric     auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
19867fa27ce4SDimitry Andric     auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
19877fa27ce4SDimitry Andric     auto *A = replaceNode(Builder, Node->Operands[0]);
19887fa27ce4SDimitry Andric     auto *B = replaceNode(Builder, Node->Operands[1]);
19897fa27ce4SDimitry Andric     auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
19907fa27ce4SDimitry Andric         cast<VectorType>(MaskReal->getType()));
1991ac9a064cSDimitry Andric     auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2,
19927fa27ce4SDimitry Andric                                             NewMaskTy, {MaskReal, MaskImag});
19937fa27ce4SDimitry Andric     ReplacementNode = Builder.CreateSelect(NewMask, A, B);
19947fa27ce4SDimitry Andric     break;
19957fa27ce4SDimitry Andric   }
19967fa27ce4SDimitry Andric   }
1997e3b55780SDimitry Andric 
19987fa27ce4SDimitry Andric   assert(ReplacementNode && "Target failed to create Intrinsic call.");
1999e3b55780SDimitry Andric   NumComplexTransformations += 1;
20007fa27ce4SDimitry Andric   Node->ReplacementNode = ReplacementNode;
20017fa27ce4SDimitry Andric   return ReplacementNode;
20027fa27ce4SDimitry Andric }
20037fa27ce4SDimitry Andric 
processReductionOperation(Value * OperationReplacement,RawNodePtr Node)20047fa27ce4SDimitry Andric void ComplexDeinterleavingGraph::processReductionOperation(
20057fa27ce4SDimitry Andric     Value *OperationReplacement, RawNodePtr Node) {
20067fa27ce4SDimitry Andric   auto *Real = cast<Instruction>(Node->Real);
20077fa27ce4SDimitry Andric   auto *Imag = cast<Instruction>(Node->Imag);
20087fa27ce4SDimitry Andric   auto *OldPHIReal = ReductionInfo[Real].first;
20097fa27ce4SDimitry Andric   auto *OldPHIImag = ReductionInfo[Imag].first;
20107fa27ce4SDimitry Andric   auto *NewPHI = OldToNewPHI[OldPHIReal];
20117fa27ce4SDimitry Andric 
20127fa27ce4SDimitry Andric   auto *VTy = cast<VectorType>(Real->getType());
20137fa27ce4SDimitry Andric   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
20147fa27ce4SDimitry Andric 
20157fa27ce4SDimitry Andric   // We have to interleave initial origin values coming from IncomingBlock
20167fa27ce4SDimitry Andric   Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
20177fa27ce4SDimitry Andric   Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
20187fa27ce4SDimitry Andric 
20197fa27ce4SDimitry Andric   IRBuilder<> Builder(Incoming->getTerminator());
2020ac9a064cSDimitry Andric   auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2021ac9a064cSDimitry Andric                                           {InitReal, InitImag});
20227fa27ce4SDimitry Andric 
20237fa27ce4SDimitry Andric   NewPHI->addIncoming(NewInit, Incoming);
20247fa27ce4SDimitry Andric   NewPHI->addIncoming(OperationReplacement, BackEdge);
20257fa27ce4SDimitry Andric 
20267fa27ce4SDimitry Andric   // Deinterleave complex vector outside of loop so that it can be finally
20277fa27ce4SDimitry Andric   // reduced
20287fa27ce4SDimitry Andric   auto *FinalReductionReal = ReductionInfo[Real].second;
20297fa27ce4SDimitry Andric   auto *FinalReductionImag = ReductionInfo[Imag].second;
20307fa27ce4SDimitry Andric 
20317fa27ce4SDimitry Andric   Builder.SetInsertPoint(
20327fa27ce4SDimitry Andric       &*FinalReductionReal->getParent()->getFirstInsertionPt());
2033ac9a064cSDimitry Andric   auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2034ac9a064cSDimitry Andric                                                OperationReplacement->getType(),
2035ac9a064cSDimitry Andric                                                OperationReplacement);
20367fa27ce4SDimitry Andric 
20377fa27ce4SDimitry Andric   auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
20387fa27ce4SDimitry Andric   FinalReductionReal->replaceUsesOfWith(Real, NewReal);
20397fa27ce4SDimitry Andric 
20407fa27ce4SDimitry Andric   Builder.SetInsertPoint(FinalReductionImag);
20417fa27ce4SDimitry Andric   auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
20427fa27ce4SDimitry Andric   FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2043e3b55780SDimitry Andric }
2044e3b55780SDimitry Andric 
replaceNodes()2045e3b55780SDimitry Andric void ComplexDeinterleavingGraph::replaceNodes() {
20467fa27ce4SDimitry Andric   SmallVector<Instruction *, 16> DeadInstrRoots;
20477fa27ce4SDimitry Andric   for (auto *RootInstruction : OrderedRoots) {
20487fa27ce4SDimitry Andric     // Check if this potential root went through check process and we can
20497fa27ce4SDimitry Andric     // deinterleave it
20507fa27ce4SDimitry Andric     if (!RootToNode.count(RootInstruction))
20517fa27ce4SDimitry Andric       continue;
20527fa27ce4SDimitry Andric 
20537fa27ce4SDimitry Andric     IRBuilder<> Builder(RootInstruction);
20547fa27ce4SDimitry Andric     auto RootNode = RootToNode[RootInstruction];
20557fa27ce4SDimitry Andric     Value *R = replaceNode(Builder, RootNode.get());
20567fa27ce4SDimitry Andric 
20577fa27ce4SDimitry Andric     if (RootNode->Operation ==
20587fa27ce4SDimitry Andric         ComplexDeinterleavingOperation::ReductionOperation) {
20597fa27ce4SDimitry Andric       auto *RootReal = cast<Instruction>(RootNode->Real);
20607fa27ce4SDimitry Andric       auto *RootImag = cast<Instruction>(RootNode->Imag);
20617fa27ce4SDimitry Andric       ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
20627fa27ce4SDimitry Andric       ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
20637fa27ce4SDimitry Andric       DeadInstrRoots.push_back(cast<Instruction>(RootReal));
20647fa27ce4SDimitry Andric       DeadInstrRoots.push_back(cast<Instruction>(RootImag));
20657fa27ce4SDimitry Andric     } else {
20667fa27ce4SDimitry Andric       assert(R && "Unable to find replacement for RootInstruction");
20677fa27ce4SDimitry Andric       DeadInstrRoots.push_back(RootInstruction);
20687fa27ce4SDimitry Andric       RootInstruction->replaceAllUsesWith(R);
20697fa27ce4SDimitry Andric     }
2070e3b55780SDimitry Andric   }
2071e3b55780SDimitry Andric 
20727fa27ce4SDimitry Andric   for (auto *I : DeadInstrRoots)
20737fa27ce4SDimitry Andric     RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
2074e3b55780SDimitry Andric }
2075