Skip to content

Commit ea7c521

Browse files
authored
Mitigate onnxruntime segault in matmul due to zero-sized inputs (#135)
1 parent 2eb52ff commit ea7c521

File tree

1 file changed

+93
-4
lines changed

1 file changed

+93
-4
lines changed

ndonnx/_typed_array/onnx.py

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Literal, TypeAlias, TypeGuard, TypeVar, cast, overload
1111

1212
import numpy as np
13-
from spox import Tensor, Var, argument
13+
from spox import Tensor, Var, argument, build, inline
1414
from typing_extensions import Self
1515

1616
from ndonnx import DType
@@ -19,7 +19,7 @@
1919
from .._schema import DTypeInfoV1
2020
from . import TyArrayBase, safe_cast
2121
from . import ort_compat as op
22-
from .dtype_independent_funcs import maximum, minimum, where
22+
from .dtype_independent_funcs import maximum, minimum, where, zeros
2323
from .indexing import FancySlice
2424

2525
_ScalarInt: TypeAlias = "TyArrayInteger"
@@ -1394,8 +1394,14 @@ def __lt__(self, other: TyArrayBase | PyScalar) -> TyArrayBase: ...
13941394
def __lt__(self, other: TyArrayBase | PyScalar) -> TyArrayBase:
13951395
return self._apply(other, op.less, forward=True, result_type=TyArrayBool)
13961396

1397-
def __matmul__(self, other) -> TyArrayBase:
1398-
return self._apply(other, op.matmul, forward=True, result_type=TyArrayNumber)
1397+
def __matmul__(self, other: TyArrayBase) -> TyArrayNumber:
1398+
if isinstance(other, TyArrayNumber):
1399+
a, b = promote(self, other)
1400+
a = safe_cast(TyArrayNumber, a)
1401+
b = safe_cast(TyArrayNumber, b)
1402+
return _matmul_mitigate_zero_sized(a, b)
1403+
1404+
return NotImplemented
13991405

14001406
def __mul__(self, other: TyArrayBase | PyScalar) -> TyArrayBase:
14011407
return self._apply(other, op.mul, forward=True, result_type=TyArrayNumber)
@@ -2643,3 +2649,86 @@ def numeric_like(x: TyArrayBase) -> TyArrayInt32:
26432649
target_shape = ret.dynamic_shape
26442650

26452651
return tuple(a.broadcast_to(target_shape) for a in arrays)
2652+
2653+
2654+
def _matmul_mitigate_zero_sized(
2655+
a: TY_ARRAY_NUMBER, b: TY_ARRAY_NUMBER
2656+
) -> TY_ARRAY_NUMBER:
2657+
"""Mitigate onnxruntime bug for inputs with 0-length inputs to matmul."""
2658+
2659+
# Mitigations of this kind would be better situated in
2660+
# ort_compat.py, but doing the below in spox-operators is just
2661+
# needless self-inflicted pain.
2662+
2663+
# The core idea of this mitigation is to do a runtime check on the
2664+
# dimensions of the inputs and the then hide the matmul inside an
2665+
# if-subgraph. However, there are a couple of complications:
2666+
2667+
# 1. The output of the zero-size branch has to match the shape of
2668+
# the output of the regular matmul. The semantics of that shape
2669+
# operation are involved
2670+
# (https://data-apis.org/array-api/draft/API_specification/generated/array_api.matmul.html#array_api.matmul).
2671+
2672+
# 2. We have to disable the value propagation or else onnxruntime
2673+
# will throw a segfault when executing the matmul branch (which is
2674+
# always happening).
2675+
2676+
# We grind our way through 1. The second point is solved by creating
2677+
# a small ONNX model and inlining it. We could consider doing this
2678+
# in a more generic fashion to enable faster build times.
2679+
2680+
# TODO: Fix upstream in onnxruntime!
2681+
def dummy_n_1(a: TY_ARRAY_NUMBER, b: TY_ARRAY_NUMBER) -> TyArrayInt64:
2682+
return a.dynamic_shape[:-1]
2683+
2684+
def dummy_1_n(a: TY_ARRAY_NUMBER, b: TY_ARRAY_NUMBER) -> TyArrayInt64:
2685+
out_shape = b.dynamic_shape
2686+
out_shape = out_shape[:-2].concat([out_shape[-1:]], axis=0)
2687+
return out_shape
2688+
2689+
def dummy_n_m(a: TY_ARRAY_NUMBER, b: TY_ARRAY_NUMBER) -> TyArrayInt64:
2690+
a_shape = a.dynamic_shape
2691+
dummy_a_shape = a_shape.copy()
2692+
dummy_a_shape[-2:] = const(1)
2693+
2694+
b_shape = b.dynamic_shape
2695+
dummy_b_shape = b_shape.copy()
2696+
dummy_b_shape[-2:] = const(1)
2697+
2698+
dummy_a = zeros(shape=dummy_a_shape, dtype=uint8)
2699+
dummy_b = zeros(shape=dummy_b_shape, dtype=uint8)
2700+
2701+
out_shape = (dummy_a + dummy_b).dynamic_shape
2702+
out_shape[-2] = a_shape[-2]
2703+
out_shape[-1] = b_shape[-1]
2704+
return out_shape
2705+
2706+
a_ = type(a)(argument(a._var.unwrap_type()))
2707+
b_ = type(b)(argument(b._var.unwrap_type()))
2708+
2709+
if a_.ndim == 1 and b_.ndim == 1:
2710+
out_shape = const(np.asarray([], np.int64), int64)
2711+
elif a_.ndim >= 2 and b_.ndim == 1:
2712+
out_shape = dummy_n_1(a_, b_)
2713+
elif a_.ndim == 1 and b_.ndim >= 2:
2714+
out_shape = dummy_1_n(a_, b_)
2715+
elif a_.ndim >= 2 and b_.ndim >= 2:
2716+
out_shape = dummy_n_m(a_, b_)
2717+
else:
2718+
raise ValueError(
2719+
"unsupported input ranks for 'matmul': `{a_.ndim}` and `{b_.ndim}`"
2720+
)
2721+
2722+
dummy_out = zeros(shape=out_shape, dtype=a.dtype)._var
2723+
res = op.matmul(a_._var, b_._var)
2724+
2725+
(var,) = op.if_(
2726+
safe_cast(TyArrayBool, (a_.dynamic_size * b_.dynamic_size == 0))._var,
2727+
then_branch=lambda: [dummy_out],
2728+
else_branch=lambda: [res],
2729+
)
2730+
model = build({"a": a_._var, "b": b_._var}, {"res": var})
2731+
(var_with_propagation,) = inline(model)(a=a._var, b=b._var).values()
2732+
2733+
arr = type(a)(var_with_propagation)
2734+
return arr

0 commit comments

Comments
 (0)