1b60736ecSDimitry Andric //===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
2b60736ecSDimitry Andric //
3b60736ecSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b60736ecSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5b60736ecSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b60736ecSDimitry Andric //
7b60736ecSDimitry Andric //===----------------------------------------------------------------------===//
8b60736ecSDimitry Andric //
9b60736ecSDimitry Andric // Utilities for generating tiled loops for matrix operations.
10b60736ecSDimitry Andric //
11b60736ecSDimitry Andric //===----------------------------------------------------------------------===//
12b60736ecSDimitry Andric
13b60736ecSDimitry Andric #include "llvm/Transforms/Utils/MatrixUtils.h"
14b60736ecSDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h"
15b60736ecSDimitry Andric #include "llvm/Analysis/LoopInfo.h"
16b60736ecSDimitry Andric #include "llvm/IR/BasicBlock.h"
17b60736ecSDimitry Andric #include "llvm/IR/Dominators.h"
18b60736ecSDimitry Andric #include "llvm/IR/IRBuilder.h"
19b60736ecSDimitry Andric #include "llvm/IR/Type.h"
20b60736ecSDimitry Andric
21b60736ecSDimitry Andric using namespace llvm;
22b60736ecSDimitry Andric
CreateLoop(BasicBlock * Preheader,BasicBlock * Exit,Value * Bound,Value * Step,StringRef Name,IRBuilderBase & B,DomTreeUpdater & DTU,Loop * L,LoopInfo & LI)23b60736ecSDimitry Andric BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
24b60736ecSDimitry Andric Value *Bound, Value *Step, StringRef Name,
25b60736ecSDimitry Andric IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
26b60736ecSDimitry Andric LoopInfo &LI) {
27b60736ecSDimitry Andric LLVMContext &Ctx = Preheader->getContext();
28b60736ecSDimitry Andric BasicBlock *Header = BasicBlock::Create(
29b60736ecSDimitry Andric Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit);
30b60736ecSDimitry Andric BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body",
31b60736ecSDimitry Andric Header->getParent(), Exit);
32b60736ecSDimitry Andric BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch",
33b60736ecSDimitry Andric Header->getParent(), Exit);
34b60736ecSDimitry Andric
35b60736ecSDimitry Andric Type *I32Ty = Type::getInt64Ty(Ctx);
36b60736ecSDimitry Andric BranchInst::Create(Body, Header);
37b60736ecSDimitry Andric BranchInst::Create(Latch, Body);
38b60736ecSDimitry Andric PHINode *IV =
39ac9a064cSDimitry Andric PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator()->getIterator());
40b60736ecSDimitry Andric IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader);
41b60736ecSDimitry Andric
42b60736ecSDimitry Andric B.SetInsertPoint(Latch);
43b60736ecSDimitry Andric Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
44b60736ecSDimitry Andric Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
45b60736ecSDimitry Andric BranchInst::Create(Header, Exit, Cond, Latch);
46b60736ecSDimitry Andric IV->addIncoming(Inc, Latch);
47b60736ecSDimitry Andric
48b60736ecSDimitry Andric BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
49b60736ecSDimitry Andric BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
50b60736ecSDimitry Andric PreheaderBr->setSuccessor(0, Header);
51b60736ecSDimitry Andric DTU.applyUpdatesPermissive({
52b60736ecSDimitry Andric {DominatorTree::Delete, Preheader, Tmp},
53b60736ecSDimitry Andric {DominatorTree::Insert, Header, Body},
54b60736ecSDimitry Andric {DominatorTree::Insert, Body, Latch},
55b60736ecSDimitry Andric {DominatorTree::Insert, Latch, Header},
56b60736ecSDimitry Andric {DominatorTree::Insert, Latch, Exit},
57b60736ecSDimitry Andric {DominatorTree::Insert, Preheader, Header},
58b60736ecSDimitry Andric });
59b60736ecSDimitry Andric
60b60736ecSDimitry Andric L->addBasicBlockToLoop(Header, LI);
61b60736ecSDimitry Andric L->addBasicBlockToLoop(Body, LI);
62b60736ecSDimitry Andric L->addBasicBlockToLoop(Latch, LI);
63b60736ecSDimitry Andric return Body;
64b60736ecSDimitry Andric }
65b60736ecSDimitry Andric
66b60736ecSDimitry Andric // Creates the following loop nest skeleton:
67b60736ecSDimitry Andric // for C = 0; C < NumColumns; C += TileSize
68b60736ecSDimitry Andric // for R = 0; R < NumRows; R += TileSize
69b60736ecSDimitry Andric // for K = 0; K < Inner ; K += TileSize
CreateTiledLoops(BasicBlock * Start,BasicBlock * End,IRBuilderBase & B,DomTreeUpdater & DTU,LoopInfo & LI)70b60736ecSDimitry Andric BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
71b60736ecSDimitry Andric IRBuilderBase &B, DomTreeUpdater &DTU,
72b60736ecSDimitry Andric LoopInfo &LI) {
7308e8dd7bSDimitry Andric Loop *ColumnLoopInfo = LI.AllocateLoop();
7408e8dd7bSDimitry Andric Loop *RowLoopInfo = LI.AllocateLoop();
7508e8dd7bSDimitry Andric Loop *KLoopInfo = LI.AllocateLoop();
7608e8dd7bSDimitry Andric RowLoopInfo->addChildLoop(KLoopInfo);
7708e8dd7bSDimitry Andric ColumnLoopInfo->addChildLoop(RowLoopInfo);
78b60736ecSDimitry Andric if (Loop *ParentL = LI.getLoopFor(Start))
7908e8dd7bSDimitry Andric ParentL->addChildLoop(ColumnLoopInfo);
80b60736ecSDimitry Andric else
8108e8dd7bSDimitry Andric LI.addTopLevelLoop(ColumnLoopInfo);
82b60736ecSDimitry Andric
83b60736ecSDimitry Andric BasicBlock *ColBody =
84b60736ecSDimitry Andric CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
8508e8dd7bSDimitry Andric "cols", B, DTU, ColumnLoopInfo, LI);
8608e8dd7bSDimitry Andric ColumnLoop.Latch = ColBody->getSingleSuccessor();
87b60736ecSDimitry Andric BasicBlock *RowBody =
8808e8dd7bSDimitry Andric CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows),
8908e8dd7bSDimitry Andric B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI);
9008e8dd7bSDimitry Andric RowLoop.Latch = RowBody->getSingleSuccessor();
91b60736ecSDimitry Andric
92b60736ecSDimitry Andric BasicBlock *InnerBody =
9308e8dd7bSDimitry Andric CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner),
9408e8dd7bSDimitry Andric B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI);
9508e8dd7bSDimitry Andric KLoop.Latch = InnerBody->getSingleSuccessor();
9608e8dd7bSDimitry Andric ColumnLoop.Header = ColBody->getSinglePredecessor();
9708e8dd7bSDimitry Andric RowLoop.Header = RowBody->getSinglePredecessor();
9808e8dd7bSDimitry Andric KLoop.Header = InnerBody->getSinglePredecessor();
9908e8dd7bSDimitry Andric RowLoop.Index = &*RowLoop.Header->begin();
10008e8dd7bSDimitry Andric ColumnLoop.Index = &*ColumnLoop.Header->begin();
10108e8dd7bSDimitry Andric KLoop.Index = &*KLoop.Header->begin();
102b60736ecSDimitry Andric
103b60736ecSDimitry Andric return InnerBody;
104b60736ecSDimitry Andric }
105