@@ -4757,9 +4757,10 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
4757
4757
};
4758
4758
4759
4759
// / Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
4760
- // / This only applies when the shape of the broadcast source is a suffix of the
4761
- // / shape of the result (i.e. when broadcast without reshape is expressive
4762
- // / enough to capture the result in a single op).
4760
+ // / This only applies when the shape of the broadcast source
4761
+ // / 1. is a suffix of the shape of the result (i.e. when broadcast without
4762
+ // / reshape is expressive enough to capture the result in a single op), or
4763
+ // / 2. has the same element count as the shape cast result.
4763
4764
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
4764
4765
public:
4765
4766
using OpRewritePattern::OpRewritePattern;
@@ -4771,23 +4772,35 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
4771
4772
if (!broadcastOp)
4772
4773
return failure ();
4773
4774
4774
- auto broadcastSourceVectorType =
4775
- llvm::dyn_cast<VectorType>(broadcastOp.getSourceType ());
4776
- auto broadcastSourceShape = broadcastSourceVectorType
4777
- ? broadcastSourceVectorType.getShape ()
4778
- : ArrayRef<int64_t >{};
4779
- auto shapeCastTargetShape = shapeCastOp.getResultVectorType ().getShape ();
4780
-
4781
- // Bail if `broadcastSourceShape` is not a suffix of the result.
4782
- bool isSuffix = (broadcastSourceShape == shapeCastTargetShape.take_back (
4783
- broadcastSourceShape.size ()));
4784
- if (!isSuffix)
4785
- return failure ();
4775
+ ArrayRef<int64_t > broadcastSourceShape;
4776
+ if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType ()))
4777
+ broadcastSourceShape = srcType.getShape ();
4778
+ ArrayRef<int64_t > shapeCastTargetShape =
4779
+ shapeCastOp.getResultVectorType ().getShape ();
4786
4780
4787
- rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
4788
- shapeCastOp, shapeCastOp.getResultVectorType (),
4789
- broadcastOp.getSource ());
4790
- return success ();
4781
+ // If `broadcastSourceShape` is a suffix of the result, we can just replace
4782
+ // with a broadcast to the final shape.
4783
+ if (broadcastSourceShape ==
4784
+ shapeCastTargetShape.take_back (broadcastSourceShape.size ())) {
4785
+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
4786
+ shapeCastOp, shapeCastOp.getResultVectorType (),
4787
+ broadcastOp.getSource ());
4788
+ return success ();
4789
+ }
4790
+
4791
+ // Otherwise, if the final result has the same element count, we can replace
4792
+ // with a shape cast.
4793
+ if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType ())) {
4794
+ if (srcType.getNumElements () ==
4795
+ shapeCastOp.getResultVectorType ().getNumElements ()) {
4796
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(
4797
+ shapeCastOp, shapeCastOp.getResultVectorType (),
4798
+ broadcastOp.getSource ());
4799
+ return success ();
4800
+ }
4801
+ }
4802
+
4803
+ return failure ();
4791
4804
}
4792
4805
};
4793
4806
0 commit comments