Skip to content

Commit e391293

Browse files
committed
Add rmatmul to DataFrame, Series
1 parent fa753be commit e391293

File tree

4 files changed

+94
-21
lines changed

4 files changed

+94
-21
lines changed

pandas/core/frame.py

+4
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,10 @@ def __matmul__(self, other):
910910
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
911911
return self.dot(other)
912912

913+
def __rmatmul__(self, other):
914+
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
915+
return self.T.dot(np.transpose(T)).T
916+
913917
# ----------------------------------------------------------------------
914918
# IO methods (to / from other formats)
915919

pandas/core/series.py

+4
Original file line numberDiff line numberDiff line change
@@ -1992,6 +1992,10 @@ def __matmul__(self, other):
19921992
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
19931993
return self.dot(other)
19941994

1995+
def __rmatmul__(self, other):
1996+
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
1997+
return self.dot(other)
1998+
19951999
@Substitution(klass='Series')
19962000
@Appender(base._shared_docs['searchsorted'])
19972001
@deprecate_kwarg(old_arg_name='v', new_arg_name='value')

pandas/tests/frame/test_analytics.py

+45-12
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import warnings
66
from datetime import timedelta
77
from distutils.version import LooseVersion
8+
import operator
89
import sys
910
import pytest
1011

@@ -2089,42 +2090,40 @@ def test_clip_with_na_args(self):
20892090
self.frame)
20902091

20912092
# Matrix-like
2092-
@pytest.mark.parametrize('dot_fn', [DataFrame.dot, DataFrame.__matmul__])
2093-
def test_dot(self, dot_fn):
2094-
# __matmul__ test is for GH #10259
2093+
def test_dot(self):
20952094
a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'],
20962095
columns=['p', 'q', 'r', 's'])
20972096
b = DataFrame(np.random.randn(4, 2), index=['p', 'q', 'r', 's'],
20982097
columns=['one', 'two'])
20992098

2100-
result = dot_fn(a, b)
2099+
result = a.dot(b)
21012100
expected = DataFrame(np.dot(a.values, b.values),
21022101
index=['a', 'b', 'c'],
21032102
columns=['one', 'two'])
21042103
# Check alignment
21052104
b1 = b.reindex(index=reversed(b.index))
2106-
result = dot_fn(a, b)
2105+
result = a.dot(b)
21072106
tm.assert_frame_equal(result, expected)
21082107

21092108
# Check series argument
2110-
result = dot_fn(a, b['one'])
2109+
result = a.dot(b['one'])
21112110
tm.assert_series_equal(result, expected['one'], check_names=False)
21122111
assert result.name is None
21132112

2114-
result = dot_fn(a, b1['one'])
2113+
result = a.dot(b1['one'])
21152114
tm.assert_series_equal(result, expected['one'], check_names=False)
21162115
assert result.name is None
21172116

21182117
# can pass correct-length arrays
21192118
row = a.iloc[0].values
21202119

2121-
result = dot_fn(a, row)
2122-
exp = dot_fn(a, a.iloc[0])
2120+
result = a.dot(row)
2121+
exp = a.dot(a.iloc[0])
21232122
tm.assert_series_equal(result, exp)
21242123

21252124
with tm.assert_raises_regex(ValueError,
21262125
'Dot product shape mismatch'):
2127-
dot_fn(a, row[:-1])
2126+
a.dot(row[:-1])
21282127

21292128
a = np.random.rand(1, 5)
21302129
b = np.random.rand(5, 1)
@@ -2134,14 +2133,48 @@ def test_dot(self, dot_fn):
21342133
B = DataFrame(b) # noqa
21352134

21362135
# it works
2137-
result = dot_fn(A, b)
2136+
result = A.dot(b)
21382137

21392138
# unaligned
21402139
df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4))
21412140
df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3])
21422141

