Skip to content

Commit b41987b

Browse files
authored
[SandboxVec][DAG] Fix MemDGNode chain maintenance when move destination is non-mem (llvm#124227)
This patch fixes a bug in the maintenance of the MemDGNode chain of the DAG. Whenever we move a memory instruction, the DAG gets notified about the move and maintains the chain of memory nodes. The bug was that if the destination of the move was not a memory instruction, then the memory node's next node would end up pointing to itself.
1 parent 73b4623 commit b41987b

File tree

3 files changed

+103
-26
lines changed

3 files changed

+103
-26
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,14 @@ class MemDGNode final : public DGNode {
218218
friend class PredIterator; // For MemPreds.
219219
/// Creates both edges: this<->N.
220220
void setNextNode(MemDGNode *N) {
221+
assert(N != this && "About to point to self!");
221222
NextMemN = N;
222223
if (NextMemN != nullptr)
223224
NextMemN->PrevMemN = this;
224225
}
225226
/// Creates both edges: N<->this.
226227
void setPrevNode(MemDGNode *N) {
228+
assert(N != this && "About to point to self!");
227229
PrevMemN = N;
228230
if (PrevMemN != nullptr)
229231
PrevMemN->NextMemN = this;
@@ -348,13 +350,15 @@ class DependencyGraph {
348350
void createNewNodes(const Interval<Instruction> &NewInterval);
349351

350352
/// Helper for `notify*Instr()`. \Returns the first MemDGNode that comes
351-
/// before \p N, including or excluding \p N based on \p IncludingN, or
352-
/// nullptr if not found.
353-
MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN) const;
353+
/// before \p N, skipping \p SkipN, including or excluding \p N based on
354+
/// \p IncludingN, or nullptr if not found.
355+
MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN,
356+
MemDGNode *SkipN = nullptr) const;
354357
/// Helper for `notifyMoveInstr()`. \Returns the first MemDGNode that comes
355-
/// after \p N, including or excluding \p N based on \p IncludingN, or nullptr
356-
/// if not found.
357-
MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN) const;
358+
/// after \p N, skipping \p SkipN, including or excluding \p N based on \p
359+
/// IncludingN, or nullptr if not found.
360+
MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN,
361+
MemDGNode *SkipN = nullptr) const;
358362

359363
/// Called by the callbacks when a new instruction \p I has been created.
360364
void notifyCreateInstr(Instruction *I);

llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -325,29 +325,31 @@ void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
325325
setDefUseUnscheduledSuccs(NewInterval);
326326
}
327327

328-
MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N,
329-
bool IncludingN) const {
328+
MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N, bool IncludingN,
329+
MemDGNode *SkipN) const {
330330
auto *I = N->getInstruction();
331331
for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr;
332332
PrevI = PrevI->getPrevNode()) {
333333
auto *PrevN = getNodeOrNull(PrevI);
334334
if (PrevN == nullptr)
335335
return nullptr;
336-
if (auto *PrevMemN = dyn_cast<MemDGNode>(PrevN))
336+
auto *PrevMemN = dyn_cast<MemDGNode>(PrevN);
337+
if (PrevMemN != nullptr && PrevMemN != SkipN)
337338
return PrevMemN;
338339
}
339340
return nullptr;
340341
}
341342

342-
MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N,
343-
bool IncludingN) const {
343+
MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN,
344+
MemDGNode *SkipN) const {
344345
auto *I = N->getInstruction();
345346
for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr;
346347
NextI = NextI->getNextNode()) {
347348
auto *NextN = getNodeOrNull(NextI);
348349
if (NextN == nullptr)
349350
return nullptr;
350-
if (auto *NextMemN = dyn_cast<MemDGNode>(NextN))
351+
auto *NextMemN = dyn_cast<MemDGNode>(NextN);
352+
if (NextMemN != nullptr && NextMemN != SkipN)
351353
return NextMemN;
352354
}
353355
return nullptr;
@@ -377,6 +379,20 @@ void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
377379
!(To == BB->end() && std::next(I->getIterator()) == BB->end()) &&
378380
"Should not have been called if destination is same as origin.");
379381

382+
// TODO: We can only handle fully internal movements within DAGInterval or at
383+
// the borders, i.e., right before the top or right after the bottom.
384+
assert(To.getNodeParent() == I->getParent() &&
385+
"TODO: We don't support movement across BBs!");
386+
assert(
387+
(To == std::next(DAGInterval.bottom()->getIterator()) ||
388+
(To != BB->end() && std::next(To) == DAGInterval.top()->getIterator()) ||
389+
(To != BB->end() && DAGInterval.contains(&*To))) &&
390+
"TODO: To should be either within the DAGInterval or right "
391+
"before/after it.");
392+
393+
// Make a copy of the DAGInterval before we update it.
394+
auto OrigDAGInterval = DAGInterval;
395+
380396
// Maintain the DAGInterval.
381397
DAGInterval.notifyMoveInstr(I, To);
382398

