Skip to content

Commit a794c07

Browse files
Fix modulus (#127)
1 parent 2109043 commit a794c07

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

.github/workflows/array-api.yml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,13 @@ jobs:
2323
array-api-tests:
2424
# Run if the commit message contains 'run array-api tests' or if the job is triggered on schedule
2525
name: Array API test
26-
timeout-minutes: 90
27-
runs-on: ubuntu-latest
26+
timeout-minutes: 60
27+
runs-on: ${{ matrix.os }}
28+
strategy:
29+
matrix:
30+
os:
31+
- ubuntu-latest
32+
- macos-latest
2833
steps:
2934
- name: Checkout branch
3035
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -44,8 +49,8 @@ jobs:
4449
- name: Upload Array API tests report
4550
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
4651
with:
47-
name: array-api-tests
48-
path: array-api-tests.json
52+
name: array-api-tests-${{ matrix.os }}
53+
path: array-api-tests-${{ matrix.os }}.json
4954
- name: Issue on failure
5055
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
5156
if: ${{ failure() && github.event_name == 'schedule' && github.ref == 'refs/heads/main' }}

ndonnx/_typed_array/onnx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,12 +1577,12 @@ def __or__(self, other) -> TyArrayBase:
15771577
def __ror__(self, other) -> TyArrayBase:
15781578
return self._apply_int_only(other, op.bitwise_or, forward=False)
15791579

1580-
def __mod__(self, other) -> TyArrayBase:
1580+
def __mod__(self, other) -> TyArrayInteger:
15811581
return self._apply_int_only(
15821582
other, lambda a, b: op.mod(a, b, fmod=0), forward=True
15831583
)
15841584

1585-
def __rmod__(self, other) -> TyArrayBase:
1585+
def __rmod__(self, other) -> TyArrayInteger:
15861586
return self._apply_int_only(
15871587
other, lambda a, b: op.mod(a, b, fmod=0), forward=False
15881588
)

ndonnx/_typed_array/typed_array.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from ndonnx import DType
1515
from ndonnx.types import OnnxShape, PyScalar
1616

17+
from .ort_compat import const, if_
18+
1719
if TYPE_CHECKING:
1820
from spox import Var
1921

@@ -230,7 +232,14 @@ def _roll_axis(x: Self, shift: int, axis: int, /) -> Self:
230232
indices_b = [slice(None) for i in range(x.ndim)]
231233

232234
dim = x.dynamic_shape[axis]
233-
shift_ = astyarray(shift) % dim
235+
(shift_,) = map(
236+
astyarray,
237+
if_(
238+
(dim == 0).disassemble(),
239+
then_branch=lambda: (const(0, dtype=np.int64),),
240+
else_branch=lambda: ((astyarray(shift) % dim).disassemble(),),
241+
),
242+
)
234243
# pre roll: |----a------|---b---|
235244
# postroll: |---b---|----a------|
236245
# |-shift-|

0 commit comments

Comments
 (0)