|
10 | 10 | from typing import Literal, TypeAlias, TypeGuard, TypeVar, cast, overload
|
11 | 11 |
|
12 | 12 | import numpy as np
|
13 |
| -from spox import Tensor, Var, argument |
| 13 | +from spox import Tensor, Var, argument, build, inline |
14 | 14 | from typing_extensions import Self
|
15 | 15 |
|
16 | 16 | from ndonnx import DType
|
|
19 | 19 | from .._schema import DTypeInfoV1
|
20 | 20 | from . import TyArrayBase, safe_cast
|
21 | 21 | from . import ort_compat as op
|
22 |
| -from .dtype_independent_funcs import maximum, minimum, where |
| 22 | +from .dtype_independent_funcs import maximum, minimum, where, zeros |
23 | 23 | from .indexing import FancySlice
|
24 | 24 |
|
25 | 25 | _ScalarInt: TypeAlias = "TyArrayInteger"
|
@@ -1394,8 +1394,14 @@ def __lt__(self, other: TyArrayBase | PyScalar) -> TyArrayBase: ...
|
1394 | 1394 | def __lt__(self, other: TyArrayBase | PyScalar) -> TyArrayBase:
|
1395 | 1395 | return self._apply(other, op.less, forward=True, result_type=TyArrayBool)
|
1396 | 1396 |
|
1397 |
| - def __matmul__(self, other) -> TyArrayBase: |
1398 |
| - return self._apply(other, op.matmul, forward=True, result_type=TyArrayNumber) |
| 1397 | + def __matmul__(self, other: TyArrayBase) -> TyArrayNumber: |
| 1398 | + if isinstance(other, TyArrayNumber): |
| 1399 | + a, b = promote(self, other) |
| 1400 | + a = safe_cast(TyArrayNumber, a) |
| 1401 | + b = safe_cast(TyArrayNumber, b) |
| 1402 | + return _matmul_mitigate_zero_sized(a, b) |
| 1403 | + |
| 1404 | + return NotImplemented |
1399 | 1405 |
|
1400 | 1406 | def __mul__(self, other: TyArrayBase | PyScalar) -> TyArrayBase:
|
1401 | 1407 | return self._apply(other, op.mul, forward=True, result_type=TyArrayNumber)
|
@@ -2643,3 +2649,86 @@ def numeric_like(x: TyArrayBase) -> TyArrayInt32:
|
2643 | 2649 | target_shape = ret.dynamic_shape
|
2644 | 2650 |
|
2645 | 2651 | return tuple(a.broadcast_to(target_shape) for a in arrays)
|
| 2652 | + |
| 2653 | + |
| 2654 | +def _matmul_mitigate_zero_sized( |
| 2655 | + a: TY_ARRAY_NUMBER, b: TY_ARRAY_NUMBER |
| 2656 | +) -> TY_ARRAY_NUMBER: |
| 2657 | + """Mitigate onnxruntime bug for inputs with 0-length inputs to matmul.""" |
| 2658 | + |
| 2659 | + # Mitigations of this kind would be better situated in |
| 2660 | + # ort_compat.py, but doing the below in spox-operators is just |
| 2661 | + # needless self-inflicted pain. |
| 2662 | + |
| 2663 | + # The core idea of this mitigation is to do a runtime check on the |
| 2664 | + # dimensions of the inputs and the then hide the matmul inside an |
| 2665 | + # if-subgraph. However, there are a couple of complications: |
| 2666 | + |
| 2667 | + # 1. The output of the zero-size branch has to match the shape of |
| 2668 | + # the output of the regular matmul. The semantics of that shape |
| 2669 | + # operation are involved |
| 2670 | + # (https://data-apis.org/array-api/draft/API_specification/generated/array_api.matmul.html#array_api.matmul). |
| 2671 | + |
| 2672 | + # 2. We have to disable the value propagation or else onnxruntime |
| 2673 | + # will throw a segfault when executing the matmul branch (which is |
| 2674 | + # always happening). |
| 2675 | + |
| 2676 | + # We grind our way through 1. The second point is solved by creating |
| 2677 | + # a small ONNX model and inlining it. We could consider doing this |
| 2678 | + # in a more generic fashion to enable faster build times. |
| 2679 | + |
| 2680 | + # TODO: Fix upstream in onnxruntime! |
| 2681 | + def dummy_n_1(a: TY_ARRAY_NUMBER, b: TY_ARRAY_NUMBER) -> TyArrayInt64: |
| 2682 | + return a.dynamic_shape[:-1] |
| 2683 | + |
| 2684 | + def dummy_1_n(a: TY_ARRAY_NUMBER, b: TY_ARRAY_NUMBER) -> TyArrayInt64: |
| 2685 | + out_shape = b.dynamic_shape |
| 2686 | + out_shape = out_shape[:-2].concat([out_shape[-1:]], axis=0) |
| 2687 | + return out_shape |
| 2688 | + |
| 2689 | + def dummy_n_m(a: TY_ARRAY_NUMBER, b: TY_ARRAY_NUMBER) -> TyArrayInt64: |
| 2690 | + a_shape = a.dynamic_shape |
| 2691 | + dummy_a_shape = a_shape.copy() |
| 2692 | + dummy_a_shape[-2:] = const(1) |
| 2693 | + |
| 2694 | + b_shape = b.dynamic_shape |
| 2695 | + dummy_b_shape = b_shape.copy() |
| 2696 | + dummy_b_shape[-2:] = const(1) |
| 2697 | + |
| 2698 | + dummy_a = zeros(shape=dummy_a_shape, dtype=uint8) |
| 2699 | + dummy_b = zeros(shape=dummy_b_shape, dtype=uint8) |
| 2700 | + |
| 2701 | + out_shape = (dummy_a + dummy_b).dynamic_shape |
| 2702 | + out_shape[-2] = a_shape[-2] |
| 2703 | + out_shape[-1] = b_shape[-1] |
| 2704 | + return out_shape |
| 2705 | + |
| 2706 | + a_ = type(a)(argument(a._var.unwrap_type())) |
| 2707 | + b_ = type(b)(argument(b._var.unwrap_type())) |
| 2708 | + |
| 2709 | + if a_.ndim == 1 and b_.ndim == 1: |
| 2710 | + out_shape = const(np.asarray([], np.int64), int64) |
| 2711 | + elif a_.ndim >= 2 and b_.ndim == 1: |
| 2712 | + out_shape = dummy_n_1(a_, b_) |
| 2713 | + elif a_.ndim == 1 and b_.ndim >= 2: |
| 2714 | + out_shape = dummy_1_n(a_, b_) |
| 2715 | + elif a_.ndim >= 2 and b_.ndim >= 2: |
| 2716 | + out_shape = dummy_n_m(a_, b_) |
| 2717 | + else: |
| 2718 | + raise ValueError( |
| 2719 | + "unsupported input ranks for 'matmul': `{a_.ndim}` and `{b_.ndim}`" |
| 2720 | + ) |
| 2721 | + |
| 2722 | + dummy_out = zeros(shape=out_shape, dtype=a.dtype)._var |
| 2723 | + res = op.matmul(a_._var, b_._var) |
| 2724 | + |
| 2725 | + (var,) = op.if_( |
| 2726 | + safe_cast(TyArrayBool, (a_.dynamic_size * b_.dynamic_size == 0))._var, |
| 2727 | + then_branch=lambda: [dummy_out], |
| 2728 | + else_branch=lambda: [res], |
| 2729 | + ) |
| 2730 | + model = build({"a": a_._var, "b": b_._var}, {"res": var}) |
| 2731 | + (var_with_propagation,) = inline(model)(a=a._var, b=b._var).values() |
| 2732 | + |
| 2733 | + arr = type(a)(var_with_propagation) |
| 2734 | + return arr |
0 commit comments