Skip to content

Commit bab8393

Browse files
[ArrowStringArray] TST: move/combine a couple of tests (#41473)
1 parent 9ae8f1d commit bab8393

File tree

3 files changed

+38
-49
lines changed

3 files changed

+38
-49
lines changed

pandas/tests/strings/test_case_justify.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,21 @@ def test_lower_upper_mixed_object():
4949
tm.assert_series_equal(result, expected)
5050

5151

52-
def test_capitalize(any_string_dtype):
53-
s = Series(["FOO", "BAR", np.nan, "Blah", "blurg"], dtype=any_string_dtype)
52+
@pytest.mark.parametrize(
53+
"data, expected",
54+
[
55+
(
56+
["FOO", "BAR", np.nan, "Blah", "blurg"],
57+
["Foo", "Bar", np.nan, "Blah", "Blurg"],
58+
),
59+
(["a", "b", "c"], ["A", "B", "C"]),
60+
(["a b", "a bc. de"], ["A b", "A bc. de"]),
61+
],
62+
)
63+
def test_capitalize(data, expected, any_string_dtype):
64+
s = Series(data, dtype=any_string_dtype)
5465
result = s.str.capitalize()
55-
expected = Series(["Foo", "Bar", np.nan, "Blah", "Blurg"], dtype=any_string_dtype)
66+
expected = Series(expected, dtype=any_string_dtype)
5667
tm.assert_series_equal(result, expected)
5768

5869

pandas/tests/strings/test_split_partition.py

+24-19
Original file line numberDiff line numberDiff line change
@@ -614,56 +614,61 @@ def test_partition_sep_kwarg(any_string_dtype):
614614

615615

616616
def test_get():
617-
values = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"])
618-
619-
result = values.str.split("_").str.get(1)
617+
ser = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"])
618+
result = ser.str.split("_").str.get(1)
620619
expected = Series(["b", "d", np.nan, "g"])
621620
tm.assert_series_equal(result, expected)
622621

623-
# mixed
624-
mixed = Series(["a_b_c", np.nan, "c_d_e", True, datetime.today(), None, 1, 2.0])
625622

626-
rs = Series(mixed).str.split("_").str.get(1)
627-
xp = Series(["b", np.nan, "d", np.nan, np.nan, np.nan, np.nan, np.nan])
623+
def test_get_mixed_object():
624+
ser = Series(["a_b_c", np.nan, "c_d_e", True, datetime.today(), None, 1, 2.0])
625+
result = ser.str.split("_").str.get(1)
626+
expected = Series(["b", np.nan, "d", np.nan, np.nan, np.nan, np.nan, np.nan])
627+
tm.assert_series_equal(result, expected)
628628

629-
assert isinstance(rs, Series)
630-
tm.assert_almost_equal(rs, xp)
631629

632-
# bounds testing
633-
values = Series(["1_2_3_4_5", "6_7_8_9_10", "11_12"])
630+
def test_get_bounds():
631+
ser = Series(["1_2_3_4_5", "6_7_8_9_10", "11_12"])
634632

635633
# positive index
636-
result = values.str.split("_").str.get(2)
634+
result = ser.str.split("_").str.get(2)
637635
expected = Series(["3", "8", np.nan])
638636
tm.assert_series_equal(result, expected)
639637

640638
# negative index
641-
result = values.str.split("_").str.get(-3)
639+
result = ser.str.split("_").str.get(-3)
642640
expected = Series(["3", "8", np.nan])
643641
tm.assert_series_equal(result, expected)
644642

645643

646644
def test_get_complex():
647645
# GH 20671, getting value not in dict raising `KeyError`
648-
values = Series([(1, 2, 3), [1, 2, 3], {1, 2, 3}, {1: "a", 2: "b", 3: "c"}])
646+
ser = Series([(1, 2, 3), [1, 2, 3], {1, 2, 3}, {1: "a", 2: "b", 3: "c"}])
649647

650-
result = values.str.get(1)
648+
result = ser.str.get(1)
651649
expected = Series([2, 2, np.nan, "a"])
652650
tm.assert_series_equal(result, expected)
653651

654-
result = values.str.get(-1)
652+
result = ser.str.get(-1)
655653
expected = Series([3, 3, np.nan, np.nan])
656654
tm.assert_series_equal(result, expected)
657655

658656

659657
@pytest.mark.parametrize("to_type", [tuple, list, np.array])
660658
def test_get_complex_nested(to_type):
661-
values = Series([to_type([to_type([1, 2])])])
659+
ser = Series([to_type([to_type([1, 2])])])
662660

663-
result = values.str.get(0)
661+
result = ser.str.get(0)
664662
expected = Series([to_type([1, 2])])
665663
tm.assert_series_equal(result, expected)
666664

667-
result = values.str.get(1)
665+
result = ser.str.get(1)
668666
expected = Series([np.nan])
669667
tm.assert_series_equal(result, expected)
668+
669+
670+
def test_get_strings(any_string_dtype):
671+
ser = Series(["a", "ab", np.nan, "abc"], dtype=any_string_dtype)
672+
result = ser.str.get(2)
673+
expected = Series([np.nan, np.nan, np.nan, "c"], dtype=any_string_dtype)
674+
tm.assert_series_equal(result, expected)

pandas/tests/strings/test_string_array.py

-27
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
import operator
2-
31
import numpy as np
42
import pytest
53

64
from pandas._libs import lib
75

8-
import pandas as pd
96
from pandas import (
107
DataFrame,
118
Series,
@@ -99,27 +96,3 @@ def test_string_array_extract(nullable_string_dtype):
9996

10097
result = result.astype(object)
10198
tm.assert_equal(result, expected)
102-
103-
104-
def test_str_get_stringarray_multiple_nans(nullable_string_dtype):
105-
s = Series(pd.array(["a", "ab", pd.NA, "abc"], dtype=nullable_string_dtype))
106-
result = s.str.get(2)
107-
expected = Series(pd.array([pd.NA, pd.NA, pd.NA, "c"], dtype=nullable_string_dtype))
108-
tm.assert_series_equal(result, expected)
109-
110-
111-
@pytest.mark.parametrize(
112-
"input, method",
113-
[
114-
(["a", "b", "c"], operator.methodcaller("capitalize")),
115-
(["a b", "a bc. de"], operator.methodcaller("capitalize")),
116-
],
117-
)
118-
def test_capitalize(input, method, nullable_string_dtype):
119-
a = Series(input, dtype=nullable_string_dtype)
120-
b = Series(input, dtype="object")
121-
result = method(a.str)
122-
expected = method(b.str)
123-
124-
assert result.dtype.name == nullable_string_dtype
125-
tm.assert_series_equal(result.astype(object), expected)

0 commit comments

Comments
 (0)