5
5
import warnings
6
6
from datetime import timedelta
7
7
from distutils .version import LooseVersion
8
+ import operator
8
9
import sys
9
10
import pytest
10
11
@@ -2091,42 +2092,40 @@ def test_clip_with_na_args(self):
2091
2092
self .frame )
2092
2093
2093
2094
# Matrix-like
2094
- @pytest .mark .parametrize ('dot_fn' , [DataFrame .dot , DataFrame .__matmul__ ])
2095
- def test_dot (self , dot_fn ):
2096
- # __matmul__ test is for GH #10259
2095
+ def test_dot (self ):
2097
2096
a = DataFrame (np .random .randn (3 , 4 ), index = ['a' , 'b' , 'c' ],
2098
2097
columns = ['p' , 'q' , 'r' , 's' ])
2099
2098
b = DataFrame (np .random .randn (4 , 2 ), index = ['p' , 'q' , 'r' , 's' ],
2100
2099
columns = ['one' , 'two' ])
2101
2100
2102
- result = dot_fn ( a , b )
2101
+ result = a . dot ( b )
2103
2102
expected = DataFrame (np .dot (a .values , b .values ),
2104
2103
index = ['a' , 'b' , 'c' ],
2105
2104
columns = ['one' , 'two' ])
2106
2105
# Check alignment
2107
2106
b1 = b .reindex (index = reversed (b .index ))
2108
- result = dot_fn ( a , b )
2107
+ result = a . dot ( b )
2109
2108
tm .assert_frame_equal (result , expected )
2110
2109
2111
2110
# Check series argument
2112
- result = dot_fn ( a , b ['one' ])
2111
+ result = a . dot ( b ['one' ])
2113
2112
tm .assert_series_equal (result , expected ['one' ], check_names = False )
2114
2113
assert result .name is None
2115
2114
2116
- result = dot_fn ( a , b1 ['one' ])
2115
+ result = a . dot ( b1 ['one' ])
2117
2116
tm .assert_series_equal (result , expected ['one' ], check_names = False )
2118
2117
assert result .name is None
2119
2118
2120
2119
# can pass correct-length arrays
2121
2120
row = a .iloc [0 ].values
2122
2121
2123
- result = dot_fn ( a , row )
2124
- exp = dot_fn ( a , a .iloc [0 ])
2122
+ result = a . dot ( row )
2123
+ exp = a . dot ( a .iloc [0 ])
2125
2124
tm .assert_series_equal (result , exp )
2126
2125
2127
2126
with tm .assert_raises_regex (ValueError ,
2128
2127
'Dot product shape mismatch' ):
2129
- dot_fn ( a , row [:- 1 ])
2128
+ a . dot ( row [:- 1 ])
2130
2129
2131
2130
a = np .random .rand (1 , 5 )
2132
2131
b = np .random .rand (5 , 1 )
@@ -2136,14 +2135,55 @@ def test_dot(self, dot_fn):
2136
2135
B = DataFrame (b ) # noqa
2137
2136
2138
2137
# it works
2139
- result = dot_fn ( A , b )
2138
+ result = A . dot ( b )
2140
2139
2141
2140
# unaligned
2142
2141
df = DataFrame (randn (3 , 4 ), index = [1 , 2 , 3 ], columns = lrange (4 ))
2143
2142
df2 = DataFrame (randn (5 , 3 ), index = lrange (5 ), columns = [1 , 2 , 3 ])
2144
2143
2145
2144
with tm .assert_raises_regex (ValueError , 'aligned' ):
2146
- dot_fn (df , df2 )
2145
+ df .dot (df2 )
2146
+
2147
+ @pytest .mark .skipif (sys .version_info < (3 , 5 ),
2148
+ reason = 'matmul supported for Python>=3.5' )
2149
+ def test_matmul (self ):
2150
+ # matmul test is for GH #10259
2151
+ a = DataFrame (np .random .randn (3 , 4 ), index = ['a' , 'b' , 'c' ],
2152
+ columns = ['p' , 'q' , 'r' , 's' ])
2153
+ b = DataFrame (np .random .randn (4 , 2 ), index = ['p' , 'q' , 'r' , 's' ],
2154
+ columns = ['one' , 'two' ])
2155
+
2156
+ # DataFrame @ DataFrame
2157
+ result = operator .matmul (a , b )
2158
+ expected = DataFrame (np .dot (a .values , b .values ),
2159
+ index = ['a' , 'b' , 'c' ],
2160
+ columns = ['one' , 'two' ])
2161
+ tm .assert_frame_equal (result , expected )
2162
+
2163
+ # DataFrame @ Series
2164
+ result = operator .matmul (a , b .one )
2165
+ expected = Series (np .dot (a .values , b .one .values ),
2166
+ index = ['a' , 'b' , 'c' ])
2167
+ tm .assert_series_equal (result , expected )
2168
+
2169
+ # np.array @ DataFrame
2170
+ result = operator .matmul (a .values , b )
2171
+ expected = np .dot (a .values , b .values )
2172
+ tm .assert_almost_equal (result , expected )
2173
+
2174
+ # nested list @ DataFrame (__rmatmul__)
2175
+ result = operator .matmul (a .values .tolist (), b )
2176
+ expected = DataFrame (np .dot (a .values , b .values ),
2177
+ index = ['a' , 'b' , 'c' ],
2178
+ columns = ['one' , 'two' ])
2179
+ tm .assert_almost_equal (result .values , expected .values )
2180
+
2181
+ # unaligned
2182
+ df = DataFrame (randn (3 , 4 ), index = [1 , 2 , 3 ], columns = lrange (4 ))
2183
+ df2 = DataFrame (randn (5 , 3 ), index = lrange (5 ), columns = [1 , 2 , 3 ])
2184
+
2185
+ with tm .assert_raises_regex (ValueError , 'aligned' ):
2186
+ operator .matmul (df , df2 )
2147
2187
2148
2188
2149
2189
@pytest .fixture
0 commit comments