Skip to content

Commit 8af8cdb

Browse files
jbrockmendelTomAugspurger
authored andcommitted
use tm.assert_equal instead of parametrizing assert funcs (pandas-dev#22995)
* Use tm.assert_equal instead of parametrizing assert_func * Extend assert_equal * Use tm.assert_equal in more places * typo fixup
1 parent d430195 commit 8af8cdb

File tree

5 files changed

+35
-40
lines changed

5 files changed

+35
-40
lines changed

doc/source/contributing.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ If your change involves checking that a warning is actually emitted, use
880880

881881
.. code-block:: python
882882
883-
with tm.assert_prodcues_warning(FutureWarning):
883+
with tm.assert_produces_warning(FutureWarning):
884884
df.some_operation()
885885
886886
We prefer this to the ``pytest.warns`` context manager because ours checks that the warning's

pandas/tests/indexes/datetimes/test_tools.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,9 @@ def test_to_datetime_format_weeks(self, cache):
180180
for s, format, dt in data:
181181
assert to_datetime(s, format=format, cache=cache) == dt
182182

183-
@pytest.mark.parametrize("box,const,assert_equal", [
184-
[True, pd.Index, 'assert_index_equal'],
185-
[False, np.array, 'assert_numpy_array_equal']])
183+
@pytest.mark.parametrize("box,const", [
184+
[True, pd.Index],
185+
[False, np.array]])
186186
@pytest.mark.parametrize("fmt,dates,expected_dates", [
187187
['%Y-%m-%d %H:%M:%S %Z',
188188
['2010-01-01 12:00:00 UTC'] * 2,
@@ -215,12 +215,11 @@ def test_to_datetime_format_weeks(self, cache):
215215
pd.Timestamp('2010-01-01 12:00:00',
216216
tzinfo=pytz.FixedOffset(0))]]])
217217
def test_to_datetime_parse_tzname_or_tzoffset(self, box, const,
218-
assert_equal, fmt,
219-
dates, expected_dates):
218+
fmt, dates, expected_dates):
220219
# GH 13486
221220
result = pd.to_datetime(dates, format=fmt, box=box)
222221
expected = const(expected_dates)
223-
getattr(tm, assert_equal)(result, expected)
222+
tm.assert_equal(result, expected)
224223

225224
with pytest.raises(ValueError):
226225
pd.to_datetime(dates, format=fmt, box=box, utc=True)
@@ -1049,17 +1048,16 @@ def test_to_datetime_types(self, cache):
10491048
# assert result == expected
10501049

