|
| 1 | +//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===---------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#include "DXILDataScalarization.h" |
| 10 | +#include "DirectX.h" |
| 11 | +#include "llvm/ADT/PostOrderIterator.h" |
| 12 | +#include "llvm/ADT/STLExtras.h" |
| 13 | +#include "llvm/Analysis/DXILResource.h" |
| 14 | +#include "llvm/IR/GlobalVariable.h" |
| 15 | +#include "llvm/IR/IRBuilder.h" |
| 16 | +#include "llvm/IR/InstVisitor.h" |
| 17 | +#include "llvm/IR/Module.h" |
| 18 | +#include "llvm/IR/Operator.h" |
| 19 | +#include "llvm/IR/PassManager.h" |
| 20 | +#include "llvm/IR/ReplaceConstant.h" |
| 21 | +#include "llvm/IR/Type.h" |
| 22 | +#include "llvm/Transforms/Utils/Cloning.h" |
| 23 | +#include "llvm/Transforms/Utils/Local.h" |
| 24 | + |
| 25 | +#define DEBUG_TYPE "dxil-data-scalarization" |
| 26 | +static const int MaxVecSize = 4; |
| 27 | + |
| 28 | +using namespace llvm; |
| 29 | + |
| 30 | +class DXILDataScalarizationLegacy : public ModulePass { |
| 31 | + |
| 32 | +public: |
| 33 | + bool runOnModule(Module &M) override; |
| 34 | + DXILDataScalarizationLegacy() : ModulePass(ID) {} |
| 35 | + |
| 36 | + void getAnalysisUsage(AnalysisUsage &AU) const override; |
| 37 | + static char ID; // Pass identification. |
| 38 | +}; |
| 39 | + |
| 40 | +static bool findAndReplaceVectors(Module &M); |
| 41 | + |
| 42 | +class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> { |
| 43 | +public: |
| 44 | + DataScalarizerVisitor() : GlobalMap() {} |
| 45 | + bool visit(Function &F); |
| 46 | + // InstVisitor methods. They return true if the instruction was scalarized, |
| 47 | + // false if nothing changed. |
| 48 | + bool visitInstruction(Instruction &I) { return false; } |
| 49 | + bool visitSelectInst(SelectInst &SI) { return false; } |
| 50 | + bool visitICmpInst(ICmpInst &ICI) { return false; } |
| 51 | + bool visitFCmpInst(FCmpInst &FCI) { return false; } |
| 52 | + bool visitUnaryOperator(UnaryOperator &UO) { return false; } |
| 53 | + bool visitBinaryOperator(BinaryOperator &BO) { return false; } |
| 54 | + bool visitGetElementPtrInst(GetElementPtrInst &GEPI); |
| 55 | + bool visitCastInst(CastInst &CI) { return false; } |
| 56 | + bool visitBitCastInst(BitCastInst &BCI) { return false; } |
| 57 | + bool visitInsertElementInst(InsertElementInst &IEI) { return false; } |
| 58 | + bool visitExtractElementInst(ExtractElementInst &EEI) { return false; } |
| 59 | + bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; } |
| 60 | + bool visitPHINode(PHINode &PHI) { return false; } |
| 61 | + bool visitLoadInst(LoadInst &LI); |
| 62 | + bool visitStoreInst(StoreInst &SI); |
| 63 | + bool visitCallInst(CallInst &ICI) { return false; } |
| 64 | + bool visitFreezeInst(FreezeInst &FI) { return false; } |
| 65 | + friend bool findAndReplaceVectors(llvm::Module &M); |
| 66 | + |
| 67 | +private: |
| 68 | + GlobalVariable *lookupReplacementGlobal(Value *CurrOperand); |
| 69 | + DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap; |
| 70 | + SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs; |
| 71 | + bool finish(); |
| 72 | +}; |
| 73 | + |
| 74 | +bool DataScalarizerVisitor::visit(Function &F) { |
| 75 | + assert(!GlobalMap.empty()); |
| 76 | + ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock()); |
| 77 | + for (BasicBlock *BB : RPOT) { |
| 78 | + for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { |
| 79 | + Instruction *I = &*II; |
| 80 | + bool Done = InstVisitor::visit(I); |
| 81 | + ++II; |
| 82 | + if (Done && I->getType()->isVoidTy()) |
| 83 | + I->eraseFromParent(); |
| 84 | + } |
| 85 | + } |
| 86 | + return finish(); |
| 87 | +} |
| 88 | + |
| 89 | +bool DataScalarizerVisitor::finish() { |
| 90 | + RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); |
| 91 | + return true; |
| 92 | +} |
| 93 | + |
| 94 | +GlobalVariable * |
| 95 | +DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) { |
| 96 | + if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) { |
| 97 | + auto It = GlobalMap.find(OldGlobal); |
| 98 | + if (It != GlobalMap.end()) { |
| 99 | + return It->second; // Found, return the new global |
| 100 | + } |
| 101 | + } |
| 102 | + return nullptr; // Not found |
| 103 | +} |
| 104 | + |
| 105 | +bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) { |
| 106 | + unsigned NumOperands = LI.getNumOperands(); |
| 107 | + for (unsigned I = 0; I < NumOperands; ++I) { |
| 108 | + Value *CurrOpperand = LI.getOperand(I); |
| 109 | + if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) |
| 110 | + LI.setOperand(I, NewGlobal); |
| 111 | + } |
| 112 | + return false; |
| 113 | +} |
| 114 | + |
| 115 | +bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) { |
| 116 | + unsigned NumOperands = SI.getNumOperands(); |
| 117 | + for (unsigned I = 0; I < NumOperands; ++I) { |
| 118 | + Value *CurrOpperand = SI.getOperand(I); |
| 119 | + if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) { |
| 120 | + SI.setOperand(I, NewGlobal); |
| 121 | + } |
| 122 | + } |
| 123 | + return false; |
| 124 | +} |
| 125 | + |
| 126 | +bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { |
| 127 | + unsigned NumOperands = GEPI.getNumOperands(); |
| 128 | + for (unsigned I = 0; I < NumOperands; ++I) { |
| 129 | + Value *CurrOpperand = GEPI.getOperand(I); |
| 130 | + GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand); |
| 131 | + if (!NewGlobal) |
| 132 | + continue; |
| 133 | + IRBuilder<> Builder(&GEPI); |
| 134 | + |
| 135 | + SmallVector<Value *, MaxVecSize> Indices; |
| 136 | + for (auto &Index : GEPI.indices()) |
| 137 | + Indices.push_back(Index); |
| 138 | + |
| 139 | + Value *NewGEP = |
| 140 | + Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices); |
| 141 | + |
| 142 | + GEPI.replaceAllUsesWith(NewGEP); |
| 143 | + PotentiallyDeadInstrs.emplace_back(&GEPI); |
| 144 | + } |
| 145 | + return true; |
| 146 | +} |
| 147 | + |
| 148 | +// Recursively Creates and Array like version of the given vector like type. |
| 149 | +static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) { |
| 150 | + if (auto *VecTy = dyn_cast<VectorType>(T)) |
| 151 | + return ArrayType::get(VecTy->getElementType(), |
| 152 | + dyn_cast<FixedVectorType>(VecTy)->getNumElements()); |
| 153 | + if (auto *ArrayTy = dyn_cast<ArrayType>(T)) { |
| 154 | + Type *NewElementType = |
| 155 | + replaceVectorWithArray(ArrayTy->getElementType(), Ctx); |
| 156 | + return ArrayType::get(NewElementType, ArrayTy->getNumElements()); |
| 157 | + } |
| 158 | + // If it's not a vector or array, return the original type. |
| 159 | + return T; |
| 160 | +} |
| 161 | + |
| 162 | +Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType, |
| 163 | + LLVMContext &Ctx) { |
| 164 | + // Handle ConstantAggregateZero (zero-initialized constants) |
| 165 | + if (isa<ConstantAggregateZero>(Init)) { |
| 166 | + return ConstantAggregateZero::get(NewType); |
| 167 | + } |
| 168 | + |
| 169 | + // Handle UndefValue (undefined constants) |
| 170 | + if (isa<UndefValue>(Init)) { |
| 171 | + return UndefValue::get(NewType); |
| 172 | + } |
| 173 | + |
| 174 | + // Handle vector to array transformation |
| 175 | + if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) { |
| 176 | + // Convert vector initializer to array initializer |
| 177 | + SmallVector<Constant *, MaxVecSize> ArrayElements; |
| 178 | + if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) { |
| 179 | + for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I) |
| 180 | + ArrayElements.push_back(ConstVecInit->getOperand(I)); |
| 181 | + } else if (ConstantDataVector *ConstDataVecInit = |
| 182 | + llvm::dyn_cast<llvm::ConstantDataVector>(Init)) { |
| 183 | + for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I) |
| 184 | + ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I)); |
| 185 | + } else { |
| 186 | + assert(false && "Expected a ConstantVector or ConstantDataVector for " |
| 187 | + "vector initializer!"); |
| 188 | + } |
| 189 | + |
| 190 | + return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements); |
| 191 | + } |
| 192 | + |
| 193 | + // Handle array of vectors transformation |
| 194 | + if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) { |
| 195 | + auto *ArrayInit = dyn_cast<ConstantArray>(Init); |
| 196 | + assert(ArrayInit && "Expected a ConstantArray for array initializer!"); |
| 197 | + |
| 198 | + SmallVector<Constant *, MaxVecSize> NewArrayElements; |
| 199 | + for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) { |
| 200 | + // Recursively transform array elements |
| 201 | + Constant *NewElemInit = transformInitializer( |
| 202 | + ArrayInit->getOperand(I), ArrayTy->getElementType(), |
| 203 | + cast<ArrayType>(NewType)->getElementType(), Ctx); |
| 204 | + NewArrayElements.push_back(NewElemInit); |
| 205 | + } |
| 206 | + |
| 207 | + return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements); |
| 208 | + } |
| 209 | + |
| 210 | + // If not a vector or array, return the original initializer |
| 211 | + return Init; |
| 212 | +} |
| 213 | + |
| 214 | +static bool findAndReplaceVectors(Module &M) { |
| 215 | + bool MadeChange = false; |
| 216 | + LLVMContext &Ctx = M.getContext(); |
| 217 | + IRBuilder<> Builder(Ctx); |
| 218 | + DataScalarizerVisitor Impl; |
| 219 | + for (GlobalVariable &G : M.globals()) { |
| 220 | + Type *OrigType = G.getValueType(); |
| 221 | + |
| 222 | + Type *NewType = replaceVectorWithArray(OrigType, Ctx); |
| 223 | + if (OrigType != NewType) { |
| 224 | + // Create a new global variable with the updated type |
| 225 | + // Note: Initializer is set via transformInitializer |
| 226 | + GlobalVariable *NewGlobal = new GlobalVariable( |
| 227 | + M, NewType, G.isConstant(), G.getLinkage(), |
| 228 | + /*Initializer=*/nullptr, G.getName() + ".scalarized", &G, |
| 229 | + G.getThreadLocalMode(), G.getAddressSpace(), |
| 230 | + G.isExternallyInitialized()); |
| 231 | + |
| 232 | + // Copy relevant attributes |
| 233 | + NewGlobal->setUnnamedAddr(G.getUnnamedAddr()); |
| 234 | + if (G.getAlignment() > 0) { |
| 235 | + NewGlobal->setAlignment(G.getAlign()); |
| 236 | + } |
| 237 | + |
| 238 | + if (G.hasInitializer()) { |
| 239 | + Constant *Init = G.getInitializer(); |
| 240 | + Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx); |
| 241 | + NewGlobal->setInitializer(NewInit); |
| 242 | + } |
| 243 | + |
| 244 | + // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes |
| 245 | + // type equality. Instead we will use the visitor pattern. |
| 246 | + Impl.GlobalMap[&G] = NewGlobal; |
| 247 | + for (User *U : make_early_inc_range(G.users())) { |
| 248 | + if (isa<ConstantExpr>(U) && isa<Operator>(U)) { |
| 249 | + ConstantExpr *CE = cast<ConstantExpr>(U); |
| 250 | + convertUsersOfConstantsToInstructions(CE, |
| 251 | + /*RestrictToFunc=*/nullptr, |
| 252 | + /*RemoveDeadConstants=*/false, |
| 253 | + /*IncludeSelf=*/true); |
| 254 | + } |
| 255 | + if (isa<Instruction>(U)) { |
| 256 | + Instruction *Inst = cast<Instruction>(U); |
| 257 | + Function *F = Inst->getFunction(); |
| 258 | + if (F) |
| 259 | + Impl.visit(*F); |
| 260 | + } |
| 261 | + } |
| 262 | + } |
| 263 | + } |
| 264 | + |
| 265 | + // Remove the old globals after the iteration |
| 266 | + for (auto &[Old, New] : Impl.GlobalMap) { |
| 267 | + Old->eraseFromParent(); |
| 268 | + MadeChange = true; |
| 269 | + } |
| 270 | + return MadeChange; |
| 271 | +} |
| 272 | + |
| 273 | +PreservedAnalyses DXILDataScalarization::run(Module &M, |
| 274 | + ModuleAnalysisManager &) { |
| 275 | + bool MadeChanges = findAndReplaceVectors(M); |
| 276 | + if (!MadeChanges) |
| 277 | + return PreservedAnalyses::all(); |
| 278 | + PreservedAnalyses PA; |
| 279 | + PA.preserve<DXILResourceAnalysis>(); |
| 280 | + return PA; |
| 281 | +} |
| 282 | + |
| 283 | +bool DXILDataScalarizationLegacy::runOnModule(Module &M) { |
| 284 | + return findAndReplaceVectors(M); |
| 285 | +} |
| 286 | + |
| 287 | +void DXILDataScalarizationLegacy::getAnalysisUsage(AnalysisUsage &AU) const { |
| 288 | + AU.addPreserved<DXILResourceWrapperPass>(); |
| 289 | +} |
| 290 | + |
| 291 | +char DXILDataScalarizationLegacy::ID = 0; |
| 292 | + |
| 293 | +INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE, |
| 294 | + "DXIL Data Scalarization", false, false) |
| 295 | +INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE, |
| 296 | + "DXIL Data Scalarization", false, false) |
| 297 | + |
| 298 | +ModulePass *llvm::createDXILDataScalarizationLegacyPass() { |
| 299 | + return new DXILDataScalarizationLegacy(); |
| 300 | +} |
0 commit comments