Skip to content

Commit 101159d

Browse files
committed
Parametrized TestWhereCoercion except for date/time
1 parent 27bb751 commit 101159d

File tree

1 file changed

+80
-191
lines changed

1 file changed

+80
-191
lines changed

pandas/tests/indexing/test_coercion.py

+80-191
Original file line numberDiff line numberDiff line change
@@ -544,231 +544,120 @@ def _assert_where_conversion(self, original, cond, values,
544544
res = target.where(cond, values)
545545
self._assert(res, expected, expected_dtype)
546546

547-
def _where_object_common(self, klass):
547+
@pytest.mark.parametrize("klass", [pd.Series, pd.Index])
548+
@pytest.mark.parametrize("fill_val,exp_dtype", [
549+
(1, np.object),
550+
(1.1, np.object),
551+
(1 + 1j, np.object),
552+
(True, np.object)])
553+
def test_where_object(self, klass, fill_val,exp_dtype):
548554
obj = klass(list('abcd'))
549555
assert obj.dtype == np.object
550556
cond = klass([True, False, True, False])
551557

552-
# object + int -> object
553-
exp = klass(['a', 1, 'c', 1])
554-
self._assert_where_conversion(obj, cond, 1, exp, np.object)
555-
556-
values = klass([5, 6, 7, 8])
557-
exp = klass(['a', 6, 'c', 8])
558-
self._assert_where_conversion(obj, cond, values, exp, np.object)
559-
560-
# object + float -> object
561-
exp = klass(['a', 1.1, 'c', 1.1])
562-
self._assert_where_conversion(obj, cond, 1.1, exp, np.object)
563-
564-
values = klass([5.5, 6.6, 7.7, 8.8])
565-
exp = klass(['a', 6.6, 'c', 8.8])
566-
self._assert_where_conversion(obj, cond, values, exp, np.object)
567-
568-
# object + complex -> object
569-
exp = klass(['a', 1 + 1j, 'c', 1 + 1j])
570-
self._assert_where_conversion(obj, cond, 1 + 1j, exp, np.object)
571-
572-
values = klass([5 + 5j, 6 + 6j, 7 + 7j, 8 + 8j])
573-
exp = klass(['a', 6 + 6j, 'c', 8 + 8j])
574-
self._assert_where_conversion(obj, cond, values, exp, np.object)
575-
576-
if klass is pd.Series:
577-
exp = klass(['a', 1, 'c', 1])
578-
self._assert_where_conversion(obj, cond, True, exp, np.object)
558+
if fill_val is True and klass is pd.Series:
559+
ret_val = 1
560+
else:
561+
ret_val = fill_val
579562

580-
values = klass([True, False, True, True])
581-
exp = klass(['a', 0, 'c', 1])
582-
self._assert_where_conversion(obj, cond, values, exp, np.object)
583-
elif klass is pd.Index:
584-
# object + bool -> object
585-
exp = klass(['a', True, 'c', True])
586-
self._assert_where_conversion(obj, cond, True, exp, np.object)
563+
exp = klass(['a', ret_val, 'c', ret_val])
564+
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)
587565

566+
if fill_val is True:
588567
values = klass([True, False, True, True])
589-
exp = klass(['a', False, 'c', True])
590-
self._assert_where_conversion(obj, cond, values, exp, np.object)
591568
else:
592-
NotImplementedError
569+
values = klass(fill_val*x for x in [5, 6, 7, 8])
570+
571+
exp = klass(['a', values[1], 'c', values[3]])
572+
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)
593573

594-
def test_where_series_object(self):
595-
self._where_object_common(pd.Series)
596-
597-
def test_where_index_object(self):
598-
self._where_object_common(pd.Index)
599-
600-
def _where_int64_common(self, klass):
574+
@pytest.mark.parametrize("klass", [pd.Series, pd.Index])
575+
@pytest.mark.parametrize("fill_val,exp_dtype", [
576+
(1, np.int64),
577+
(1.1, np.float64),
578+
(1 + 1j, np.complex128),
579+
(True, np.object)])
580+
def test_where_int64(self, klass, fill_val, exp_dtype):
581+
if klass is pd.Index and exp_dtype is np.complex128:
582+
pytest.skip("Complex Index not supported")
601583
obj = klass([1, 2, 3, 4])
602584
assert obj.dtype == np.int64
603585
cond = klass([True, False, True, False])
604586