@@ -389,23 +405,37 @@ void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
389405
MemDGNode *MemN = dyn_cast<MemDGNode>(N);
390406
if (MemN == nullptr)
391407
return;
392-
// First detach it from the existing chain.
408+
409+
// First safely detach it from the existing chain.
393410
MemN->detachFromChain();
411+
394412
// Now insert it back into the chain at the new location.
395-
if (To != BB->end()) {
396-
DGNode *ToN = getNodeOrNull(&*To);
397-
if (ToN != nullptr) {
398-
MemN->setPrevNode(getMemDGNodeBefore(ToN, /*IncludingN=*/false));
399-
MemN->setNextNode(getMemDGNodeAfter(ToN, /*IncludingN=*/true));
400-
}
413+
//
414+
// We won't always have a DGNode to insert before it. If `To` is BB->end() or
415+
// if it points to an instr after DAGInterval.bottom() then we will have to
416+
// find a node to insert *after*.
417+
//
418+
// BB: BB:
419+
// I1 I1 ^
420+
// I2 I2 | DAGInteval [I1 to I3]
421+
// I3 I3 V
422+
// I4 I4 <- `To` == right after DAGInterval
423+
// <- `To` == BB->end()
424+
//
425+
if (To == BB->end() ||
426+
To == std::next(OrigDAGInterval.bottom()->getIterator())) {
427+
// If we don't have a node to insert before, find a node to insert after and
428+
// update the chain.
429+
DGNode *InsertAfterN = getNode(&*std::prev(To));
430+
MemN->setPrevNode(
431+
getMemDGNodeBefore(InsertAfterN, /*IncludingN=*/true, /*SkipN=*/MemN));
401432
} else {
402-
// MemN becomes the last instruction in the BB.
403-
auto *TermN = getNodeOrNull(BB->getTerminator());
404-
if (TermN != nullptr) {
405-
MemN->setPrevNode(getMemDGNodeBefore(TermN, /*IncludingN=*/false));
406-
} else {
407-
// The terminator is outside the DAG interval so do nothing.
408-
}
433+
// We have a node to insert before, so update the chain.
434+
DGNode *BeforeToN = getNode(&*To);
435+
MemN->setPrevNode(
436+
getMemDGNodeBefore(BeforeToN, /*IncludingN=*/false, /*SkipN=*/MemN));
437+
MemN->setNextNode(
438+
getMemDGNodeAfter(BeforeToN, /*IncludingN=*/true, /*SkipN=*/MemN));
409439
}
410440
}
411441

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,3 +926,46 @@ define void @foo(ptr %ptr, ptr %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
926926
EXPECT_EQ(LdN->getPrevNode(), S1N);
927927
EXPECT_EQ(LdN->getNextNode(), S2N);
928928
}
929+
930+
// Check that the mem chain is maintained correctly when the move destination is
931+
// not a mem node.
932+
TEST_F(DependencyGraphTest, MoveInstrCallbackWithNonMemInstrs) {
933+
parseIR(C, R"IR(
934+
define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %arg) {
935+
%ld = load i8, ptr %ptr
936+
%zext1 = zext i8 %arg to i32
937+
%zext2 = zext i8 %arg to i32
938+
store i8 %v1, ptr %ptr
939+
store i8 %v2, ptr %ptr
940+
ret void
941+
}
942+
)IR");
943+
llvm::Function *LLVMF = &*M->getFunction("foo");
944+
sandboxir::Context Ctx(C);
945+
auto *F = Ctx.createFunction(LLVMF);
946+
auto *BB = &*F->begin();
947+
auto It = BB->begin();
948+
auto *Ld = cast<sandboxir::LoadInst>(&*It++);
949+
[[maybe_unused]] auto *Zext1 = cast<sandboxir::CastInst>(&*It++);
950+
auto *Zext2 = cast<sandboxir::CastInst>(&*It++);
951+
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
952+
auto *S2 = cast<sandboxir::StoreInst>(&*It++);
953+
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
954+
955+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
956+
DAG.extend({Ld, S2});
957+
auto *LdN = cast<sandboxir::MemDGNode>(DAG.getNode(Ld));
958+
auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
959+
auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
960+
EXPECT_EQ(LdN->getNextNode(), S1N);
961+
EXPECT_EQ(S1N->getNextNode(), S2N);
962+
963+
S1->moveBefore(Zext2);
964+
EXPECT_EQ(LdN->getNextNode(), S1N);
965+
EXPECT_EQ(S1N->getNextNode(), S2N);
966+
967+
// Try move right after the end of the DAGInterval.
968+
S1->moveBefore(Ret);
969+
EXPECT_EQ(S2N->getNextNode(), S1N);
970+
EXPECT_EQ(S1N->getNextNode(), nullptr);
971+
}

0 commit comments

Comments
 (0)