|
5 | 5 | import pytest
|
6 | 6 |
|
7 | 7 | from pandas import Timestamp
|
| 8 | +import pandas._testing as tm |
8 | 9 |
|
9 | 10 |
|
10 | 11 | class TestTimestampComparison:
|
| 12 | + def test_comparison_dt64_ndarray(self): |
| 13 | + ts = Timestamp.now() |
| 14 | + ts2 = Timestamp("2019-04-05") |
| 15 | + arr = np.array([[ts.asm8, ts2.asm8]], dtype="M8[ns]") |
| 16 | + |
| 17 | + result = ts == arr |
| 18 | + expected = np.array([[True, False]], dtype=bool) |
| 19 | + tm.assert_numpy_array_equal(result, expected) |
| 20 | + |
| 21 | + result = arr == ts |
| 22 | + tm.assert_numpy_array_equal(result, expected) |
| 23 | + |
| 24 | + result = ts != arr |
| 25 | + tm.assert_numpy_array_equal(result, ~expected) |
| 26 | + |
| 27 | + result = arr != ts |
| 28 | + tm.assert_numpy_array_equal(result, ~expected) |
| 29 | + |
| 30 | + result = ts2 < arr |
| 31 | + tm.assert_numpy_array_equal(result, expected) |
| 32 | + |
| 33 | + result = arr < ts2 |
| 34 | + tm.assert_numpy_array_equal(result, np.array([[False, False]], dtype=bool)) |
| 35 | + |
| 36 | + result = ts2 <= arr |
| 37 | + tm.assert_numpy_array_equal(result, np.array([[True, True]], dtype=bool)) |
| 38 | + |
| 39 | + result = arr <= ts2 |
| 40 | + tm.assert_numpy_array_equal(result, ~expected) |
| 41 | + |
| 42 | + result = ts >= arr |
| 43 | + tm.assert_numpy_array_equal(result, np.array([[True, True]], dtype=bool)) |
| 44 | + |
| 45 | + result = arr >= ts |
| 46 | + tm.assert_numpy_array_equal(result, np.array([[True, False]], dtype=bool)) |
| 47 | + |
| 48 | + @pytest.mark.parametrize("reverse", [True, False]) |
| 49 | + def test_comparison_dt64_ndarray_tzaware(self, reverse, all_compare_operators): |
| 50 | + op = getattr(operator, all_compare_operators.strip("__")) |
| 51 | + |
| 52 | + ts = Timestamp.now("UTC") |
| 53 | + arr = np.array([ts.asm8, ts.asm8], dtype="M8[ns]") |
| 54 | + |
| 55 | + left, right = ts, arr |
| 56 | + if reverse: |
| 57 | + left, right = arr, ts |
| 58 | + |
| 59 | + msg = "Cannot compare tz-naive and tz-aware timestamps" |
| 60 | + with pytest.raises(TypeError, match=msg): |
| 61 | + op(left, right) |
| 62 | + |
11 | 63 | def test_comparison_object_array(self):
|
12 | 64 | # GH#15183
|
13 | 65 | ts = Timestamp("2011-01-03 00:00:00-0500", tz="US/Eastern")
|
|
0 commit comments