605-
# int + int -> int
606-
exp = klass([1, 1, 3, 1])
607-
self._assert_where_conversion(obj, cond, 1, exp, np.int64)
608-
609-
values = klass([5, 6, 7, 8])
610-
exp = klass([1, 6, 3, 8])
611-
self._assert_where_conversion(obj, cond, values, exp, np.int64)
612-
613-
# int + float -> float
614-
exp = klass([1, 1.1, 3, 1.1])
615-
self._assert_where_conversion(obj, cond, 1.1, exp, np.float64)
616-
617-
values = klass([5.5, 6.6, 7.7, 8.8])
618-
exp = klass([1, 6.6, 3, 8.8])
619-
self._assert_where_conversion(obj, cond, values, exp, np.float64)
620-
621-
# int + complex -> complex
622-
if klass is pd.Series:
623-
exp = klass([1, 1 + 1j, 3, 1 + 1j])
624-
self._assert_where_conversion(obj, cond, 1 + 1j, exp,
625-
np.complex128)
626-
627-
values = klass([5 + 5j, 6 + 6j, 7 + 7j, 8 + 8j])
628-
exp = klass([1, 6 + 6j, 3, 8 + 8j])
629-
self._assert_where_conversion(obj, cond, values, exp,
630-
np.complex128)
631-
632-
# int + bool -> object
633-
exp = klass([1, True, 3, True])
634-
self._assert_where_conversion(obj, cond, True, exp, np.object)
587+
exp = klass([1, fill_val, 3, fill_val])
588+
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)
635589

636-
values = klass([True, False, True, True])
637-
exp = klass([1, False, 3, True])
638-
self._assert_where_conversion(obj, cond, values, exp, np.object)
639-
640-
def test_where_series_int64(self):
641-
self._where_int64_common(pd.Series)
642-
643-
def test_where_index_int64(self):
644-
self._where_int64_common(pd.Index)
590+
if fill_val is True:
591+
values = klass([True, False, True, True])
592+
else:
593+
values = klass(x*fill_val for x in [5, 6, 7, 8])
594+
exp = klass([1, values[1], 3, values[3]])
595+
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)
645596

646-
def _where_float64_common(self, klass):
597+
@pytest.mark.parametrize("klass", [pd.Series, pd.Index])
598+
@pytest.mark.parametrize("fill_val, exp_dtype", [
599+
(1, np.float64),
600+
(1.1, np.float64),
601+
(1 + 1j, np.complex128),
602+
(True, np.object)])
603+
def test_where_float64(self, klass, fill_val, exp_dtype):
604+
if klass is pd.Index and exp_dtype is np.complex128:
605+
pytest.skip("Complex Index not supported")
647606
obj = klass([1.1, 2.2, 3.3, 4.4])
648607
assert obj.dtype == np.float64
649608
cond = klass([True, False, True, False])
650609

651-
# float + int -> float
652-
exp = klass([1.1, 1.0, 3.3, 1.0])
653-
self._assert_where_conversion(obj, cond, 1, exp, np.float64)
610+
exp = klass([1.1, fill_val, 3.3, fill_val])
611+
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)
654612

655-
values = klass([5, 6, 7, 8])
656-
exp = klass([1.1, 6.0, 3.3, 8.0])
657-
self._assert_where_conversion(obj, cond, values, exp, np.float64)
658-
659-
# float + float -> float
660-
exp = klass([1.1, 1.1, 3.3, 1.1])
661-
self._assert_where_conversion(obj, cond, 1.1, exp, np.float64)
662-
663-
values = klass([5.5, 6.6, 7.7, 8.8])
664-
exp = klass([1.1, 6.6, 3.3, 8.8])
665-
self._assert_where_conversion(obj, cond, values, exp, np.float64)
666-
667-
# float + complex -> complex
668-
if klass is pd.Series:
669-
exp = klass([1.1, 1 + 1j, 3.3, 1 + 1j])
670-
self._assert_where_conversion(obj, cond, 1 + 1j, exp,
671-
np.complex128)
672-
673-
values = klass([5 + 5j, 6 + 6j, 7 + 7j, 8 + 8j])
674-
exp = klass([1.1, 6 + 6j, 3.3, 8 + 8j])
675-
self._assert_where_conversion(obj, cond, values, exp,
676-
np.complex128)
677-
678-
# float + bool -> object
679-
exp = klass([1.1, True, 3.3, True])
680-
self._assert_where_conversion(obj, cond, True, exp, np.object)
681-
682-
values = klass([True, False, True, True])
683-
exp = klass([1.1, False, 3.3, True])
684-
self._assert_where_conversion(obj, cond, values, exp, np.object)
685-
686-
def test_where_series_float64(self):
687-
self._where_float64_common(pd.Series)
688-
689-
def test_where_index_float64(self):
690-
self._where_float64_common(pd.Index)
613+
if fill_val is True:
614+
values = klass([True, False, True, True])
615+
else:
616+
values = klass(x*fill_val for x in [5, 6, 7, 8])
617+
exp = klass([1.1, values[1], 3.3, values[3]])
618+
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)
691619

692-
def test_where_series_complex128(self):
620+
@pytest.mark.parametrize("fill_val,exp_dtype", [
621+
(1, np.complex128),
622+
(1.1, np.complex128),
623+
(1 + 1j, np.complex128),
624+
(True, np.object)])
625+
def test_where_series_complex128(self, fill_val, exp_dtype):
693626
obj = pd.Series([1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j])
694627
assert obj.dtype == np.complex128
695628
cond = pd.Series([True, False, True, False])
696629

