Skip to content

Commit 6d829bd

Browse files
authored
Fix concat error (#139)
1 parent ea7c521 commit 6d829bd

File tree

4 files changed

+48
-19
lines changed

4 files changed

+48
-19
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ Changelog
1212

1313
**Bug fixes**
1414

15+
- :func:`ndonnx.concat` no longer raises an error if ``axis=None``, the resulting data type is ``int32`` or ``int64``, and one of the provided arrays is zero-sized.
1516
- :func:`ndonnx.__array_namespace_info__.capabilities()` now reports the number of supported dimensions via the ``"max dimensions"`` entry rather than ``"max rank"``.
16-
- Add missing onnxruntime workaround for uint32 inputs to ``ndonnx.min`` and ``ndonnx.max``.
17-
- Fix array instantiation with ``ndonnx.asarray`` and very large Python integers for ``uint64`` data types.
18-
- Fix passing an Python scalar as the second argument to ``ndonnx.where``.
17+
- Add missing onnxruntime workaround for uint32 inputs to :func:`ndonnx.min` and :func:`ndonnx.max`.
18+
- Fix array instantiation with :func:`ndonnx.asarray` and very large Python integers for ``uint64`` data types.
19+
- Fix passing an Python scalar as the second argument to :func:`ndonnx.where`.
1920

2021

2122
**New features**

ndonnx/_typed_array/onnx.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -724,21 +724,7 @@ def concat(self, others: Sequence[Self], axis: None | int) -> Self:
724724
if len(arrays) == 1:
725725
return arrays[0].copy()
726726

727-
# It seems that there is currently a bug(?) in the type/shape
728-
# inference in ONNX which prohibits us from concatenating
729-
# empty 1D int32 and int64 arrays (see test case). We therefor
730-
# do some hacky special-casing here for those types and those
731-
# types only. Other data types are fine.
732-
#
733-
# TODO: File upstream bug; this may also be what caused the
734-
# segfaults in onnxruntime in the past!
735-
if self.dtype in (int32, int64) and self.ndim == 1:
736-
dummy_axis = op.const([axis + 1], dtype=np.int64)
737-
vars = [op.unsqueeze(a._var, dummy_axis) for a in arrays]
738-
var = op.concat(vars, axis=axis)
739-
var = op.squeeze(var, dummy_axis)
740-
else:
741-
var = op.concat([a._var for a in arrays], axis=0 if axis is None else axis)
727+
var = op.concat([a._var for a in arrays], axis=0 if axis is None else axis)
742728
return type(self)(var)
743729

744730
def copy(self) -> Self:

ndonnx/_typed_array/ort_compat.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from spox.opset.ai.onnx.v21 import cast as cast
2828
from spox.opset.ai.onnx.v21 import ceil as ceil
2929
from spox.opset.ai.onnx.v21 import compress as compress
30-
from spox.opset.ai.onnx.v21 import concat as concat
3130
from spox.opset.ai.onnx.v21 import const as const
3231
from spox.opset.ai.onnx.v21 import exp as exp # Only for floats in both standards
3332
from spox.opset.ai.onnx.v21 import expand as expand
@@ -792,3 +791,29 @@ def einsum(
792791
mapping=mapping,
793792
fun_name="einsum",
794793
)(x1, x2)
794+
795+
796+
def concat(
797+
inputs: Sequence[Var],
798+
*,
799+
axis: int,
800+
) -> Var:
801+
# It seems that there is currently a bug(?) in the type/shape
802+
# inference in ONNX which prohibits us from concatenating
803+
# empty 1D int32 and int64 arrays (see test case). We therefor
804+
# do some hacky special-casing here for those types and those
805+
# types only. Other data types are fine.
806+
#
807+
# TODO: File upstream bug; this may also be what caused the
808+
# segfaults in onnxruntime in the past!
809+
tensors = [el.unwrap_tensor() for el in inputs]
810+
# All elements must have the same rank at this point
811+
first_shape = tensors[0].shape
812+
if first_shape is None:
813+
raise ValueError("rank of inputs must be statically known")
814+
if tensors[0].dtype in (np.int32, np.int64) and len(first_shape) == 1:
815+
dummy_axis = op.const([axis + 1], dtype=np.int64)
816+
vars = [op.unsqueeze(el, dummy_axis) for el in inputs]
817+
var = op.concat(vars, axis=axis)
818+
return op.squeeze(var, dummy_axis)
819+
return op.concat(inputs, axis=0 if axis is None else axis)

tests/test_manipulation_functions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) QuantCo 2023-2025
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
import numpy as np
5+
import pytest
6+
7+
import ndonnx as ndx
8+
9+
10+
@pytest.mark.skipif(np.__version__ < "2", reason="NumPy 1.x does not provide 'concat'")
11+
def test_concat_axis_none_zero_sized():
12+
def do(npx):
13+
a1 = npx.zeros(shape=(0, 0), dtype=npx.int8)
14+
a2 = npx.asarray(0, dtype=npx.int32)
15+
return npx.concat([a1, a2], axis=None)
16+
17+
np.testing.assert_array_equal(do(np), do(ndx))

0 commit comments

Comments
 (0)