Skip to content

[LV] Split RecurrenceDescriptor into RecurKind + FastMathFlags in LoopUtils. NFC #132014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 19, 2025

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Mar 19, 2025

Split off from #131300, this splits up RecurrenceDescriptor arguments so that arbitrary recurrence kinds may be used down the line.

…pUtils. NFC

Split off from llvm#131300, this splits up RecurrenceDescriptor arguments so that
arbitrary recurrence kinds may be used down the line.
@llvmbot
Copy link
Member

llvmbot commented Mar 19, 2025

@llvm/pr-subscribers-vectorizers

Author: Luke Lau (lukel97)

Changes

Split off from #131300, this splits up RecurrenceDescriptor arguments so that
arbitrary recurrence kinds may be used down the line.


Full diff: https://github.com/llvm/llvm-project/pull/132014.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Utils/LoopUtils.h (+4-6)
  • (modified) llvm/lib/Transforms/Utils/LoopUtils.cpp (+6-11)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+3-3)
diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index 1818ee03d2ec8..29e9f5458923d 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -411,8 +411,8 @@ Value *createSimpleReduction(IRBuilderBase &B, Value *Src,
                              RecurKind RdxKind);
 /// Overloaded function to generate vector-predication intrinsics for
 /// reduction.
-Value *createSimpleReduction(VectorBuilder &VB, Value *Src,
-                             const RecurrenceDescriptor &Desc);
+Value *createSimpleReduction(VectorBuilder &VB, Value *Src, RecurKind RdxKind,
+                             FastMathFlags FMFs);
 
 /// Create a reduction of the given vector \p Src for a reduction of the
 /// kind RecurKind::IAnyOf or RecurKind::FAnyOf. The reduction operation is
@@ -429,13 +429,11 @@ Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
 
 /// Create an ordered reduction intrinsic using the given recurrence
 /// descriptor \p Desc.
-Value *createOrderedReduction(IRBuilderBase &B,
-                              const RecurrenceDescriptor &Desc, Value *Src,
+Value *createOrderedReduction(IRBuilderBase &B, RecurKind RdxKind, Value *Src,
                               Value *Start);
 /// Overloaded function to generate vector-predication intrinsics for ordered
 /// reduction.
-Value *createOrderedReduction(VectorBuilder &VB,
-                              const RecurrenceDescriptor &Desc, Value *Src,
+Value *createOrderedReduction(VectorBuilder &VB, RecurKind RdxKind, Value *Src,
                               Value *Start);
 
 /// Get the intersection (logical and) of all of the potential IR flags
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 185af8631454a..41f43a24e19e6 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1333,24 +1333,21 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
 }
 
 Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
-                                   const RecurrenceDescriptor &Desc) {
-  RecurKind Kind = Desc.getRecurrenceKind();
+                                   RecurKind Kind, FastMathFlags FMFs) {
   assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
          !RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
          "AnyOf or FindLastIV reductions are not supported.");
   Intrinsic::ID Id = getReductionIntrinsicID(Kind);
   auto *SrcTy = cast<VectorType>(Src->getType());
   Type *SrcEltTy = SrcTy->getElementType();
-  Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
+  Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, FMFs);
   Value *Ops[] = {Iden, Src};
   return VBuilder.createSimpleReduction(Id, SrcTy, Ops);
 }
 
-Value *llvm::createOrderedReduction(IRBuilderBase &B,
-                                    const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
                                     Value *Src, Value *Start) {
-  assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
-          Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+  assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
          "Unexpected reduction kind");
   assert(Src->getType()->isVectorTy() && "Expected a vector type");
   assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
@@ -1358,11 +1355,9 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B,
   return B.CreateFAddReduce(Start, Src);
 }
 
