diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h index 1818ee03d2ec8..6823967ebca16 100644 --- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h +++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h @@ -411,8 +411,8 @@ Value *createSimpleReduction(IRBuilderBase &B, Value *Src, RecurKind RdxKind); /// Overloaded function to generate vector-predication intrinsics for /// reduction. -Value *createSimpleReduction(VectorBuilder &VB, Value *Src, - const RecurrenceDescriptor &Desc); +Value *createSimpleReduction(VectorBuilder &VB, Value *Src, RecurKind RdxKind, + FastMathFlags FMFs); /// Create a reduction of the given vector \p Src for a reduction of the /// kind RecurKind::IAnyOf or RecurKind::FAnyOf. The reduction operation is @@ -428,14 +428,12 @@ Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src, const RecurrenceDescriptor &Desc); /// Create an ordered reduction intrinsic using the given recurrence -/// descriptor \p Desc. -Value *createOrderedReduction(IRBuilderBase &B, - const RecurrenceDescriptor &Desc, Value *Src, +/// kind \p RdxKind. +Value *createOrderedReduction(IRBuilderBase &B, RecurKind RdxKind, Value *Src, Value *Start); /// Overloaded function to generate vector-predication intrinsics for ordered /// reduction. -Value *createOrderedReduction(VectorBuilder &VB, - const RecurrenceDescriptor &Desc, Value *Src, +Value *createOrderedReduction(VectorBuilder &VB, RecurKind RdxKind, Value *Src, Value *Start); /// Get the intersection (logical and) of all of the potential IR flags diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 185af8631454a..41f43a24e19e6 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -1333,24 +1333,21 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src, } Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src, - const RecurrenceDescriptor &Desc) { - RecurKind Kind = Desc.getRecurrenceKind(); + RecurKind Kind, FastMathFlags FMFs) { assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) && !RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) && "AnyOf or FindLastIV reductions are not supported."); Intrinsic::ID Id = getReductionIntrinsicID(Kind); auto *SrcTy = cast(Src->getType()); Type *SrcEltTy = SrcTy->getElementType(); - Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags()); + Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, FMFs); Value *Ops[] = {Iden, Src}; return VBuilder.createSimpleReduction(Id, SrcTy, Ops); } -Value *llvm::createOrderedReduction(IRBuilderBase &B, - const RecurrenceDescriptor &Desc, +Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind, Value *Src, Value *Start) { - assert((Desc.getRecurrenceKind() == RecurKind::FAdd || - Desc.getRecurrenceKind() == RecurKind::FMulAdd) && + assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) && "Unexpected reduction kind"); assert(Src->getType()->isVectorTy() && "Expected a vector type"); assert(!Start->getType()->isVectorTy() && "Expected a scalar type"); @@ -1358,11 +1355,9 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B, return B.CreateFAddReduce(Start, Src); } -Value *llvm::createOrderedReduction(VectorBuilder &VBuilder, - const RecurrenceDescriptor &Desc, +Value *llvm::createOrderedReduction(VectorBuilder &VBuilder, RecurKind Kind, Value *Src, Value *Start) { - assert((Desc.getRecurrenceKind() == RecurKind::FAdd || - Desc.getRecurrenceKind() == RecurKind::FMulAdd) && + assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) && "Unexpected reduction kind"); assert(Src->getType()->isVectorTy() && "Expected a vector type"); assert(!Start->getType()->isVectorTy() && "Expected a scalar type"); diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index d315dbe9b4170..d5b6b47e21c3b 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -2311,7 +2311,7 @@ void VPReductionRecipe::execute(VPTransformState &State) { if (IsOrdered) { if (State.VF.isVector()) NewRed = - createOrderedReduction(State.Builder, RdxDesc, NewVecOp, PrevInChain); + createOrderedReduction(State.Builder, Kind, NewVecOp, PrevInChain); else NewRed = State.Builder.CreateBinOp( (Instruction::BinaryOps)RdxDesc.getOpcode(), PrevInChain, NewVecOp); @@ -2356,9 +2356,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) { Value *NewRed; if (isOrdered()) { - NewRed = createOrderedReduction(VBuilder, RdxDesc, VecOp, Prev); + NewRed = createOrderedReduction(VBuilder, Kind, VecOp, Prev); } else { - NewRed = createSimpleReduction(VBuilder, VecOp, RdxDesc); + NewRed = createSimpleReduction(VBuilder, VecOp, Kind, getFastMathFlags()); if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev); else