697630
# complex + int -> complex
698-
exp = pd.Series([1 + 1j, 1, 3 + 3j, 1])
699-
self._assert_where_conversion(obj, cond, 1, exp, np.complex128)
700-
701-
values = pd.Series([5, 6, 7, 8])
702-
exp = pd.Series([1 + 1j, 6.0, 3 + 3j, 8.0])
703-
self._assert_where_conversion(obj, cond, values, exp, np.complex128)
704-
705-
# complex + float -> complex
706-
exp = pd.Series([1 + 1j, 1.1, 3 + 3j, 1.1])
707-
self._assert_where_conversion(obj, cond, 1.1, exp, np.complex128)
708-
709-
values = pd.Series([5.5, 6.6, 7.7, 8.8])
710-
exp = pd.Series([1 + 1j, 6.6, 3 + 3j, 8.8])
711-
self._assert_where_conversion(obj, cond, values, exp, np.complex128)
712-
713-
# complex + complex -> complex
714-
exp = pd.Series([1 + 1j, 1 + 1j, 3 + 3j, 1 + 1j])
715-
self._assert_where_conversion(obj, cond, 1 + 1j, exp, np.complex128)
716-
717-
values = pd.Series([5 + 5j, 6 + 6j, 7 + 7j, 8 + 8j])
718-
exp = pd.Series([1 + 1j, 6 + 6j, 3 + 3j, 8 + 8j])
719-
self._assert_where_conversion(obj, cond, values, exp, np.complex128)
720-
721-
# complex + bool -> object
722-
exp = pd.Series([1 + 1j, True, 3 + 3j, True])
723-
self._assert_where_conversion(obj, cond, True, exp, np.object)
724-
725-
values = pd.Series([True, False, True, True])
726-
exp = pd.Series([1 + 1j, False, 3 + 3j, True])
727-
self._assert_where_conversion(obj, cond, values, exp, np.object)
631+
exp = pd.Series([1 + 1j, fill_val, 3 + 3j, fill_val])
632+
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)
728633

729-
def test_where_index_complex128(self):
730-
pass
634+
if fill_val is True:
635+
values = pd.Series([True, False, True, True])
636+
else:
637+
values = pd.Series(x*fill_val for x in [5, 6, 7, 8])
638+
exp = pd.Series([1 + 1j, values[1], 3 + 3j, values[3]])
639+
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)
731640

732-
def test_where_series_bool(self):
641+
@pytest.mark.parametrize("fill_val,exp_dtype", [
642+
(1, np.object),
643+
(1.1, np.object),
644+
(1 + 1j, np.object),
645+
(True, np.bool)])
646+
def test_where_series_bool(self, fill_val, exp_dtype):
733647

734648
obj = pd.Series([True, False, True, False])
735649
assert obj.dtype == np.bool
736650
cond = pd.Series([True, False, True, False])
737651

738-
# bool + int -> object
739-
exp = pd.Series([True, 1, True, 1])
740-
self._assert_where_conversion(obj, cond, 1, exp, np.object)
652+
exp = pd.Series([True, fill_val, True, fill_val])
653+
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)
741654

742-
values = pd.Series([5, 6, 7, 8])
743-
exp = pd.Series([True, 6, True, 8])
744-
self._assert_where_conversion(obj, cond, values, exp, np.object)
745-
746-
# bool + float -> object
747-
exp = pd.Series([True, 1.1, True, 1.1])
748-
self._assert_where_conversion(obj, cond, 1.1, exp, np.object)
749-
750-
values = pd.Series([5.5, 6.6, 7.7, 8.8])
751-
exp = pd.Series([True, 6.6, True, 8.8])
752-
self._assert_where_conversion(obj, cond, values, exp, np.object)
753-
754-
# bool + complex -> object
755-
exp = pd.Series([True, 1 + 1j, True, 1 + 1j])
756-
self._assert_where_conversion(obj, cond, 1 + 1j, exp, np.object)
757-
758-
values = pd.Series([5 + 5j, 6 + 6j, 7 + 7j, 8 + 8j])
759-
exp = pd.Series([True, 6 + 6j, True, 8 + 8j])
760-
self._assert_where_conversion(obj, cond, values, exp, np.object)
761-
762-
# bool + bool -> bool
763-
exp = pd.Series([True, True, True, True])
764-
self._assert_where_conversion(obj, cond, True, exp, np.bool)
765-
766-
values = pd.Series([True, False, True, True])
767-
exp = pd.Series([True, False, True, True])
768-
self._assert_where_conversion(obj, cond, values, exp, np.bool)
769-
770-
def test_where_index_bool(self):
771-
pass
655+
if fill_val is True:
656+
values = pd.Series([True, False, True, True])
657+
else:
658+
values = pd.Series(x*fill_val for x in [5, 6, 7, 8])
659+
exp = pd.Series([True, values[1], True, values[3]])
660+
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)
772661

773662
def test_where_series_datetime64(self):
774663
obj = pd.Series([pd.Timestamp('2011-01-01'),

0 commit comments

Comments
 (0)