5
5
import warnings
6
6
from datetime import timedelta
7
7
from distutils .version import LooseVersion
8
+ import operator
8
9
import sys
9
10
import pytest
10
11
@@ -2089,42 +2090,40 @@ def test_clip_with_na_args(self):
2089
2090
self .frame )
2090
2091
2091
2092
# Matrix-like
2092
- @pytest .mark .parametrize ('dot_fn' , [DataFrame .dot , DataFrame .__matmul__ ])
2093
- def test_dot (self , dot_fn ):
2094
- # __matmul__ test is for GH #10259
2093
+ def test_dot (self ):
2095
2094
a = DataFrame (np .random .randn (3 , 4 ), index = ['a' , 'b' , 'c' ],
2096
2095
columns = ['p' , 'q' , 'r' , 's' ])
2097
2096
b = DataFrame (np .random .randn (4 , 2 ), index = ['p' , 'q' , 'r' , 's' ],
2098
2097
columns = ['one' , 'two' ])
2099
2098
2100
- result = dot_fn ( a , b )
2099
+ result = a . dot ( b )
2101
2100
expected = DataFrame (np .dot (a .values , b .values ),
2102
2101
index = ['a' , 'b' , 'c' ],
2103
2102
columns = ['one' , 'two' ])
2104
2103
# Check alignment
2105
2104
b1 = b .reindex (index = reversed (b .index ))
2106
- result = dot_fn ( a , b )
2105
+ result = a . dot ( b )
2107
2106
tm .assert_frame_equal (result , expected )
2108
2107
2109
2108
# Check series argument
2110
- result = dot_fn ( a , b ['one' ])
2109
+ result = a . dot ( b ['one' ])
2111
2110
tm .assert_series_equal (result , expected ['one' ], check_names = False )
2112
2111
assert result .name is None
2113
2112
2114
- result = dot_fn ( a , b1 ['one' ])
2113
+ result = a . dot ( b1 ['one' ])
2115
2114
tm .assert_series_equal (result , expected ['one' ], check_names = False )
2116
2115
assert result .name is None
2117
2116
2118
2117
# can pass correct-length arrays
2119
2118
row = a .iloc [0 ].values
2120
2119
2121
- result = dot_fn ( a , row )
2122
- exp = dot_fn ( a , a .iloc [0 ])
2120
+ result = a . dot ( row )
2121
+ exp = a . dot ( a .iloc [0 ])
2123
2122
tm .assert_series_equal (result , exp )
2124
2123
2125
2124
with tm .assert_raises_regex (ValueError ,
2126
2125
'Dot product shape mismatch' ):
2127
- dot_fn ( a , row [:- 1 ])
2126
+ a . dot ( row [:- 1 ])
2128
2127
2129
2128
a = np .random .rand (1 , 5 )
2130
2129
b = np .random .rand (5 , 1 )
@@ -2134,14 +2133,55 @@ def test_dot(self, dot_fn):
2134
2133
B = DataFrame (b ) # noqa
2135
2134
2136
2135
# it works
2137
- result = dot_fn ( A , b )
2136
+ result = A . dot ( b )
2138
2137
2139
2138
# unaligned
2140
2139
df = DataFrame (randn (3 , 4 ), index = [1 , 2 , 3 ], columns = lrange (4 ))
2141
2140
df2 = DataFrame (randn (5 , 3 ), index = lrange (5 ), columns = [1 , 2 , 3 ])
2142
2141
2143
2142
with tm .assert_raises_regex (ValueError , 'aligned' ):
2144
- dot_fn (df , df2 )
2143
+ df .dot (df2 )
2144
+
2145
+ @pytest .mark .skipif (sys .version_info < (3 , 5 ),
2146
+ reason = 'matmul supported for Python>=3.5' )
2147
+ def test_matmul (self ):
2148
+ # matmul test is for GH #10259
2149
+ a = DataFrame (np .random .randn (3 , 4 ), index = ['a' , 'b' , 'c' ],
2150
+ columns = ['p' , 'q' , 'r' , 's' ])
2151
+ b = DataFrame (np .random .randn (4 , 2 ), index = ['p' , 'q' , 'r' , 's' ],
2152
+ columns = ['one' , 'two' ])
2153
+
2154
+ # DataFrame @ DataFrame
2155
+ result = operator .matmul (a , b )
2156
+ expected = DataFrame (np .dot (a .values , b .values ),
2157
+ index = ['a' , 'b' , 'c' ],
2158
+ columns = ['one' , 'two' ])
2159
+ tm .assert_frame_equal (result , expected )
2160
+
2161
+ # DataFrame @ Series
2162
+ result = operator .matmul (a , b .one )
2163
+ expected = Series (np .dot (a .values , b .one .values ),
2164
+ index = ['a' , 'b' , 'c' ])
2165
+ tm .assert_series_equal (result , expected )
2166
+
2167
+ # np.array @ DataFrame
2168
+ result = operator .matmul (a .values , b )
2169
+ expected = np .dot (a .values , b .values )
2170
+ tm .assert_almost_equal (result , expected )
2171
+
2172
+ # nested list @ DataFrame (__rmatmul__)
2173
+ result = operator .matmul (a .values .tolist (), b )
2174
+ expected = DataFrame (np .dot (a .values , b .values ),
2175
+ index = ['a' , 'b' , 'c' ],
2176
+ columns = ['one' , 'two' ])
2177
+ tm .assert_almost_equal (result .values , expected .values )
2178
+
2179
+ # unaligned
2180
+ df = DataFrame (randn (3 , 4 ), index = [1 , 2 , 3 ], columns = lrange (4 ))
2181
+ df2 = DataFrame (randn (5 , 3 ), index = lrange (5 ), columns = [1 , 2 , 3 ])
2182
+
2183
+ with tm .assert_raises_regex (ValueError , 'aligned' ):
2184
+ operator .matmul (df , df2 )
2145
2185
2146
2186
2147
2187
@pytest .fixture
0 commit comments