xref: /src/contrib/llvm-project/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp (revision 06c3fb2749bda94cb5201f81ffdb8fa6c3161b2e)
101095a5dSDimitry Andric //===- ScalarEvolutionNormalization.cpp - See below -----------------------===//
2d7f7719eSRoman Divacky //
3e6d15924SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e6d15924SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5e6d15924SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d7f7719eSRoman Divacky //
7d7f7719eSRoman Divacky //===----------------------------------------------------------------------===//
8d7f7719eSRoman Divacky //
9d7f7719eSRoman Divacky // This file implements utilities for working with "normalized" expressions.
10d7f7719eSRoman Divacky // See the comments at the top of ScalarEvolutionNormalization.h for details.
11d7f7719eSRoman Divacky //
12d7f7719eSRoman Divacky //===----------------------------------------------------------------------===//
13d7f7719eSRoman Divacky 
147ab83427SDimitry Andric #include "llvm/Analysis/ScalarEvolutionNormalization.h"
15d7f7719eSRoman Divacky #include "llvm/Analysis/LoopInfo.h"
16145449b1SDimitry Andric #include "llvm/Analysis/ScalarEvolution.h"
17d7f7719eSRoman Divacky #include "llvm/Analysis/ScalarEvolutionExpressions.h"
18d7f7719eSRoman Divacky using namespace llvm;
19d7f7719eSRoman Divacky 
2071d5a254SDimitry Andric /// TransformKind - Different types of transformations that
2171d5a254SDimitry Andric /// TransformForPostIncUse can do.
2271d5a254SDimitry Andric enum TransformKind {
2371d5a254SDimitry Andric   /// Normalize - Normalize according to the given loops.
2471d5a254SDimitry Andric   Normalize,
2571d5a254SDimitry Andric   /// Denormalize - Perform the inverse transform on the expression with the
2671d5a254SDimitry Andric   /// given loop set.
2771d5a254SDimitry Andric   Denormalize
2830815c53SDimitry Andric };
2930815c53SDimitry Andric 
3071d5a254SDimitry Andric namespace {
3171d5a254SDimitry Andric struct NormalizeDenormalizeRewriter
3271d5a254SDimitry Andric     : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> {
3371d5a254SDimitry Andric   const TransformKind Kind;
3471d5a254SDimitry Andric 
3571d5a254SDimitry Andric   // NB! Pred is a function_ref.  Storing it here is okay only because
3671d5a254SDimitry Andric   // we're careful about the lifetime of NormalizeDenormalizeRewriter.
3771d5a254SDimitry Andric   const NormalizePredTy Pred;
3871d5a254SDimitry Andric 
NormalizeDenormalizeRewriter__anonc06632ba0111::NormalizeDenormalizeRewriter3971d5a254SDimitry Andric   NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred,
4071d5a254SDimitry Andric                                ScalarEvolution &SE)
4171d5a254SDimitry Andric       : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind),
4271d5a254SDimitry Andric         Pred(Pred) {}
4371d5a254SDimitry Andric   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr);
4471d5a254SDimitry Andric };
4530815c53SDimitry Andric } // namespace
4630815c53SDimitry Andric 
4771d5a254SDimitry Andric const SCEV *
visitAddRecExpr(const SCEVAddRecExpr * AR)4871d5a254SDimitry Andric NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) {
49d39c594dSDimitry Andric   SmallVector<const SCEV *, 8> Operands;
5071d5a254SDimitry Andric 
5171d5a254SDimitry Andric   transform(AR->operands(), std::back_inserter(Operands),
5271d5a254SDimitry Andric             [&](const SCEV *Op) { return visit(Op); });
5371d5a254SDimitry Andric 
5412f3ca4cSDimitry Andric   if (!Pred(AR))
5512f3ca4cSDimitry Andric     return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
5612f3ca4cSDimitry Andric 
5712f3ca4cSDimitry Andric   // Normalization and denormalization are fancy names for decrementing and
5812f3ca4cSDimitry Andric   // incrementing a SCEV expression with respect to a set of loops.  Since
5912f3ca4cSDimitry Andric   // Pred(AR) has returned true, we know we need to normalize or denormalize AR
6012f3ca4cSDimitry Andric   // with respect to its loop.
6112f3ca4cSDimitry Andric 
6212f3ca4cSDimitry Andric   if (Kind == Denormalize) {
6312f3ca4cSDimitry Andric     // Denormalization / "partial increment" is essentially the same as \c
6412f3ca4cSDimitry Andric     // SCEVAddRecExpr::getPostIncExpr.  Here we use an explicit loop to make the
6512f3ca4cSDimitry Andric     // symmetry with Normalization clear.
6612f3ca4cSDimitry Andric     for (int i = 0, e = Operands.size() - 1; i < e; i++)
6712f3ca4cSDimitry Andric       Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]);
6812f3ca4cSDimitry Andric   } else {
6912f3ca4cSDimitry Andric     assert(Kind == Normalize && "Only two possibilities!");
7012f3ca4cSDimitry Andric 
7112f3ca4cSDimitry Andric     // Normalization / "partial decrement" is a bit more subtle.  Since
7212f3ca4cSDimitry Andric     // incrementing a SCEV expression (in general) changes the step of the SCEV
7312f3ca4cSDimitry Andric     // expression as well, we cannot use the step of the current expression.
7412f3ca4cSDimitry Andric     // Instead, we have to use the step of the very expression we're trying to
7512f3ca4cSDimitry Andric     // compute!
765ca98fd9SDimitry Andric     //
7712f3ca4cSDimitry Andric     // We solve the issue by recursively building up the result, starting from
7812f3ca4cSDimitry Andric     // the "least significant" operand in the add recurrence:
7912f3ca4cSDimitry Andric     //
8012f3ca4cSDimitry Andric     // Base case:
8112f3ca4cSDimitry Andric     //   Single operand add recurrence.  It's its own normalization.
8212f3ca4cSDimitry Andric     //
8312f3ca4cSDimitry Andric     // N-operand case:
8412f3ca4cSDimitry Andric     //   {S_{N-1},+,S_{N-2},+,...,+,S_0} = S
8512f3ca4cSDimitry Andric     //
8612f3ca4cSDimitry Andric     //   Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its
8712f3ca4cSDimitry Andric     //   normalization by induction.  We subtract the normalized step
8812f3ca4cSDimitry Andric     //   recurrence from S_{N-1} to get the normalization of S.
8912f3ca4cSDimitry Andric 
9012f3ca4cSDimitry Andric     for (int i = Operands.size() - 2; i >= 0; i--)
9112f3ca4cSDimitry Andric       Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]);
92d7f7719eSRoman Divacky   }
9312f3ca4cSDimitry Andric 
9412f3ca4cSDimitry Andric   return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
95d7f7719eSRoman Divacky }
96d39c594dSDimitry Andric 
normalizeForPostIncUse(const SCEV * S,const PostIncLoopSet & Loops,ScalarEvolution & SE,bool CheckInvertible)9771d5a254SDimitry Andric const SCEV *llvm::normalizeForPostIncUse(const SCEV *S,
9871d5a254SDimitry Andric                                          const PostIncLoopSet &Loops,
997fa27ce4SDimitry Andric                                          ScalarEvolution &SE,
1007fa27ce4SDimitry Andric                                          bool CheckInvertible) {
1017fa27ce4SDimitry Andric   if (Loops.empty())
1027fa27ce4SDimitry Andric     return S;
10371d5a254SDimitry Andric   auto Pred = [&](const SCEVAddRecExpr *AR) {
10471d5a254SDimitry Andric     return Loops.count(AR->getLoop());
10571d5a254SDimitry Andric   };
1067fa27ce4SDimitry Andric   const SCEV *Normalized =
1077fa27ce4SDimitry Andric       NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
1087fa27ce4SDimitry Andric   const SCEV *Denormalized = denormalizeForPostIncUse(Normalized, Loops, SE);
1097fa27ce4SDimitry Andric   // If the normalized expression isn't invertible.
1107fa27ce4SDimitry Andric   if (CheckInvertible && Denormalized != S)
1117fa27ce4SDimitry Andric     return nullptr;
1127fa27ce4SDimitry Andric   return Normalized;
113d7f7719eSRoman Divacky }
114d39c594dSDimitry Andric 
normalizeForPostIncUseIf(const SCEV * S,NormalizePredTy Pred,ScalarEvolution & SE)11571d5a254SDimitry Andric const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred,
11671d5a254SDimitry Andric                                            ScalarEvolution &SE) {
11771d5a254SDimitry Andric   return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
118d7f7719eSRoman Divacky }
119d39c594dSDimitry Andric 
denormalizeForPostIncUse(const SCEV * S,const PostIncLoopSet & Loops,ScalarEvolution & SE)12071d5a254SDimitry Andric const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
12171d5a254SDimitry Andric                                            const PostIncLoopSet &Loops,
12271d5a254SDimitry Andric                                            ScalarEvolution &SE) {
1237fa27ce4SDimitry Andric   if (Loops.empty())
1247fa27ce4SDimitry Andric     return S;
12571d5a254SDimitry Andric   auto Pred = [&](const SCEVAddRecExpr *AR) {
12671d5a254SDimitry Andric     return Loops.count(AR->getLoop());
12771d5a254SDimitry Andric   };
12871d5a254SDimitry Andric   return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S);
12930815c53SDimitry Andric }
130