-Value *llvm::createOrderedReduction(VectorBuilder &VBuilder,
-                                    const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(VectorBuilder &VBuilder, RecurKind Kind,
                                     Value *Src, Value *Start) {
-  assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
-          Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+  assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
          "Unexpected reduction kind");
   assert(Src->getType()->isVectorTy() && "Expected a vector type");
   assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index d315dbe9b4170..d5b6b47e21c3b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2311,7 +2311,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
   if (IsOrdered) {
     if (State.VF.isVector())
       NewRed =
-          createOrderedReduction(State.Builder, RdxDesc, NewVecOp, PrevInChain);
+          createOrderedReduction(State.Builder, Kind, NewVecOp, PrevInChain);
     else
       NewRed = State.Builder.CreateBinOp(
           (Instruction::BinaryOps)RdxDesc.getOpcode(), PrevInChain, NewVecOp);
@@ -2356,9 +2356,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
 
   Value *NewRed;
   if (isOrdered()) {
-    NewRed = createOrderedReduction(VBuilder, RdxDesc, VecOp, Prev);
+    NewRed = createOrderedReduction(VBuilder, Kind, VecOp, Prev);
   } else {
-    NewRed = createSimpleReduction(VBuilder, VecOp, RdxDesc);
+    NewRed = createSimpleReduction(VBuilder, VecOp, Kind, getFastMathFlags());
     if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
       NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev);
     else

@llvmbot
Copy link
Member

llvmbot commented Mar 19, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Luke Lau (lukel97)

Changes

Split off from #131300, this splits up RecurrenceDescriptor arguments so that
arbitrary recurrence kinds may be used down the line.


Full diff: https://github.com/llvm/llvm-project/pull/132014.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Utils/LoopUtils.h (+4-6)
  • (modified) llvm/lib/Transforms/Utils/LoopUtils.cpp (+6-11)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+3-3)
diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index 1818ee03d2ec8..29e9f5458923d 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -411,8 +411,8 @@ Value *createSimpleReduction(IRBuilderBase &B, Value *Src,
                              RecurKind RdxKind);
 /// Overloaded function to generate vector-predication intrinsics for
 /// reduction.
-Value *createSimpleReduction(VectorBuilder &VB, Value *Src,
-                             const RecurrenceDescriptor &Desc);
+Value *createSimpleReduction(VectorBuilder &VB, Value *Src, RecurKind RdxKind,
+                             FastMathFlags FMFs);
 
 /// Create a reduction of the given vector \p Src for a reduction of the
 /// kind RecurKind::IAnyOf or RecurKind::FAnyOf. The reduction operation is
@@ -429,13 +429,11 @@ Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
 
 /// Create an ordered reduction intrinsic using the given recurrence
 /// descriptor \p Desc.
-Value *createOrderedReduction(IRBuilderBase &B,
-                              const RecurrenceDescriptor &Desc, Value *Src,
+Value *createOrderedReduction(IRBuilderBase &B, RecurKind RdxKind, Value *Src,
                               Value *Start);
 /// Overloaded function to generate vector-predication intrinsics for ordered
 /// reduction.
-Value *createOrderedReduction(VectorBuilder &VB,
-                              const RecurrenceDescriptor &Desc, Value *Src,
+Value *createOrderedReduction(VectorBuilder &VB, RecurKind RdxKind, Value *Src,
                               Value *Start);
 
 /// Get the intersection (logical and) of all of the potential IR flags
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 185af8631454a..41f43a24e19e6 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1333,24 +1333,21 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
 }
 
 Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
-                                   const RecurrenceDescriptor &Desc) {
-  RecurKind Kind = Desc.getRecurrenceKind();
+                                   RecurKind Kind, FastMathFlags FMFs) {
   assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
          !RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
          "AnyOf or FindLastIV reductions are not supported.");
   Intrinsic::ID Id = getReductionIntrinsicID(Kind);
   auto *SrcTy = cast<VectorType>(Src->getType());
   Type *SrcEltTy = SrcTy->getElementType();
-  Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
+  Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, FMFs);
   Value *Ops[] = {Iden, Src};
   return VBuilder.createSimpleReduction(Id, SrcTy, Ops);
 }
 
-Value *llvm::createOrderedReduction(IRBuilderBase &B,
-                                    const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
                                     Value *Src, Value *Start) {
-  assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
-          Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+  assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
          "Unexpected reduction kind");
   assert(Src->getType()->isVectorTy() && "Expected a vector type");
   assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
@@ -1358,11 +1355,9 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B,
   return B.CreateFAddReduce(Start, Src);
 }
 
-Value *llvm::createOrderedReduction(VectorBuilder &VBuilder,
-                                    const RecurrenceDescriptor &Desc,
+Value *llvm::createOrderedReduction(VectorBuilder &VBuilder, RecurKind Kind,
                                     Value *Src, Value *Start) {
-  assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
-          Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
+  assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
          "Unexpected reduction kind");
   assert(Src->getType()->isVectorTy() && "Expected a vector type");
   assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index d315dbe9b4170..d5b6b47e21c3b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2311,7 +2311,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
   if (IsOrdered) {
     if (State.VF.isVector())
       NewRed =
-          createOrderedReduction(State.Builder, RdxDesc, NewVecOp, PrevInChain);
+          createOrderedReduction(State.Builder, Kind, NewVecOp, PrevInChain);
     else
       NewRed = State.Builder.CreateBinOp(
           (Instruction::BinaryOps)RdxDesc.getOpcode(), PrevInChain, NewVecOp);
@@ -2356,9 +2356,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
 
   Value *NewRed;
   if (isOrdered()) {
-    NewRed = createOrderedReduction(VBuilder, RdxDesc, VecOp, Prev);
+    NewRed = createOrderedReduction(VBuilder, Kind, VecOp, Prev);
   } else {
-    NewRed = createSimpleReduction(VBuilder, VecOp, RdxDesc);
+    NewRed = createSimpleReduction(VBuilder, VecOp, Kind, getFastMathFlags());
     if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
       NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev);
     else

@lukel97 lukel97 force-pushed the loop-vectorize/looputils-recurkind branch from 97b8536 to 824559b Compare March 19, 2025 12:35
lukel97 added a commit to lukel97/llvm-project that referenced this pull request Mar 19, 2025
The other createSimpleReduction takes the FMFs from the IRBuilder, so this aligns the VectorBuilder variant to do the same and reduce the possibility of there being a mismatch in flags.

Stacked on llvm#132014
Copy link
Contributor

@david-arm david-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@lukel97 lukel97 merged commit f536f71 into llvm:main Mar 19, 2025
9 of 10 checks passed
lukel97 added a commit to lukel97/llvm-project that referenced this pull request Mar 19, 2025
The other createSimpleReduction takes the FMFs from the IRBuilder, so this aligns the VectorBuilder variant to do the same and reduce the possibility of there being a mismatch in flags.

Stacked on llvm#132014
pawosm-arm pushed a commit to pawosm-arm/llvm-project that referenced this pull request Apr 10, 2025
…pUtils. NFC (llvm#132014)

Split off from llvm#131300, this splits up RecurrenceDescriptor arguments so
that arbitrary recurrence kinds may be used down the line.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants