Skip to content

Commit 75ca6b0

Browse files
committed
Fix ndonnx.roll and run more ci examples
1 parent 6d829bd commit 75ca6b0

File tree

4 files changed

+37
-23
lines changed

4 files changed

+37
-23
lines changed

.github/workflows/array-api.yml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Array API coverage tests
1+
name: Array API tests
22

33
on:
44
# We would like to trigger for CI for any pull request action -
@@ -28,9 +28,15 @@ jobs:
2828
strategy:
2929
fail-fast: false
3030
matrix:
31-
os:
32-
- ubuntu-latest
33-
- macos-latest
31+
include:
32+
# Different platforms are faster than others. We try to keep
33+
# the CI snappy at <= 2min.
34+
- os: ubuntu-latest
35+
max_examples_ci: 50
36+
- os: macos-latest
37+
max_examples_ci: 100
38+
- os: windows-latest
39+
max_examples_ci: 30
3440
steps:
3541
- name: Checkout branch
3642
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -42,11 +48,12 @@ jobs:
4248
- name: Install repository
4349
run: pixi run postinstall
4450
- name: Run Array API tests (Scheduled)
51+
# TODO: Move schedule part into its own workflow
4552
if: ${{ github.event_name == 'schedule' && github.ref == 'refs/heads/main' }}
4653
run: pixi run arrayapitests --max-examples 1000 --hypothesis-seed=""
4754
- name: Run Array API tests (PR and main branch)
4855
if: ${{ github.event_name != 'schedule' }}
49-
run: pixi run arrayapitests
56+
run: pixi run arrayapitests --max-examples ${{ matrix.max_examples_ci }}
5057
- name: Upload Array API tests report
5158
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
5259
with:

CHANGELOG.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ Changelog
1414

1515
- :func:`ndonnx.concat` no longer raises an error if ``axis=None``, the resulting data type is ``int32`` or ``int64``, and one of the provided arrays is zero-sized.
1616
- :func:`ndonnx.__array_namespace_info__.capabilities()` now reports the number of supported dimensions via the ``"max dimensions"`` entry rather than ``"max rank"``.
17-
- Add missing onnxruntime workaround for uint32 inputs to :func:`ndonnx.min` and :func:`ndonnx.max`.
17+
- Add missing onnxruntime workaround for ``uint32`` inputs to :func:`ndonnx.min` and :func:`ndonnx.max`.
1818
- Fix array instantiation with :func:`ndonnx.asarray` and very large Python integers for ``uint64`` data types.
1919
- Fix passing an Python scalar as the second argument to :func:`ndonnx.where`.
20+
- Calling :func:`ndonnx.roll` on zero-sized inputs no longer causes a segfault on Linux.
2021

2122

2223
**New features**

ndonnx/_typed_array/typed_array.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
from typing_extensions import Self
1313

1414
from ndonnx import DType
15+
from ndonnx._typed_array import safe_cast
1516
from ndonnx.types import OnnxShape, PyScalar
1617

17-
from .ort_compat import const, if_
18-
1918
if TYPE_CHECKING:
2019
from spox import Var
2120

@@ -210,7 +209,15 @@ def roll(
210209
*,
211210
axis: int | tuple[int, ...] | None = None,
212211
) -> Self:
212+
# From a user's perspective this is a dtype-independent
213+
# default that is very nice to have. However, some
214+
# implementation details require us to use int64 and float64
215+
# casts on indices explicitly.
216+
from . import onnx
217+
from .funcs import where
218+
213219
x = self
220+
214221
axis_ = axis
215222
if isinstance(shift, int):
216223
shift = (shift,)
@@ -223,27 +230,26 @@ def roll(
223230
raise ValueError("'shift' and 'axis' must be tuples of equal length")
224231

225232
def _roll_axis(x: Self, shift: int, axis: int, /) -> Self:
226-
from .funcs import astyarray
227-
228233
if shift == 0:
229234
return x
230235

231-
indices_a = [slice(None) for i in range(x.ndim)]
232-
indices_b = [slice(None) for i in range(x.ndim)]
233-
234-
dim = x.dynamic_shape[axis]
235-
(shift_,) = map(
236-
astyarray,
237-
if_(
238-
(dim == 0).disassemble(),
239-
then_branch=lambda: (const(0, dtype=np.int64),),
240-
else_branch=lambda: ((astyarray(shift) % dim).disassemble(),),
241-
),
236+
indices_a = [slice(None) for _ in range(x.ndim)]
237+
indices_b = [slice(None) for _ in range(x.ndim)]
238+
239+
# May be zero. mod may crash for integer data types which
240+
# is why we compute the following for floats, then fill
241+
# the NaNs, and cast back to int
242+
dim_length = x.dynamic_shape[axis].astype(onnx.float64)
243+
length_mod = safe_cast(
244+
onnx.TyArrayFloat64, onnx.const(float(shift), onnx.float64) % dim_length
245+
)
246+
shift_ = where(length_mod.isnan(), onnx.const(0.0), length_mod).astype(
247+
onnx.int64
242248
)
249+
243250
# pre roll: |----a------|---b---|
244251
# postroll: |---b---|----a------|
245252
# |-shift-|
246-
247253
indices_a[axis] = slice(None, -shift_, 1)
248254
indices_b[axis] = slice(-shift_, None, 1)
249255
return x[tuple(indices_b)].concat([x[tuple(indices_a)]], axis=axis)

pixi.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ test-coverage = "pytest --cov=ndonnx --cov-report=xml --cov-report=term-missing"
5151

5252
[feature.test.tasks.arrayapitests]
5353
# Seed and max-examples can be overridden by providing them a second time
54-
cmd = "pytest array-api-tests/array_api_tests/ -v -rfX --json-report --json-report-file=array-api-tests.json --disable-deadline --disable-extension linalg,fft --skips-file=skips.txt -s --hypothesis-seed 0 --max-examples 16"
54+
cmd = "pytest array-api-tests/array_api_tests/ -v -rfX --json-report --json-report-file=array-api-tests.json --disable-deadline --disable-extension linalg,fft --skips-file=skips.txt -s --hypothesis-seed 0 --max-examples 100 -n auto"
5555
[feature.test.tasks.arrayapitests.env]
5656
ARRAY_API_TESTS_MODULE="ndonnx"
5757
ARRAY_API_TESTS_VERSION="2024.12"

0 commit comments

Comments
 (0)