-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[LV] Get FMFs from VectorBuilder in createSimpleReduction. NFC #132017
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LV] Get FMFs from VectorBuilder in createSimpleReduction. NFC #132017
Conversation
@llvm/pr-subscribers-vectorizers @llvm/pr-subscribers-llvm-transforms Author: Luke Lau (lukel97) ChangesThe other createSimpleReduction takes the FMFs from the IRBuilder, so this aligns the VectorBuilder variant to do the same and reduce the possibility of there being a mismatch in flags. Stacked on #132014 Full diff: https://github.com/llvm/llvm-project/pull/132017.diff 4 Files Affected:
diff --git a/llvm/include/llvm/IR/VectorBuilder.h b/llvm/include/llvm/IR/VectorBuilder.h
index b0277c2b52595..35d8e57576817 100644
--- a/llvm/include/llvm/IR/VectorBuilder.h
+++ b/llvm/include/llvm/IR/VectorBuilder.h
@@ -87,6 +87,9 @@ class VectorBuilder {
StaticVectorLength = ElementCount::getFixed(NewFixedVL);
return *this;
}
+
+ FastMathFlags &getFastMathFlags() const { return Builder.getFastMathFlags(); }
+
// TODO: setStaticVL(ElementCount) for scalable types.
// Emit a VP intrinsic call that mimics a regular instruction.
diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index 1818ee03d2ec8..193f505fb03fe 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -411,8 +411,7 @@ Value *createSimpleReduction(IRBuilderBase &B, Value *Src,
RecurKind RdxKind);
/// Overloaded function to generate vector-predication intrinsics for
/// reduction.
-Value *createSimpleReduction(VectorBuilder &VB, Value *Src,
- const RecurrenceDescriptor &Desc);
+Value *createSimpleReduction(VectorBuilder &VB, Value *Src, RecurKind RdxKind);
/// Create a reduction of the given vector \p Src for a reduction of the
/// kind RecurKind::IAnyOf or RecurKind::FAnyOf. The reduction operation is
@@ -428,14 +427,12 @@ Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
const RecurrenceDescriptor &Desc);
/// Create an ordered reduction intrinsic using the given recurrence
-/// descriptor \p Desc.
-Value *createOrderedReduction(IRBuilderBase &B,
- const RecurrenceDescriptor &Desc, Value *Src,
+/// kind \p RdxKind.
+Value *createOrderedReduction(IRBuilderBase &B, RecurKind RdxKind, Value *Src,
Value *Start);
/// Overloaded function to generate vector-predication intrinsics for ordered
/// reduction.
-Value *createOrderedReduction(VectorBuilder &VB,
- const RecurrenceDescriptor &Desc, Value *Src,
+Value *createOrderedReduction(VectorBuilder &VB, RecurKind RdxKind, Value *Src,
Value *Start);
/// Get the intersection (logical and) of all of the potential IR flags
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 185af8631454a..96d7b3342349e 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1333,24 +1333,22 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
}
Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
- const RecurrenceDescriptor &Desc) {
- RecurKind Kind = Desc.getRecurrenceKind();
+ RecurKind Kind) {
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
"AnyOf or FindLastIV reductions are not supported.");
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
auto *SrcTy = cast<VectorType>(Src->getType());
Type *SrcEltTy = SrcTy->getElementType();
- Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
+ Value *Iden =
+ getRecurrenceIdentity(Kind, SrcEltTy, VBuilder.getFastMathFlags());
Value *Ops[] = {Iden, Src};
return VBuilder.createSimpleReduction(Id, SrcTy, Ops);
}
-Value *llvm::createOrderedReduction(IRBuilderBase &B,
- const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
Value *Src, Value *Start) {
- assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
- Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+ assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
"Unexpected reduction kind");
assert(Src->getType()->isVectorTy() && "Expected a vector type");
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
@@ -1358,11 +1356,9 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B,
return B.CreateFAddReduce(Start, Src);
}
-Value *llvm::createOrderedReduction(VectorBuilder &VBuilder,
- const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(VectorBuilder &VBuilder, RecurKind Kind,
Value *Src, Value *Start) {
- assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
- Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+ assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
"Unexpected reduction kind");
assert(Src->getType()->isVectorTy() && "Expected a vector type");
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index d315dbe9b4170..af15b6b6d3dba 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2311,7 +2311,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
if (IsOrdered) {
if (State.VF.isVector())
NewRed =
- createOrderedReduction(State.Builder, RdxDesc, NewVecOp, PrevInChain);
+ createOrderedReduction(State.Builder, Kind, NewVecOp, PrevInChain);
else
NewRed = State.Builder.CreateBinOp(
(Instruction::BinaryOps)RdxDesc.getOpcode(), PrevInChain, NewVecOp);
@@ -2356,9 +2356,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
Value *NewRed;
if (isOrdered()) {
- NewRed = createOrderedReduction(VBuilder, RdxDesc, VecOp, Prev);
+ NewRed = createOrderedReduction(VBuilder, Kind, VecOp, Prev);
} else {
- NewRed = createSimpleReduction(VBuilder, VecOp, RdxDesc);
+ NewRed = createSimpleReduction(VBuilder, VecOp, Kind);
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev);
else
|
@llvm/pr-subscribers-llvm-ir Author: Luke Lau (lukel97) ChangesThe other createSimpleReduction takes the FMFs from the IRBuilder, so this aligns the VectorBuilder variant to do the same and reduce the possibility of there being a mismatch in flags. Stacked on #132014 Full diff: https://github.com/llvm/llvm-project/pull/132017.diff 4 Files Affected:
diff --git a/llvm/include/llvm/IR/VectorBuilder.h b/llvm/include/llvm/IR/VectorBuilder.h
index b0277c2b52595..35d8e57576817 100644
--- a/llvm/include/llvm/IR/VectorBuilder.h
+++ b/llvm/include/llvm/IR/VectorBuilder.h
@@ -87,6 +87,9 @@ class VectorBuilder {
StaticVectorLength = ElementCount::getFixed(NewFixedVL);
return *this;
}
+
+ FastMathFlags &getFastMathFlags() const { return Builder.getFastMathFlags(); }
+
// TODO: setStaticVL(ElementCount) for scalable types.
// Emit a VP intrinsic call that mimics a regular instruction.
diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index 1818ee03d2ec8..193f505fb03fe 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -411,8 +411,7 @@ Value *createSimpleReduction(IRBuilderBase &B, Value *Src,
RecurKind RdxKind);
/// Overloaded function to generate vector-predication intrinsics for
/// reduction.
-Value *createSimpleReduction(VectorBuilder &VB, Value *Src,
- const RecurrenceDescriptor &Desc);
+Value *createSimpleReduction(VectorBuilder &VB, Value *Src, RecurKind RdxKind);
/// Create a reduction of the given vector \p Src for a reduction of the
/// kind RecurKind::IAnyOf or RecurKind::FAnyOf. The reduction operation is
@@ -428,14 +427,12 @@ Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
const RecurrenceDescriptor &Desc);
/// Create an ordered reduction intrinsic using the given recurrence
-/// descriptor \p Desc.
-Value *createOrderedReduction(IRBuilderBase &B,
- const RecurrenceDescriptor &Desc, Value *Src,
+/// kind \p RdxKind.
+Value *createOrderedReduction(IRBuilderBase &B, RecurKind RdxKind, Value *Src,
Value *Start);
/// Overloaded function to generate vector-predication intrinsics for ordered
/// reduction.
-Value *createOrderedReduction(VectorBuilder &VB,
- const RecurrenceDescriptor &Desc, Value *Src,
+Value *createOrderedReduction(VectorBuilder &VB, RecurKind RdxKind, Value *Src,
Value *Start);
/// Get the intersection (logical and) of all of the potential IR flags
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 185af8631454a..96d7b3342349e 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1333,24 +1333,22 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
}
Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
- const RecurrenceDescriptor &Desc) {
- RecurKind Kind = Desc.getRecurrenceKind();
+ RecurKind Kind) {
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
"AnyOf or FindLastIV reductions are not supported.");
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
auto *SrcTy = cast<VectorType>(Src->getType());
Type *SrcEltTy = SrcTy->getElementType();
- Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
+ Value *Iden =
+ getRecurrenceIdentity(Kind, SrcEltTy, VBuilder.getFastMathFlags());
Value *Ops[] = {Iden, Src};
return VBuilder.createSimpleReduction(Id, SrcTy, Ops);
}
-Value *llvm::createOrderedReduction(IRBuilderBase &B,
- const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
Value *Src, Value *Start) {
- assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
- Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+ assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
"Unexpected reduction kind");
assert(Src->getType()->isVectorTy() && "Expected a vector type");
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
@@ -1358,11 +1356,9 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B,
return B.CreateFAddReduce(Start, Src);
}
-Value *llvm::createOrderedReduction(VectorBuilder &VBuilder,
- const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(VectorBuilder &VBuilder, RecurKind Kind,
Value *Src, Value *Start) {
- assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
- Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+ assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
"Unexpected reduction kind");
assert(Src->getType()->isVectorTy() && "Expected a vector type");
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index d315dbe9b4170..af15b6b6d3dba 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2311,7 +2311,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
if (IsOrdered) {
if (State.VF.isVector())
NewRed =
- createOrderedReduction(State.Builder, RdxDesc, NewVecOp, PrevInChain);
+ createOrderedReduction(State.Builder, Kind, NewVecOp, PrevInChain);
else
NewRed = State.Builder.CreateBinOp(
(Instruction::BinaryOps)RdxDesc.getOpcode(), PrevInChain, NewVecOp);
@@ -2356,9 +2356,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
Value *NewRed;
if (isOrdered()) {
- NewRed = createOrderedReduction(VBuilder, RdxDesc, VecOp, Prev);
+ NewRed = createOrderedReduction(VBuilder, Kind, VecOp, Prev);
} else {
- NewRed = createSimpleReduction(VBuilder, VecOp, RdxDesc);
+ NewRed = createSimpleReduction(VBuilder, VecOp, Kind);
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev);
else
|
The other createSimpleReduction takes the FMFs from the IRBuilder, so this aligns the VectorBuilder variant to do the same and reduce the possibility of there being a mismatch in flags. Stacked on llvm#132014
9e815ce
to
a7c5a13
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
llvm/include/llvm/IR/VectorBuilder.h
Outdated
@@ -87,6 +87,9 @@ class VectorBuilder { | |||
StaticVectorLength = ElementCount::getFixed(NewFixedVL); | |||
return *this; | |||
} | |||
|
|||
FastMathFlags &getFastMathFlags() const { return Builder.getFastMathFlags(); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be good to return const FastMathFlags
and document
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks!
The other createSimpleReduction takes the FMFs from the IRBuilder, so this aligns the VectorBuilder variant to do the same and reduce the possibility of there being a mismatch in flags.
Stacked on #132014