Skip to content

Commit 86ada1b

Browse files
Zero size matmul mitigation (#102)
1 parent a73f1b3 commit 86ada1b

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

ndonnx/_opset_extensions.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,25 @@ def clip(input: _CoreArray, min: _CoreArray, max: _CoreArray) -> _CoreArray:
437437

438438
@eager_propagate
439439
def matmul(a: _CoreArray, b: _CoreArray) -> _CoreArray:
440-
return _CoreArray(op.matmul(a.var, b.var))
440+
# TODO(adityagoel4512): this requires an upstream patch in onnxruntime
441+
# onnxruntime goes into UB with zero size inputs
442+
(out,) = op.if_(
443+
op.equal(op.size(a.var), op.const(0, dtype=np.int64)),
444+
then_branch=lambda: [
445+
op.const(
446+
np.zeros(
447+
(),
448+
dtype=np.result_type(
449+
a.var.unwrap_tensor().dtype, b.var.unwrap_tensor().dtype
450+
),
451+
)
452+
),
453+
],
454+
else_branch=lambda: [
455+
op.matmul(a.var, b.var),
456+
],
457+
)
458+
return _CoreArray(out)
441459

442460

443461
@eager_propagate

tests/test_core.py

+23
Original file line numberDiff line numberDiff line change
@@ -1067,3 +1067,26 @@ def test_repeat_raises(a, repeats, axis):
10671067
def test_getitem_bool_raises(x, index):
10681068
with pytest.raises(IndexError):
10691069
x[index]
1070+
1071+
1072+
@pytest.mark.parametrize(
1073+
"x, y",
1074+
[
1075+
(
1076+
ndx.asarray([], dtype=ndx.uint8),
1077+
ndx.asarray([], dtype=ndx.uint8),
1078+
),
1079+
(
1080+
ndx.asarray([], dtype=ndx.float32),
1081+
ndx.asarray([], dtype=ndx.int16),
1082+
),
1083+
(
1084+
ndx.asarray([1, 2, 3], dtype=ndx.uint8),
1085+
ndx.asarray([1, 2, 3], dtype=ndx.float32),
1086+
),
1087+
],
1088+
)
1089+
def test_matmul_zero_dims(x, y):
1090+
ndx_result = x @ y
1091+
np_result = x.to_numpy() @ y.to_numpy()
1092+
assert_array_equal(ndx_result.to_numpy(), np_result)

0 commit comments

Comments
 (0)