21432142
with tm.assert_raises_regex(ValueError, 'aligned'):
2144-
dot_fn(df, df2)
2143+
df.dot(df2)
2144+
2145+
@pytest.mark.skipif(sys.version_info < (3, 5),
2146+
reason='matmul supported for Python>=3.5')
2147+
def test_matmul(self):
2148+
# matmul test is for GH #10259
2149+
a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'],
2150+
columns=['p', 'q', 'r', 's'])
2151+
b = DataFrame(np.random.randn(4, 2), index=['p', 'q', 'r', 's'],
2152+
columns=['one', 'two'])
2153+
2154+
# DataFrame @ DataFrame
2155+
result = operator.matmul(a, b)
2156+
expected = DataFrame(np.dot(a.values, b.values),
2157+
index=['a', 'b', 'c'],
2158+
columns=['one', 'two'])
2159+
tm.assert_frame_equal(result, expected)
2160+
2161+
# DataFrame @ Series
2162+
result = operator.matmul(a, b.one)
2163+
expected = Series(np.dot(a.values, b.one.values),
2164+
index=['a', 'b', 'c'])
2165+
tm.assert_series_equal(result, expected)
2166+
2167+
# np.array @ DataFrame (__rmatmul__)
2168+
result = operator.matmul(a.values, b)
2169+
expected = np.dot(a.values, b.values)
2170+
tm.assert_almost_equal(result, expected)
2171+
2172+
# unaligned
2173+
df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4))
2174+
df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3])
2175+
2176+
with tm.assert_raises_regex(ValueError, 'aligned'):
2177+
operator.matmul(df, df2)
21452178

21462179

21472180
@pytest.fixture

pandas/tests/series/test_analytics.py

+41-9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from itertools import product
55
from distutils.version import LooseVersion
6+
import operator
7+
import sys
68

79
import pytest
810

@@ -895,30 +897,60 @@ def test_count(self):
895897
ts.iloc[[0, 3, 5]] = nan
896898
assert_series_equal(ts.count(level=1), right - 1)
897899

898-
@pytest.mark.parametrize('dot_fn', [Series.dot, Series.__matmul__])
899-
def test_dot(self, dot_fn):
900-
# __matmul__ test is for GH #10259
900+
def test_dot(self):
901901
a = Series(np.random.randn(4), index=['p', 'q', 'r', 's'])
902902
b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'],
903903
columns=['p', 'q', 'r', 's']).T
904904

905-
result = dot_fn(a, b)
905+
result = a.dot(b)
906906
expected = Series(np.dot(a.values, b.values), index=['1', '2', '3'])
907907
assert_series_equal(result, expected)
908908

909909
# Check index alignment
910910
b2 = b.reindex(index=reversed(b.index))
911-
result = dot_fn(a, b)
911+
result = a.dot(b)
912912
assert_series_equal(result, expected)
913913

914914
# Check ndarray argument
915-
result = dot_fn(a, b.values)
915+
result = a.dot(b.values)
916916
assert np.all(result == expected.values)
917-
assert_almost_equal(dot_fn(a, b['2'].values), expected['2'])
917+
assert_almost_equal(a.dot(b['2'].values), expected['2'])
918918

919919
# Check series argument
920-
assert_almost_equal(dot_fn(a, b['1']), expected['1'])
921-
assert_almost_equal(dot_fn(a, b2['1']), expected['1'])
920+
assert_almost_equal(a.dot(b['1']), expected['1'])
921+
assert_almost_equal(a.dot(b2['1']), expected['1'])
922+
923+
pytest.raises(Exception, a.dot, a.values[:3])
924+
pytest.raises(ValueError, a.dot, b.T)
925+
926+
@pytest.mark.skipif(sys.version_info < (3, 5),
927+
reason='matmul supported for Python>=3.5')
928+
def test_matmul(self):
929+
# matmul test is for GH #10259
930+
a = Series(np.random.randn(4), index=['p', 'q', 'r', 's'])
931+
b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'],
932+
columns=['p', 'q', 'r', 's']).T
933+
934+
# Series @ DataFrame
935+
result = operator.matmul(a, b)
936+
expected = Series(np.dot(a.values, b.values), index=['1', '2', '3'])
937+
assert_series_equal(result, expected)
938+
939+
# DataFrame @ Series
940+
result = operator.matmul(b.T, a)
941+
expected = Series(np.dot(b.T.values, a.T.values),
942+
index=['1', '2', '3'])
943+
assert_series_equal(result, expected)
944+
945+
# Series @ Series
946+
result = operator.matmul(a, a)
947+
expected = np.dot(a.values, a.values)
948+
assert_almost_equal(result, expected)
949+
950+
# np.array @ Series (__rmatmul__)
951+
result = operator.matmul(a.values, a)
952+
expected = np.dot(a.values, a.values)
953+
assert_almost_equal(result, expected)
922954

923955
pytest.raises(Exception, a.dot, a.values[:3])
924956
pytest.raises(ValueError, a.dot, b.T)

0 commit comments

Comments
 (0)