Skip to content

Commit ce6e8f6

Browse files
authored
Extend time unit support (#132)
1 parent 4a4d37d commit ce6e8f6

File tree

3 files changed

+80
-46
lines changed

3 files changed

+80
-46
lines changed

CHANGELOG.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@
77
Changelog
88
=========
99

10+
0.13.0 (2025-05-22)
11+
-------------------
12+
13+
**New features**
14+
15+
- The :class:`ndonnx.TimeDelta64DType` and :class:`ndonnx.DateTime64DType` gained support for milli and microseconds as units.
16+
- :func:`ndonnx.where` now promotes time units between the two branches.
17+
- Addition, multiplication, division, and subtraction between arrays with timedelta or datetime data types now support promotion between time units.
18+
- Comparison operations between arrays with timedelta or datetime data types now support promotion between time units.
19+
20+
1021
0.12.0 (2025-05-15)
1122
-------------------
1223

ndonnx/_typed_array/datetime.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ndonnx.types import NestedSequence, OnnxShape, PyScalar
3131

3232

33-
Unit = Literal["ns", "s"]
33+
Unit = Literal["ns", "us", "ms", "s"]
3434

3535
_NAT_SENTINEL = onnx.const(np.iinfo(np.int64).min).astype(onnx.int64)
3636
TIMEARRAY_co = TypeVar("TIMEARRAY_co", bound="TimeBaseArray", covariant=True)
@@ -63,6 +63,9 @@ def __ndx_cast_from__(self, arr: TyArrayBase) -> TIMEARRAY_co:
6363
def __ndx_result_type__(self, other: DType | PyScalar) -> DType:
6464
if isinstance(other, int):
6565
return self
66+
if isinstance(other, BaseTimeDType):
67+
target_unit = _result_unit(self.unit, other.unit)
68+
return type(self)(target_unit)
6669
return NotImplemented
6770

6871
def __ndx_argument__(self, shape: OnnxShape) -> TIMEARRAY_co:
@@ -277,7 +280,7 @@ def _apply_comp(
277280
if type(self) is not type(other):
278281
return NotImplemented
279282

280-
self, other = _coerce_units(self, other_arr)
283+
self, other = _promote_unit(self, other_arr)
281284

282285
data = op(self._data, other._data)
283286
is_nat = self.is_nat | other.is_nat
@@ -326,10 +329,10 @@ def __ndx_where__(
326329
) -> TyArrayBase:
327330
if not isinstance(other, TyArrayBase):
328331
return NotImplemented
329-
if self.dtype != other.dtype or not isinstance(other, type(self)):
330-
return NotImplemented
331-
332-
return self.dtype._build(onnx.where(cond, self._data, other._data))
332+
if isinstance(other, type(self)):
333+
a, b = _promote_unit(self, other)
334+
return a.dtype._build(onnx.where(cond, a._data, b._data))
335+
return NotImplemented
333336

334337
def clip(
335338
self, /, min: TyArrayBase | None = None, max: TyArrayBase | None = None
@@ -394,10 +397,11 @@ def __add__(self, rhs: TyArrayBase | PyScalar) -> TyArrayTimeDelta:
394397
if isinstance(rhs, int):
395398
rhs = TyArrayTimeDelta(onnx.const(rhs), self.dtype.unit)
396399
if isinstance(rhs, TyArrayTimeDelta):
397-
if {self.dtype.unit, rhs.dtype.unit} == {"s", "ns"}:
398-
self = self.astype(TimeDelta64DType("ns"))
399-
rhs = rhs.astype(TimeDelta64DType("ns"))
400-
return _apply_op(self, rhs, operator.add, True)
400+
allowed_units = set(get_args(Unit))
401+
lhs = self
402+
if lhs.dtype.unit in allowed_units and rhs.dtype.unit in allowed_units:
403+
lhs, rhs = _promote_unit(lhs, rhs)
404+
return _apply_op(lhs, rhs, operator.add, True)
401405
return NotImplemented
402406

403407
def __radd__(self, lhs: TyArrayBase | PyScalar) -> TyArrayTimeDelta:
@@ -419,10 +423,11 @@ def __sub__(self, rhs: TyArrayBase | PyScalar) -> TyArrayTimeDelta:
419423
if isinstance(rhs, int):
420424
rhs = TyArrayTimeDelta(onnx.const(rhs), self.dtype.unit)
421425
if isinstance(rhs, TyArrayTimeDelta):
422-
if {self.dtype.unit, rhs.dtype.unit} == {"s", "ns"}:
423-
self = self.astype(TimeDelta64DType("ns"))
424-
rhs = rhs.astype(TimeDelta64DType("ns"))
425-
return _apply_op(self, rhs, operator.sub, True)
426+
allowed_units = set(get_args(Unit))
427+
lhs = self
428+
if lhs.dtype.unit in allowed_units and rhs.dtype.unit in allowed_units:
429+
lhs, rhs = _promote_unit(lhs, rhs)
430+
return _apply_op(lhs, rhs, operator.sub, True)
426431
return NotImplemented
427432

428433
def __rsub__(self, lhs: TyArrayBase | PyScalar) -> TyArrayTimeDelta:
@@ -551,7 +556,7 @@ def __add__(self, rhs: TyArrayBase | PyScalar) -> Self:
551556
if rhs is NotImplemented:
552557
return NotImplemented
553558

554-
lhs, rhs = _coerce_units(self, rhs)
559+
lhs, rhs = _promote_unit(self, rhs)
555560

556561
data = lhs._data + rhs._data
557562
is_nat = lhs.is_nat | rhs.is_nat
@@ -571,7 +576,7 @@ def _sub(self, other, forward: bool):
571576
return self - other_ if forward else other_ - self
572577

573578
if isinstance(other, TyArrayDateTime):
574-
a, b = _coerce_units(self, other)
579+
a, b = _promote_unit(self, other)
575580
is_nat = a.is_nat | b.is_nat
576581
data = safe_cast(
577582
onnx.TyArrayInt64, a._data - b._data if forward else b._data - a._data
@@ -582,7 +587,7 @@ def _sub(self, other, forward: bool):
582587

583588
elif isinstance(other, TyArrayTimeDelta) and forward:
584589
# *_ due to types of various locals set in the previous if statement
585-
a_, b_ = _coerce_units(self, other)
590+
a_, b_ = _promote_unit(self, other)
586591
is_nat = a_.is_nat | b_.is_nat
587592
data = safe_cast(
588593
onnx.TyArrayInt64,
@@ -610,13 +615,10 @@ def __ndx_equal__(self, other) -> onnx.TyArrayBool:
610615

611616
if not isinstance(other, TyArrayDateTime):
612617
return NotImplemented
613-
if self.dtype.unit != other.dtype.unit:
614-
raise TypeError(
615-
"comparison between different units is not implemented, yet"
616-
)
617618

618-
res = self._data == other._data
619-
is_nat = self.is_nat | other.is_nat
619+
lhs, rhs = _promote_unit(self, other)
620+
res = lhs._data == rhs._data
621+
is_nat = lhs.is_nat | rhs.is_nat
620622

621623
return safe_cast(onnx.TyArrayBool, res & ~is_nat)
622624

@@ -662,17 +664,16 @@ def _coerce_other(
662664
return NotImplemented
663665

664666

665-
def _coerce_units(a: T1, b: T2) -> tuple[T1, T2]:
666-
table: dict[tuple[Unit, Unit], Unit] = {
667-
("ns", "s"): "ns",
668-
("s", "ns"): "ns",
669-
("s", "s"): "s",
670-
("ns", "ns"): "ns",
671-
}
672-
target = table[(a.dtype.unit, b.dtype.unit)]
673-
dtype_a = type(a.dtype)(unit=target)
674-
dtype_b = type(b.dtype)(unit=target)
675-
return (a.astype(dtype_a), b.astype(dtype_b))
667+
def _promote_unit(a: T1, b: T2) -> tuple[T1, T2]:
668+
unit = _result_unit(a.dtype.unit, b.dtype.unit)
669+
670+
return a.astype(type(a.dtype)(unit=unit)), b.astype(type(b.dtype)(unit=unit))
671+
672+
673+
def _result_unit(a: Unit, b: Unit) -> Unit:
674+
ordered_units = ["ns", "us", "ms", "s"]
675+
res, _ = sorted([a, b], key=lambda el: ordered_units.index(el))
676+
return res # type: ignore
676677

677678

678679
def validate_unit(unit: str) -> Unit:

tests/test_datetime.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def test_datetime_from_np_array(ty, unit):
6161
"scalar, dtype, res_dtype",
6262
[
6363
(1, ndx.DateTime64DType("s"), ndx.DateTime64DType("s")),
64+
(1, ndx.DateTime64DType("ms"), ndx.DateTime64DType("ms")),
65+
(1, ndx.DateTime64DType("us"), ndx.DateTime64DType("us")),
6466
(1, ndx.DateTime64DType("ns"), ndx.DateTime64DType("ns")),
6567
],
6668
)
@@ -196,9 +198,11 @@ def test_comparisons_timedelta(op, x, y, unit):
196198
(["NaT"], ["NaT"]),
197199
],
198200
)
199-
def test_comparisons_datetime(op, x, y, unit):
200-
np_x = np.array(x, f"datetime64[{unit}]")
201-
np_y = np.array(y, f"datetime64[{unit}]")
201+
@pytest.mark.parametrize("unit1", get_args(Unit))
202+
@pytest.mark.parametrize("unit2", get_args(Unit))
203+
def test_comparisons_datetime(op, x, y, unit1, unit2):
204+
np_x = np.array(x, f"datetime64[{unit1}]")
205+
np_y = np.array(y, f"datetime64[{unit2}]")
202206

203207
desired = op(np_x, np_y)
204208
actual = op(ndx.asarray(np_x), ndx.asarray(np_y))
@@ -216,9 +220,11 @@ def test_comparisons_datetime(op, x, y, unit):
216220
],
217221
)
218222
@pytest.mark.parametrize("forward", [True, False])
219-
def test_subtraction_datetime_arrays(x, y, unit, forward):
220-
np_x = np.array(x, f"datetime64[{unit}]")
221-
np_y = np.array(y, f"datetime64[{unit}]")
223+
@pytest.mark.parametrize("unit1", get_args(Unit))
224+
@pytest.mark.parametrize("unit2", get_args(Unit))
225+
def test_subtraction_datetime_arrays(x, y, unit1, unit2, forward):
226+
np_x = np.array(x, f"datetime64[{unit1}]")
227+
np_y = np.array(y, f"datetime64[{unit2}]")
222228

223229
desired = np_x - np_y if forward else np_y - np_x
224230
actual = (
@@ -265,13 +271,13 @@ def test_isnan(unit):
265271
)
266272

267273

268-
@pytest.mark.parametrize(
269-
"dtype", ["datetime64[s]", "timedelta64[s]", "datetime64[ns]", "timedelta64[ns]"]
270-
)
271-
def test_where(dtype):
274+
@pytest.mark.parametrize("unit1", get_args(Unit))
275+
@pytest.mark.parametrize("unit2", get_args(Unit))
276+
@pytest.mark.parametrize("dtype_name", ["datetime64", "timedelta64"])
277+
def test_where(dtype_name, unit1, unit2):
272278
cond = np.asarray([False, True, False])
273-
np_arr1 = np.asarray(["NaT", 1, 2], dtype=dtype)
274-
np_arr2 = np.asarray(["NaT", "NaT", "NaT"], dtype=dtype)
279+
np_arr1 = np.asarray(["NaT", 1, 2], dtype=f"{dtype_name}[{unit1}]")
280+
np_arr2 = np.asarray(["NaT", "NaT", "NaT"], dtype=f"{dtype_name}[{unit2}]")
275281

276282
expected = np.where(cond, np_arr1, np_arr2)
277283
actual = ndx.where(ndx.asarray(cond), ndx.asarray(np_arr1), ndx.asarray(np_arr2))
@@ -400,3 +406,19 @@ def test_datetime_dtypes_have_numpy_repr(unit):
400406
dtype = ndx.DateTime64DType(unit)
401407

402408
assert dtype.unwrap_numpy() == np.dtype(f"datetime64[{unit}]")
409+
410+
411+
@pytest.mark.parametrize("dtype_cls", [ndx.DateTime64DType, ndx.TimeDelta64DType])
412+
@pytest.mark.parametrize("unit1", get_args(Unit))
413+
@pytest.mark.parametrize("unit2", get_args(Unit))
414+
def test_result_type(dtype_cls, unit1, unit2):
415+
def do(npx):
416+
dtype1 = dtype_cls(unit1)
417+
dtype2 = dtype_cls(unit2)
418+
if npx == np:
419+
dtype1 = dtype1.unwrap_numpy()
420+
dtype2 = dtype2.unwrap_numpy()
421+
422+
return npx.result_type(dtype1, dtype2)
423+
424+
assert do(np) == do(ndx).unwrap_numpy()

0 commit comments

Comments
 (0)