10511050
@pytest.mark.parametrize('cache', [True, False])
1052-
@pytest.mark.parametrize('box, klass, assert_method', [
1053-
[True, Index, 'assert_index_equal'],
1054-
[False, np.array, 'assert_numpy_array_equal']
1051+
@pytest.mark.parametrize('box, klass', [
1052+
[True, Index],
1053+
[False, np.array]
10551054
])
1056-
def test_to_datetime_unprocessable_input(self, cache, box, klass,
1057-
assert_method):
1055+
def test_to_datetime_unprocessable_input(self, cache, box, klass):
10581056
# GH 4928
10591057
# GH 21864
10601058
result = to_datetime([1, '1'], errors='ignore', cache=cache, box=box)
10611059
expected = klass(np.array([1, '1'], dtype='O'))
1062-
getattr(tm, assert_method)(result, expected)
1060+
tm.assert_equal(result, expected)
10631061
pytest.raises(TypeError, to_datetime, [1, '1'], errors='raise',
10641062
cache=cache, box=box)
10651063

pandas/tests/scalar/test_nat.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -312,19 +312,16 @@ def test_nat_arithmetic_index():
312312
tm.assert_index_equal(NaT - tdi, tdi_nat)
313313

314314

315-
@pytest.mark.parametrize('box, assert_func', [
316-
(TimedeltaIndex, tm.assert_index_equal),
317-
(Series, tm.assert_series_equal)
318-
])
319-
def test_nat_arithmetic_td64_vector(box, assert_func):
315+
@pytest.mark.parametrize('box', [TimedeltaIndex, Series])
316+
def test_nat_arithmetic_td64_vector(box):
320317
# GH#19124
321318
vec = box(['1 day', '2 day'], dtype='timedelta64[ns]')
322319
box_nat = box([NaT, NaT], dtype='timedelta64[ns]')
323320

324-
assert_func(vec + NaT, box_nat)
325-
assert_func(NaT + vec, box_nat)
326-
assert_func(vec - NaT, box_nat)
327-
assert_func(NaT - vec, box_nat)
321+
tm.assert_equal(vec + NaT, box_nat)
322+
tm.assert_equal(NaT + vec, box_nat)
323+
tm.assert_equal(vec - NaT, box_nat)
324+
tm.assert_equal(NaT - vec, box_nat)
328325

329326

330327
def test_nat_pinned_docstrings():

pandas/tests/tseries/offsets/test_offsets.py

+12-16
Original file line numberDiff line numberDiff line change
@@ -2516,28 +2516,26 @@ def test_onOffset(self, case):
25162516
dt, expected = case
25172517
assert_onOffset(SemiMonthEnd(), dt, expected)
25182518

2519-
@pytest.mark.parametrize('klass,assert_func',
2520-
[(Series, tm.assert_series_equal),
2521-
(DatetimeIndex, tm.assert_index_equal)])
2522-
def test_vectorized_offset_addition(self, klass, assert_func):
2519+
@pytest.mark.parametrize('klass', [Series, DatetimeIndex])
2520+
def test_vectorized_offset_addition(self, klass):
25232521
s = klass([Timestamp('2000-01-15 00:15:00', tz='US/Central'),
25242522
Timestamp('2000-02-15', tz='US/Central')], name='a')
25252523

25262524
result = s + SemiMonthEnd()
25272525
result2 = SemiMonthEnd() + s
25282526
exp = klass([Timestamp('2000-01-31 00:15:00', tz='US/Central'),
25292527
Timestamp('2000-02-29', tz='US/Central')], name='a')
2530-
assert_func(result, exp)
2531-
assert_func(result2, exp)
2528+
tm.assert_equal(result, exp)
2529+
tm.assert_equal(result2, exp)
25322530

25332531
s = klass([Timestamp('2000-01-01 00:15:00', tz='US/Central'),
25342532
Timestamp('2000-02-01', tz='US/Central')], name='a')
25352533
result = s + SemiMonthEnd()
25362534
result2 = SemiMonthEnd() + s
25372535
exp = klass([Timestamp('2000-01-15 00:15:00', tz='US/Central'),
25382536
Timestamp('2000-02-15', tz='US/Central')], name='a')
2539-
assert_func(result, exp)
2540-
assert_func(result2, exp)
2537+
tm.assert_equal(result, exp)
2538+
tm.assert_equal(result2, exp)
25412539

25422540

25432541
class TestSemiMonthBegin(Base):
@@ -2692,27 +2690,25 @@ def test_onOffset(self, case):
26922690
dt, expected = case
26932691
assert_onOffset(SemiMonthBegin(), dt, expected)
26942692

2695-
@pytest.mark.parametrize('klass,assert_func',
2696-
[(Series, tm.assert_series_equal),
2697-
(DatetimeIndex, tm.assert_index_equal)])
2698-
def test_vectorized_offset_addition(self, klass, assert_func):
2693+
@pytest.mark.parametrize('klass', [Series, DatetimeIndex])
2694+
def test_vectorized_offset_addition(self, klass):
26992695
s = klass([Timestamp('2000-01-15 00:15:00', tz='US/Central'),
27002696
Timestamp('2000-02-15', tz='US/Central')], name='a')
27012697
result = s + SemiMonthBegin()
27022698
result2 = SemiMonthBegin() + s
27032699
exp = klass([Timestamp('2000-02-01 00:15:00', tz='US/Central'),
27042700
Timestamp('2000-03-01', tz='US/Central')], name='a')
2705-
assert_func(result, exp)
2706-
assert_func(result2, exp)
2701+
tm.assert_equal(result, exp)
2702+
tm.assert_equal(result2, exp)
27072703

27082704
s = klass([Timestamp('2000-01-01 00:15:00', tz='US/Central'),
27092705
Timestamp('2000-02-01', tz='US/Central')], name='a')
27102706
result = s + SemiMonthBegin()
27112707
result2 = SemiMonthBegin() + s
27122708
exp = klass([Timestamp('2000-01-15 00:15:00', tz='US/Central'),
27132709
Timestamp('2000-02-15', tz='US/Central')], name='a')
2714-
assert_func(result, exp)
2715-
assert_func(result2, exp)
2710+
tm.assert_equal(result, exp)
2711+
tm.assert_equal(result2, exp)
27162712

27172713

27182714
def test_Easter():

pandas/util/testing.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1522,8 +1522,8 @@ def assert_equal(left, right, **kwargs):
15221522
15231523
Parameters
15241524
----------
1525-
left : Index, Series, or DataFrame
1526-
right : Index, Series, or DataFrame
1525+
left : Index, Series, DataFrame, ExtensionArray, or np.ndarray
1526+
right : Index, Series, DataFrame, ExtensionArray, or np.ndarray
15271527
**kwargs
15281528
"""
15291529
if isinstance(left, pd.Index):
@@ -1532,6 +1532,10 @@ def assert_equal(left, right, **kwargs):
15321532
assert_series_equal(left, right, **kwargs)
15331533
elif isinstance(left, pd.DataFrame):
15341534
assert_frame_equal(left, right, **kwargs)
1535+
elif isinstance(left, ExtensionArray):
1536+
assert_extension_array_equal(left, right, **kwargs)
1537+
elif isinstance(left, np.ndarray):
1538+
assert_numpy_array_equal(left, right, **kwargs)
15351539
else:
15361540
raise NotImplementedError(type(left))
15371541

0 commit comments

Comments
 (0)