Skip to content

Commit 9269aae

Browse files
authored
[mlir][mesh] fixes for 0d tensors (llvm#132948)
In some cases 0d tensors have no sharding. This PR provides a few minor fixes to account for such cases.
1 parent e8dfd70 commit 9269aae

File tree

6 files changed

+54
-26
lines changed

6 files changed

+54
-26
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

+2
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ inline bool isFullReplication(MeshSharding sharding) {
119119
inline mesh::MeshOp
120120
getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
121121
SymbolTableCollection &symbolTableCollection) {
122+
if (!meshSymbol)
123+
return nullptr;
122124
return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
123125
op, meshSymbol);
124126
}

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
269269

270270
Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
271271
RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
272-
if (rankedTensorType) {
272+
if (rankedTensorType && !rankedTensorType.getShape().empty()) {
273273
return shardShapedType(rankedTensorType, mesh, sharding);
274274
}
275275
return type;

mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -716,8 +716,8 @@ void mesh::spmdizeTriviallyShardableOperation(
716716
// Set the result types to the sharded counterparts.
717717
for (auto [oldResult, newResult, sharding] :
718718
llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
719-
newResult.setType(
720-
shardType(newResult.getType(),
721-
getMesh(&op, sharding.getMeshAttr(), symbolTable), sharding));
719+
newResult.setType(shardType(
720+
newResult.getType(),
721+
getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
722722
}
723723
}

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

+29-16
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ shardedBlockArgumentTypes(Block &block,
622622
block.getArguments(), std::back_inserter(res),
623623
[&symbolTableCollection](BlockArgument arg) {
624624
auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
625-
if (!rankedTensorArg) {
625+
if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
626626
return arg.getType();
627627
}
628628

@@ -672,7 +672,7 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
672672
llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
673673
TypedValue<RankedTensorType> rankedTensor =
674674
dyn_cast<TypedValue<RankedTensorType>>(operand);
675-
if (!rankedTensor) {
675+
if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
676676
return MeshSharding();
677677
}
678678

@@ -689,20 +689,33 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
689689
static std::vector<MeshSharding> getResultShardings(Operation &op) {
690690
std::vector<MeshSharding> res;
691691
res.reserve(op.getNumResults());
692-
llvm::transform(op.getResults(), std::back_inserter(res),
693-
[](OpResult result) {
694-
TypedValue<RankedTensorType> rankedTensor =
695-
dyn_cast<TypedValue<RankedTensorType>>(result);
696-
if (!rankedTensor) {
697-
return MeshSharding();
698-
}
699-
if (!result.hasOneUse()) {
700-
return MeshSharding();
701-
}
702-
Operation *userOp = *result.getUsers().begin();
703-
ShardOp shardOp = llvm::cast<ShardOp>(userOp);
704-
return MeshSharding(shardOp.getSharding());
705-
});
692+
llvm::transform(
693+
op.getResults(), std::back_inserter(res), [&op](OpResult result) {
694+
if (!result.hasOneUse() || result.use_empty()) {
695+
return MeshSharding();
696+
}
697+
TypedValue<RankedTensorType> rankedTensor =
698+
dyn_cast<TypedValue<RankedTensorType>>(result);
699+
if (!rankedTensor) {
700+
return MeshSharding();
701+
}
702+
Operation *userOp = *result.getUsers().begin();
703+
ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
704+
if (shardOp) {
705+
return MeshSharding(shardOp.getSharding());
706+
}
707+
if (rankedTensor.getType().getRank() == 0) {
708+
// This is a 0d tensor result without explicit sharding.
709+
// Find mesh symbol from operands, if any.
710+
// Shardings without mesh are not always fully supported yet.
711+
for (auto operand : op.getOperands()) {
712+
if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
713+
return MeshSharding(sharding.getMeshAttr());
714+
}
715+
}
716+
}
717+
return MeshSharding();
718+
});
706719
return res;
707720
}
708721

mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp

+12-6
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,25 @@ struct CreatorOpShardingInterface
5050
IRMapping &spmdizationMap,
5151
SymbolTableCollection &symbolTable,
5252
OpBuilder &builder) const {
53-
auto mesh =
54-
mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
55-
auto shardType = cast<ShapedType>(
56-
mesh::shardType(op->getResult(0).getType(), mesh, resultShardings[0]));
53+
assert(resultShardings.size() == 1);
54+
auto resType = cast<RankedTensorType>(op->getResult(0).getType());
55+
mlir::mesh::MeshOp mesh;
56+
ShapedType shardType;
57+
if (resType.getRank() > 0) {
58+
mesh = mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
59+
shardType =
60+
cast<ShapedType>(mesh::shardType(resType, mesh, resultShardings[0]));
61+
} else {
62+
shardType = resType;
63+
}
5764
Operation *newOp = nullptr;
5865
// if the sharding introduces a new dynamic dimension, we take it from
5966
// the dynamic sharding info. For now bail out if it's not
6067
// provided.
61-
assert(resultShardings.size() == 1);
6268
if (!shardType.hasStaticShape()) {
6369
assert(op->getResult(0).hasOneUse());
6470
SmallVector<Value> newOperands;
65-
auto oldType = cast<ShapedType>(op->getResult(0).getType());
71+
auto oldType = cast<ShapedType>(resType);
6672
assert(oldType.getRank() == shardType.getRank());
6773
int currOldOprndNum = -1;
6874
mesh::ShardShapeOp shapeForDevice;

mlir/test/Dialect/Tensor/mesh-spmdization.mlir

+7
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,10 @@ func.func @tensor_empty_same_static_dims_sizes() -> () {
4343

4444
return
4545
}
46+
47+
// CHECK-LABEL: func @tensor_empty_0d
48+
func.func @tensor_empty_0d() -> () {
49+
tensor.empty() : tensor<f32>
50+
// CHECK-NEXT: tensor.empty() : tensor<f32>
51+
return
52+
}

0 commit comments

Comments
 (0)