Skip to content

Commit df68cce

Browse files
authored
REF: .dot tests (pandas-dev#33214)
1 parent 8d299e4 commit df68cce

File tree

3 files changed

+128
-85
lines changed

3 files changed

+128
-85
lines changed

pandas/tests/frame/test_analytics.py

-53
Original file line numberDiff line numberDiff line change
@@ -1147,59 +1147,6 @@ def test_any_all_level_axis_none_raises(self, method):
11471147
# ---------------------------------------------------------------------
11481148
# Matrix-like
11491149

1150-
def test_dot(self):
1151-
a = DataFrame(
1152-
np.random.randn(3, 4), index=["a", "b", "c"], columns=["p", "q", "r", "s"]
1153-
)
1154-
b = DataFrame(
1155-
np.random.randn(4, 2), index=["p", "q", "r", "s"], columns=["one", "two"]
1156-
)
1157-
1158-
result = a.dot(b)
1159-
expected = DataFrame(
1160-
np.dot(a.values, b.values), index=["a", "b", "c"], columns=["one", "two"]
1161-
)
1162-
# Check alignment
1163-
b1 = b.reindex(index=reversed(b.index))
1164-
result = a.dot(b)
1165-
tm.assert_frame_equal(result, expected)
1166-
1167-
# Check series argument
1168-
result = a.dot(b["one"])
1169-
tm.assert_series_equal(result, expected["one"], check_names=False)
1170-
assert result.name is None
1171-
1172-
result = a.dot(b1["one"])
1173-
tm.assert_series_equal(result, expected["one"], check_names=False)
1174-
assert result.name is None
1175-
1176-
# can pass correct-length arrays
1177-
row = a.iloc[0].values
1178-
1179-
result = a.dot(row)
1180-
expected = a.dot(a.iloc[0])
1181-
tm.assert_series_equal(result, expected)
1182-
1183-
with pytest.raises(ValueError, match="Dot product shape mismatch"):
1184-
a.dot(row[:-1])
1185-
1186-
a = np.random.rand(1, 5)
1187-
b = np.random.rand(5, 1)
1188-
A = DataFrame(a)
1189-
1190-
# TODO(wesm): unused
1191-
B = DataFrame(b) # noqa
1192-
1193-
# it works
1194-
result = A.dot(b)
1195-
1196-
# unaligned
1197-
df = DataFrame(np.random.randn(3, 4), index=[1, 2, 3], columns=range(4))
1198-
df2 = DataFrame(np.random.randn(5, 3), index=range(5), columns=[1, 2, 3])
1199-
1200-
with pytest.raises(ValueError, match="aligned"):
1201-
df.dot(df2)
1202-
12031150
def test_matmul(self):
12041151
# matmul test is for GH 10259
12051152
a = DataFrame(
+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pandas import DataFrame, Series
5+
import pandas._testing as tm
6+
7+
8+
class DotSharedTests:
9+
@pytest.fixture
10+
def obj(self):
11+
raise NotImplementedError
12+
13+
@pytest.fixture
14+
def other(self) -> DataFrame:
15+
"""
16+
other is a DataFrame that is indexed so that obj.dot(other) is valid
17+
"""
18+
raise NotImplementedError
19+
20+
@pytest.fixture
21+
def expected(self, obj, other) -> DataFrame:
22+
"""
23+
The expected result of obj.dot(other)
24+
"""
25+
raise NotImplementedError
26+
27+
@classmethod
28+
def reduced_dim_assert(cls, result, expected):
29+
"""
30+
Assertion about results with 1 fewer dimension that self.obj
31+
"""
32+
raise NotImplementedError
33+
34+
def test_dot_equiv_values_dot(self, obj, other, expected):
35+
# `expected` is constructed from obj.values.dot(other.values)
36+
result = obj.dot(other)
37+
tm.assert_equal(result, expected)
38+
39+
def test_dot_2d_ndarray(self, obj, other, expected):
40+
# Check ndarray argument; in this case we get matching values,
41+
# but index/columns may not match
42+
result = obj.dot(other.values)
43+
assert np.all(result == expected.values)
44+
45+
def test_dot_1d_ndarray(self, obj, expected):
46+
# can pass correct-length array
47+
row = obj.iloc[0] if obj.ndim == 2 else obj
48+
49+
result = obj.dot(row.values)
50+
expected = obj.dot(row)
51+
self.reduced_dim_assert(result, expected)
52+
53+
def test_dot_series(self, obj, other, expected):
54+
# Check series argument
55+
result = obj.dot(other["1"])
56+
self.reduced_dim_assert(result, expected["1"])
57+
58+
def test_dot_series_alignment(self, obj, other, expected):
59+
result = obj.dot(other.iloc[::-1]["1"])
60+
self.reduced_dim_assert(result, expected["1"])
61+
62+
def test_dot_aligns(self, obj, other, expected):
63+
# Check index alignment
64+
other2 = other.iloc[::-1]
65+
result = obj.dot(other2)
66+
tm.assert_equal(result, expected)
67+
68+
def test_dot_shape_mismatch(self, obj):
69+
msg = "Dot product shape mismatch"
70+
# exception raised is of type Exception
71+
with pytest.raises(Exception, match=msg):
72+
obj.dot(obj.values[:3])
73+
74+
def test_dot_misaligned(self, obj, other):
75+
msg = "matrices are not aligned"
76+
with pytest.raises(ValueError, match=msg):
77+
obj.dot(other.T)
78+
79+
80+
class TestSeriesDot(DotSharedTests):
81+
@pytest.fixture
82+
def obj(self):
83+
return Series(np.random.randn(4), index=["p", "q", "r", "s"])
84+
85+
@pytest.fixture
86+
def other(self):
87+
return DataFrame(
88+
np.random.randn(3, 4), index=["1", "2", "3"], columns=["p", "q", "r", "s"]
89+
).T
90+
91+
@pytest.fixture
92+
def expected(self, obj, other):
93+
return Series(np.dot(obj.values, other.values), index=other.columns)
94+
95+
@classmethod
96+
def reduced_dim_assert(cls, result, expected):
97+
"""
98+
Assertion about results with 1 fewer dimension that self.obj
99+
"""
100+
tm.assert_almost_equal(result, expected)
101+
102+
103+
class TestDataFrameDot(DotSharedTests):
104+
@pytest.fixture
105+
def obj(self):
106+
return DataFrame(
107+
np.random.randn(3, 4), index=["a", "b", "c"], columns=["p", "q", "r", "s"]
108+
)
109+
110+
@pytest.fixture
111+
def other(self):
112+
return DataFrame(
113+
np.random.randn(4, 2), index=["p", "q", "r", "s"], columns=["1", "2"]
114+
)
115+
116+
@pytest.fixture
117+
def expected(self, obj, other):
118+
return DataFrame(
119+
np.dot(obj.values, other.values), index=obj.index, columns=other.columns
120+
)
121+
122+
@classmethod
123+
def reduced_dim_assert(cls, result, expected):
124+
"""
125+
Assertion about results with 1 fewer dimension that self.obj
126+
"""
127+
tm.assert_series_equal(result, expected, check_names=False)
128+
assert result.name is None

pandas/tests/series/test_analytics.py

-32
Original file line numberDiff line numberDiff line change
@@ -17,38 +17,6 @@ def test_prod_numpy16_bug(self):
1717

1818
assert not isinstance(result, Series)
1919

20-
def test_dot(self):
21-
a = Series(np.random.randn(4), index=["p", "q", "r", "s"])
22-
b = DataFrame(
23-
np.random.randn(3, 4), index=["1", "2", "3"], columns=["p", "q", "r", "s"]
24-
).T
25-
26-
result = a.dot(b)
27-
expected = Series(np.dot(a.values, b.values), index=["1", "2", "3"])
28-
tm.assert_series_equal(result, expected)
29-
30-
# Check index alignment
31-
b2 = b.reindex(index=reversed(b.index))
32-
result = a.dot(b)
33-
tm.assert_series_equal(result, expected)
34-
35-
# Check ndarray argument
36-
result = a.dot(b.values)
37-
assert np.all(result == expected.values)
38-
tm.assert_almost_equal(a.dot(b["2"].values), expected["2"])
39-
40-
# Check series argument
41-
tm.assert_almost_equal(a.dot(b["1"]), expected["1"])
42-
tm.assert_almost_equal(a.dot(b2["1"]), expected["1"])
43-
44-
msg = r"Dot product shape mismatch, \(4,\) vs \(3,\)"
45-
# exception raised is of type Exception
46-
with pytest.raises(Exception, match=msg):
47-
a.dot(a.values[:3])
48-
msg = "matrices are not aligned"
49-
with pytest.raises(ValueError, match=msg):
50-
a.dot(b.T)
51-
5220
def test_matmul(self):
5321
# matmul test is for GH #10259
5422
a = Series(np.random.randn(4), index=["p", "q", "r", "s"])

0 commit comments

Comments
 (0)