Skip to content

Commit 72db40e

Browse files
jbrockmendeljreback
authored andcommitted
REF: cosmetic differences between DTA/TDA/PA comparison methods (#30720)
1 parent 1c9ebd7 commit 72db40e

File tree

3 files changed

+39
-19
lines changed

3 files changed

+39
-19
lines changed

pandas/core/arrays/datetimes.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@
2626
_INT64_DTYPE,
2727
_NS_DTYPE,
2828
is_categorical_dtype,
29+
is_datetime64_any_dtype,
2930
is_datetime64_dtype,
3031
is_datetime64_ns_dtype,
3132
is_datetime64tz_dtype,
3233
is_dtype_equal,
3334
is_extension_array_dtype,
3435
is_float_dtype,
36+
is_list_like,
3537
is_object_dtype,
3638
is_period_dtype,
3739
is_string_dtype,
@@ -148,17 +150,22 @@ def wrapper(self, other):
148150
# string that cannot be parsed to Timestamp
149151
return invalid_comparison(self, other, op)
150152

151-
if isinstance(other, (datetime, np.datetime64)):
152-
other = Timestamp(other)
153+
if isinstance(other, self._recognized_scalars) or other is NaT:
154+
other = self._scalar_type(other)
153155
self._assert_tzawareness_compat(other)
154156

155-
result = op(self.asi8, other.value)
157+
other_i8 = other.value
158+
159+
result = op(self.view("i8"), other_i8)
156160
if isna(other):
157161
result.fill(nat_result)
158-
elif lib.is_scalar(other) or np.ndim(other) == 0:
162+
163+
elif not is_list_like(other):
159164
return invalid_comparison(self, other, op)
165+
160166
elif len(other) != len(self):
161167
raise ValueError("Lengths must match")
168+
162169
else:
163170
if isinstance(other, list):
164171
other = np.array(other)
@@ -178,7 +185,7 @@ def wrapper(self, other):
178185
)
179186
o_mask = isna(other)
180187

181-
elif not (is_datetime64_dtype(other) or is_datetime64tz_dtype(other)):
188+
elif not cls._is_recognized_dtype(other.dtype):
182189
# e.g. is_timedelta64_dtype(other)
183190
return invalid_comparison(self, other, op)
184191

@@ -239,6 +246,8 @@ class DatetimeArray(dtl.DatetimeLikeArrayMixin, dtl.TimelikeOps, dtl.DatelikeOps
239246

240247
_typ = "datetimearray"
241248
_scalar_type = Timestamp
249+
_recognized_scalars = (datetime, np.datetime64)
250+
_is_recognized_dtype = is_datetime64_any_dtype
242251

243252
# define my properties & methods for delegation
244253
_bool_ops = [

pandas/core/arrays/period.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,6 @@ def _period_array_cmp(cls, op):
7575
def wrapper(self, other):
7676
ordinal_op = getattr(self.asi8, opname)
7777

78-
if is_list_like(other) and len(other) != len(self):
79-
raise ValueError("Lengths must match")
80-
8178
if isinstance(other, str):
8279
try:
8380
other = self._scalar_from_string(other)
@@ -90,18 +87,22 @@ def wrapper(self, other):
9087
other = Period(other, freq=self.freq)
9188
result = ordinal_op(other.ordinal)
9289

93-
if isinstance(other, Period):
90+
if isinstance(other, self._recognized_scalars) or other is NaT:
91+
other = self._scalar_type(other)
9492
self._check_compatible_with(other)
9593

96-
result = ordinal_op(other.ordinal)
94+
other_i8 = self._unbox_scalar(other)
9795

98-
elif other is NaT:
99-
result = np.empty(len(self.asi8), dtype=bool)
100-
result.fill(nat_result)
96+
result = op(self.view("i8"), other_i8)
97+
if isna(other):
98+
result.fill(nat_result)
10199

102100
elif not is_list_like(other):
103101
return invalid_comparison(self, other, op)
104102

103+
elif len(other) != len(self):
104+
raise ValueError("Lengths must match")
105+
105106
else:
106107
if isinstance(other, list):
107108
# TODO: could use pd.Index to do inference?
@@ -117,7 +118,7 @@ def wrapper(self, other):
117118
)
118119
o_mask = isna(other)
119120

120-
elif not is_period_dtype(other):
121+
elif not cls._is_recognized_dtype(other.dtype):
121122
# e.g. is_timedelta64_dtype(other)
122123
return invalid_comparison(self, other, op)
123124

@@ -126,7 +127,7 @@ def wrapper(self, other):
126127

127128
self._check_compatible_with(other)
128129

129-
result = ordinal_op(other.asi8)
130+
result = op(self.view("i8"), other.view("i8"))
130131
o_mask = other._isnan
131132

132133
if o_mask.any():
@@ -195,6 +196,8 @@ class PeriodArray(dtl.DatetimeLikeArrayMixin, dtl.DatelikeOps):
195196
__array_priority__ = 1000
196197
_typ = "periodarray" # ABCPeriodArray
197198
_scalar_type = Period
199+
_recognized_scalars = (Period,)
200+
_is_recognized_dtype = is_period_dtype
198201

199202
# Names others delegate to us
200203
_other_ops: List[str] = []

pandas/core/arrays/timedeltas.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,13 @@ def wrapper(self, other):
8989
# failed to parse as timedelta
9090
return invalid_comparison(self, other, op)
9191

92-
if _is_convertible_to_td(other) or other is NaT:
93-
other = Timedelta(other)
92+
if isinstance(other, self._recognized_scalars) or other is NaT:
93+
other = self._scalar_type(other)
94+
self._check_compatible_with(other)
95+
96+
other_i8 = self._unbox_scalar(other)
9497

95-
result = op(self.view("i8"), other.value)
98+
result = op(self.view("i8"), other_i8)
9699
if isna(other):
97100
result.fill(nat_result)
98101

@@ -116,13 +119,15 @@ def wrapper(self, other):
116119
)
117120
o_mask = isna(other)
118121

119-
elif not is_timedelta64_dtype(other):
122+
elif not cls._is_recognized_dtype(other.dtype):
120123
# e.g. other is datetimearray
121124
return invalid_comparison(self, other, op)
122125

123126
else:
124127
other = type(self)._from_sequence(other)
125128

129+
self._check_compatible_with(other)
130+
126131
result = op(self.view("i8"), other.view("i8"))
127132
o_mask = other._isnan
128133

@@ -172,6 +177,9 @@ class TimedeltaArray(dtl.DatetimeLikeArrayMixin, dtl.TimelikeOps):
172177

173178
_typ = "timedeltaarray"
174179
_scalar_type = Timedelta
180+
_recognized_scalars = (timedelta, np.timedelta64, Tick)
181+
_is_recognized_dtype = is_timedelta64_dtype
182+
175183
__array_priority__ = 1000
176184
# define my properties & methods for delegation
177185
_other_ops: List[str] = []

0 commit comments

Comments
 (0)