1 //=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// Post-legalization combines on generic MachineInstrs. 11 /// 12 /// The combines here must preserve instruction legality. 13 /// 14 /// Lowering combines (e.g. pseudo matching) should be handled by 15 /// AArch64PostLegalizerLowering. 16 /// 17 /// Combines which don't rely on instruction legality should go in the 18 /// AArch64PreLegalizerCombiner. 19 /// 20 //===----------------------------------------------------------------------===// 21 22 #include "AArch64TargetMachine.h" 23 #include "llvm/CodeGen/GlobalISel/Combiner.h" 24 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" 25 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h" 26 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" 27 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" 28 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 29 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 30 #include "llvm/CodeGen/GlobalISel/Utils.h" 31 #include "llvm/CodeGen/MachineDominators.h" 32 #include "llvm/CodeGen/MachineFunctionPass.h" 33 #include "llvm/CodeGen/MachineRegisterInfo.h" 34 #include "llvm/CodeGen/TargetOpcodes.h" 35 #include "llvm/CodeGen/TargetPassConfig.h" 36 #include "llvm/Support/Debug.h" 37 38 #define DEBUG_TYPE "aarch64-postlegalizer-combiner" 39 40 using namespace llvm; 41 using namespace MIPatternMatch; 42 43 /// This combine tries do what performExtractVectorEltCombine does in SDAG. 44 /// Rewrite for pairwise fadd pattern 45 /// (s32 (g_extract_vector_elt 46 /// (g_fadd (vXs32 Other) 47 /// (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0)) 48 /// -> 49 /// (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0) 50 /// (g_extract_vector_elt (vXs32 Other) 1)) 51 bool matchExtractVecEltPairwiseAdd( 52 MachineInstr &MI, MachineRegisterInfo &MRI, 53 std::tuple<unsigned, LLT, Register> &MatchInfo) { 54 Register Src1 = MI.getOperand(1).getReg(); 55 Register Src2 = MI.getOperand(2).getReg(); 56 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 57 58 auto Cst = getIConstantVRegValWithLookThrough(Src2, MRI); 59 if (!Cst || Cst->Value != 0) 60 return false; 61 // SDAG also checks for FullFP16, but this looks to be beneficial anyway. 62 63 // Now check for an fadd operation. TODO: expand this for integer add? 64 auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI); 65 if (!FAddMI) 66 return false; 67 68 // If we add support for integer add, must restrict these types to just s64. 69 unsigned DstSize = DstTy.getSizeInBits(); 70 if (DstSize != 16 && DstSize != 32 && DstSize != 64) 71 return false; 72 73 Register Src1Op1 = FAddMI->getOperand(1).getReg(); 74 Register Src1Op2 = FAddMI->getOperand(2).getReg(); 75 MachineInstr *Shuffle = 76 getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI); 77 MachineInstr *Other = MRI.getVRegDef(Src1Op1); 78 if (!Shuffle) { 79 Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI); 80 Other = MRI.getVRegDef(Src1Op2); 81 } 82 83 // We're looking for a shuffle that moves the second element to index 0. 84 if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 && 85 Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) { 86 std::get<0>(MatchInfo) = TargetOpcode::G_FADD; 87 std::get<1>(MatchInfo) = DstTy; 88 std::get<2>(MatchInfo) = Other->getOperand(0).getReg(); 89 return true; 90 } 91 return false; 92 } 93 94 bool applyExtractVecEltPairwiseAdd( 95 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, 96 std::tuple<unsigned, LLT, Register> &MatchInfo) { 97 unsigned Opc = std::get<0>(MatchInfo); 98 assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!"); 99 // We want to generate two extracts of elements 0 and 1, and add them. 100 LLT Ty = std::get<1>(MatchInfo); 101 Register Src = std::get<2>(MatchInfo); 102 LLT s64 = LLT::scalar(64); 103 B.setInstrAndDebugLoc(MI); 104 auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0)); 105 auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1)); 106 B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1}); 107 MI.eraseFromParent(); 108 return true; 109 } 110 111 static bool isSignExtended(Register R, MachineRegisterInfo &MRI) { 112 // TODO: check if extended build vector as well. 113 unsigned Opc = MRI.getVRegDef(R)->getOpcode(); 114 return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG; 115 } 116 117 static bool isZeroExtended(Register R, MachineRegisterInfo &MRI) { 118 // TODO: check if extended build vector as well. 119 return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT; 120 } 121 122 bool matchAArch64MulConstCombine( 123 MachineInstr &MI, MachineRegisterInfo &MRI, 124 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { 125 assert(MI.getOpcode() == TargetOpcode::G_MUL); 126 Register LHS = MI.getOperand(1).getReg(); 127 Register RHS = MI.getOperand(2).getReg(); 128 Register Dst = MI.getOperand(0).getReg(); 129 const LLT Ty = MRI.getType(LHS); 130 131 // The below optimizations require a constant RHS. 132 auto Const = getIConstantVRegValWithLookThrough(RHS, MRI); 133 if (!Const) 134 return false; 135 136 const APInt ConstValue = Const->Value.sextOrSelf(Ty.getSizeInBits()); 137 // The following code is ported from AArch64ISelLowering. 138 // Multiplication of a power of two plus/minus one can be done more 139 // cheaply as as shift+add/sub. For now, this is true unilaterally. If 140 // future CPUs have a cheaper MADD instruction, this may need to be 141 // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and 142 // 64-bit is 5 cycles, so this is always a win. 143 // More aggressively, some multiplications N0 * C can be lowered to 144 // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M, 145 // e.g. 6=3*2=(2+1)*2. 146 // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45 147 // which equals to (1+2)*16-(1+2). 148 // TrailingZeroes is used to test if the mul can be lowered to 149 // shift+add+shift. 150 unsigned TrailingZeroes = ConstValue.countTrailingZeros(); 151 if (TrailingZeroes) { 152 // Conservatively do not lower to shift+add+shift if the mul might be 153 // folded into smul or umul. 154 if (MRI.hasOneNonDBGUse(LHS) && 155 (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI))) 156 return false; 157 // Conservatively do not lower to shift+add+shift if the mul might be 158 // folded into madd or msub. 159 if (MRI.hasOneNonDBGUse(Dst)) { 160 MachineInstr &UseMI = *MRI.use_instr_begin(Dst); 161 unsigned UseOpc = UseMI.getOpcode(); 162 if (UseOpc == TargetOpcode::G_ADD || UseOpc == TargetOpcode::G_PTR_ADD || 163 UseOpc == TargetOpcode::G_SUB) 164 return false; 165 } 166 } 167 // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub 168 // and shift+add+shift. 169 APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes); 170 171 unsigned ShiftAmt, AddSubOpc; 172 // Is the shifted value the LHS operand of the add/sub? 173 bool ShiftValUseIsLHS = true; 174 // Do we need to negate the result? 175 bool NegateResult = false; 176 177 if (ConstValue.isNonNegative()) { 178 // (mul x, 2^N + 1) => (add (shl x, N), x) 179 // (mul x, 2^N - 1) => (sub (shl x, N), x) 180 // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M) 181 APInt SCVMinus1 = ShiftedConstValue - 1; 182 APInt CVPlus1 = ConstValue + 1; 183 if (SCVMinus1.isPowerOf2()) { 184 ShiftAmt = SCVMinus1.logBase2(); 185 AddSubOpc = TargetOpcode::G_ADD; 186 } else if (CVPlus1.isPowerOf2()) { 187 ShiftAmt = CVPlus1.logBase2(); 188 AddSubOpc = TargetOpcode::G_SUB; 189 } else 190 return false; 191 } else { 192 // (mul x, -(2^N - 1)) => (sub x, (shl x, N)) 193 // (mul x, -(2^N + 1)) => - (add (shl x, N), x) 194 APInt CVNegPlus1 = -ConstValue + 1; 195 APInt CVNegMinus1 = -ConstValue - 1; 196 if (CVNegPlus1.isPowerOf2()) { 197 ShiftAmt = CVNegPlus1.logBase2(); 198 AddSubOpc = TargetOpcode::G_SUB; 199 ShiftValUseIsLHS = false; 200 } else if (CVNegMinus1.isPowerOf2()) { 201 ShiftAmt = CVNegMinus1.logBase2(); 202 AddSubOpc = TargetOpcode::G_ADD; 203 NegateResult = true; 204 } else 205 return false; 206 } 207 208 if (NegateResult && TrailingZeroes) 209 return false; 210 211 ApplyFn = [=](MachineIRBuilder &B, Register DstReg) { 212 auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt); 213 auto ShiftedVal = B.buildShl(Ty, LHS, Shift); 214 215 Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS; 216 Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0); 217 auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS}); 218 assert(!(NegateResult && TrailingZeroes) && 219 "NegateResult and TrailingZeroes cannot both be true for now."); 220 // Negate the result. 221 if (NegateResult) { 222 B.buildSub(DstReg, B.buildConstant(Ty, 0), Res); 223 return; 224 } 225 // Shift the result. 226 if (TrailingZeroes) { 227 B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes)); 228 return; 229 } 230 B.buildCopy(DstReg, Res.getReg(0)); 231 }; 232 return true; 233 } 234 235 bool applyAArch64MulConstCombine( 236 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, 237 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { 238 B.setInstrAndDebugLoc(MI); 239 ApplyFn(B, MI.getOperand(0).getReg()); 240 MI.eraseFromParent(); 241 return true; 242 } 243 244 /// Try to fold a G_MERGE_VALUES of 2 s32 sources, where the second source 245 /// is a zero, into a G_ZEXT of the first. 246 bool matchFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI) { 247 auto &Merge = cast<GMerge>(MI); 248 LLT SrcTy = MRI.getType(Merge.getSourceReg(0)); 249 if (SrcTy != LLT::scalar(32) || Merge.getNumSources() != 2) 250 return false; 251 return mi_match(Merge.getSourceReg(1), MRI, m_SpecificICst(0)); 252 } 253 254 void applyFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI, 255 MachineIRBuilder &B, GISelChangeObserver &Observer) { 256 // Mutate %d(s64) = G_MERGE_VALUES %a(s32), 0(s32) 257 // -> 258 // %d(s64) = G_ZEXT %a(s32) 259 Observer.changingInstr(MI); 260 MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT)); 261 MI.RemoveOperand(2); 262 Observer.changedInstr(MI); 263 } 264 265 /// \returns True if a G_ANYEXT instruction \p MI should be mutated to a G_ZEXT 266 /// instruction. 267 static bool matchMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI) { 268 // If this is coming from a scalar compare then we can use a G_ZEXT instead of 269 // a G_ANYEXT: 270 // 271 // %cmp:_(s32) = G_[I|F]CMP ... <-- produces 0/1. 272 // %ext:_(s64) = G_ANYEXT %cmp(s32) 273 // 274 // By doing this, we can leverage more KnownBits combines. 275 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT); 276 Register Dst = MI.getOperand(0).getReg(); 277 Register Src = MI.getOperand(1).getReg(); 278 return MRI.getType(Dst).isScalar() && 279 mi_match(Src, MRI, 280 m_any_of(m_GICmp(m_Pred(), m_Reg(), m_Reg()), 281 m_GFCmp(m_Pred(), m_Reg(), m_Reg()))); 282 } 283 284 static void applyMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI, 285 MachineIRBuilder &B, 286 GISelChangeObserver &Observer) { 287 Observer.changingInstr(MI); 288 MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT)); 289 Observer.changedInstr(MI); 290 } 291 292 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS 293 #include "AArch64GenPostLegalizeGICombiner.inc" 294 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS 295 296 namespace { 297 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H 298 #include "AArch64GenPostLegalizeGICombiner.inc" 299 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H 300 301 class AArch64PostLegalizerCombinerInfo : public CombinerInfo { 302 GISelKnownBits *KB; 303 MachineDominatorTree *MDT; 304 305 public: 306 AArch64GenPostLegalizerCombinerHelperRuleConfig GeneratedRuleCfg; 307 308 AArch64PostLegalizerCombinerInfo(bool EnableOpt, bool OptSize, bool MinSize, 309 GISelKnownBits *KB, 310 MachineDominatorTree *MDT) 311 : CombinerInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false, 312 /*LegalizerInfo*/ nullptr, EnableOpt, OptSize, MinSize), 313 KB(KB), MDT(MDT) { 314 if (!GeneratedRuleCfg.parseCommandLineOption()) 315 report_fatal_error("Invalid rule identifier"); 316 } 317 318 virtual bool combine(GISelChangeObserver &Observer, MachineInstr &MI, 319 MachineIRBuilder &B) const override; 320 }; 321 322 bool AArch64PostLegalizerCombinerInfo::combine(GISelChangeObserver &Observer, 323 MachineInstr &MI, 324 MachineIRBuilder &B) const { 325 const auto *LI = 326 MI.getParent()->getParent()->getSubtarget().getLegalizerInfo(); 327 CombinerHelper Helper(Observer, B, KB, MDT, LI); 328 AArch64GenPostLegalizerCombinerHelper Generated(GeneratedRuleCfg); 329 return Generated.tryCombineAll(Observer, MI, B, Helper); 330 } 331 332 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP 333 #include "AArch64GenPostLegalizeGICombiner.inc" 334 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP 335 336 class AArch64PostLegalizerCombiner : public MachineFunctionPass { 337 public: 338 static char ID; 339 340 AArch64PostLegalizerCombiner(bool IsOptNone = false); 341 342 StringRef getPassName() const override { 343 return "AArch64PostLegalizerCombiner"; 344 } 345 346 bool runOnMachineFunction(MachineFunction &MF) override; 347 void getAnalysisUsage(AnalysisUsage &AU) const override; 348 349 private: 350 bool IsOptNone; 351 }; 352 } // end anonymous namespace 353 354 void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const { 355 AU.addRequired<TargetPassConfig>(); 356 AU.setPreservesCFG(); 357 getSelectionDAGFallbackAnalysisUsage(AU); 358 AU.addRequired<GISelKnownBitsAnalysis>(); 359 AU.addPreserved<GISelKnownBitsAnalysis>(); 360 if (!IsOptNone) { 361 AU.addRequired<MachineDominatorTree>(); 362 AU.addPreserved<MachineDominatorTree>(); 363 AU.addRequired<GISelCSEAnalysisWrapperPass>(); 364 AU.addPreserved<GISelCSEAnalysisWrapperPass>(); 365 } 366 MachineFunctionPass::getAnalysisUsage(AU); 367 } 368 369 AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone) 370 : MachineFunctionPass(ID), IsOptNone(IsOptNone) { 371 initializeAArch64PostLegalizerCombinerPass(*PassRegistry::getPassRegistry()); 372 } 373 374 bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) { 375 if (MF.getProperties().hasProperty( 376 MachineFunctionProperties::Property::FailedISel)) 377 return false; 378 assert(MF.getProperties().hasProperty( 379 MachineFunctionProperties::Property::Legalized) && 380 "Expected a legalized function?"); 381 auto *TPC = &getAnalysis<TargetPassConfig>(); 382 const Function &F = MF.getFunction(); 383 bool EnableOpt = 384 MF.getTarget().getOptLevel() != CodeGenOpt::None && !skipFunction(F); 385 GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF); 386 MachineDominatorTree *MDT = 387 IsOptNone ? nullptr : &getAnalysis<MachineDominatorTree>(); 388 AArch64PostLegalizerCombinerInfo PCInfo(EnableOpt, F.hasOptSize(), 389 F.hasMinSize(), KB, MDT); 390 GISelCSEAnalysisWrapper &Wrapper = 391 getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper(); 392 auto *CSEInfo = &Wrapper.get(TPC->getCSEConfig()); 393 Combiner C(PCInfo, TPC); 394 return C.combineMachineInstrs(MF, CSEInfo); 395 } 396 397 char AArch64PostLegalizerCombiner::ID = 0; 398 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE, 399 "Combine AArch64 MachineInstrs after legalization", false, 400 false) 401 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 402 INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis) 403 INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE, 404 "Combine AArch64 MachineInstrs after legalization", false, 405 false) 406 407 namespace llvm { 408 FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) { 409 return new AArch64PostLegalizerCombiner(IsOptNone); 410 } 411 } // end namespace llvm 412