30
30
from ndonnx .types import NestedSequence , OnnxShape , PyScalar
31
31
32
32
33
- Unit = Literal ["ns" , "s" ]
33
+ Unit = Literal ["ns" , "us" , "ms" , " s" ]
34
34
35
35
_NAT_SENTINEL = onnx .const (np .iinfo (np .int64 ).min ).astype (onnx .int64 )
36
36
TIMEARRAY_co = TypeVar ("TIMEARRAY_co" , bound = "TimeBaseArray" , covariant = True )
@@ -63,6 +63,9 @@ def __ndx_cast_from__(self, arr: TyArrayBase) -> TIMEARRAY_co:
63
63
def __ndx_result_type__ (self , other : DType | PyScalar ) -> DType :
64
64
if isinstance (other , int ):
65
65
return self
66
+ if isinstance (other , BaseTimeDType ):
67
+ target_unit = _result_unit (self .unit , other .unit )
68
+ return type (self )(target_unit )
66
69
return NotImplemented
67
70
68
71
def __ndx_argument__ (self , shape : OnnxShape ) -> TIMEARRAY_co :
@@ -277,7 +280,7 @@ def _apply_comp(
277
280
if type (self ) is not type (other ):
278
281
return NotImplemented
279
282
280
- self , other = _coerce_units (self , other_arr )
283
+ self , other = _promote_unit (self , other_arr )
281
284
282
285
data = op (self ._data , other ._data )
283
286
is_nat = self .is_nat | other .is_nat
@@ -326,10 +329,10 @@ def __ndx_where__(
326
329
) -> TyArrayBase :
327
330
if not isinstance (other , TyArrayBase ):
328
331
return NotImplemented
329
- if self . dtype != other . dtype or not isinstance (other , type (self )):
330
- return NotImplemented
331
-
332
- return self . dtype . _build ( onnx . where ( cond , self . _data , other . _data ))
332
+ if isinstance (other , type (self )):
333
+ a , b = _promote_unit ( self , other )
334
+ return a . dtype . _build ( onnx . where ( cond , a . _data , b . _data ))
335
+ return NotImplemented
333
336
334
337
def clip (
335
338
self , / , min : TyArrayBase | None = None , max : TyArrayBase | None = None
@@ -394,10 +397,11 @@ def __add__(self, rhs: TyArrayBase | PyScalar) -> TyArrayTimeDelta:
394
397
if isinstance (rhs , int ):
395
398
rhs = TyArrayTimeDelta (onnx .const (rhs ), self .dtype .unit )
396
399
if isinstance (rhs , TyArrayTimeDelta ):
397
- if {self .dtype .unit , rhs .dtype .unit } == {"s" , "ns" }:
398
- self = self .astype (TimeDelta64DType ("ns" ))
399
- rhs = rhs .astype (TimeDelta64DType ("ns" ))
400
- return _apply_op (self , rhs , operator .add , True )
400
+ allowed_units = set (get_args (Unit ))
401
+ lhs = self
402
+ if lhs .dtype .unit in allowed_units and rhs .dtype .unit in allowed_units :
403
+ lhs , rhs = _promote_unit (lhs , rhs )
404
+ return _apply_op (lhs , rhs , operator .add , True )
401
405
return NotImplemented
402
406
403
407
def __radd__ (self , lhs : TyArrayBase | PyScalar ) -> TyArrayTimeDelta :
@@ -419,10 +423,11 @@ def __sub__(self, rhs: TyArrayBase | PyScalar) -> TyArrayTimeDelta:
419
423
if isinstance (rhs , int ):
420
424
rhs = TyArrayTimeDelta (onnx .const (rhs ), self .dtype .unit )
421
425
if isinstance (rhs , TyArrayTimeDelta ):
422
- if {self .dtype .unit , rhs .dtype .unit } == {"s" , "ns" }:
423
- self = self .astype (TimeDelta64DType ("ns" ))
424
- rhs = rhs .astype (TimeDelta64DType ("ns" ))
425
- return _apply_op (self , rhs , operator .sub , True )
426
+ allowed_units = set (get_args (Unit ))
427
+ lhs = self
428
+ if lhs .dtype .unit in allowed_units and rhs .dtype .unit in allowed_units :
429
+ lhs , rhs = _promote_unit (lhs , rhs )
430
+ return _apply_op (lhs , rhs , operator .sub , True )
426
431
return NotImplemented
427
432
428
433
def __rsub__ (self , lhs : TyArrayBase | PyScalar ) -> TyArrayTimeDelta :
@@ -551,7 +556,7 @@ def __add__(self, rhs: TyArrayBase | PyScalar) -> Self:
551
556
if rhs is NotImplemented :
552
557
return NotImplemented
553
558
554
- lhs , rhs = _coerce_units (self , rhs )
559
+ lhs , rhs = _promote_unit (self , rhs )
555
560
556
561
data = lhs ._data + rhs ._data
557
562
is_nat = lhs .is_nat | rhs .is_nat
@@ -571,7 +576,7 @@ def _sub(self, other, forward: bool):
571
576
return self - other_ if forward else other_ - self
572
577
573
578
if isinstance (other , TyArrayDateTime ):
574
- a , b = _coerce_units (self , other )
579
+ a , b = _promote_unit (self , other )
575
580
is_nat = a .is_nat | b .is_nat
576
581
data = safe_cast (
577
582
onnx .TyArrayInt64 , a ._data - b ._data if forward else b ._data - a ._data
@@ -582,7 +587,7 @@ def _sub(self, other, forward: bool):
582
587
583
588
elif isinstance (other , TyArrayTimeDelta ) and forward :
584
589
# *_ due to types of various locals set in the previous if statement
585
- a_ , b_ = _coerce_units (self , other )
590
+ a_ , b_ = _promote_unit (self , other )
586
591
is_nat = a_ .is_nat | b_ .is_nat
587
592
data = safe_cast (
588
593
onnx .TyArrayInt64 ,
@@ -610,13 +615,10 @@ def __ndx_equal__(self, other) -> onnx.TyArrayBool:
610
615
611
616
if not isinstance (other , TyArrayDateTime ):
612
617
return NotImplemented
613
- if self .dtype .unit != other .dtype .unit :
614
- raise TypeError (
615
- "comparison between different units is not implemented, yet"
616
- )
617
618
618
- res = self ._data == other ._data
619
- is_nat = self .is_nat | other .is_nat
619
+ lhs , rhs = _promote_unit (self , other )
620
+ res = lhs ._data == rhs ._data
621
+ is_nat = lhs .is_nat | rhs .is_nat
620
622
621
623
return safe_cast (onnx .TyArrayBool , res & ~ is_nat )
622
624
@@ -662,17 +664,16 @@ def _coerce_other(
662
664
return NotImplemented
663
665
664
666
665
- def _coerce_units (a : T1 , b : T2 ) -> tuple [T1 , T2 ]:
666
- table : dict [tuple [Unit , Unit ], Unit ] = {
667
- ("ns" , "s" ): "ns" ,
668
- ("s" , "ns" ): "ns" ,
669
- ("s" , "s" ): "s" ,
670
- ("ns" , "ns" ): "ns" ,
671
- }
672
- target = table [(a .dtype .unit , b .dtype .unit )]
673
- dtype_a = type (a .dtype )(unit = target )
674
- dtype_b = type (b .dtype )(unit = target )
675
- return (a .astype (dtype_a ), b .astype (dtype_b ))
667
+ def _promote_unit (a : T1 , b : T2 ) -> tuple [T1 , T2 ]:
668
+ unit = _result_unit (a .dtype .unit , b .dtype .unit )
669
+
670
+ return a .astype (type (a .dtype )(unit = unit )), b .astype (type (b .dtype )(unit = unit ))
671
+
672
+
673
+ def _result_unit (a : Unit , b : Unit ) -> Unit :
674
+ ordered_units = ["ns" , "us" , "ms" , "s" ]
675
+ res , _ = sorted ([a , b ], key = lambda el : ordered_units .index (el ))
676
+ return res # type: ignore
676
677
677
678
678
679
def validate_unit (unit : str ) -> Unit :
0 commit comments