@@ -622,7 +622,7 @@ shardedBlockArgumentTypes(Block &block,
622
622
block.getArguments (), std::back_inserter (res),
623
623
[&symbolTableCollection](BlockArgument arg) {
624
624
auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
625
- if (!rankedTensorArg) {
625
+ if (!rankedTensorArg || rankedTensorArg. getType (). getRank () == 0 ) {
626
626
return arg.getType ();
627
627
}
628
628
@@ -672,7 +672,7 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
672
672
llvm::transform (op.getOperands (), std::back_inserter (res), [](Value operand) {
673
673
TypedValue<RankedTensorType> rankedTensor =
674
674
dyn_cast<TypedValue<RankedTensorType>>(operand);
675
- if (!rankedTensor) {
675
+ if (!rankedTensor || rankedTensor. getType (). getRank () == 0 ) {
676
676
return MeshSharding ();
677
677
}
678
678
@@ -689,20 +689,33 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
689
689
static std::vector<MeshSharding> getResultShardings (Operation &op) {
690
690
std::vector<MeshSharding> res;
691
691
res.reserve (op.getNumResults ());
692
- llvm::transform (op.getResults (), std::back_inserter (res),
693
- [](OpResult result) {
694
- TypedValue<RankedTensorType> rankedTensor =
695
- dyn_cast<TypedValue<RankedTensorType>>(result);
696
- if (!rankedTensor) {
697
- return MeshSharding ();
698
- }
699
- if (!result.hasOneUse ()) {
700
- return MeshSharding ();
701
- }
702
- Operation *userOp = *result.getUsers ().begin ();
703
- ShardOp shardOp = llvm::cast<ShardOp>(userOp);
704
- return MeshSharding (shardOp.getSharding ());
705
- });
692
+ llvm::transform (
693
+ op.getResults (), std::back_inserter (res), [&op](OpResult result) {
694
+ if (!result.hasOneUse () || result.use_empty ()) {
695
+ return MeshSharding ();
696
+ }
697
+ TypedValue<RankedTensorType> rankedTensor =
698
+ dyn_cast<TypedValue<RankedTensorType>>(result);
699
+ if (!rankedTensor) {
700
+ return MeshSharding ();
701
+ }
702
+ Operation *userOp = *result.getUsers ().begin ();
703
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
704
+ if (shardOp) {
705
+ return MeshSharding (shardOp.getSharding ());
706
+ }
707
+ if (rankedTensor.getType ().getRank () == 0 ) {
708
+ // This is a 0d tensor result without explicit sharding.
709
+ // Find mesh symbol from operands, if any.
710
+ // Shardings without mesh are not always fully supported yet.
711
+ for (auto operand : op.getOperands ()) {
712
+ if (auto sharding = operand.getDefiningOp <ShardingOp>()) {
713
+ return MeshSharding (sharding.getMeshAttr ());
714
+ }
715
+ }
716
+ }
717
+ return MeshSharding ();
718
+ });
706
719
return res;
707
720
}
708
721
0 commit comments