Skip to content

Commit 81ff2d3

Browse files
author
Krzysztof Parzyszek
committed
[DSE] Handle masked stores
1 parent a2cb544 commit 81ff2d3

File tree

2 files changed

+46
-19
lines changed

2 files changed

+46
-19
lines changed

llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ static bool hasAnalyzableMemoryWrite(Instruction *I,
234234
case Intrinsic::memset_element_unordered_atomic:
235235
case Intrinsic::init_trampoline:
236236
case Intrinsic::lifetime_end:
237+
case Intrinsic::masked_store:
237238
return true;
238239
}
239240
}
@@ -257,8 +258,8 @@ static bool hasAnalyzableMemoryWrite(Instruction *I,
257258
/// Return a Location stored to by the specified instruction. If isRemovable
258259
/// returns true, this function and getLocForRead completely describe the memory
259260
/// operations for this instruction.
260-
static MemoryLocation getLocForWrite(Instruction *Inst) {
261-
261+
static MemoryLocation getLocForWrite(Instruction *Inst,
262+
const TargetLibraryInfo &TLI) {
262263
if (StoreInst *SI = dyn_cast<StoreInst>(Inst))
263264
return MemoryLocation::get(SI);
264265

@@ -274,6 +275,8 @@ static MemoryLocation getLocForWrite(Instruction *Inst) {
274275
return MemoryLocation(); // Unhandled intrinsic.
275276
case Intrinsic::init_trampoline:
276277
return MemoryLocation(II->getArgOperand(0));
278+
case Intrinsic::masked_store:
279+
return MemoryLocation::getForArgument(II, 1, TLI);
277280
case Intrinsic::lifetime_end: {
278281
uint64_t Len = cast<ConstantInt>(II->getArgOperand(0))->getZExtValue();
279282
return MemoryLocation(II->getArgOperand(1), Len);
@@ -325,6 +328,7 @@ static bool isRemovable(Instruction *I) {
325328
case Intrinsic::memcpy_element_unordered_atomic:
326329
case Intrinsic::memmove_element_unordered_atomic:
327330
case Intrinsic::memset_element_unordered_atomic:
331+
case Intrinsic::masked_store:
328332
return true;
329333
}
330334
}
@@ -370,9 +374,10 @@ static bool isShortenableAtTheBeginning(Instruction *I) {
370374
}
371375

372376
/// Return the pointer that is being written to.
373-
static Value *getStoredPointerOperand(Instruction *I) {
377+
static Value *getStoredPointerOperand(Instruction *I,
378+
const TargetLibraryInfo &TLI) {
374379
//TODO: factor this to reuse getLocForWrite
375-
MemoryLocation Loc = getLocForWrite(I);
380+
MemoryLocation Loc = getLocForWrite(I, TLI);
376381
assert(Loc.Ptr &&
377382
"unable to find pointer written for analyzable instruction?");
378383
// TODO: most APIs don't expect const Value *
@@ -487,6 +492,24 @@ isOverwrite(const MemoryLocation &Later, const MemoryLocation &Earlier,
487492
return OW_MaybePartial;
488493
}
489494

495+
static OverwriteResult isMaskedStoreOverwrite(Instruction *Later,
496+
Instruction *Earlier) {
497+
auto *IIL = dyn_cast<IntrinsicInst>(Later);
498+
auto *IIE = dyn_cast<IntrinsicInst>(Earlier);
499+
if (IIL == nullptr || IIE == nullptr)
500+
return OW_Unknown;
501+
if (IIL->getIntrinsicID() != Intrinsic::masked_store ||
502+
IIE->getIntrinsicID() != Intrinsic::masked_store)
503+
return OW_Unknown;
504+
// Pointers.
505+
if (IIL->getArgOperand(1) != IIE->getArgOperand(1))
506+
return OW_Unknown;
507+
// Masks.
508+
if (IIL->getArgOperand(3) != IIE->getArgOperand(3))
509+
return OW_Unknown;
510+
return OW_Complete;
511+
}
512+
490513
/// Return 'OW_Complete' if a store to the 'Later' location completely
491514
/// overwrites a store to the 'Earlier' location, 'OW_End' if the end of the
492515
/// 'Earlier' location is completely overwritten by 'Later', 'OW_Begin' if the
@@ -796,7 +819,7 @@ static bool handleFree(CallInst *F, AliasAnalysis *AA,
796819
break;
797820

798821
Value *DepPointer =
799-
getUnderlyingObject(getStoredPointerOperand(Dependency));
822+
getUnderlyingObject(getStoredPointerOperand(Dependency, *TLI));
800823

801824
// Check for aliasing.
802825
if (!AA->isMustAlias(F->getArgOperand(0), DepPointer))
@@ -902,7 +925,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA,
902925
if (hasAnalyzableMemoryWrite(&*BBI, *TLI) && isRemovable(&*BBI)) {
903926
// See through pointer-to-pointer bitcasts
904927
SmallVector<const Value *, 4> Pointers;
905-
getUnderlyingObjects(getStoredPointerOperand(&*BBI), Pointers);
928+
getUnderlyingObjects(getStoredPointerOperand(&*BBI, *TLI), Pointers);
906929

907930
// Stores to stack values are valid candidates for removal.
908931
bool AllDead = true;
@@ -1119,11 +1142,12 @@ static bool tryToShortenBegin(Instruction *EarlierWrite,
11191142
}
11201143

11211144
static bool removePartiallyOverlappedStores(const DataLayout &DL,
1122-
InstOverlapIntervalsTy &IOL) {
1145+
InstOverlapIntervalsTy &IOL,
1146+
const TargetLibraryInfo &TLI) {
11231147
bool Changed = false;
11241148
for (auto OI : IOL) {
11251149
Instruction *EarlierWrite = OI.first;
1126-
MemoryLocation Loc = getLocForWrite(EarlierWrite);
1150+
MemoryLocation Loc = getLocForWrite(EarlierWrite, TLI);
11271151
assert(isRemovable(EarlierWrite) && "Expect only removable instruction");
11281152

11291153
const Value *Ptr = Loc.Ptr->stripPointerCasts();
@@ -1284,7 +1308,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
12841308
continue;
12851309

12861310
// Figure out what location is being stored to.
1287-
MemoryLocation Loc = getLocForWrite(Inst);
1311+
MemoryLocation Loc = getLocForWrite(Inst, *TLI);
12881312

12891313
// If we didn't get a useful location, fail.
12901314
if (!Loc.Ptr)
@@ -1308,7 +1332,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
13081332
Instruction *DepWrite = InstDep.getInst();
13091333
if (!hasAnalyzableMemoryWrite(DepWrite, *TLI))
13101334
break;
1311-
MemoryLocation DepLoc = getLocForWrite(DepWrite);
1335+
MemoryLocation DepLoc = getLocForWrite(DepWrite, *TLI);
13121336
// If we didn't get a useful location, or if it isn't a size, bail out.
13131337
if (!DepLoc.Ptr)
13141338
break;
@@ -1352,6 +1376,11 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
13521376
int64_t InstWriteOffset, DepWriteOffset;
13531377
OverwriteResult OR = isOverwrite(Loc, DepLoc, DL, *TLI, DepWriteOffset,
13541378
InstWriteOffset, *AA, BB.getParent());
1379+
if (OR == OW_Unknown) {
1380+
// isOverwrite punts on MemoryLocations with an imprecise size, such
1381+
// as masked stores. Handle this here, somwewhat inelegantly.
1382+
OR = isMaskedStoreOverwrite(Inst, DepWrite);
1383+
}
13551384
if (OR == OW_MaybePartial)
13561385
OR = isPartialOverwrite(Loc, DepLoc, DepWriteOffset, InstWriteOffset,
13571386
DepWrite, IOL);
@@ -1433,7 +1462,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,
14331462
}
14341463

14351464
if (EnablePartialOverwriteTracking)
1436-
MadeChange |= removePartiallyOverlappedStores(DL, IOL);
1465+
MadeChange |= removePartiallyOverlappedStores(DL, IOL, *TLI);
14371466

14381467
// If this block ends in a return, unwind, or unreachable, all allocas are
14391468
// dead at its end, which means stores to them are also dead.
@@ -2494,7 +2523,7 @@ bool eliminateDeadStoresMemorySSA(Function &F, AliasAnalysis &AA,
24942523

24952524
if (EnablePartialOverwriteTracking)
24962525
for (auto &KV : State.IOLs)
2497-
MadeChange |= removePartiallyOverlappedStores(State.DL, KV.second);
2526+
MadeChange |= removePartiallyOverlappedStores(State.DL, KV.second, TLI);
24982527

24992528
MadeChange |= State.eliminateDeadWritesAtEndOfFunction();
25002529
return MadeChange;

0 commit comments

Comments
 (0)