1e3b55780SDimitry Andric //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2e3b55780SDimitry Andric //
3e3b55780SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e3b55780SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5e3b55780SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e3b55780SDimitry Andric //
7e3b55780SDimitry Andric //===----------------------------------------------------------------------===//
8e3b55780SDimitry Andric //
9e3b55780SDimitry Andric // Identification:
10e3b55780SDimitry Andric // This step is responsible for finding the patterns that can be lowered to
11e3b55780SDimitry Andric // complex instructions, and building a graph to represent the complex
12e3b55780SDimitry Andric // structures. Starting from the "Converging Shuffle" (a shuffle that
13e3b55780SDimitry Andric // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14e3b55780SDimitry Andric // operands are evaluated and identified as "Composite Nodes" (collections of
15e3b55780SDimitry Andric // instructions that can potentially be lowered to a single complex
16e3b55780SDimitry Andric // instruction). This is performed by checking the real and imaginary components
17e3b55780SDimitry Andric // and tracking the data flow for each component while following the operand
18e3b55780SDimitry Andric // pairs. Validity of each node is expected to be done upon creation, and any
19e3b55780SDimitry Andric // validation errors should halt traversal and prevent further graph
20e3b55780SDimitry Andric // construction.
217fa27ce4SDimitry Andric // Instead of relying on Shuffle operations, vector interleaving and
227fa27ce4SDimitry Andric // deinterleaving can be represented by vector.interleave2 and
237fa27ce4SDimitry Andric // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
247fa27ce4SDimitry Andric // these intrinsics, whereas, fixed-width vectors are recognized for both
257fa27ce4SDimitry Andric // shufflevector instruction and intrinsics.
26e3b55780SDimitry Andric //
27e3b55780SDimitry Andric // Replacement:
28e3b55780SDimitry Andric // This step traverses the graph built up by identification, delegating to the
29e3b55780SDimitry Andric // target to validate and generate the correct intrinsics, and plumbs them
30e3b55780SDimitry Andric // together connecting each end of the new intrinsics graph to the existing
31e3b55780SDimitry Andric // use-def chain. This step is assumed to finish successfully, as all
32e3b55780SDimitry Andric // information is expected to be correct by this point.
33e3b55780SDimitry Andric //
34e3b55780SDimitry Andric //
35e3b55780SDimitry Andric // Internal data structure:
36e3b55780SDimitry Andric // ComplexDeinterleavingGraph:
37e3b55780SDimitry Andric // Keeps references to all the valid CompositeNodes formed as part of the
38e3b55780SDimitry Andric // transformation, and every Instruction contained within said nodes. It also
39e3b55780SDimitry Andric // holds onto a reference to the root Instruction, and the root node that should
40e3b55780SDimitry Andric // replace it.
41e3b55780SDimitry Andric //
42e3b55780SDimitry Andric // ComplexDeinterleavingCompositeNode:
43e3b55780SDimitry Andric // A CompositeNode represents a single transformation point; each node should
44e3b55780SDimitry Andric // transform into a single complex instruction (ignoring vector splitting, which
45e3b55780SDimitry Andric // would generate more instructions per node). They are identified in a
46e3b55780SDimitry Andric // depth-first manner, traversing and identifying the operands of each
47e3b55780SDimitry Andric // instruction in the order they appear in the IR.
48e3b55780SDimitry Andric // Each node maintains a reference to its Real and Imaginary instructions,
49e3b55780SDimitry Andric // as well as any additional instructions that make up the identified operation
50e3b55780SDimitry Andric // (Internal instructions should only have uses within their containing node).
51e3b55780SDimitry Andric // A Node also contains the rotation and operation type that it represents.
52e3b55780SDimitry Andric // Operands contains pointers to other CompositeNodes, acting as the edges in
53e3b55780SDimitry Andric // the graph. ReplacementValue is the transformed Value* that has been emitted
54e3b55780SDimitry Andric // to the IR.
55e3b55780SDimitry Andric //
56e3b55780SDimitry Andric // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57e3b55780SDimitry Andric // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58e3b55780SDimitry Andric // should be pre-populated.
59e3b55780SDimitry Andric //
60e3b55780SDimitry Andric //===----------------------------------------------------------------------===//
61e3b55780SDimitry Andric
62e3b55780SDimitry Andric #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63b1c73532SDimitry Andric #include "llvm/ADT/MapVector.h"
64e3b55780SDimitry Andric #include "llvm/ADT/Statistic.h"
65e3b55780SDimitry Andric #include "llvm/Analysis/TargetLibraryInfo.h"
66e3b55780SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
67e3b55780SDimitry Andric #include "llvm/CodeGen/TargetLowering.h"
68e3b55780SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
69e3b55780SDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h"
70e3b55780SDimitry Andric #include "llvm/IR/IRBuilder.h"
717fa27ce4SDimitry Andric #include "llvm/IR/PatternMatch.h"
72e3b55780SDimitry Andric #include "llvm/InitializePasses.h"
73e3b55780SDimitry Andric #include "llvm/Target/TargetMachine.h"
74e3b55780SDimitry Andric #include "llvm/Transforms/Utils/Local.h"
75e3b55780SDimitry Andric #include <algorithm>
76e3b55780SDimitry Andric
77e3b55780SDimitry Andric using namespace llvm;
78e3b55780SDimitry Andric using namespace PatternMatch;
79e3b55780SDimitry Andric
80e3b55780SDimitry Andric #define DEBUG_TYPE "complex-deinterleaving"
81e3b55780SDimitry Andric
82e3b55780SDimitry Andric STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
83e3b55780SDimitry Andric
84e3b55780SDimitry Andric static cl::opt<bool> ComplexDeinterleavingEnabled(
85e3b55780SDimitry Andric "enable-complex-deinterleaving",
86e3b55780SDimitry Andric cl::desc("Enable generation of complex instructions"), cl::init(true),
87e3b55780SDimitry Andric cl::Hidden);
88e3b55780SDimitry Andric
89e3b55780SDimitry Andric /// Checks the given mask, and determines whether said mask is interleaving.
90e3b55780SDimitry Andric ///
91e3b55780SDimitry Andric /// To be interleaving, a mask must alternate between `i` and `i + (Length /
92e3b55780SDimitry Andric /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
93e3b55780SDimitry Andric /// 4x vector interleaving mask would be <0, 2, 1, 3>).
94e3b55780SDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask);
95e3b55780SDimitry Andric
96e3b55780SDimitry Andric /// Checks the given mask, and determines whether said mask is deinterleaving.
97e3b55780SDimitry Andric ///
98e3b55780SDimitry Andric /// To be deinterleaving, a mask must increment in steps of 2, and either start
99e3b55780SDimitry Andric /// with 0 or 1.
100e3b55780SDimitry Andric /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
101e3b55780SDimitry Andric /// <1, 3, 5, 7>).
102e3b55780SDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask);
103e3b55780SDimitry Andric
1047fa27ce4SDimitry Andric /// Returns true if the operation is a negation of V, and it works for both
1057fa27ce4SDimitry Andric /// integers and floats.
1067fa27ce4SDimitry Andric static bool isNeg(Value *V);
1077fa27ce4SDimitry Andric
1087fa27ce4SDimitry Andric /// Returns the operand for negation operation.
1097fa27ce4SDimitry Andric static Value *getNegOperand(Value *V);
1107fa27ce4SDimitry Andric
111e3b55780SDimitry Andric namespace {
112e3b55780SDimitry Andric
113e3b55780SDimitry Andric class ComplexDeinterleavingLegacyPass : public FunctionPass {
114e3b55780SDimitry Andric public:
115e3b55780SDimitry Andric static char ID;
116e3b55780SDimitry Andric
ComplexDeinterleavingLegacyPass(const TargetMachine * TM=nullptr)117e3b55780SDimitry Andric ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
118e3b55780SDimitry Andric : FunctionPass(ID), TM(TM) {
119e3b55780SDimitry Andric initializeComplexDeinterleavingLegacyPassPass(
120e3b55780SDimitry Andric *PassRegistry::getPassRegistry());
121e3b55780SDimitry Andric }
122e3b55780SDimitry Andric
getPassName() const123e3b55780SDimitry Andric StringRef getPassName() const override {
124e3b55780SDimitry Andric return "Complex Deinterleaving Pass";
125e3b55780SDimitry Andric }
126e3b55780SDimitry Andric
127e3b55780SDimitry Andric bool runOnFunction(Function &F) override;
getAnalysisUsage(AnalysisUsage & AU) const128e3b55780SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override {
129e3b55780SDimitry Andric AU.addRequired<TargetLibraryInfoWrapperPass>();
130e3b55780SDimitry Andric AU.setPreservesCFG();
131e3b55780SDimitry Andric }
132e3b55780SDimitry Andric
133e3b55780SDimitry Andric private:
134e3b55780SDimitry Andric const TargetMachine *TM;
135e3b55780SDimitry Andric };
136e3b55780SDimitry Andric
137e3b55780SDimitry Andric class ComplexDeinterleavingGraph;
138e3b55780SDimitry Andric struct ComplexDeinterleavingCompositeNode {
139e3b55780SDimitry Andric
ComplexDeinterleavingCompositeNode__anon446059c90111::ComplexDeinterleavingCompositeNode140e3b55780SDimitry Andric ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
1417fa27ce4SDimitry Andric Value *R, Value *I)
142e3b55780SDimitry Andric : Operation(Op), Real(R), Imag(I) {}
143e3b55780SDimitry Andric
144e3b55780SDimitry Andric private:
145e3b55780SDimitry Andric friend class ComplexDeinterleavingGraph;
146e3b55780SDimitry Andric using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
147e3b55780SDimitry Andric using RawNodePtr = ComplexDeinterleavingCompositeNode *;
148e3b55780SDimitry Andric
149e3b55780SDimitry Andric public:
150e3b55780SDimitry Andric ComplexDeinterleavingOperation Operation;
1517fa27ce4SDimitry Andric Value *Real;
1527fa27ce4SDimitry Andric Value *Imag;
153e3b55780SDimitry Andric
1547fa27ce4SDimitry Andric // This two members are required exclusively for generating
1557fa27ce4SDimitry Andric // ComplexDeinterleavingOperation::Symmetric operations.
1567fa27ce4SDimitry Andric unsigned Opcode;
1577fa27ce4SDimitry Andric std::optional<FastMathFlags> Flags;
1587fa27ce4SDimitry Andric
1597fa27ce4SDimitry Andric ComplexDeinterleavingRotation Rotation =
1607fa27ce4SDimitry Andric ComplexDeinterleavingRotation::Rotation_0;
161e3b55780SDimitry Andric SmallVector<RawNodePtr> Operands;
162e3b55780SDimitry Andric Value *ReplacementNode = nullptr;
163e3b55780SDimitry Andric
addOperand__anon446059c90111::ComplexDeinterleavingCompositeNode164e3b55780SDimitry Andric void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
165e3b55780SDimitry Andric
dump__anon446059c90111::ComplexDeinterleavingCompositeNode166e3b55780SDimitry Andric void dump() { dump(dbgs()); }
dump__anon446059c90111::ComplexDeinterleavingCompositeNode167e3b55780SDimitry Andric void dump(raw_ostream &OS) {
168e3b55780SDimitry Andric auto PrintValue = [&](Value *V) {
169e3b55780SDimitry Andric if (V) {
170e3b55780SDimitry Andric OS << "\"";
171e3b55780SDimitry Andric V->print(OS, true);
172e3b55780SDimitry Andric OS << "\"\n";
173e3b55780SDimitry Andric } else
174e3b55780SDimitry Andric OS << "nullptr\n";
175e3b55780SDimitry Andric };
176e3b55780SDimitry Andric auto PrintNodeRef = [&](RawNodePtr Ptr) {
177e3b55780SDimitry Andric if (Ptr)
178e3b55780SDimitry Andric OS << Ptr << "\n";
179e3b55780SDimitry Andric else
180e3b55780SDimitry Andric OS << "nullptr\n";
181e3b55780SDimitry Andric };
182e3b55780SDimitry Andric
183e3b55780SDimitry Andric OS << "- CompositeNode: " << this << "\n";
184e3b55780SDimitry Andric OS << " Real: ";
185e3b55780SDimitry Andric PrintValue(Real);
186e3b55780SDimitry Andric OS << " Imag: ";
187e3b55780SDimitry Andric PrintValue(Imag);
188e3b55780SDimitry Andric OS << " ReplacementNode: ";
189e3b55780SDimitry Andric PrintValue(ReplacementNode);
190e3b55780SDimitry Andric OS << " Operation: " << (int)Operation << "\n";
191e3b55780SDimitry Andric OS << " Rotation: " << ((int)Rotation * 90) << "\n";
192e3b55780SDimitry Andric OS << " Operands: \n";
193e3b55780SDimitry Andric for (const auto &Op : Operands) {
194e3b55780SDimitry Andric OS << " - ";
195e3b55780SDimitry Andric PrintNodeRef(Op);
196e3b55780SDimitry Andric }
197e3b55780SDimitry Andric }
198e3b55780SDimitry Andric };
199e3b55780SDimitry Andric
200e3b55780SDimitry Andric class ComplexDeinterleavingGraph {
201e3b55780SDimitry Andric public:
2027fa27ce4SDimitry Andric struct Product {
2037fa27ce4SDimitry Andric Value *Multiplier;
2047fa27ce4SDimitry Andric Value *Multiplicand;
2057fa27ce4SDimitry Andric bool IsPositive;
2067fa27ce4SDimitry Andric };
2077fa27ce4SDimitry Andric
2087fa27ce4SDimitry Andric using Addend = std::pair<Value *, bool>;
209e3b55780SDimitry Andric using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
210e3b55780SDimitry Andric using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
2117fa27ce4SDimitry Andric
2127fa27ce4SDimitry Andric // Helper struct for holding info about potential partial multiplication
2137fa27ce4SDimitry Andric // candidates
2147fa27ce4SDimitry Andric struct PartialMulCandidate {
2157fa27ce4SDimitry Andric Value *Common;
2167fa27ce4SDimitry Andric NodePtr Node;
2177fa27ce4SDimitry Andric unsigned RealIdx;
2187fa27ce4SDimitry Andric unsigned ImagIdx;
2197fa27ce4SDimitry Andric bool IsNodeInverted;
2207fa27ce4SDimitry Andric };
2217fa27ce4SDimitry Andric
ComplexDeinterleavingGraph(const TargetLowering * TL,const TargetLibraryInfo * TLI)2227fa27ce4SDimitry Andric explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
2237fa27ce4SDimitry Andric const TargetLibraryInfo *TLI)
2247fa27ce4SDimitry Andric : TL(TL), TLI(TLI) {}
225e3b55780SDimitry Andric
226e3b55780SDimitry Andric private:
2277fa27ce4SDimitry Andric const TargetLowering *TL = nullptr;
2287fa27ce4SDimitry Andric const TargetLibraryInfo *TLI = nullptr;
229e3b55780SDimitry Andric SmallVector<NodePtr> CompositeNodes;
230b1c73532SDimitry Andric DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
2317fa27ce4SDimitry Andric
2327fa27ce4SDimitry Andric SmallPtrSet<Instruction *, 16> FinalInstructions;
2337fa27ce4SDimitry Andric
2347fa27ce4SDimitry Andric /// Root instructions are instructions from which complex computation starts
2357fa27ce4SDimitry Andric std::map<Instruction *, NodePtr> RootToNode;
2367fa27ce4SDimitry Andric
2377fa27ce4SDimitry Andric /// Topologically sorted root instructions
2387fa27ce4SDimitry Andric SmallVector<Instruction *, 1> OrderedRoots;
2397fa27ce4SDimitry Andric
2407fa27ce4SDimitry Andric /// When examining a basic block for complex deinterleaving, if it is a simple
2417fa27ce4SDimitry Andric /// one-block loop, then the only incoming block is 'Incoming' and the
2427fa27ce4SDimitry Andric /// 'BackEdge' block is the block itself."
2437fa27ce4SDimitry Andric BasicBlock *BackEdge = nullptr;
2447fa27ce4SDimitry Andric BasicBlock *Incoming = nullptr;
2457fa27ce4SDimitry Andric
2467fa27ce4SDimitry Andric /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
2477fa27ce4SDimitry Andric /// %OutsideUser as it is shown in the IR:
2487fa27ce4SDimitry Andric ///
2497fa27ce4SDimitry Andric /// vector.body:
2507fa27ce4SDimitry Andric /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
2517fa27ce4SDimitry Andric /// [ %ReductionOp, %vector.body ]
2527fa27ce4SDimitry Andric /// ...
2537fa27ce4SDimitry Andric /// %ReductionOp = fadd i64 ...
2547fa27ce4SDimitry Andric /// ...
2557fa27ce4SDimitry Andric /// br i1 %condition, label %vector.body, %middle.block
2567fa27ce4SDimitry Andric ///
2577fa27ce4SDimitry Andric /// middle.block:
2587fa27ce4SDimitry Andric /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
2597fa27ce4SDimitry Andric ///
2607fa27ce4SDimitry Andric /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
2617fa27ce4SDimitry Andric /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
262b1c73532SDimitry Andric MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
2637fa27ce4SDimitry Andric
2647fa27ce4SDimitry Andric /// In the process of detecting a reduction, we consider a pair of
2657fa27ce4SDimitry Andric /// %ReductionOP, which we refer to as real and imag (or vice versa), and
2667fa27ce4SDimitry Andric /// traverse the use-tree to detect complex operations. As this is a reduction
2677fa27ce4SDimitry Andric /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
2687fa27ce4SDimitry Andric /// to the %ReductionOPs that we suspect to be complex.
2697fa27ce4SDimitry Andric /// RealPHI and ImagPHI are used by the identifyPHINode method.
2707fa27ce4SDimitry Andric PHINode *RealPHI = nullptr;
2717fa27ce4SDimitry Andric PHINode *ImagPHI = nullptr;
2727fa27ce4SDimitry Andric
2737fa27ce4SDimitry Andric /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
2747fa27ce4SDimitry Andric /// detection.
2757fa27ce4SDimitry Andric bool PHIsFound = false;
2767fa27ce4SDimitry Andric
2777fa27ce4SDimitry Andric /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
2787fa27ce4SDimitry Andric /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
2797fa27ce4SDimitry Andric /// This mapping is populated during
2807fa27ce4SDimitry Andric /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
2817fa27ce4SDimitry Andric /// used in the ComplexDeinterleavingOperation::ReductionOperation node
2827fa27ce4SDimitry Andric /// replacement process.
2837fa27ce4SDimitry Andric std::map<PHINode *, PHINode *> OldToNewPHI;
284e3b55780SDimitry Andric
prepareCompositeNode(ComplexDeinterleavingOperation Operation,Value * R,Value * I)285e3b55780SDimitry Andric NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
2867fa27ce4SDimitry Andric Value *R, Value *I) {
2877fa27ce4SDimitry Andric assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
2887fa27ce4SDimitry Andric Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
2897fa27ce4SDimitry Andric (R && I)) &&
2907fa27ce4SDimitry Andric "Reduction related nodes must have Real and Imaginary parts");
291e3b55780SDimitry Andric return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
292e3b55780SDimitry Andric I);
293e3b55780SDimitry Andric }
294e3b55780SDimitry Andric
submitCompositeNode(NodePtr Node)295e3b55780SDimitry Andric NodePtr submitCompositeNode(NodePtr Node) {
296e3b55780SDimitry Andric CompositeNodes.push_back(Node);
297b1c73532SDimitry Andric if (Node->Real && Node->Imag)
298b1c73532SDimitry Andric CachedResult[{Node->Real, Node->Imag}] = Node;
299e3b55780SDimitry Andric return Node;
300e3b55780SDimitry Andric }
301e3b55780SDimitry Andric
302e3b55780SDimitry Andric /// Identifies a complex partial multiply pattern and its rotation, based on
303e3b55780SDimitry Andric /// the following patterns
304e3b55780SDimitry Andric ///
305e3b55780SDimitry Andric /// 0: r: cr + ar * br
306e3b55780SDimitry Andric /// i: ci + ar * bi
307e3b55780SDimitry Andric /// 90: r: cr - ai * bi
308e3b55780SDimitry Andric /// i: ci + ai * br
309e3b55780SDimitry Andric /// 180: r: cr - ar * br
310e3b55780SDimitry Andric /// i: ci - ar * bi
311e3b55780SDimitry Andric /// 270: r: cr + ai * bi
312e3b55780SDimitry Andric /// i: ci - ai * br
313e3b55780SDimitry Andric NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
314e3b55780SDimitry Andric
315e3b55780SDimitry Andric /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
316e3b55780SDimitry Andric /// is partially known from identifyPartialMul, filling in the other half of
317e3b55780SDimitry Andric /// the complex pair.
3187fa27ce4SDimitry Andric NodePtr
3197fa27ce4SDimitry Andric identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
3207fa27ce4SDimitry Andric std::pair<Value *, Value *> &CommonOperandI);
321e3b55780SDimitry Andric
322e3b55780SDimitry Andric /// Identifies a complex add pattern and its rotation, based on the following
323e3b55780SDimitry Andric /// patterns.
324e3b55780SDimitry Andric ///
325e3b55780SDimitry Andric /// 90: r: ar - bi
326e3b55780SDimitry Andric /// i: ai + br
327e3b55780SDimitry Andric /// 270: r: ar + bi
328e3b55780SDimitry Andric /// i: ai - br
329e3b55780SDimitry Andric NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
3307fa27ce4SDimitry Andric NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
331e3b55780SDimitry Andric
3327fa27ce4SDimitry Andric NodePtr identifyNode(Value *R, Value *I);
333e3b55780SDimitry Andric
3347fa27ce4SDimitry Andric /// Determine if a sum of complex numbers can be formed from \p RealAddends
3357fa27ce4SDimitry Andric /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
3367fa27ce4SDimitry Andric /// Return nullptr if it is not possible to construct a complex number.
3377fa27ce4SDimitry Andric /// \p Flags are needed to generate symmetric Add and Sub operations.
3387fa27ce4SDimitry Andric NodePtr identifyAdditions(std::list<Addend> &RealAddends,
3397fa27ce4SDimitry Andric std::list<Addend> &ImagAddends,
3407fa27ce4SDimitry Andric std::optional<FastMathFlags> Flags,
3417fa27ce4SDimitry Andric NodePtr Accumulator);
3427fa27ce4SDimitry Andric
3437fa27ce4SDimitry Andric /// Extract one addend that have both real and imaginary parts positive.
3447fa27ce4SDimitry Andric NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
3457fa27ce4SDimitry Andric std::list<Addend> &ImagAddends);
3467fa27ce4SDimitry Andric
3477fa27ce4SDimitry Andric /// Determine if sum of multiplications of complex numbers can be formed from
3487fa27ce4SDimitry Andric /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
3497fa27ce4SDimitry Andric /// to it. Return nullptr if it is not possible to construct a complex number.
3507fa27ce4SDimitry Andric NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
3517fa27ce4SDimitry Andric std::vector<Product> &ImagMuls,
3527fa27ce4SDimitry Andric NodePtr Accumulator);
3537fa27ce4SDimitry Andric
3547fa27ce4SDimitry Andric /// Go through pairs of multiplication (one Real and one Imag) and find all
3557fa27ce4SDimitry Andric /// possible candidates for partial multiplication and put them into \p
3567fa27ce4SDimitry Andric /// Candidates. Returns true if all Product has pair with common operand
3577fa27ce4SDimitry Andric bool collectPartialMuls(const std::vector<Product> &RealMuls,
3587fa27ce4SDimitry Andric const std::vector<Product> &ImagMuls,
3597fa27ce4SDimitry Andric std::vector<PartialMulCandidate> &Candidates);
3607fa27ce4SDimitry Andric
3617fa27ce4SDimitry Andric /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
3627fa27ce4SDimitry Andric /// the order of complex computation operations may be significantly altered,
3637fa27ce4SDimitry Andric /// and the real and imaginary parts may not be executed in parallel. This
3647fa27ce4SDimitry Andric /// function takes this into consideration and employs a more general approach
3657fa27ce4SDimitry Andric /// to identify complex computations. Initially, it gathers all the addends
3667fa27ce4SDimitry Andric /// and multiplicands and then constructs a complex expression from them.
3677fa27ce4SDimitry Andric NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
3687fa27ce4SDimitry Andric
3697fa27ce4SDimitry Andric NodePtr identifyRoot(Instruction *I);
3707fa27ce4SDimitry Andric
3717fa27ce4SDimitry Andric /// Identifies the Deinterleave operation applied to a vector containing
3727fa27ce4SDimitry Andric /// complex numbers. There are two ways to represent the Deinterleave
3737fa27ce4SDimitry Andric /// operation:
3747fa27ce4SDimitry Andric /// * Using two shufflevectors with even indices for /pReal instruction and
3757fa27ce4SDimitry Andric /// odd indices for /pImag instructions (only for fixed-width vectors)
3767fa27ce4SDimitry Andric /// * Using two extractvalue instructions applied to `vector.deinterleave2`
3777fa27ce4SDimitry Andric /// intrinsic (for both fixed and scalable vectors)
3787fa27ce4SDimitry Andric NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
3797fa27ce4SDimitry Andric
3807fa27ce4SDimitry Andric /// identifying the operation that represents a complex number repeated in a
3817fa27ce4SDimitry Andric /// Splat vector. There are two possible types of splats: ConstantExpr with
3827fa27ce4SDimitry Andric /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
3837fa27ce4SDimitry Andric /// initialization mask with all values set to zero.
3847fa27ce4SDimitry Andric NodePtr identifySplat(Value *Real, Value *Imag);
3857fa27ce4SDimitry Andric
3867fa27ce4SDimitry Andric NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
3877fa27ce4SDimitry Andric
3887fa27ce4SDimitry Andric /// Identifies SelectInsts in a loop that has reduction with predication masks
3897fa27ce4SDimitry Andric /// and/or predicated tail folding
3907fa27ce4SDimitry Andric NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
3917fa27ce4SDimitry Andric
3927fa27ce4SDimitry Andric Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
3937fa27ce4SDimitry Andric
3947fa27ce4SDimitry Andric /// Complete IR modifications after producing new reduction operation:
3957fa27ce4SDimitry Andric /// * Populate the PHINode generated for
3967fa27ce4SDimitry Andric /// ComplexDeinterleavingOperation::ReductionPHI
3977fa27ce4SDimitry Andric /// * Deinterleave the final value outside of the loop and repurpose original
3987fa27ce4SDimitry Andric /// reduction users
3997fa27ce4SDimitry Andric void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
400e3b55780SDimitry Andric
401e3b55780SDimitry Andric public:
dump()402e3b55780SDimitry Andric void dump() { dump(dbgs()); }
dump(raw_ostream & OS)403e3b55780SDimitry Andric void dump(raw_ostream &OS) {
404e3b55780SDimitry Andric for (const auto &Node : CompositeNodes)
405e3b55780SDimitry Andric Node->dump(OS);
406e3b55780SDimitry Andric }
407e3b55780SDimitry Andric
408e3b55780SDimitry Andric /// Returns false if the deinterleaving operation should be cancelled for the
409e3b55780SDimitry Andric /// current graph.
410e3b55780SDimitry Andric bool identifyNodes(Instruction *RootI);
411e3b55780SDimitry Andric
4127fa27ce4SDimitry Andric /// In case \pB is one-block loop, this function seeks potential reductions
4137fa27ce4SDimitry Andric /// and populates ReductionInfo. Returns true if any reductions were
4147fa27ce4SDimitry Andric /// identified.
4157fa27ce4SDimitry Andric bool collectPotentialReductions(BasicBlock *B);
4167fa27ce4SDimitry Andric
4177fa27ce4SDimitry Andric void identifyReductionNodes();
4187fa27ce4SDimitry Andric
4197fa27ce4SDimitry Andric /// Check that every instruction, from the roots to the leaves, has internal
4207fa27ce4SDimitry Andric /// uses.
4217fa27ce4SDimitry Andric bool checkNodes();
4227fa27ce4SDimitry Andric
423e3b55780SDimitry Andric /// Perform the actual replacement of the underlying instruction graph.
424e3b55780SDimitry Andric void replaceNodes();
425e3b55780SDimitry Andric };
426e3b55780SDimitry Andric
427e3b55780SDimitry Andric class ComplexDeinterleaving {
428e3b55780SDimitry Andric public:
ComplexDeinterleaving(const TargetLowering * tl,const TargetLibraryInfo * tli)429e3b55780SDimitry Andric ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
430e3b55780SDimitry Andric : TL(tl), TLI(tli) {}
431e3b55780SDimitry Andric bool runOnFunction(Function &F);
432e3b55780SDimitry Andric
433e3b55780SDimitry Andric private:
434e3b55780SDimitry Andric bool evaluateBasicBlock(BasicBlock *B);
435e3b55780SDimitry Andric
436e3b55780SDimitry Andric const TargetLowering *TL = nullptr;
437e3b55780SDimitry Andric const TargetLibraryInfo *TLI = nullptr;
438e3b55780SDimitry Andric };
439e3b55780SDimitry Andric
440e3b55780SDimitry Andric } // namespace
441e3b55780SDimitry Andric
442e3b55780SDimitry Andric char ComplexDeinterleavingLegacyPass::ID = 0;
443e3b55780SDimitry Andric
444e3b55780SDimitry Andric INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
445e3b55780SDimitry Andric "Complex Deinterleaving", false, false)
446e3b55780SDimitry Andric INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
447e3b55780SDimitry Andric "Complex Deinterleaving", false, false)
448e3b55780SDimitry Andric
run(Function & F,FunctionAnalysisManager & AM)449e3b55780SDimitry Andric PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
450e3b55780SDimitry Andric FunctionAnalysisManager &AM) {
451e3b55780SDimitry Andric const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
452e3b55780SDimitry Andric auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
453e3b55780SDimitry Andric if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
454e3b55780SDimitry Andric return PreservedAnalyses::all();
455e3b55780SDimitry Andric
456e3b55780SDimitry Andric PreservedAnalyses PA;
457e3b55780SDimitry Andric PA.preserve<FunctionAnalysisManagerModuleProxy>();
458e3b55780SDimitry Andric return PA;
459e3b55780SDimitry Andric }
460e3b55780SDimitry Andric
createComplexDeinterleavingPass(const TargetMachine * TM)461e3b55780SDimitry Andric FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
462e3b55780SDimitry Andric return new ComplexDeinterleavingLegacyPass(TM);
463e3b55780SDimitry Andric }
464e3b55780SDimitry Andric
runOnFunction(Function & F)465e3b55780SDimitry Andric bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
466e3b55780SDimitry Andric const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
467e3b55780SDimitry Andric auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
468e3b55780SDimitry Andric return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
469e3b55780SDimitry Andric }
470e3b55780SDimitry Andric
runOnFunction(Function & F)471e3b55780SDimitry Andric bool ComplexDeinterleaving::runOnFunction(Function &F) {
472e3b55780SDimitry Andric if (!ComplexDeinterleavingEnabled) {
473e3b55780SDimitry Andric LLVM_DEBUG(
474e3b55780SDimitry Andric dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
475e3b55780SDimitry Andric return false;
476e3b55780SDimitry Andric }
477e3b55780SDimitry Andric
478e3b55780SDimitry Andric if (!TL->isComplexDeinterleavingSupported()) {
479e3b55780SDimitry Andric LLVM_DEBUG(
480e3b55780SDimitry Andric dbgs() << "Complex deinterleaving has been disabled, target does "
481e3b55780SDimitry Andric "not support lowering of complex number operations.\n");
482e3b55780SDimitry Andric return false;
483e3b55780SDimitry Andric }
484e3b55780SDimitry Andric
485e3b55780SDimitry Andric bool Changed = false;
486e3b55780SDimitry Andric for (auto &B : F)
487e3b55780SDimitry Andric Changed |= evaluateBasicBlock(&B);
488e3b55780SDimitry Andric
489e3b55780SDimitry Andric return Changed;
490e3b55780SDimitry Andric }
491e3b55780SDimitry Andric
isInterleavingMask(ArrayRef<int> Mask)492e3b55780SDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask) {
493e3b55780SDimitry Andric // If the size is not even, it's not an interleaving mask
494e3b55780SDimitry Andric if ((Mask.size() & 1))
495e3b55780SDimitry Andric return false;
496e3b55780SDimitry Andric
497e3b55780SDimitry Andric int HalfNumElements = Mask.size() / 2;
498e3b55780SDimitry Andric for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
499e3b55780SDimitry Andric int MaskIdx = Idx * 2;
500e3b55780SDimitry Andric if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
501e3b55780SDimitry Andric return false;
502e3b55780SDimitry Andric }
503e3b55780SDimitry Andric
504e3b55780SDimitry Andric return true;
505e3b55780SDimitry Andric }
506e3b55780SDimitry Andric
isDeinterleavingMask(ArrayRef<int> Mask)507e3b55780SDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask) {
508e3b55780SDimitry Andric int Offset = Mask[0];
509e3b55780SDimitry Andric int HalfNumElements = Mask.size() / 2;
510e3b55780SDimitry Andric
511e3b55780SDimitry Andric for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
512e3b55780SDimitry Andric if (Mask[Idx] != (Idx * 2) + Offset)
513e3b55780SDimitry Andric return false;
514e3b55780SDimitry Andric }
515e3b55780SDimitry Andric
516e3b55780SDimitry Andric return true;
517e3b55780SDimitry Andric }
518e3b55780SDimitry Andric
isNeg(Value * V)5197fa27ce4SDimitry Andric bool isNeg(Value *V) {
5207fa27ce4SDimitry Andric return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
5217fa27ce4SDimitry Andric }
5227fa27ce4SDimitry Andric
getNegOperand(Value * V)5237fa27ce4SDimitry Andric Value *getNegOperand(Value *V) {
5247fa27ce4SDimitry Andric assert(isNeg(V));
5257fa27ce4SDimitry Andric auto *I = cast<Instruction>(V);
5267fa27ce4SDimitry Andric if (I->getOpcode() == Instruction::FNeg)
5277fa27ce4SDimitry Andric return I->getOperand(0);
5287fa27ce4SDimitry Andric
5297fa27ce4SDimitry Andric return I->getOperand(1);
5307fa27ce4SDimitry Andric }
5317fa27ce4SDimitry Andric
evaluateBasicBlock(BasicBlock * B)532e3b55780SDimitry Andric bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
5337fa27ce4SDimitry Andric ComplexDeinterleavingGraph Graph(TL, TLI);
5347fa27ce4SDimitry Andric if (Graph.collectPotentialReductions(B))
5357fa27ce4SDimitry Andric Graph.identifyReductionNodes();
536e3b55780SDimitry Andric
5377fa27ce4SDimitry Andric for (auto &I : *B)
5387fa27ce4SDimitry Andric Graph.identifyNodes(&I);
539e3b55780SDimitry Andric
5407fa27ce4SDimitry Andric if (Graph.checkNodes()) {
541e3b55780SDimitry Andric Graph.replaceNodes();
5427fa27ce4SDimitry Andric return true;
543e3b55780SDimitry Andric }
544e3b55780SDimitry Andric
5457fa27ce4SDimitry Andric return false;
546e3b55780SDimitry Andric }
547e3b55780SDimitry Andric
548e3b55780SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyNodeWithImplicitAdd(Instruction * Real,Instruction * Imag,std::pair<Value *,Value * > & PartialMatch)549e3b55780SDimitry Andric ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
550e3b55780SDimitry Andric Instruction *Real, Instruction *Imag,
5517fa27ce4SDimitry Andric std::pair<Value *, Value *> &PartialMatch) {
552e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
553e3b55780SDimitry Andric << "\n");
554e3b55780SDimitry Andric
555e3b55780SDimitry Andric if (!Real->hasOneUse() || !Imag->hasOneUse()) {
556e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
557e3b55780SDimitry Andric return nullptr;
558e3b55780SDimitry Andric }
559e3b55780SDimitry Andric
5607fa27ce4SDimitry Andric if ((Real->getOpcode() != Instruction::FMul &&
5617fa27ce4SDimitry Andric Real->getOpcode() != Instruction::Mul) ||
5627fa27ce4SDimitry Andric (Imag->getOpcode() != Instruction::FMul &&
5637fa27ce4SDimitry Andric Imag->getOpcode() != Instruction::Mul)) {
5647fa27ce4SDimitry Andric LLVM_DEBUG(
5657fa27ce4SDimitry Andric dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
566e3b55780SDimitry Andric return nullptr;
567e3b55780SDimitry Andric }
568e3b55780SDimitry Andric
5697fa27ce4SDimitry Andric Value *R0 = Real->getOperand(0);
5707fa27ce4SDimitry Andric Value *R1 = Real->getOperand(1);
5717fa27ce4SDimitry Andric Value *I0 = Imag->getOperand(0);
5727fa27ce4SDimitry Andric Value *I1 = Imag->getOperand(1);
573e3b55780SDimitry Andric
574e3b55780SDimitry Andric // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
575e3b55780SDimitry Andric // rotations and use the operand.
576e3b55780SDimitry Andric unsigned Negs = 0;
5777fa27ce4SDimitry Andric Value *Op;
5787fa27ce4SDimitry Andric if (match(R0, m_Neg(m_Value(Op)))) {
579e3b55780SDimitry Andric Negs |= 1;
5807fa27ce4SDimitry Andric R0 = Op;
5817fa27ce4SDimitry Andric } else if (match(R1, m_Neg(m_Value(Op)))) {
5827fa27ce4SDimitry Andric Negs |= 1;
5837fa27ce4SDimitry Andric R1 = Op;
584e3b55780SDimitry Andric }
5857fa27ce4SDimitry Andric
5867fa27ce4SDimitry Andric if (isNeg(I0)) {
587e3b55780SDimitry Andric Negs |= 2;
588e3b55780SDimitry Andric Negs ^= 1;
5897fa27ce4SDimitry Andric I0 = Op;
5907fa27ce4SDimitry Andric } else if (match(I1, m_Neg(m_Value(Op)))) {
5917fa27ce4SDimitry Andric Negs |= 2;
5927fa27ce4SDimitry Andric Negs ^= 1;
5937fa27ce4SDimitry Andric I1 = Op;
594e3b55780SDimitry Andric }
595e3b55780SDimitry Andric
596e3b55780SDimitry Andric ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
597e3b55780SDimitry Andric
5987fa27ce4SDimitry Andric Value *CommonOperand;
5997fa27ce4SDimitry Andric Value *UncommonRealOp;
6007fa27ce4SDimitry Andric Value *UncommonImagOp;
601e3b55780SDimitry Andric
602e3b55780SDimitry Andric if (R0 == I0 || R0 == I1) {
603e3b55780SDimitry Andric CommonOperand = R0;
604e3b55780SDimitry Andric UncommonRealOp = R1;
605e3b55780SDimitry Andric } else if (R1 == I0 || R1 == I1) {
606e3b55780SDimitry Andric CommonOperand = R1;
607e3b55780SDimitry Andric UncommonRealOp = R0;
608e3b55780SDimitry Andric } else {
609e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - No equal operand\n");
610e3b55780SDimitry Andric return nullptr;
611e3b55780SDimitry Andric }
612e3b55780SDimitry Andric
613e3b55780SDimitry Andric UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
614e3b55780SDimitry Andric if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
615e3b55780SDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_270)
616e3b55780SDimitry Andric std::swap(UncommonRealOp, UncommonImagOp);
617e3b55780SDimitry Andric
618e3b55780SDimitry Andric // Between identifyPartialMul and here we need to have found a complete valid
619e3b55780SDimitry Andric // pair from the CommonOperand of each part.
620e3b55780SDimitry Andric if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
621e3b55780SDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_180)
622e3b55780SDimitry Andric PartialMatch.first = CommonOperand;
623e3b55780SDimitry Andric else
624e3b55780SDimitry Andric PartialMatch.second = CommonOperand;
625e3b55780SDimitry Andric
626e3b55780SDimitry Andric if (!PartialMatch.first || !PartialMatch.second) {
627e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
628e3b55780SDimitry Andric return nullptr;
629e3b55780SDimitry Andric }
630e3b55780SDimitry Andric
631e3b55780SDimitry Andric NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
632e3b55780SDimitry Andric if (!CommonNode) {
633e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
634e3b55780SDimitry Andric return nullptr;
635e3b55780SDimitry Andric }
636e3b55780SDimitry Andric
637e3b55780SDimitry Andric NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
638e3b55780SDimitry Andric if (!UncommonNode) {
639e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
640e3b55780SDimitry Andric return nullptr;
641e3b55780SDimitry Andric }
642e3b55780SDimitry Andric
643e3b55780SDimitry Andric NodePtr Node = prepareCompositeNode(
644e3b55780SDimitry Andric ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
645e3b55780SDimitry Andric Node->Rotation = Rotation;
646e3b55780SDimitry Andric Node->addOperand(CommonNode);
647e3b55780SDimitry Andric Node->addOperand(UncommonNode);
648e3b55780SDimitry Andric return submitCompositeNode(Node);
649e3b55780SDimitry Andric }
650e3b55780SDimitry Andric
651e3b55780SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyPartialMul(Instruction * Real,Instruction * Imag)652e3b55780SDimitry Andric ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
653e3b55780SDimitry Andric Instruction *Imag) {
654e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
655e3b55780SDimitry Andric << "\n");
656e3b55780SDimitry Andric // Determine rotation
6577fa27ce4SDimitry Andric auto IsAdd = [](unsigned Op) {
6587fa27ce4SDimitry Andric return Op == Instruction::FAdd || Op == Instruction::Add;
6597fa27ce4SDimitry Andric };
6607fa27ce4SDimitry Andric auto IsSub = [](unsigned Op) {
6617fa27ce4SDimitry Andric return Op == Instruction::FSub || Op == Instruction::Sub;
6627fa27ce4SDimitry Andric };
663e3b55780SDimitry Andric ComplexDeinterleavingRotation Rotation;
6647fa27ce4SDimitry Andric if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
665e3b55780SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_0;
6667fa27ce4SDimitry Andric else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
667e3b55780SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_90;
6687fa27ce4SDimitry Andric else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
669e3b55780SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_180;
6707fa27ce4SDimitry Andric else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
671e3b55780SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_270;
672e3b55780SDimitry Andric else {
673e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
674e3b55780SDimitry Andric return nullptr;
675e3b55780SDimitry Andric }
676e3b55780SDimitry Andric
6777fa27ce4SDimitry Andric if (isa<FPMathOperator>(Real) &&
6787fa27ce4SDimitry Andric (!Real->getFastMathFlags().allowContract() ||
6797fa27ce4SDimitry Andric !Imag->getFastMathFlags().allowContract())) {
680e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
681e3b55780SDimitry Andric return nullptr;
682e3b55780SDimitry Andric }
683e3b55780SDimitry Andric
684e3b55780SDimitry Andric Value *CR = Real->getOperand(0);
685e3b55780SDimitry Andric Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
686e3b55780SDimitry Andric if (!RealMulI)
687e3b55780SDimitry Andric return nullptr;
688e3b55780SDimitry Andric Value *CI = Imag->getOperand(0);
689e3b55780SDimitry Andric Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
690e3b55780SDimitry Andric if (!ImagMulI)
691e3b55780SDimitry Andric return nullptr;
692e3b55780SDimitry Andric
693e3b55780SDimitry Andric if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
694e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
695e3b55780SDimitry Andric return nullptr;
696e3b55780SDimitry Andric }
697e3b55780SDimitry Andric
6987fa27ce4SDimitry Andric Value *R0 = RealMulI->getOperand(0);
6997fa27ce4SDimitry Andric Value *R1 = RealMulI->getOperand(1);
7007fa27ce4SDimitry Andric Value *I0 = ImagMulI->getOperand(0);
7017fa27ce4SDimitry Andric Value *I1 = ImagMulI->getOperand(1);
702e3b55780SDimitry Andric
7037fa27ce4SDimitry Andric Value *CommonOperand;
7047fa27ce4SDimitry Andric Value *UncommonRealOp;
7057fa27ce4SDimitry Andric Value *UncommonImagOp;
706e3b55780SDimitry Andric
707e3b55780SDimitry Andric if (R0 == I0 || R0 == I1) {
708e3b55780SDimitry Andric CommonOperand = R0;
709e3b55780SDimitry Andric UncommonRealOp = R1;
710e3b55780SDimitry Andric } else if (R1 == I0 || R1 == I1) {
711e3b55780SDimitry Andric CommonOperand = R1;
712e3b55780SDimitry Andric UncommonRealOp = R0;
713e3b55780SDimitry Andric } else {
714e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - No equal operand\n");
715e3b55780SDimitry Andric return nullptr;
716e3b55780SDimitry Andric }
717e3b55780SDimitry Andric
718e3b55780SDimitry Andric UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
719e3b55780SDimitry Andric if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
720e3b55780SDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_270)
721e3b55780SDimitry Andric std::swap(UncommonRealOp, UncommonImagOp);
722e3b55780SDimitry Andric
7237fa27ce4SDimitry Andric std::pair<Value *, Value *> PartialMatch(
724e3b55780SDimitry Andric (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
725e3b55780SDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_180)
726e3b55780SDimitry Andric ? CommonOperand
727e3b55780SDimitry Andric : nullptr,
728e3b55780SDimitry Andric (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
729e3b55780SDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_270)
730e3b55780SDimitry Andric ? CommonOperand
731e3b55780SDimitry Andric : nullptr);
7327fa27ce4SDimitry Andric
7337fa27ce4SDimitry Andric auto *CRInst = dyn_cast<Instruction>(CR);
7347fa27ce4SDimitry Andric auto *CIInst = dyn_cast<Instruction>(CI);
7357fa27ce4SDimitry Andric
7367fa27ce4SDimitry Andric if (!CRInst || !CIInst) {
7377fa27ce4SDimitry Andric LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
7387fa27ce4SDimitry Andric return nullptr;
7397fa27ce4SDimitry Andric }
7407fa27ce4SDimitry Andric
7417fa27ce4SDimitry Andric NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
742e3b55780SDimitry Andric if (!CNode) {
743e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - No cnode identified\n");
744e3b55780SDimitry Andric return nullptr;
745e3b55780SDimitry Andric }
746e3b55780SDimitry Andric
747e3b55780SDimitry Andric NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
748e3b55780SDimitry Andric if (!UncommonRes) {
749e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
750e3b55780SDimitry Andric return nullptr;
751e3b55780SDimitry Andric }
752e3b55780SDimitry Andric
753e3b55780SDimitry Andric assert(PartialMatch.first && PartialMatch.second);
754e3b55780SDimitry Andric NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
755e3b55780SDimitry Andric if (!CommonRes) {
756e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
757e3b55780SDimitry Andric return nullptr;
758e3b55780SDimitry Andric }
759e3b55780SDimitry Andric
760e3b55780SDimitry Andric NodePtr Node = prepareCompositeNode(
761e3b55780SDimitry Andric ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
762e3b55780SDimitry Andric Node->Rotation = Rotation;
763e3b55780SDimitry Andric Node->addOperand(CommonRes);
764e3b55780SDimitry Andric Node->addOperand(UncommonRes);
765e3b55780SDimitry Andric Node->addOperand(CNode);
766e3b55780SDimitry Andric return submitCompositeNode(Node);
767e3b55780SDimitry Andric }
768e3b55780SDimitry Andric
769e3b55780SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyAdd(Instruction * Real,Instruction * Imag)770e3b55780SDimitry Andric ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
771e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
772e3b55780SDimitry Andric
773e3b55780SDimitry Andric // Determine rotation
774e3b55780SDimitry Andric ComplexDeinterleavingRotation Rotation;
775e3b55780SDimitry Andric if ((Real->getOpcode() == Instruction::FSub &&
776e3b55780SDimitry Andric Imag->getOpcode() == Instruction::FAdd) ||
777e3b55780SDimitry Andric (Real->getOpcode() == Instruction::Sub &&
778e3b55780SDimitry Andric Imag->getOpcode() == Instruction::Add))
779e3b55780SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_90;
780e3b55780SDimitry Andric else if ((Real->getOpcode() == Instruction::FAdd &&
781e3b55780SDimitry Andric Imag->getOpcode() == Instruction::FSub) ||
782e3b55780SDimitry Andric (Real->getOpcode() == Instruction::Add &&
783e3b55780SDimitry Andric Imag->getOpcode() == Instruction::Sub))
784e3b55780SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_270;
785e3b55780SDimitry Andric else {
786e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
787e3b55780SDimitry Andric return nullptr;
788e3b55780SDimitry Andric }
789e3b55780SDimitry Andric
790e3b55780SDimitry Andric auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
791e3b55780SDimitry Andric auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
792e3b55780SDimitry Andric auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
793e3b55780SDimitry Andric auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
794e3b55780SDimitry Andric
795e3b55780SDimitry Andric if (!AR || !AI || !BR || !BI) {
796e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
797e3b55780SDimitry Andric return nullptr;
798e3b55780SDimitry Andric }
799e3b55780SDimitry Andric
800e3b55780SDimitry Andric NodePtr ResA = identifyNode(AR, AI);
801e3b55780SDimitry Andric if (!ResA) {
802e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
803e3b55780SDimitry Andric return nullptr;
804e3b55780SDimitry Andric }
805e3b55780SDimitry Andric NodePtr ResB = identifyNode(BR, BI);
806e3b55780SDimitry Andric if (!ResB) {
807e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
808e3b55780SDimitry Andric return nullptr;
809e3b55780SDimitry Andric }
810e3b55780SDimitry Andric
811e3b55780SDimitry Andric NodePtr Node =
812e3b55780SDimitry Andric prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
813e3b55780SDimitry Andric Node->Rotation = Rotation;
814e3b55780SDimitry Andric Node->addOperand(ResA);
815e3b55780SDimitry Andric Node->addOperand(ResB);
816e3b55780SDimitry Andric return submitCompositeNode(Node);
817e3b55780SDimitry Andric }
818e3b55780SDimitry Andric
isInstructionPairAdd(Instruction * A,Instruction * B)819e3b55780SDimitry Andric static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
820e3b55780SDimitry Andric unsigned OpcA = A->getOpcode();
821e3b55780SDimitry Andric unsigned OpcB = B->getOpcode();
822e3b55780SDimitry Andric
823e3b55780SDimitry Andric return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
824e3b55780SDimitry Andric (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
825e3b55780SDimitry Andric (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
826e3b55780SDimitry Andric (OpcA == Instruction::Add && OpcB == Instruction::Sub);
827e3b55780SDimitry Andric }
828e3b55780SDimitry Andric
isInstructionPairMul(Instruction * A,Instruction * B)829e3b55780SDimitry Andric static bool isInstructionPairMul(Instruction *A, Instruction *B) {
830e3b55780SDimitry Andric auto Pattern =
831e3b55780SDimitry Andric m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
832e3b55780SDimitry Andric
833e3b55780SDimitry Andric return match(A, Pattern) && match(B, Pattern);
834e3b55780SDimitry Andric }
835e3b55780SDimitry Andric
isInstructionPotentiallySymmetric(Instruction * I)8367fa27ce4SDimitry Andric static bool isInstructionPotentiallySymmetric(Instruction *I) {
8377fa27ce4SDimitry Andric switch (I->getOpcode()) {
8387fa27ce4SDimitry Andric case Instruction::FAdd:
8397fa27ce4SDimitry Andric case Instruction::FSub:
8407fa27ce4SDimitry Andric case Instruction::FMul:
8417fa27ce4SDimitry Andric case Instruction::FNeg:
8427fa27ce4SDimitry Andric case Instruction::Add:
8437fa27ce4SDimitry Andric case Instruction::Sub:
8447fa27ce4SDimitry Andric case Instruction::Mul:
8457fa27ce4SDimitry Andric return true;
8467fa27ce4SDimitry Andric default:
8477fa27ce4SDimitry Andric return false;
8487fa27ce4SDimitry Andric }
8497fa27ce4SDimitry Andric }
8507fa27ce4SDimitry Andric
851e3b55780SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifySymmetricOperation(Instruction * Real,Instruction * Imag)8527fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
8537fa27ce4SDimitry Andric Instruction *Imag) {
8547fa27ce4SDimitry Andric if (Real->getOpcode() != Imag->getOpcode())
8557fa27ce4SDimitry Andric return nullptr;
8567fa27ce4SDimitry Andric
8577fa27ce4SDimitry Andric if (!isInstructionPotentiallySymmetric(Real) ||
8587fa27ce4SDimitry Andric !isInstructionPotentiallySymmetric(Imag))
8597fa27ce4SDimitry Andric return nullptr;
8607fa27ce4SDimitry Andric
8617fa27ce4SDimitry Andric auto *R0 = Real->getOperand(0);
8627fa27ce4SDimitry Andric auto *I0 = Imag->getOperand(0);
8637fa27ce4SDimitry Andric
8647fa27ce4SDimitry Andric NodePtr Op0 = identifyNode(R0, I0);
8657fa27ce4SDimitry Andric NodePtr Op1 = nullptr;
8667fa27ce4SDimitry Andric if (Op0 == nullptr)
8677fa27ce4SDimitry Andric return nullptr;
8687fa27ce4SDimitry Andric
8697fa27ce4SDimitry Andric if (Real->isBinaryOp()) {
8707fa27ce4SDimitry Andric auto *R1 = Real->getOperand(1);
8717fa27ce4SDimitry Andric auto *I1 = Imag->getOperand(1);
8727fa27ce4SDimitry Andric Op1 = identifyNode(R1, I1);
8737fa27ce4SDimitry Andric if (Op1 == nullptr)
8747fa27ce4SDimitry Andric return nullptr;
8757fa27ce4SDimitry Andric }
8767fa27ce4SDimitry Andric
8777fa27ce4SDimitry Andric if (isa<FPMathOperator>(Real) &&
8787fa27ce4SDimitry Andric Real->getFastMathFlags() != Imag->getFastMathFlags())
8797fa27ce4SDimitry Andric return nullptr;
8807fa27ce4SDimitry Andric
8817fa27ce4SDimitry Andric auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
8827fa27ce4SDimitry Andric Real, Imag);
8837fa27ce4SDimitry Andric Node->Opcode = Real->getOpcode();
8847fa27ce4SDimitry Andric if (isa<FPMathOperator>(Real))
8857fa27ce4SDimitry Andric Node->Flags = Real->getFastMathFlags();
8867fa27ce4SDimitry Andric
8877fa27ce4SDimitry Andric Node->addOperand(Op0);
8887fa27ce4SDimitry Andric if (Real->isBinaryOp())
8897fa27ce4SDimitry Andric Node->addOperand(Op1);
8907fa27ce4SDimitry Andric
8917fa27ce4SDimitry Andric return submitCompositeNode(Node);
8927fa27ce4SDimitry Andric }
8937fa27ce4SDimitry Andric
8947fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyNode(Value * R,Value * I)8957fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
8967fa27ce4SDimitry Andric LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
8977fa27ce4SDimitry Andric assert(R->getType() == I->getType() &&
8987fa27ce4SDimitry Andric "Real and imaginary parts should not have different types");
899b1c73532SDimitry Andric
900b1c73532SDimitry Andric auto It = CachedResult.find({R, I});
901b1c73532SDimitry Andric if (It != CachedResult.end()) {
902e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
903b1c73532SDimitry Andric return It->second;
904e3b55780SDimitry Andric }
905e3b55780SDimitry Andric
9067fa27ce4SDimitry Andric if (NodePtr CN = identifySplat(R, I))
9077fa27ce4SDimitry Andric return CN;
9087fa27ce4SDimitry Andric
9097fa27ce4SDimitry Andric auto *Real = dyn_cast<Instruction>(R);
9107fa27ce4SDimitry Andric auto *Imag = dyn_cast<Instruction>(I);
9117fa27ce4SDimitry Andric if (!Real || !Imag)
9127fa27ce4SDimitry Andric return nullptr;
9137fa27ce4SDimitry Andric
9147fa27ce4SDimitry Andric if (NodePtr CN = identifyDeinterleave(Real, Imag))
9157fa27ce4SDimitry Andric return CN;
9167fa27ce4SDimitry Andric
9177fa27ce4SDimitry Andric if (NodePtr CN = identifyPHINode(Real, Imag))
9187fa27ce4SDimitry Andric return CN;
9197fa27ce4SDimitry Andric
9207fa27ce4SDimitry Andric if (NodePtr CN = identifySelectNode(Real, Imag))
9217fa27ce4SDimitry Andric return CN;
9227fa27ce4SDimitry Andric
9237fa27ce4SDimitry Andric auto *VTy = cast<VectorType>(Real->getType());
9247fa27ce4SDimitry Andric auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
9257fa27ce4SDimitry Andric
9267fa27ce4SDimitry Andric bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
9277fa27ce4SDimitry Andric ComplexDeinterleavingOperation::CMulPartial, NewVTy);
9287fa27ce4SDimitry Andric bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
9297fa27ce4SDimitry Andric ComplexDeinterleavingOperation::CAdd, NewVTy);
9307fa27ce4SDimitry Andric
9317fa27ce4SDimitry Andric if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
9327fa27ce4SDimitry Andric if (NodePtr CN = identifyPartialMul(Real, Imag))
9337fa27ce4SDimitry Andric return CN;
9347fa27ce4SDimitry Andric }
9357fa27ce4SDimitry Andric
9367fa27ce4SDimitry Andric if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
9377fa27ce4SDimitry Andric if (NodePtr CN = identifyAdd(Real, Imag))
9387fa27ce4SDimitry Andric return CN;
9397fa27ce4SDimitry Andric }
9407fa27ce4SDimitry Andric
9417fa27ce4SDimitry Andric if (HasCMulSupport && HasCAddSupport) {
9427fa27ce4SDimitry Andric if (NodePtr CN = identifyReassocNodes(Real, Imag))
9437fa27ce4SDimitry Andric return CN;
9447fa27ce4SDimitry Andric }
9457fa27ce4SDimitry Andric
9467fa27ce4SDimitry Andric if (NodePtr CN = identifySymmetricOperation(Real, Imag))
9477fa27ce4SDimitry Andric return CN;
9487fa27ce4SDimitry Andric
9497fa27ce4SDimitry Andric LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
950b1c73532SDimitry Andric CachedResult[{R, I}] = nullptr;
9517fa27ce4SDimitry Andric return nullptr;
9527fa27ce4SDimitry Andric }
9537fa27ce4SDimitry Andric
9547fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyReassocNodes(Instruction * Real,Instruction * Imag)9557fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
9567fa27ce4SDimitry Andric Instruction *Imag) {
9577fa27ce4SDimitry Andric auto IsOperationSupported = [](unsigned Opcode) -> bool {
9587fa27ce4SDimitry Andric return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
9597fa27ce4SDimitry Andric Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
9607fa27ce4SDimitry Andric Opcode == Instruction::Sub;
9617fa27ce4SDimitry Andric };
9627fa27ce4SDimitry Andric
9637fa27ce4SDimitry Andric if (!IsOperationSupported(Real->getOpcode()) ||
9647fa27ce4SDimitry Andric !IsOperationSupported(Imag->getOpcode()))
9657fa27ce4SDimitry Andric return nullptr;
9667fa27ce4SDimitry Andric
9677fa27ce4SDimitry Andric std::optional<FastMathFlags> Flags;
9687fa27ce4SDimitry Andric if (isa<FPMathOperator>(Real)) {
9697fa27ce4SDimitry Andric if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
9707fa27ce4SDimitry Andric LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
9717fa27ce4SDimitry Andric "not identical\n");
9727fa27ce4SDimitry Andric return nullptr;
9737fa27ce4SDimitry Andric }
9747fa27ce4SDimitry Andric
9757fa27ce4SDimitry Andric Flags = Real->getFastMathFlags();
9767fa27ce4SDimitry Andric if (!Flags->allowReassoc()) {
9777fa27ce4SDimitry Andric LLVM_DEBUG(
9787fa27ce4SDimitry Andric dbgs()
9797fa27ce4SDimitry Andric << "the 'Reassoc' attribute is missing in the FastMath flags\n");
9807fa27ce4SDimitry Andric return nullptr;
9817fa27ce4SDimitry Andric }
9827fa27ce4SDimitry Andric }
9837fa27ce4SDimitry Andric
9847fa27ce4SDimitry Andric // Collect multiplications and addend instructions from the given instruction
9857fa27ce4SDimitry Andric // while traversing it operands. Additionally, verify that all instructions
9867fa27ce4SDimitry Andric // have the same fast math flags.
9877fa27ce4SDimitry Andric auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
9887fa27ce4SDimitry Andric std::list<Addend> &Addends) -> bool {
9897fa27ce4SDimitry Andric SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
9907fa27ce4SDimitry Andric SmallPtrSet<Value *, 8> Visited;
9917fa27ce4SDimitry Andric while (!Worklist.empty()) {
9927fa27ce4SDimitry Andric auto [V, IsPositive] = Worklist.back();
9937fa27ce4SDimitry Andric Worklist.pop_back();
9947fa27ce4SDimitry Andric if (!Visited.insert(V).second)
9957fa27ce4SDimitry Andric continue;
9967fa27ce4SDimitry Andric
9977fa27ce4SDimitry Andric Instruction *I = dyn_cast<Instruction>(V);
9987fa27ce4SDimitry Andric if (!I) {
9997fa27ce4SDimitry Andric Addends.emplace_back(V, IsPositive);
10007fa27ce4SDimitry Andric continue;
10017fa27ce4SDimitry Andric }
10027fa27ce4SDimitry Andric
10037fa27ce4SDimitry Andric // If an instruction has more than one user, it indicates that it either
10047fa27ce4SDimitry Andric // has an external user, which will be later checked by the checkNodes
10057fa27ce4SDimitry Andric // function, or it is a subexpression utilized by multiple expressions. In
10067fa27ce4SDimitry Andric // the latter case, we will attempt to separately identify the complex
10077fa27ce4SDimitry Andric // operation from here in order to create a shared
10087fa27ce4SDimitry Andric // ComplexDeinterleavingCompositeNode.
10097fa27ce4SDimitry Andric if (I != Insn && I->getNumUses() > 1) {
10107fa27ce4SDimitry Andric LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
10117fa27ce4SDimitry Andric Addends.emplace_back(I, IsPositive);
10127fa27ce4SDimitry Andric continue;
10137fa27ce4SDimitry Andric }
10147fa27ce4SDimitry Andric switch (I->getOpcode()) {
10157fa27ce4SDimitry Andric case Instruction::FAdd:
10167fa27ce4SDimitry Andric case Instruction::Add:
10177fa27ce4SDimitry Andric Worklist.emplace_back(I->getOperand(1), IsPositive);
10187fa27ce4SDimitry Andric Worklist.emplace_back(I->getOperand(0), IsPositive);
10197fa27ce4SDimitry Andric break;
10207fa27ce4SDimitry Andric case Instruction::FSub:
10217fa27ce4SDimitry Andric Worklist.emplace_back(I->getOperand(1), !IsPositive);
10227fa27ce4SDimitry Andric Worklist.emplace_back(I->getOperand(0), IsPositive);
10237fa27ce4SDimitry Andric break;
10247fa27ce4SDimitry Andric case Instruction::Sub:
10257fa27ce4SDimitry Andric if (isNeg(I)) {
10267fa27ce4SDimitry Andric Worklist.emplace_back(getNegOperand(I), !IsPositive);
10277fa27ce4SDimitry Andric } else {
10287fa27ce4SDimitry Andric Worklist.emplace_back(I->getOperand(1), !IsPositive);
10297fa27ce4SDimitry Andric Worklist.emplace_back(I->getOperand(0), IsPositive);
10307fa27ce4SDimitry Andric }
10317fa27ce4SDimitry Andric break;
10327fa27ce4SDimitry Andric case Instruction::FMul:
10337fa27ce4SDimitry Andric case Instruction::Mul: {
10347fa27ce4SDimitry Andric Value *A, *B;
10357fa27ce4SDimitry Andric if (isNeg(I->getOperand(0))) {
10367fa27ce4SDimitry Andric A = getNegOperand(I->getOperand(0));
10377fa27ce4SDimitry Andric IsPositive = !IsPositive;
10387fa27ce4SDimitry Andric } else {
10397fa27ce4SDimitry Andric A = I->getOperand(0);
10407fa27ce4SDimitry Andric }
10417fa27ce4SDimitry Andric
10427fa27ce4SDimitry Andric if (isNeg(I->getOperand(1))) {
10437fa27ce4SDimitry Andric B = getNegOperand(I->getOperand(1));
10447fa27ce4SDimitry Andric IsPositive = !IsPositive;
10457fa27ce4SDimitry Andric } else {
10467fa27ce4SDimitry Andric B = I->getOperand(1);
10477fa27ce4SDimitry Andric }
10487fa27ce4SDimitry Andric Muls.push_back(Product{A, B, IsPositive});
10497fa27ce4SDimitry Andric break;
10507fa27ce4SDimitry Andric }
10517fa27ce4SDimitry Andric case Instruction::FNeg:
10527fa27ce4SDimitry Andric Worklist.emplace_back(I->getOperand(0), !IsPositive);
10537fa27ce4SDimitry Andric break;
10547fa27ce4SDimitry Andric default:
10557fa27ce4SDimitry Andric Addends.emplace_back(I, IsPositive);
10567fa27ce4SDimitry Andric continue;
10577fa27ce4SDimitry Andric }
10587fa27ce4SDimitry Andric
10597fa27ce4SDimitry Andric if (Flags && I->getFastMathFlags() != *Flags) {
10607fa27ce4SDimitry Andric LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
10617fa27ce4SDimitry Andric "inconsistent with the root instructions' flags: "
10627fa27ce4SDimitry Andric << *I << "\n");
10637fa27ce4SDimitry Andric return false;
10647fa27ce4SDimitry Andric }
10657fa27ce4SDimitry Andric }
10667fa27ce4SDimitry Andric return true;
10677fa27ce4SDimitry Andric };
10687fa27ce4SDimitry Andric
10697fa27ce4SDimitry Andric std::vector<Product> RealMuls, ImagMuls;
10707fa27ce4SDimitry Andric std::list<Addend> RealAddends, ImagAddends;
10717fa27ce4SDimitry Andric if (!Collect(Real, RealMuls, RealAddends) ||
10727fa27ce4SDimitry Andric !Collect(Imag, ImagMuls, ImagAddends))
10737fa27ce4SDimitry Andric return nullptr;
10747fa27ce4SDimitry Andric
10757fa27ce4SDimitry Andric if (RealAddends.size() != ImagAddends.size())
10767fa27ce4SDimitry Andric return nullptr;
10777fa27ce4SDimitry Andric
10787fa27ce4SDimitry Andric NodePtr FinalNode;
10797fa27ce4SDimitry Andric if (!RealMuls.empty() || !ImagMuls.empty()) {
10807fa27ce4SDimitry Andric // If there are multiplicands, extract positive addend and use it as an
10817fa27ce4SDimitry Andric // accumulator
10827fa27ce4SDimitry Andric FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
10837fa27ce4SDimitry Andric FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
10847fa27ce4SDimitry Andric if (!FinalNode)
10857fa27ce4SDimitry Andric return nullptr;
10867fa27ce4SDimitry Andric }
10877fa27ce4SDimitry Andric
10887fa27ce4SDimitry Andric // Identify and process remaining additions
10897fa27ce4SDimitry Andric if (!RealAddends.empty() || !ImagAddends.empty()) {
10907fa27ce4SDimitry Andric FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
10917fa27ce4SDimitry Andric if (!FinalNode)
10927fa27ce4SDimitry Andric return nullptr;
10937fa27ce4SDimitry Andric }
10947fa27ce4SDimitry Andric assert(FinalNode && "FinalNode can not be nullptr here");
10957fa27ce4SDimitry Andric // Set the Real and Imag fields of the final node and submit it
10967fa27ce4SDimitry Andric FinalNode->Real = Real;
10977fa27ce4SDimitry Andric FinalNode->Imag = Imag;
10987fa27ce4SDimitry Andric submitCompositeNode(FinalNode);
10997fa27ce4SDimitry Andric return FinalNode;
11007fa27ce4SDimitry Andric }
11017fa27ce4SDimitry Andric
collectPartialMuls(const std::vector<Product> & RealMuls,const std::vector<Product> & ImagMuls,std::vector<PartialMulCandidate> & PartialMulCandidates)11027fa27ce4SDimitry Andric bool ComplexDeinterleavingGraph::collectPartialMuls(
11037fa27ce4SDimitry Andric const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
11047fa27ce4SDimitry Andric std::vector<PartialMulCandidate> &PartialMulCandidates) {
11057fa27ce4SDimitry Andric // Helper function to extract a common operand from two products
11067fa27ce4SDimitry Andric auto FindCommonInstruction = [](const Product &Real,
11077fa27ce4SDimitry Andric const Product &Imag) -> Value * {
11087fa27ce4SDimitry Andric if (Real.Multiplicand == Imag.Multiplicand ||
11097fa27ce4SDimitry Andric Real.Multiplicand == Imag.Multiplier)
11107fa27ce4SDimitry Andric return Real.Multiplicand;
11117fa27ce4SDimitry Andric
11127fa27ce4SDimitry Andric if (Real.Multiplier == Imag.Multiplicand ||
11137fa27ce4SDimitry Andric Real.Multiplier == Imag.Multiplier)
11147fa27ce4SDimitry Andric return Real.Multiplier;
11157fa27ce4SDimitry Andric
11167fa27ce4SDimitry Andric return nullptr;
11177fa27ce4SDimitry Andric };
11187fa27ce4SDimitry Andric
11197fa27ce4SDimitry Andric // Iterating over real and imaginary multiplications to find common operands
11207fa27ce4SDimitry Andric // If a common operand is found, a partial multiplication candidate is created
11217fa27ce4SDimitry Andric // and added to the candidates vector The function returns false if no common
11227fa27ce4SDimitry Andric // operands are found for any product
11237fa27ce4SDimitry Andric for (unsigned i = 0; i < RealMuls.size(); ++i) {
11247fa27ce4SDimitry Andric bool FoundCommon = false;
11257fa27ce4SDimitry Andric for (unsigned j = 0; j < ImagMuls.size(); ++j) {
11267fa27ce4SDimitry Andric auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
11277fa27ce4SDimitry Andric if (!Common)
11287fa27ce4SDimitry Andric continue;
11297fa27ce4SDimitry Andric
11307fa27ce4SDimitry Andric auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
11317fa27ce4SDimitry Andric : RealMuls[i].Multiplicand;
11327fa27ce4SDimitry Andric auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
11337fa27ce4SDimitry Andric : ImagMuls[j].Multiplicand;
11347fa27ce4SDimitry Andric
11357fa27ce4SDimitry Andric auto Node = identifyNode(A, B);
11367fa27ce4SDimitry Andric if (Node) {
11377fa27ce4SDimitry Andric FoundCommon = true;
11387fa27ce4SDimitry Andric PartialMulCandidates.push_back({Common, Node, i, j, false});
11397fa27ce4SDimitry Andric }
11407fa27ce4SDimitry Andric
11417fa27ce4SDimitry Andric Node = identifyNode(B, A);
11427fa27ce4SDimitry Andric if (Node) {
11437fa27ce4SDimitry Andric FoundCommon = true;
11447fa27ce4SDimitry Andric PartialMulCandidates.push_back({Common, Node, i, j, true});
11457fa27ce4SDimitry Andric }
11467fa27ce4SDimitry Andric }
11477fa27ce4SDimitry Andric if (!FoundCommon)
11487fa27ce4SDimitry Andric return false;
11497fa27ce4SDimitry Andric }
11507fa27ce4SDimitry Andric return true;
11517fa27ce4SDimitry Andric }
11527fa27ce4SDimitry Andric
11537fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyMultiplications(std::vector<Product> & RealMuls,std::vector<Product> & ImagMuls,NodePtr Accumulator=nullptr)11547fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifyMultiplications(
11557fa27ce4SDimitry Andric std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
11567fa27ce4SDimitry Andric NodePtr Accumulator = nullptr) {
11577fa27ce4SDimitry Andric if (RealMuls.size() != ImagMuls.size())
11587fa27ce4SDimitry Andric return nullptr;
11597fa27ce4SDimitry Andric
11607fa27ce4SDimitry Andric std::vector<PartialMulCandidate> Info;
11617fa27ce4SDimitry Andric if (!collectPartialMuls(RealMuls, ImagMuls, Info))
11627fa27ce4SDimitry Andric return nullptr;
11637fa27ce4SDimitry Andric
11647fa27ce4SDimitry Andric // Map to store common instruction to node pointers
11657fa27ce4SDimitry Andric std::map<Value *, NodePtr> CommonToNode;
11667fa27ce4SDimitry Andric std::vector<bool> Processed(Info.size(), false);
11677fa27ce4SDimitry Andric for (unsigned I = 0; I < Info.size(); ++I) {
11687fa27ce4SDimitry Andric if (Processed[I])
11697fa27ce4SDimitry Andric continue;
11707fa27ce4SDimitry Andric
11717fa27ce4SDimitry Andric PartialMulCandidate &InfoA = Info[I];
11727fa27ce4SDimitry Andric for (unsigned J = I + 1; J < Info.size(); ++J) {
11737fa27ce4SDimitry Andric if (Processed[J])
11747fa27ce4SDimitry Andric continue;
11757fa27ce4SDimitry Andric
11767fa27ce4SDimitry Andric PartialMulCandidate &InfoB = Info[J];
11777fa27ce4SDimitry Andric auto *InfoReal = &InfoA;
11787fa27ce4SDimitry Andric auto *InfoImag = &InfoB;
11797fa27ce4SDimitry Andric
11807fa27ce4SDimitry Andric auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
11817fa27ce4SDimitry Andric if (!NodeFromCommon) {
11827fa27ce4SDimitry Andric std::swap(InfoReal, InfoImag);
11837fa27ce4SDimitry Andric NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
11847fa27ce4SDimitry Andric }
11857fa27ce4SDimitry Andric if (!NodeFromCommon)
11867fa27ce4SDimitry Andric continue;
11877fa27ce4SDimitry Andric
11887fa27ce4SDimitry Andric CommonToNode[InfoReal->Common] = NodeFromCommon;
11897fa27ce4SDimitry Andric CommonToNode[InfoImag->Common] = NodeFromCommon;
11907fa27ce4SDimitry Andric Processed[I] = true;
11917fa27ce4SDimitry Andric Processed[J] = true;
11927fa27ce4SDimitry Andric }
11937fa27ce4SDimitry Andric }
11947fa27ce4SDimitry Andric
11957fa27ce4SDimitry Andric std::vector<bool> ProcessedReal(RealMuls.size(), false);
11967fa27ce4SDimitry Andric std::vector<bool> ProcessedImag(ImagMuls.size(), false);
11977fa27ce4SDimitry Andric NodePtr Result = Accumulator;
11987fa27ce4SDimitry Andric for (auto &PMI : Info) {
11997fa27ce4SDimitry Andric if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
12007fa27ce4SDimitry Andric continue;
12017fa27ce4SDimitry Andric
12027fa27ce4SDimitry Andric auto It = CommonToNode.find(PMI.Common);
12037fa27ce4SDimitry Andric // TODO: Process independent complex multiplications. Cases like this:
12047fa27ce4SDimitry Andric // A.real() * B where both A and B are complex numbers.
12057fa27ce4SDimitry Andric if (It == CommonToNode.end()) {
12067fa27ce4SDimitry Andric LLVM_DEBUG({
12077fa27ce4SDimitry Andric dbgs() << "Unprocessed independent partial multiplication:\n";
12087fa27ce4SDimitry Andric for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
12097fa27ce4SDimitry Andric dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
12107fa27ce4SDimitry Andric << " multiplied by " << *Mul->Multiplicand << "\n";
12117fa27ce4SDimitry Andric });
12127fa27ce4SDimitry Andric return nullptr;
12137fa27ce4SDimitry Andric }
12147fa27ce4SDimitry Andric
12157fa27ce4SDimitry Andric auto &RealMul = RealMuls[PMI.RealIdx];
12167fa27ce4SDimitry Andric auto &ImagMul = ImagMuls[PMI.ImagIdx];
12177fa27ce4SDimitry Andric
12187fa27ce4SDimitry Andric auto NodeA = It->second;
12197fa27ce4SDimitry Andric auto NodeB = PMI.Node;
12207fa27ce4SDimitry Andric auto IsMultiplicandReal = PMI.Common == NodeA->Real;
12217fa27ce4SDimitry Andric // The following table illustrates the relationship between multiplications
12227fa27ce4SDimitry Andric // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
12237fa27ce4SDimitry Andric // can see:
12247fa27ce4SDimitry Andric //
12257fa27ce4SDimitry Andric // Rotation | Real | Imag |
12267fa27ce4SDimitry Andric // ---------+--------+--------+
12277fa27ce4SDimitry Andric // 0 | x * u | x * v |
12287fa27ce4SDimitry Andric // 90 | -y * v | y * u |
12297fa27ce4SDimitry Andric // 180 | -x * u | -x * v |
12307fa27ce4SDimitry Andric // 270 | y * v | -y * u |
12317fa27ce4SDimitry Andric //
12327fa27ce4SDimitry Andric // Check if the candidate can indeed be represented by partial
12337fa27ce4SDimitry Andric // multiplication
12347fa27ce4SDimitry Andric // TODO: Add support for multiplication by complex one
12357fa27ce4SDimitry Andric if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
12367fa27ce4SDimitry Andric (!IsMultiplicandReal && !PMI.IsNodeInverted))
12377fa27ce4SDimitry Andric continue;
12387fa27ce4SDimitry Andric
12397fa27ce4SDimitry Andric // Determine the rotation based on the multiplications
12407fa27ce4SDimitry Andric ComplexDeinterleavingRotation Rotation;
12417fa27ce4SDimitry Andric if (IsMultiplicandReal) {
12427fa27ce4SDimitry Andric // Detect 0 and 180 degrees rotation
12437fa27ce4SDimitry Andric if (RealMul.IsPositive && ImagMul.IsPositive)
12447fa27ce4SDimitry Andric Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
12457fa27ce4SDimitry Andric else if (!RealMul.IsPositive && !ImagMul.IsPositive)
12467fa27ce4SDimitry Andric Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
12477fa27ce4SDimitry Andric else
12487fa27ce4SDimitry Andric continue;
12497fa27ce4SDimitry Andric
12507fa27ce4SDimitry Andric } else {
12517fa27ce4SDimitry Andric // Detect 90 and 270 degrees rotation
12527fa27ce4SDimitry Andric if (!RealMul.IsPositive && ImagMul.IsPositive)
12537fa27ce4SDimitry Andric Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
12547fa27ce4SDimitry Andric else if (RealMul.IsPositive && !ImagMul.IsPositive)
12557fa27ce4SDimitry Andric Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
12567fa27ce4SDimitry Andric else
12577fa27ce4SDimitry Andric continue;
12587fa27ce4SDimitry Andric }
12597fa27ce4SDimitry Andric
12607fa27ce4SDimitry Andric LLVM_DEBUG({
12617fa27ce4SDimitry Andric dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
12627fa27ce4SDimitry Andric dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
12637fa27ce4SDimitry Andric dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
12647fa27ce4SDimitry Andric dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
12657fa27ce4SDimitry Andric dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
12667fa27ce4SDimitry Andric dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
12677fa27ce4SDimitry Andric });
12687fa27ce4SDimitry Andric
12697fa27ce4SDimitry Andric NodePtr NodeMul = prepareCompositeNode(
12707fa27ce4SDimitry Andric ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
12717fa27ce4SDimitry Andric NodeMul->Rotation = Rotation;
12727fa27ce4SDimitry Andric NodeMul->addOperand(NodeA);
12737fa27ce4SDimitry Andric NodeMul->addOperand(NodeB);
12747fa27ce4SDimitry Andric if (Result)
12757fa27ce4SDimitry Andric NodeMul->addOperand(Result);
12767fa27ce4SDimitry Andric submitCompositeNode(NodeMul);
12777fa27ce4SDimitry Andric Result = NodeMul;
12787fa27ce4SDimitry Andric ProcessedReal[PMI.RealIdx] = true;
12797fa27ce4SDimitry Andric ProcessedImag[PMI.ImagIdx] = true;
12807fa27ce4SDimitry Andric }
12817fa27ce4SDimitry Andric
12827fa27ce4SDimitry Andric // Ensure all products have been processed, if not return nullptr.
12837fa27ce4SDimitry Andric if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
12847fa27ce4SDimitry Andric !all_of(ProcessedImag, [](bool V) { return V; })) {
12857fa27ce4SDimitry Andric
12867fa27ce4SDimitry Andric // Dump debug information about which partial multiplications are not
12877fa27ce4SDimitry Andric // processed.
12887fa27ce4SDimitry Andric LLVM_DEBUG({
12897fa27ce4SDimitry Andric dbgs() << "Unprocessed products (Real):\n";
12907fa27ce4SDimitry Andric for (size_t i = 0; i < ProcessedReal.size(); ++i) {
12917fa27ce4SDimitry Andric if (!ProcessedReal[i])
12927fa27ce4SDimitry Andric dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
12937fa27ce4SDimitry Andric << *RealMuls[i].Multiplier << " multiplied by "
12947fa27ce4SDimitry Andric << *RealMuls[i].Multiplicand << "\n";
12957fa27ce4SDimitry Andric }
12967fa27ce4SDimitry Andric dbgs() << "Unprocessed products (Imag):\n";
12977fa27ce4SDimitry Andric for (size_t i = 0; i < ProcessedImag.size(); ++i) {
12987fa27ce4SDimitry Andric if (!ProcessedImag[i])
12997fa27ce4SDimitry Andric dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
13007fa27ce4SDimitry Andric << *ImagMuls[i].Multiplier << " multiplied by "
13017fa27ce4SDimitry Andric << *ImagMuls[i].Multiplicand << "\n";
13027fa27ce4SDimitry Andric }
13037fa27ce4SDimitry Andric });
13047fa27ce4SDimitry Andric return nullptr;
13057fa27ce4SDimitry Andric }
13067fa27ce4SDimitry Andric
13077fa27ce4SDimitry Andric return Result;
13087fa27ce4SDimitry Andric }
13097fa27ce4SDimitry Andric
13107fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyAdditions(std::list<Addend> & RealAddends,std::list<Addend> & ImagAddends,std::optional<FastMathFlags> Flags,NodePtr Accumulator=nullptr)13117fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifyAdditions(
13127fa27ce4SDimitry Andric std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
13137fa27ce4SDimitry Andric std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
13147fa27ce4SDimitry Andric if (RealAddends.size() != ImagAddends.size())
13157fa27ce4SDimitry Andric return nullptr;
13167fa27ce4SDimitry Andric
13177fa27ce4SDimitry Andric NodePtr Result;
13187fa27ce4SDimitry Andric // If we have accumulator use it as first addend
13197fa27ce4SDimitry Andric if (Accumulator)
13207fa27ce4SDimitry Andric Result = Accumulator;
13217fa27ce4SDimitry Andric // Otherwise find an element with both positive real and imaginary parts.
13227fa27ce4SDimitry Andric else
13237fa27ce4SDimitry Andric Result = extractPositiveAddend(RealAddends, ImagAddends);
13247fa27ce4SDimitry Andric
13257fa27ce4SDimitry Andric if (!Result)
13267fa27ce4SDimitry Andric return nullptr;
13277fa27ce4SDimitry Andric
13287fa27ce4SDimitry Andric while (!RealAddends.empty()) {
13297fa27ce4SDimitry Andric auto ItR = RealAddends.begin();
13307fa27ce4SDimitry Andric auto [R, IsPositiveR] = *ItR;
13317fa27ce4SDimitry Andric
13327fa27ce4SDimitry Andric bool FoundImag = false;
13337fa27ce4SDimitry Andric for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
13347fa27ce4SDimitry Andric auto [I, IsPositiveI] = *ItI;
13357fa27ce4SDimitry Andric ComplexDeinterleavingRotation Rotation;
13367fa27ce4SDimitry Andric if (IsPositiveR && IsPositiveI)
13377fa27ce4SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_0;
13387fa27ce4SDimitry Andric else if (!IsPositiveR && IsPositiveI)
13397fa27ce4SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_90;
13407fa27ce4SDimitry Andric else if (!IsPositiveR && !IsPositiveI)
13417fa27ce4SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_180;
13427fa27ce4SDimitry Andric else
13437fa27ce4SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_270;
13447fa27ce4SDimitry Andric
13457fa27ce4SDimitry Andric NodePtr AddNode;
13467fa27ce4SDimitry Andric if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
13477fa27ce4SDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_180) {
13487fa27ce4SDimitry Andric AddNode = identifyNode(R, I);
13497fa27ce4SDimitry Andric } else {
13507fa27ce4SDimitry Andric AddNode = identifyNode(I, R);
13517fa27ce4SDimitry Andric }
13527fa27ce4SDimitry Andric if (AddNode) {
13537fa27ce4SDimitry Andric LLVM_DEBUG({
13547fa27ce4SDimitry Andric dbgs() << "Identified addition:\n";
13557fa27ce4SDimitry Andric dbgs().indent(4) << "X: " << *R << "\n";
13567fa27ce4SDimitry Andric dbgs().indent(4) << "Y: " << *I << "\n";
13577fa27ce4SDimitry Andric dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
13587fa27ce4SDimitry Andric });
13597fa27ce4SDimitry Andric
13607fa27ce4SDimitry Andric NodePtr TmpNode;
13617fa27ce4SDimitry Andric if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
13627fa27ce4SDimitry Andric TmpNode = prepareCompositeNode(
13637fa27ce4SDimitry Andric ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
13647fa27ce4SDimitry Andric if (Flags) {
13657fa27ce4SDimitry Andric TmpNode->Opcode = Instruction::FAdd;
13667fa27ce4SDimitry Andric TmpNode->Flags = *Flags;
13677fa27ce4SDimitry Andric } else {
13687fa27ce4SDimitry Andric TmpNode->Opcode = Instruction::Add;
13697fa27ce4SDimitry Andric }
13707fa27ce4SDimitry Andric } else if (Rotation ==
13717fa27ce4SDimitry Andric llvm::ComplexDeinterleavingRotation::Rotation_180) {
13727fa27ce4SDimitry Andric TmpNode = prepareCompositeNode(
13737fa27ce4SDimitry Andric ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
13747fa27ce4SDimitry Andric if (Flags) {
13757fa27ce4SDimitry Andric TmpNode->Opcode = Instruction::FSub;
13767fa27ce4SDimitry Andric TmpNode->Flags = *Flags;
13777fa27ce4SDimitry Andric } else {
13787fa27ce4SDimitry Andric TmpNode->Opcode = Instruction::Sub;
13797fa27ce4SDimitry Andric }
13807fa27ce4SDimitry Andric } else {
13817fa27ce4SDimitry Andric TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
13827fa27ce4SDimitry Andric nullptr, nullptr);
13837fa27ce4SDimitry Andric TmpNode->Rotation = Rotation;
13847fa27ce4SDimitry Andric }
13857fa27ce4SDimitry Andric
13867fa27ce4SDimitry Andric TmpNode->addOperand(Result);
13877fa27ce4SDimitry Andric TmpNode->addOperand(AddNode);
13887fa27ce4SDimitry Andric submitCompositeNode(TmpNode);
13897fa27ce4SDimitry Andric Result = TmpNode;
13907fa27ce4SDimitry Andric RealAddends.erase(ItR);
13917fa27ce4SDimitry Andric ImagAddends.erase(ItI);
13927fa27ce4SDimitry Andric FoundImag = true;
13937fa27ce4SDimitry Andric break;
13947fa27ce4SDimitry Andric }
13957fa27ce4SDimitry Andric }
13967fa27ce4SDimitry Andric if (!FoundImag)
13977fa27ce4SDimitry Andric return nullptr;
13987fa27ce4SDimitry Andric }
13997fa27ce4SDimitry Andric return Result;
14007fa27ce4SDimitry Andric }
14017fa27ce4SDimitry Andric
14027fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
extractPositiveAddend(std::list<Addend> & RealAddends,std::list<Addend> & ImagAddends)14037fa27ce4SDimitry Andric ComplexDeinterleavingGraph::extractPositiveAddend(
14047fa27ce4SDimitry Andric std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
14057fa27ce4SDimitry Andric for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
14067fa27ce4SDimitry Andric for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
14077fa27ce4SDimitry Andric auto [R, IsPositiveR] = *ItR;
14087fa27ce4SDimitry Andric auto [I, IsPositiveI] = *ItI;
14097fa27ce4SDimitry Andric if (IsPositiveR && IsPositiveI) {
14107fa27ce4SDimitry Andric auto Result = identifyNode(R, I);
14117fa27ce4SDimitry Andric if (Result) {
14127fa27ce4SDimitry Andric RealAddends.erase(ItR);
14137fa27ce4SDimitry Andric ImagAddends.erase(ItI);
14147fa27ce4SDimitry Andric return Result;
14157fa27ce4SDimitry Andric }
14167fa27ce4SDimitry Andric }
14177fa27ce4SDimitry Andric }
14187fa27ce4SDimitry Andric }
14197fa27ce4SDimitry Andric return nullptr;
14207fa27ce4SDimitry Andric }
14217fa27ce4SDimitry Andric
identifyNodes(Instruction * RootI)14227fa27ce4SDimitry Andric bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
14237fa27ce4SDimitry Andric // This potential root instruction might already have been recognized as
14247fa27ce4SDimitry Andric // reduction. Because RootToNode maps both Real and Imaginary parts to
14257fa27ce4SDimitry Andric // CompositeNode we should choose only one either Real or Imag instruction to
14267fa27ce4SDimitry Andric // use as an anchor for generating complex instruction.
14277fa27ce4SDimitry Andric auto It = RootToNode.find(RootI);
1428b1c73532SDimitry Andric if (It != RootToNode.end()) {
1429b1c73532SDimitry Andric auto RootNode = It->second;
1430b1c73532SDimitry Andric assert(RootNode->Operation ==
1431b1c73532SDimitry Andric ComplexDeinterleavingOperation::ReductionOperation);
1432b1c73532SDimitry Andric // Find out which part, Real or Imag, comes later, and only if we come to
1433b1c73532SDimitry Andric // the latest part, add it to OrderedRoots.
1434b1c73532SDimitry Andric auto *R = cast<Instruction>(RootNode->Real);
1435b1c73532SDimitry Andric auto *I = cast<Instruction>(RootNode->Imag);
1436b1c73532SDimitry Andric auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
1437b1c73532SDimitry Andric if (ReplacementAnchor != RootI)
1438b1c73532SDimitry Andric return false;
14397fa27ce4SDimitry Andric OrderedRoots.push_back(RootI);
14407fa27ce4SDimitry Andric return true;
14417fa27ce4SDimitry Andric }
14427fa27ce4SDimitry Andric
14437fa27ce4SDimitry Andric auto RootNode = identifyRoot(RootI);
14447fa27ce4SDimitry Andric if (!RootNode)
14457fa27ce4SDimitry Andric return false;
14467fa27ce4SDimitry Andric
14477fa27ce4SDimitry Andric LLVM_DEBUG({
14487fa27ce4SDimitry Andric Function *F = RootI->getFunction();
14497fa27ce4SDimitry Andric BasicBlock *B = RootI->getParent();
14507fa27ce4SDimitry Andric dbgs() << "Complex deinterleaving graph for " << F->getName()
14517fa27ce4SDimitry Andric << "::" << B->getName() << ".\n";
14527fa27ce4SDimitry Andric dump(dbgs());
14537fa27ce4SDimitry Andric dbgs() << "\n";
14547fa27ce4SDimitry Andric });
14557fa27ce4SDimitry Andric RootToNode[RootI] = RootNode;
14567fa27ce4SDimitry Andric OrderedRoots.push_back(RootI);
14577fa27ce4SDimitry Andric return true;
14587fa27ce4SDimitry Andric }
14597fa27ce4SDimitry Andric
collectPotentialReductions(BasicBlock * B)14607fa27ce4SDimitry Andric bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
14617fa27ce4SDimitry Andric bool FoundPotentialReduction = false;
14627fa27ce4SDimitry Andric
14637fa27ce4SDimitry Andric auto *Br = dyn_cast<BranchInst>(B->getTerminator());
14647fa27ce4SDimitry Andric if (!Br || Br->getNumSuccessors() != 2)
14657fa27ce4SDimitry Andric return false;
14667fa27ce4SDimitry Andric
14677fa27ce4SDimitry Andric // Identify simple one-block loop
14687fa27ce4SDimitry Andric if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
14697fa27ce4SDimitry Andric return false;
14707fa27ce4SDimitry Andric
14717fa27ce4SDimitry Andric SmallVector<PHINode *> PHIs;
14727fa27ce4SDimitry Andric for (auto &PHI : B->phis()) {
14737fa27ce4SDimitry Andric if (PHI.getNumIncomingValues() != 2)
14747fa27ce4SDimitry Andric continue;
14757fa27ce4SDimitry Andric
14767fa27ce4SDimitry Andric if (!PHI.getType()->isVectorTy())
14777fa27ce4SDimitry Andric continue;
14787fa27ce4SDimitry Andric
14797fa27ce4SDimitry Andric auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
14807fa27ce4SDimitry Andric if (!ReductionOp)
14817fa27ce4SDimitry Andric continue;
14827fa27ce4SDimitry Andric
14837fa27ce4SDimitry Andric // Check if final instruction is reduced outside of current block
14847fa27ce4SDimitry Andric Instruction *FinalReduction = nullptr;
14857fa27ce4SDimitry Andric auto NumUsers = 0u;
14867fa27ce4SDimitry Andric for (auto *U : ReductionOp->users()) {
14877fa27ce4SDimitry Andric ++NumUsers;
14887fa27ce4SDimitry Andric if (U == &PHI)
14897fa27ce4SDimitry Andric continue;
14907fa27ce4SDimitry Andric FinalReduction = dyn_cast<Instruction>(U);
14917fa27ce4SDimitry Andric }
14927fa27ce4SDimitry Andric
14937fa27ce4SDimitry Andric if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
14947fa27ce4SDimitry Andric isa<PHINode>(FinalReduction))
14957fa27ce4SDimitry Andric continue;
14967fa27ce4SDimitry Andric
14977fa27ce4SDimitry Andric ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
14987fa27ce4SDimitry Andric BackEdge = B;
14997fa27ce4SDimitry Andric auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
15007fa27ce4SDimitry Andric auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
15017fa27ce4SDimitry Andric Incoming = PHI.getIncomingBlock(IncomingIdx);
15027fa27ce4SDimitry Andric FoundPotentialReduction = true;
15037fa27ce4SDimitry Andric
15047fa27ce4SDimitry Andric // If the initial value of PHINode is an Instruction, consider it a leaf
15057fa27ce4SDimitry Andric // value of a complex deinterleaving graph.
15067fa27ce4SDimitry Andric if (auto *InitPHI =
15077fa27ce4SDimitry Andric dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
15087fa27ce4SDimitry Andric FinalInstructions.insert(InitPHI);
15097fa27ce4SDimitry Andric }
15107fa27ce4SDimitry Andric return FoundPotentialReduction;
15117fa27ce4SDimitry Andric }
15127fa27ce4SDimitry Andric
identifyReductionNodes()15137fa27ce4SDimitry Andric void ComplexDeinterleavingGraph::identifyReductionNodes() {
15147fa27ce4SDimitry Andric SmallVector<bool> Processed(ReductionInfo.size(), false);
15157fa27ce4SDimitry Andric SmallVector<Instruction *> OperationInstruction;
15167fa27ce4SDimitry Andric for (auto &P : ReductionInfo)
15177fa27ce4SDimitry Andric OperationInstruction.push_back(P.first);
15187fa27ce4SDimitry Andric
15197fa27ce4SDimitry Andric // Identify a complex computation by evaluating two reduction operations that
15207fa27ce4SDimitry Andric // potentially could be involved
15217fa27ce4SDimitry Andric for (size_t i = 0; i < OperationInstruction.size(); ++i) {
15227fa27ce4SDimitry Andric if (Processed[i])
15237fa27ce4SDimitry Andric continue;
15247fa27ce4SDimitry Andric for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
15257fa27ce4SDimitry Andric if (Processed[j])
15267fa27ce4SDimitry Andric continue;
15277fa27ce4SDimitry Andric
15287fa27ce4SDimitry Andric auto *Real = OperationInstruction[i];
15297fa27ce4SDimitry Andric auto *Imag = OperationInstruction[j];
15307fa27ce4SDimitry Andric if (Real->getType() != Imag->getType())
15317fa27ce4SDimitry Andric continue;
15327fa27ce4SDimitry Andric
15337fa27ce4SDimitry Andric RealPHI = ReductionInfo[Real].first;
15347fa27ce4SDimitry Andric ImagPHI = ReductionInfo[Imag].first;
15357fa27ce4SDimitry Andric PHIsFound = false;
15367fa27ce4SDimitry Andric auto Node = identifyNode(Real, Imag);
15377fa27ce4SDimitry Andric if (!Node) {
15387fa27ce4SDimitry Andric std::swap(Real, Imag);
15397fa27ce4SDimitry Andric std::swap(RealPHI, ImagPHI);
15407fa27ce4SDimitry Andric Node = identifyNode(Real, Imag);
15417fa27ce4SDimitry Andric }
15427fa27ce4SDimitry Andric
15437fa27ce4SDimitry Andric // If a node is identified and reduction PHINode is used in the chain of
15447fa27ce4SDimitry Andric // operations, mark its operation instructions as used to prevent
15457fa27ce4SDimitry Andric // re-identification and attach the node to the real part
15467fa27ce4SDimitry Andric if (Node && PHIsFound) {
15477fa27ce4SDimitry Andric LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
15487fa27ce4SDimitry Andric << *Real << " / " << *Imag << "\n");
15497fa27ce4SDimitry Andric Processed[i] = true;
15507fa27ce4SDimitry Andric Processed[j] = true;
15517fa27ce4SDimitry Andric auto RootNode = prepareCompositeNode(
15527fa27ce4SDimitry Andric ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
15537fa27ce4SDimitry Andric RootNode->addOperand(Node);
15547fa27ce4SDimitry Andric RootToNode[Real] = RootNode;
15557fa27ce4SDimitry Andric RootToNode[Imag] = RootNode;
15567fa27ce4SDimitry Andric submitCompositeNode(RootNode);
15577fa27ce4SDimitry Andric break;
15587fa27ce4SDimitry Andric }
15597fa27ce4SDimitry Andric }
15607fa27ce4SDimitry Andric }
15617fa27ce4SDimitry Andric
15627fa27ce4SDimitry Andric RealPHI = nullptr;
15637fa27ce4SDimitry Andric ImagPHI = nullptr;
15647fa27ce4SDimitry Andric }
15657fa27ce4SDimitry Andric
checkNodes()15667fa27ce4SDimitry Andric bool ComplexDeinterleavingGraph::checkNodes() {
15677fa27ce4SDimitry Andric // Collect all instructions from roots to leaves
15687fa27ce4SDimitry Andric SmallPtrSet<Instruction *, 16> AllInstructions;
15697fa27ce4SDimitry Andric SmallVector<Instruction *, 8> Worklist;
15707fa27ce4SDimitry Andric for (auto &Pair : RootToNode)
15717fa27ce4SDimitry Andric Worklist.push_back(Pair.first);
15727fa27ce4SDimitry Andric
15737fa27ce4SDimitry Andric // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
15747fa27ce4SDimitry Andric // chains
15757fa27ce4SDimitry Andric while (!Worklist.empty()) {
15767fa27ce4SDimitry Andric auto *I = Worklist.back();
15777fa27ce4SDimitry Andric Worklist.pop_back();
15787fa27ce4SDimitry Andric
15797fa27ce4SDimitry Andric if (!AllInstructions.insert(I).second)
15807fa27ce4SDimitry Andric continue;
15817fa27ce4SDimitry Andric
15827fa27ce4SDimitry Andric for (Value *Op : I->operands()) {
15837fa27ce4SDimitry Andric if (auto *OpI = dyn_cast<Instruction>(Op)) {
15847fa27ce4SDimitry Andric if (!FinalInstructions.count(I))
15857fa27ce4SDimitry Andric Worklist.emplace_back(OpI);
15867fa27ce4SDimitry Andric }
15877fa27ce4SDimitry Andric }
15887fa27ce4SDimitry Andric }
15897fa27ce4SDimitry Andric
15907fa27ce4SDimitry Andric // Find instructions that have users outside of chain
15917fa27ce4SDimitry Andric SmallVector<Instruction *, 2> OuterInstructions;
15927fa27ce4SDimitry Andric for (auto *I : AllInstructions) {
15937fa27ce4SDimitry Andric // Skip root nodes
15947fa27ce4SDimitry Andric if (RootToNode.count(I))
15957fa27ce4SDimitry Andric continue;
15967fa27ce4SDimitry Andric
15977fa27ce4SDimitry Andric for (User *U : I->users()) {
15987fa27ce4SDimitry Andric if (AllInstructions.count(cast<Instruction>(U)))
15997fa27ce4SDimitry Andric continue;
16007fa27ce4SDimitry Andric
16017fa27ce4SDimitry Andric // Found an instruction that is not used by XCMLA/XCADD chain
16027fa27ce4SDimitry Andric Worklist.emplace_back(I);
16037fa27ce4SDimitry Andric break;
16047fa27ce4SDimitry Andric }
16057fa27ce4SDimitry Andric }
16067fa27ce4SDimitry Andric
16077fa27ce4SDimitry Andric // If any instructions are found to be used outside, find and remove roots
16087fa27ce4SDimitry Andric // that somehow connect to those instructions.
16097fa27ce4SDimitry Andric SmallPtrSet<Instruction *, 16> Visited;
16107fa27ce4SDimitry Andric while (!Worklist.empty()) {
16117fa27ce4SDimitry Andric auto *I = Worklist.back();
16127fa27ce4SDimitry Andric Worklist.pop_back();
16137fa27ce4SDimitry Andric if (!Visited.insert(I).second)
16147fa27ce4SDimitry Andric continue;
16157fa27ce4SDimitry Andric
16167fa27ce4SDimitry Andric // Found an impacted root node. Removing it from the nodes to be
16177fa27ce4SDimitry Andric // deinterleaved
16187fa27ce4SDimitry Andric if (RootToNode.count(I)) {
16197fa27ce4SDimitry Andric LLVM_DEBUG(dbgs() << "Instruction " << *I
16207fa27ce4SDimitry Andric << " could be deinterleaved but its chain of complex "
16217fa27ce4SDimitry Andric "operations have an outside user\n");
16227fa27ce4SDimitry Andric RootToNode.erase(I);
16237fa27ce4SDimitry Andric }
16247fa27ce4SDimitry Andric
16257fa27ce4SDimitry Andric if (!AllInstructions.count(I) || FinalInstructions.count(I))
16267fa27ce4SDimitry Andric continue;
16277fa27ce4SDimitry Andric
16287fa27ce4SDimitry Andric for (User *U : I->users())
16297fa27ce4SDimitry Andric Worklist.emplace_back(cast<Instruction>(U));
16307fa27ce4SDimitry Andric
16317fa27ce4SDimitry Andric for (Value *Op : I->operands()) {
16327fa27ce4SDimitry Andric if (auto *OpI = dyn_cast<Instruction>(Op))
16337fa27ce4SDimitry Andric Worklist.emplace_back(OpI);
16347fa27ce4SDimitry Andric }
16357fa27ce4SDimitry Andric }
16367fa27ce4SDimitry Andric return !RootToNode.empty();
16377fa27ce4SDimitry Andric }
16387fa27ce4SDimitry Andric
16397fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyRoot(Instruction * RootI)16407fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
16417fa27ce4SDimitry Andric if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1642ac9a064cSDimitry Andric if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2)
16437fa27ce4SDimitry Andric return nullptr;
16447fa27ce4SDimitry Andric
16457fa27ce4SDimitry Andric auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
16467fa27ce4SDimitry Andric auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
16477fa27ce4SDimitry Andric if (!Real || !Imag)
16487fa27ce4SDimitry Andric return nullptr;
16497fa27ce4SDimitry Andric
16507fa27ce4SDimitry Andric return identifyNode(Real, Imag);
16517fa27ce4SDimitry Andric }
16527fa27ce4SDimitry Andric
16537fa27ce4SDimitry Andric auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
16547fa27ce4SDimitry Andric if (!SVI)
16557fa27ce4SDimitry Andric return nullptr;
16567fa27ce4SDimitry Andric
16577fa27ce4SDimitry Andric // Look for a shufflevector that takes separate vectors of the real and
16587fa27ce4SDimitry Andric // imaginary components and recombines them into a single vector.
16597fa27ce4SDimitry Andric if (!isInterleavingMask(SVI->getShuffleMask()))
16607fa27ce4SDimitry Andric return nullptr;
16617fa27ce4SDimitry Andric
16627fa27ce4SDimitry Andric Instruction *Real;
16637fa27ce4SDimitry Andric Instruction *Imag;
16647fa27ce4SDimitry Andric if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
16657fa27ce4SDimitry Andric return nullptr;
16667fa27ce4SDimitry Andric
16677fa27ce4SDimitry Andric return identifyNode(Real, Imag);
16687fa27ce4SDimitry Andric }
16697fa27ce4SDimitry Andric
16707fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyDeinterleave(Instruction * Real,Instruction * Imag)16717fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
16727fa27ce4SDimitry Andric Instruction *Imag) {
16737fa27ce4SDimitry Andric Instruction *I = nullptr;
16747fa27ce4SDimitry Andric Value *FinalValue = nullptr;
16757fa27ce4SDimitry Andric if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
16767fa27ce4SDimitry Andric match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1677ac9a064cSDimitry Andric match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>(
16787fa27ce4SDimitry Andric m_Value(FinalValue)))) {
16797fa27ce4SDimitry Andric NodePtr PlaceholderNode = prepareCompositeNode(
16807fa27ce4SDimitry Andric llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
16817fa27ce4SDimitry Andric PlaceholderNode->ReplacementNode = FinalValue;
16827fa27ce4SDimitry Andric FinalInstructions.insert(Real);
16837fa27ce4SDimitry Andric FinalInstructions.insert(Imag);
16847fa27ce4SDimitry Andric return submitCompositeNode(PlaceholderNode);
16857fa27ce4SDimitry Andric }
16867fa27ce4SDimitry Andric
1687e3b55780SDimitry Andric auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1688e3b55780SDimitry Andric auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
16897fa27ce4SDimitry Andric if (!RealShuffle || !ImagShuffle) {
16907fa27ce4SDimitry Andric if (RealShuffle || ImagShuffle)
16917fa27ce4SDimitry Andric LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
16927fa27ce4SDimitry Andric return nullptr;
16937fa27ce4SDimitry Andric }
16947fa27ce4SDimitry Andric
1695e3b55780SDimitry Andric Value *RealOp1 = RealShuffle->getOperand(1);
1696e3b55780SDimitry Andric if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1697e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1698e3b55780SDimitry Andric return nullptr;
1699e3b55780SDimitry Andric }
1700e3b55780SDimitry Andric Value *ImagOp1 = ImagShuffle->getOperand(1);
1701e3b55780SDimitry Andric if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1702e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1703e3b55780SDimitry Andric return nullptr;
1704e3b55780SDimitry Andric }
1705e3b55780SDimitry Andric
1706e3b55780SDimitry Andric Value *RealOp0 = RealShuffle->getOperand(0);
1707e3b55780SDimitry Andric Value *ImagOp0 = ImagShuffle->getOperand(0);
1708e3b55780SDimitry Andric
1709e3b55780SDimitry Andric if (RealOp0 != ImagOp0) {
1710e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1711e3b55780SDimitry Andric return nullptr;
1712e3b55780SDimitry Andric }
1713e3b55780SDimitry Andric
1714e3b55780SDimitry Andric ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1715e3b55780SDimitry Andric ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1716e3b55780SDimitry Andric if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1717e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1718e3b55780SDimitry Andric return nullptr;
1719e3b55780SDimitry Andric }
1720e3b55780SDimitry Andric
1721e3b55780SDimitry Andric if (RealMask[0] != 0 || ImagMask[0] != 1) {
1722e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1723e3b55780SDimitry Andric return nullptr;
1724e3b55780SDimitry Andric }
1725e3b55780SDimitry Andric
1726e3b55780SDimitry Andric // Type checking, the shuffle type should be a vector type of the same
1727e3b55780SDimitry Andric // scalar type, but half the size
1728e3b55780SDimitry Andric auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1729e3b55780SDimitry Andric Value *Op = Shuffle->getOperand(0);
1730e3b55780SDimitry Andric auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1731e3b55780SDimitry Andric auto *OpTy = cast<FixedVectorType>(Op->getType());
1732e3b55780SDimitry Andric
1733e3b55780SDimitry Andric if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1734e3b55780SDimitry Andric return false;
1735e3b55780SDimitry Andric if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1736e3b55780SDimitry Andric return false;
1737e3b55780SDimitry Andric
1738e3b55780SDimitry Andric return true;
1739e3b55780SDimitry Andric };
1740e3b55780SDimitry Andric
1741e3b55780SDimitry Andric auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1742e3b55780SDimitry Andric if (!CheckType(Shuffle))
1743e3b55780SDimitry Andric return false;
1744e3b55780SDimitry Andric
1745e3b55780SDimitry Andric ArrayRef<int> Mask = Shuffle->getShuffleMask();
1746e3b55780SDimitry Andric int Last = *Mask.rbegin();
1747e3b55780SDimitry Andric
1748e3b55780SDimitry Andric Value *Op = Shuffle->getOperand(0);
1749e3b55780SDimitry Andric auto *OpTy = cast<FixedVectorType>(Op->getType());
1750e3b55780SDimitry Andric int NumElements = OpTy->getNumElements();
1751e3b55780SDimitry Andric
1752e3b55780SDimitry Andric // Ensure that the deinterleaving shuffle only pulls from the first
1753e3b55780SDimitry Andric // shuffle operand.
1754e3b55780SDimitry Andric return Last < NumElements;
1755e3b55780SDimitry Andric };
1756e3b55780SDimitry Andric
1757e3b55780SDimitry Andric if (RealShuffle->getType() != ImagShuffle->getType()) {
1758e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1759e3b55780SDimitry Andric return nullptr;
1760e3b55780SDimitry Andric }
1761e3b55780SDimitry Andric if (!CheckDeinterleavingShuffle(RealShuffle)) {
1762e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1763e3b55780SDimitry Andric return nullptr;
1764e3b55780SDimitry Andric }
1765e3b55780SDimitry Andric if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1766e3b55780SDimitry Andric LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1767e3b55780SDimitry Andric return nullptr;
1768e3b55780SDimitry Andric }
1769e3b55780SDimitry Andric
1770e3b55780SDimitry Andric NodePtr PlaceholderNode =
17717fa27ce4SDimitry Andric prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1772e3b55780SDimitry Andric RealShuffle, ImagShuffle);
1773e3b55780SDimitry Andric PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
17747fa27ce4SDimitry Andric FinalInstructions.insert(RealShuffle);
17757fa27ce4SDimitry Andric FinalInstructions.insert(ImagShuffle);
1776e3b55780SDimitry Andric return submitCompositeNode(PlaceholderNode);
1777e3b55780SDimitry Andric }
1778e3b55780SDimitry Andric
17797fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifySplat(Value * R,Value * I)17807fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
17817fa27ce4SDimitry Andric auto IsSplat = [](Value *V) -> bool {
17827fa27ce4SDimitry Andric // Fixed-width vector with constants
17837fa27ce4SDimitry Andric if (isa<ConstantDataVector>(V))
17847fa27ce4SDimitry Andric return true;
1785e3b55780SDimitry Andric
17867fa27ce4SDimitry Andric VectorType *VTy;
17877fa27ce4SDimitry Andric ArrayRef<int> Mask;
17887fa27ce4SDimitry Andric // Splats are represented differently depending on whether the repeated
17897fa27ce4SDimitry Andric // value is a constant or an Instruction
17907fa27ce4SDimitry Andric if (auto *Const = dyn_cast<ConstantExpr>(V)) {
17917fa27ce4SDimitry Andric if (Const->getOpcode() != Instruction::ShuffleVector)
1792e3b55780SDimitry Andric return false;
17937fa27ce4SDimitry Andric VTy = cast<VectorType>(Const->getType());
17947fa27ce4SDimitry Andric Mask = Const->getShuffleMask();
17957fa27ce4SDimitry Andric } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
17967fa27ce4SDimitry Andric VTy = Shuf->getType();
17977fa27ce4SDimitry Andric Mask = Shuf->getShuffleMask();
17987fa27ce4SDimitry Andric } else {
1799e3b55780SDimitry Andric return false;
1800e3b55780SDimitry Andric }
18017fa27ce4SDimitry Andric
18027fa27ce4SDimitry Andric // When the data type is <1 x Type>, it's not possible to differentiate
18037fa27ce4SDimitry Andric // between the ComplexDeinterleaving::Deinterleave and
18047fa27ce4SDimitry Andric // ComplexDeinterleaving::Splat operations.
18057fa27ce4SDimitry Andric if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
18067fa27ce4SDimitry Andric return false;
18077fa27ce4SDimitry Andric
18087fa27ce4SDimitry Andric return all_equal(Mask) && Mask[0] == 0;
18097fa27ce4SDimitry Andric };
18107fa27ce4SDimitry Andric
18117fa27ce4SDimitry Andric if (!IsSplat(R) || !IsSplat(I))
18127fa27ce4SDimitry Andric return nullptr;
18137fa27ce4SDimitry Andric
18147fa27ce4SDimitry Andric auto *Real = dyn_cast<Instruction>(R);
18157fa27ce4SDimitry Andric auto *Imag = dyn_cast<Instruction>(I);
18167fa27ce4SDimitry Andric if ((!Real && Imag) || (Real && !Imag))
18177fa27ce4SDimitry Andric return nullptr;
18187fa27ce4SDimitry Andric
18197fa27ce4SDimitry Andric if (Real && Imag) {
18207fa27ce4SDimitry Andric // Non-constant splats should be in the same basic block
18217fa27ce4SDimitry Andric if (Real->getParent() != Imag->getParent())
18227fa27ce4SDimitry Andric return nullptr;
18237fa27ce4SDimitry Andric
18247fa27ce4SDimitry Andric FinalInstructions.insert(Real);
18257fa27ce4SDimitry Andric FinalInstructions.insert(Imag);
1826e3b55780SDimitry Andric }
18277fa27ce4SDimitry Andric NodePtr PlaceholderNode =
18287fa27ce4SDimitry Andric prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
18297fa27ce4SDimitry Andric return submitCompositeNode(PlaceholderNode);
1830e3b55780SDimitry Andric }
1831e3b55780SDimitry Andric
18327fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyPHINode(Instruction * Real,Instruction * Imag)18337fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
18347fa27ce4SDimitry Andric Instruction *Imag) {
18357fa27ce4SDimitry Andric if (Real != RealPHI || Imag != ImagPHI)
18367fa27ce4SDimitry Andric return nullptr;
18377fa27ce4SDimitry Andric
18387fa27ce4SDimitry Andric PHIsFound = true;
18397fa27ce4SDimitry Andric NodePtr PlaceholderNode = prepareCompositeNode(
18407fa27ce4SDimitry Andric ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
18417fa27ce4SDimitry Andric return submitCompositeNode(PlaceholderNode);
18427fa27ce4SDimitry Andric }
18437fa27ce4SDimitry Andric
18447fa27ce4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifySelectNode(Instruction * Real,Instruction * Imag)18457fa27ce4SDimitry Andric ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
18467fa27ce4SDimitry Andric Instruction *Imag) {
18477fa27ce4SDimitry Andric auto *SelectReal = dyn_cast<SelectInst>(Real);
18487fa27ce4SDimitry Andric auto *SelectImag = dyn_cast<SelectInst>(Imag);
18497fa27ce4SDimitry Andric if (!SelectReal || !SelectImag)
18507fa27ce4SDimitry Andric return nullptr;
18517fa27ce4SDimitry Andric
18527fa27ce4SDimitry Andric Instruction *MaskA, *MaskB;
18537fa27ce4SDimitry Andric Instruction *AR, *AI, *RA, *BI;
18547fa27ce4SDimitry Andric if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
18557fa27ce4SDimitry Andric m_Instruction(RA))) ||
18567fa27ce4SDimitry Andric !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
18577fa27ce4SDimitry Andric m_Instruction(BI))))
18587fa27ce4SDimitry Andric return nullptr;
18597fa27ce4SDimitry Andric
18607fa27ce4SDimitry Andric if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
18617fa27ce4SDimitry Andric return nullptr;
18627fa27ce4SDimitry Andric
18637fa27ce4SDimitry Andric if (!MaskA->getType()->isVectorTy())
18647fa27ce4SDimitry Andric return nullptr;
18657fa27ce4SDimitry Andric
18667fa27ce4SDimitry Andric auto NodeA = identifyNode(AR, AI);
18677fa27ce4SDimitry Andric if (!NodeA)
18687fa27ce4SDimitry Andric return nullptr;
18697fa27ce4SDimitry Andric
18707fa27ce4SDimitry Andric auto NodeB = identifyNode(RA, BI);
18717fa27ce4SDimitry Andric if (!NodeB)
18727fa27ce4SDimitry Andric return nullptr;
18737fa27ce4SDimitry Andric
18747fa27ce4SDimitry Andric NodePtr PlaceholderNode = prepareCompositeNode(
18757fa27ce4SDimitry Andric ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
18767fa27ce4SDimitry Andric PlaceholderNode->addOperand(NodeA);
18777fa27ce4SDimitry Andric PlaceholderNode->addOperand(NodeB);
18787fa27ce4SDimitry Andric FinalInstructions.insert(MaskA);
18797fa27ce4SDimitry Andric FinalInstructions.insert(MaskB);
18807fa27ce4SDimitry Andric return submitCompositeNode(PlaceholderNode);
18817fa27ce4SDimitry Andric }
18827fa27ce4SDimitry Andric
replaceSymmetricNode(IRBuilderBase & B,unsigned Opcode,std::optional<FastMathFlags> Flags,Value * InputA,Value * InputB)18837fa27ce4SDimitry Andric static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
18847fa27ce4SDimitry Andric std::optional<FastMathFlags> Flags,
18857fa27ce4SDimitry Andric Value *InputA, Value *InputB) {
18867fa27ce4SDimitry Andric Value *I;
18877fa27ce4SDimitry Andric switch (Opcode) {
18887fa27ce4SDimitry Andric case Instruction::FNeg:
18897fa27ce4SDimitry Andric I = B.CreateFNeg(InputA);
18907fa27ce4SDimitry Andric break;
18917fa27ce4SDimitry Andric case Instruction::FAdd:
18927fa27ce4SDimitry Andric I = B.CreateFAdd(InputA, InputB);
18937fa27ce4SDimitry Andric break;
18947fa27ce4SDimitry Andric case Instruction::Add:
18957fa27ce4SDimitry Andric I = B.CreateAdd(InputA, InputB);
18967fa27ce4SDimitry Andric break;
18977fa27ce4SDimitry Andric case Instruction::FSub:
18987fa27ce4SDimitry Andric I = B.CreateFSub(InputA, InputB);
18997fa27ce4SDimitry Andric break;
19007fa27ce4SDimitry Andric case Instruction::Sub:
19017fa27ce4SDimitry Andric I = B.CreateSub(InputA, InputB);
19027fa27ce4SDimitry Andric break;
19037fa27ce4SDimitry Andric case Instruction::FMul:
19047fa27ce4SDimitry Andric I = B.CreateFMul(InputA, InputB);
19057fa27ce4SDimitry Andric break;
19067fa27ce4SDimitry Andric case Instruction::Mul:
19077fa27ce4SDimitry Andric I = B.CreateMul(InputA, InputB);
19087fa27ce4SDimitry Andric break;
19097fa27ce4SDimitry Andric default:
19107fa27ce4SDimitry Andric llvm_unreachable("Incorrect symmetric opcode");
19117fa27ce4SDimitry Andric }
19127fa27ce4SDimitry Andric if (Flags)
19137fa27ce4SDimitry Andric cast<Instruction>(I)->setFastMathFlags(*Flags);
19147fa27ce4SDimitry Andric return I;
19157fa27ce4SDimitry Andric }
19167fa27ce4SDimitry Andric
replaceNode(IRBuilderBase & Builder,RawNodePtr Node)19177fa27ce4SDimitry Andric Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
19187fa27ce4SDimitry Andric RawNodePtr Node) {
1919e3b55780SDimitry Andric if (Node->ReplacementNode)
1920e3b55780SDimitry Andric return Node->ReplacementNode;
1921e3b55780SDimitry Andric
19227fa27ce4SDimitry Andric auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
19237fa27ce4SDimitry Andric return Node->Operands.size() > Idx
19247fa27ce4SDimitry Andric ? replaceNode(Builder, Node->Operands[Idx])
19257fa27ce4SDimitry Andric : nullptr;
19267fa27ce4SDimitry Andric };
1927e3b55780SDimitry Andric
19287fa27ce4SDimitry Andric Value *ReplacementNode;
19297fa27ce4SDimitry Andric switch (Node->Operation) {
19307fa27ce4SDimitry Andric case ComplexDeinterleavingOperation::CAdd:
19317fa27ce4SDimitry Andric case ComplexDeinterleavingOperation::CMulPartial:
19327fa27ce4SDimitry Andric case ComplexDeinterleavingOperation::Symmetric: {
19337fa27ce4SDimitry Andric Value *Input0 = ReplaceOperandIfExist(Node, 0);
19347fa27ce4SDimitry Andric Value *Input1 = ReplaceOperandIfExist(Node, 1);
19357fa27ce4SDimitry Andric Value *Accumulator = ReplaceOperandIfExist(Node, 2);
19367fa27ce4SDimitry Andric assert(!Input1 || (Input0->getType() == Input1->getType() &&
19377fa27ce4SDimitry Andric "Node inputs need to be of the same type"));
19387fa27ce4SDimitry Andric assert(!Accumulator ||
19397fa27ce4SDimitry Andric (Input0->getType() == Accumulator->getType() &&
19407fa27ce4SDimitry Andric "Accumulator and input need to be of the same type"));
19417fa27ce4SDimitry Andric if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
19427fa27ce4SDimitry Andric ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
19437fa27ce4SDimitry Andric Input0, Input1);
19447fa27ce4SDimitry Andric else
19457fa27ce4SDimitry Andric ReplacementNode = TL->createComplexDeinterleavingIR(
19467fa27ce4SDimitry Andric Builder, Node->Operation, Node->Rotation, Input0, Input1,
19477fa27ce4SDimitry Andric Accumulator);
19487fa27ce4SDimitry Andric break;
19497fa27ce4SDimitry Andric }
19507fa27ce4SDimitry Andric case ComplexDeinterleavingOperation::Deinterleave:
19517fa27ce4SDimitry Andric llvm_unreachable("Deinterleave node should already have ReplacementNode");
19527fa27ce4SDimitry Andric break;
19537fa27ce4SDimitry Andric case ComplexDeinterleavingOperation::Splat: {
19547fa27ce4SDimitry Andric auto *NewTy = VectorType::getDoubleElementsVectorType(
19557fa27ce4SDimitry Andric cast<VectorType>(Node->Real->getType()));
19567fa27ce4SDimitry Andric auto *R = dyn_cast<Instruction>(Node->Real);
19577fa27ce4SDimitry Andric auto *I = dyn_cast<Instruction>(Node->Imag);
19587fa27ce4SDimitry Andric if (R && I) {
19597fa27ce4SDimitry Andric // Splats that are not constant are interleaved where they are located
19607fa27ce4SDimitry Andric Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
19617fa27ce4SDimitry Andric IRBuilder<> IRB(InsertPoint);
1962ac9a064cSDimitry Andric ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2,
19637fa27ce4SDimitry Andric NewTy, {Node->Real, Node->Imag});
1964ac9a064cSDimitry Andric } else {
1965ac9a064cSDimitry Andric ReplacementNode = Builder.CreateIntrinsic(
1966ac9a064cSDimitry Andric Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag});
19677fa27ce4SDimitry Andric }
19687fa27ce4SDimitry Andric break;
19697fa27ce4SDimitry Andric }
19707fa27ce4SDimitry Andric case ComplexDeinterleavingOperation::ReductionPHI: {
19717fa27ce4SDimitry Andric // If Operation is ReductionPHI, a new empty PHINode is created.
19727fa27ce4SDimitry Andric // It is filled later when the ReductionOperation is processed.
19737fa27ce4SDimitry Andric auto *VTy = cast<VectorType>(Node->Real->getType());
19747fa27ce4SDimitry Andric auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1975ac9a064cSDimitry Andric auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
19767fa27ce4SDimitry Andric OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
19777fa27ce4SDimitry Andric ReplacementNode = NewPHI;
19787fa27ce4SDimitry Andric break;
19797fa27ce4SDimitry Andric }
19807fa27ce4SDimitry Andric case ComplexDeinterleavingOperation::ReductionOperation:
19817fa27ce4SDimitry Andric ReplacementNode = replaceNode(Builder, Node->Operands[0]);
19827fa27ce4SDimitry Andric processReductionOperation(ReplacementNode, Node);
19837fa27ce4SDimitry Andric break;
19847fa27ce4SDimitry Andric case ComplexDeinterleavingOperation::ReductionSelect: {
19857fa27ce4SDimitry Andric auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
19867fa27ce4SDimitry Andric auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
19877fa27ce4SDimitry Andric auto *A = replaceNode(Builder, Node->Operands[0]);
19887fa27ce4SDimitry Andric auto *B = replaceNode(Builder, Node->Operands[1]);
19897fa27ce4SDimitry Andric auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
19907fa27ce4SDimitry Andric cast<VectorType>(MaskReal->getType()));
1991ac9a064cSDimitry Andric auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2,
19927fa27ce4SDimitry Andric NewMaskTy, {MaskReal, MaskImag});
19937fa27ce4SDimitry Andric ReplacementNode = Builder.CreateSelect(NewMask, A, B);
19947fa27ce4SDimitry Andric break;
19957fa27ce4SDimitry Andric }
19967fa27ce4SDimitry Andric }
1997e3b55780SDimitry Andric
19987fa27ce4SDimitry Andric assert(ReplacementNode && "Target failed to create Intrinsic call.");
1999e3b55780SDimitry Andric NumComplexTransformations += 1;
20007fa27ce4SDimitry Andric Node->ReplacementNode = ReplacementNode;
20017fa27ce4SDimitry Andric return ReplacementNode;
20027fa27ce4SDimitry Andric }
20037fa27ce4SDimitry Andric
processReductionOperation(Value * OperationReplacement,RawNodePtr Node)20047fa27ce4SDimitry Andric void ComplexDeinterleavingGraph::processReductionOperation(
20057fa27ce4SDimitry Andric Value *OperationReplacement, RawNodePtr Node) {
20067fa27ce4SDimitry Andric auto *Real = cast<Instruction>(Node->Real);
20077fa27ce4SDimitry Andric auto *Imag = cast<Instruction>(Node->Imag);
20087fa27ce4SDimitry Andric auto *OldPHIReal = ReductionInfo[Real].first;
20097fa27ce4SDimitry Andric auto *OldPHIImag = ReductionInfo[Imag].first;
20107fa27ce4SDimitry Andric auto *NewPHI = OldToNewPHI[OldPHIReal];
20117fa27ce4SDimitry Andric
20127fa27ce4SDimitry Andric auto *VTy = cast<VectorType>(Real->getType());
20137fa27ce4SDimitry Andric auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
20147fa27ce4SDimitry Andric
20157fa27ce4SDimitry Andric // We have to interleave initial origin values coming from IncomingBlock
20167fa27ce4SDimitry Andric Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
20177fa27ce4SDimitry Andric Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
20187fa27ce4SDimitry Andric
20197fa27ce4SDimitry Andric IRBuilder<> Builder(Incoming->getTerminator());
2020ac9a064cSDimitry Andric auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2021ac9a064cSDimitry Andric {InitReal, InitImag});
20227fa27ce4SDimitry Andric
20237fa27ce4SDimitry Andric NewPHI->addIncoming(NewInit, Incoming);
20247fa27ce4SDimitry Andric NewPHI->addIncoming(OperationReplacement, BackEdge);
20257fa27ce4SDimitry Andric
20267fa27ce4SDimitry Andric // Deinterleave complex vector outside of loop so that it can be finally
20277fa27ce4SDimitry Andric // reduced
20287fa27ce4SDimitry Andric auto *FinalReductionReal = ReductionInfo[Real].second;
20297fa27ce4SDimitry Andric auto *FinalReductionImag = ReductionInfo[Imag].second;
20307fa27ce4SDimitry Andric
20317fa27ce4SDimitry Andric Builder.SetInsertPoint(
20327fa27ce4SDimitry Andric &*FinalReductionReal->getParent()->getFirstInsertionPt());
2033ac9a064cSDimitry Andric auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2034ac9a064cSDimitry Andric OperationReplacement->getType(),
2035ac9a064cSDimitry Andric OperationReplacement);
20367fa27ce4SDimitry Andric
20377fa27ce4SDimitry Andric auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
20387fa27ce4SDimitry Andric FinalReductionReal->replaceUsesOfWith(Real, NewReal);
20397fa27ce4SDimitry Andric
20407fa27ce4SDimitry Andric Builder.SetInsertPoint(FinalReductionImag);
20417fa27ce4SDimitry Andric auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
20427fa27ce4SDimitry Andric FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2043e3b55780SDimitry Andric }
2044e3b55780SDimitry Andric
replaceNodes()2045e3b55780SDimitry Andric void ComplexDeinterleavingGraph::replaceNodes() {
20467fa27ce4SDimitry Andric SmallVector<Instruction *, 16> DeadInstrRoots;
20477fa27ce4SDimitry Andric for (auto *RootInstruction : OrderedRoots) {
20487fa27ce4SDimitry Andric // Check if this potential root went through check process and we can
20497fa27ce4SDimitry Andric // deinterleave it
20507fa27ce4SDimitry Andric if (!RootToNode.count(RootInstruction))
20517fa27ce4SDimitry Andric continue;
20527fa27ce4SDimitry Andric
20537fa27ce4SDimitry Andric IRBuilder<> Builder(RootInstruction);
20547fa27ce4SDimitry Andric auto RootNode = RootToNode[RootInstruction];
20557fa27ce4SDimitry Andric Value *R = replaceNode(Builder, RootNode.get());
20567fa27ce4SDimitry Andric
20577fa27ce4SDimitry Andric if (RootNode->Operation ==
20587fa27ce4SDimitry Andric ComplexDeinterleavingOperation::ReductionOperation) {
20597fa27ce4SDimitry Andric auto *RootReal = cast<Instruction>(RootNode->Real);
20607fa27ce4SDimitry Andric auto *RootImag = cast<Instruction>(RootNode->Imag);
20617fa27ce4SDimitry Andric ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
20627fa27ce4SDimitry Andric ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
20637fa27ce4SDimitry Andric DeadInstrRoots.push_back(cast<Instruction>(RootReal));
20647fa27ce4SDimitry Andric DeadInstrRoots.push_back(cast<Instruction>(RootImag));
20657fa27ce4SDimitry Andric } else {
20667fa27ce4SDimitry Andric assert(R && "Unable to find replacement for RootInstruction");
20677fa27ce4SDimitry Andric DeadInstrRoots.push_back(RootInstruction);
20687fa27ce4SDimitry Andric RootInstruction->replaceAllUsesWith(R);
20697fa27ce4SDimitry Andric }
2070e3b55780SDimitry Andric }
2071e3b55780SDimitry Andric
20727fa27ce4SDimitry Andric for (auto *I : DeadInstrRoots)
20737fa27ce4SDimitry Andric RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
2074e3b55780SDimitry Andric }
2075