xref: /src/contrib/llvm-project/llvm/lib/Transforms/Utils/MatrixUtils.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1b60736ecSDimitry Andric //===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
2b60736ecSDimitry Andric //
3b60736ecSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b60736ecSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5b60736ecSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b60736ecSDimitry Andric //
7b60736ecSDimitry Andric //===----------------------------------------------------------------------===//
8b60736ecSDimitry Andric //
9b60736ecSDimitry Andric // Utilities for generating tiled loops for matrix operations.
10b60736ecSDimitry Andric //
11b60736ecSDimitry Andric //===----------------------------------------------------------------------===//
12b60736ecSDimitry Andric 
13b60736ecSDimitry Andric #include "llvm/Transforms/Utils/MatrixUtils.h"
14b60736ecSDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h"
15b60736ecSDimitry Andric #include "llvm/Analysis/LoopInfo.h"
16b60736ecSDimitry Andric #include "llvm/IR/BasicBlock.h"
17b60736ecSDimitry Andric #include "llvm/IR/Dominators.h"
18b60736ecSDimitry Andric #include "llvm/IR/IRBuilder.h"
19b60736ecSDimitry Andric #include "llvm/IR/Type.h"
20b60736ecSDimitry Andric 
21b60736ecSDimitry Andric using namespace llvm;
22b60736ecSDimitry Andric 
CreateLoop(BasicBlock * Preheader,BasicBlock * Exit,Value * Bound,Value * Step,StringRef Name,IRBuilderBase & B,DomTreeUpdater & DTU,Loop * L,LoopInfo & LI)23b60736ecSDimitry Andric BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
24b60736ecSDimitry Andric                                  Value *Bound, Value *Step, StringRef Name,
25b60736ecSDimitry Andric                                  IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
26b60736ecSDimitry Andric                                  LoopInfo &LI) {
27b60736ecSDimitry Andric   LLVMContext &Ctx = Preheader->getContext();
28b60736ecSDimitry Andric   BasicBlock *Header = BasicBlock::Create(
29b60736ecSDimitry Andric       Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit);
30b60736ecSDimitry Andric   BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body",
31b60736ecSDimitry Andric                                         Header->getParent(), Exit);
32b60736ecSDimitry Andric   BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch",
33b60736ecSDimitry Andric                                          Header->getParent(), Exit);
34b60736ecSDimitry Andric 
35b60736ecSDimitry Andric   Type *I32Ty = Type::getInt64Ty(Ctx);
36b60736ecSDimitry Andric   BranchInst::Create(Body, Header);
37b60736ecSDimitry Andric   BranchInst::Create(Latch, Body);
38b60736ecSDimitry Andric   PHINode *IV =
39ac9a064cSDimitry Andric       PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator()->getIterator());
40b60736ecSDimitry Andric   IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader);
41b60736ecSDimitry Andric 
42b60736ecSDimitry Andric   B.SetInsertPoint(Latch);
43b60736ecSDimitry Andric   Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
44b60736ecSDimitry Andric   Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
45b60736ecSDimitry Andric   BranchInst::Create(Header, Exit, Cond, Latch);
46b60736ecSDimitry Andric   IV->addIncoming(Inc, Latch);
47b60736ecSDimitry Andric 
48b60736ecSDimitry Andric   BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
49b60736ecSDimitry Andric   BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
50b60736ecSDimitry Andric   PreheaderBr->setSuccessor(0, Header);
51b60736ecSDimitry Andric   DTU.applyUpdatesPermissive({
52b60736ecSDimitry Andric       {DominatorTree::Delete, Preheader, Tmp},
53b60736ecSDimitry Andric       {DominatorTree::Insert, Header, Body},
54b60736ecSDimitry Andric       {DominatorTree::Insert, Body, Latch},
55b60736ecSDimitry Andric       {DominatorTree::Insert, Latch, Header},
56b60736ecSDimitry Andric       {DominatorTree::Insert, Latch, Exit},
57b60736ecSDimitry Andric       {DominatorTree::Insert, Preheader, Header},
58b60736ecSDimitry Andric   });
59b60736ecSDimitry Andric 
60b60736ecSDimitry Andric   L->addBasicBlockToLoop(Header, LI);
61b60736ecSDimitry Andric   L->addBasicBlockToLoop(Body, LI);
62b60736ecSDimitry Andric   L->addBasicBlockToLoop(Latch, LI);
63b60736ecSDimitry Andric   return Body;
64b60736ecSDimitry Andric }
65b60736ecSDimitry Andric 
66b60736ecSDimitry Andric // Creates the following loop nest skeleton:
67b60736ecSDimitry Andric //  for C = 0; C < NumColumns; C += TileSize
68b60736ecSDimitry Andric //    for R = 0; R < NumRows; R += TileSize
69b60736ecSDimitry Andric //      for K = 0; K < Inner ; K += TileSize
CreateTiledLoops(BasicBlock * Start,BasicBlock * End,IRBuilderBase & B,DomTreeUpdater & DTU,LoopInfo & LI)70b60736ecSDimitry Andric BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
71b60736ecSDimitry Andric                                        IRBuilderBase &B, DomTreeUpdater &DTU,
72b60736ecSDimitry Andric                                        LoopInfo &LI) {
7308e8dd7bSDimitry Andric   Loop *ColumnLoopInfo = LI.AllocateLoop();
7408e8dd7bSDimitry Andric   Loop *RowLoopInfo = LI.AllocateLoop();
7508e8dd7bSDimitry Andric   Loop *KLoopInfo = LI.AllocateLoop();
7608e8dd7bSDimitry Andric   RowLoopInfo->addChildLoop(KLoopInfo);
7708e8dd7bSDimitry Andric   ColumnLoopInfo->addChildLoop(RowLoopInfo);
78b60736ecSDimitry Andric   if (Loop *ParentL = LI.getLoopFor(Start))
7908e8dd7bSDimitry Andric     ParentL->addChildLoop(ColumnLoopInfo);
80b60736ecSDimitry Andric   else
8108e8dd7bSDimitry Andric     LI.addTopLevelLoop(ColumnLoopInfo);
82b60736ecSDimitry Andric 
83b60736ecSDimitry Andric   BasicBlock *ColBody =
84b60736ecSDimitry Andric       CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
8508e8dd7bSDimitry Andric                  "cols", B, DTU, ColumnLoopInfo, LI);
8608e8dd7bSDimitry Andric   ColumnLoop.Latch = ColBody->getSingleSuccessor();
87b60736ecSDimitry Andric   BasicBlock *RowBody =
8808e8dd7bSDimitry Andric       CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows),
8908e8dd7bSDimitry Andric                  B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI);
9008e8dd7bSDimitry Andric   RowLoop.Latch = RowBody->getSingleSuccessor();
91b60736ecSDimitry Andric 
92b60736ecSDimitry Andric   BasicBlock *InnerBody =
9308e8dd7bSDimitry Andric       CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner),
9408e8dd7bSDimitry Andric                  B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI);
9508e8dd7bSDimitry Andric   KLoop.Latch = InnerBody->getSingleSuccessor();
9608e8dd7bSDimitry Andric   ColumnLoop.Header = ColBody->getSinglePredecessor();
9708e8dd7bSDimitry Andric   RowLoop.Header = RowBody->getSinglePredecessor();
9808e8dd7bSDimitry Andric   KLoop.Header = InnerBody->getSinglePredecessor();
9908e8dd7bSDimitry Andric   RowLoop.Index = &*RowLoop.Header->begin();
10008e8dd7bSDimitry Andric   ColumnLoop.Index = &*ColumnLoop.Header->begin();
10108e8dd7bSDimitry Andric   KLoop.Index = &*KLoop.Header->begin();
102b60736ecSDimitry Andric 
103b60736ecSDimitry Andric   return InnerBody;
104b60736ecSDimitry Andric }
105