Skip to content

Commit d6f156b

Browse files
committed
Add rmatmul to DataFrame, Series
1 parent 7db1e25 commit d6f156b

File tree

4 files changed

+101
-21
lines changed

4 files changed

+101
-21
lines changed

pandas/core/frame.py

+4
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,10 @@ def __matmul__(self, other):
910910
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
911911
return self.dot(other)
912912

913+
def __rmatmul__(self, other):
914+
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
915+
return self.T.dot(np.transpose(other)).T
916+
913917
# ----------------------------------------------------------------------
914918
# IO methods (to / from other formats)
915919

pandas/core/series.py

+4
Original file line numberDiff line numberDiff line change
@@ -2035,6 +2035,10 @@ def __matmul__(self, other):
20352035
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
20362036
return self.dot(other)
20372037

2038+
def __rmatmul__(self, other):
2039+
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
2040+
return self.dot(other)
2041+
20382042
@Substitution(klass='Series')
20392043
@Appender(base._shared_docs['searchsorted'])
20402044
@deprecate_kwarg(old_arg_name='v', new_arg_name='value')

pandas/tests/frame/test_analytics.py

+52-12
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import warnings
66
from datetime import timedelta
77
from distutils.version import LooseVersion
8+
import operator
89
import sys
910
import pytest
1011

@@ -2091,42 +2092,40 @@ def test_clip_with_na_args(self):
20912092
self.frame)
20922093

20932094
# Matrix-like
2094-
@pytest.mark.parametrize('dot_fn', [DataFrame.dot, DataFrame.__matmul__])
2095-
def test_dot(self, dot_fn):
2096-
# __matmul__ test is for GH #10259
2095+
def test_dot(self):
20972096
a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'],
20982097
columns=['p', 'q', 'r', 's'])
20992098
b = DataFrame(np.random.randn(4, 2), index=['p', 'q', 'r', 's'],
21002099
columns=['one', 'two'])
21012100

2102-
result = dot_fn(a, b)
2101+
result = a.dot(b)
21032102
expected = DataFrame(np.dot(a.values, b.values),
21042103
index=['a', 'b', 'c'],
21052104
columns=['one', 'two'])
21062105
# Check alignment
21072106
b1 = b.reindex(index=reversed(b.index))
2108-
result = dot_fn(a, b)
2107+
result = a.dot(b)
21092108
tm.assert_frame_equal(result, expected)
21102109

21112110
# Check series argument
2112-
result = dot_fn(a, b['one'])
2111+
result = a.dot(b['one'])
21132112
tm.assert_series_equal(result, expected['one'], check_names=False)
21142113
assert result.name is None
21152114

2116-
result = dot_fn(a, b1['one'])
2115+
result = a.dot(b1['one'])
21172116
tm.assert_series_equal(result, expected['one'], check_names=False)
21182117
assert result.name is None
21192118

21202119
# can pass correct-length arrays
21212120
row = a.iloc[0].values
21222121

2123-
result = dot_fn(a, row)
2124-
exp = dot_fn(a, a.iloc[0])
2122+
result = a.dot(row)
2123+
exp = a.dot(a.iloc[0])
21252124
tm.assert_series_equal(result, exp)
21262125

21272126
with tm.assert_raises_regex(ValueError,
21282127
'Dot product shape mismatch'):
2129-
dot_fn(a, row[:-1])
2128+
a.dot(row[:-1])
21302129

21312130
a = np.random.rand(1, 5)
21322131
b = np.random.rand(5, 1)
@@ -2136,14 +2135,55 @@ def test_dot(self, dot_fn):
21362135
B = DataFrame(b) # noqa
21372136

21382137
# it works
2139-
result = dot_fn(A, b)
2138+
result = A.dot(b)
21402139

21412140
# unaligned
21422141
df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4))
21432142
df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3])
21442143

21452144
with tm.assert_raises_regex(ValueError, 'aligned'):
2146-
dot_fn(df, df2)
2145+
df.dot(df2)
2146+
2147+
@pytest.mark.skipif(sys.version_info < (3, 5),
2148+
reason='matmul supported for Python>=3.5')
2149+
def test_matmul(self):
2150+
# matmul test is for GH #10259
2151+
a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'],
2152+
columns=['p', 'q', 'r', 's'])
2153+
b = DataFrame(np.random.randn(4, 2), index=['p', 'q', 'r', 's'],
2154+
columns=['one', 'two'])
2155+
2156+
# DataFrame @ DataFrame
2157+
result = operator.matmul(a, b)
2158+
expected = DataFrame(np.dot(a.values, b.values),
2159+
index=['a', 'b', 'c'],
2160+
columns=['one', 'two'])
2161+
tm.assert_frame_equal(result, expected)
2162+
2163+
# DataFrame @ Series
2164+
result = operator.matmul(a, b.one)
2165+
expected = Series(np.dot(a.values, b.one.values),
2166+
index=['a', 'b', 'c'])
2167+
tm.assert_series_equal(result, expected)
2168+
2169+
# np.array @ DataFrame
2170+
result = operator.matmul(a.values, b)
2171+
expected = np.dot(a.values, b.values)
2172+
tm.assert_almost_equal(result, expected)
2173+
2174+
# nested list @ DataFrame (__rmatmul__)
2175+
result = operator.matmul(a.values.tolist(), b)
2176+
expected = DataFrame(np.dot(a.values, b.values),
2177+
index=['a', 'b', 'c'],
2178+
columns=['one', 'two'])
2179+
tm.assert_almost_equal(result.values, expected.values)
2180+
2181+
# unaligned
2182+
df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4))
2183+
df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3])
2184+
2185+
with tm.assert_raises_regex(ValueError, 'aligned'):
2186+
operator.matmul(df, df2)
21472187

21482188

21492189
@pytest.fixture

pandas/tests/series/test_analytics.py

+41-9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from itertools import product
55
from distutils.version import LooseVersion
6+
import operator
7+
import sys
68

79
import pytest
810

@@ -895,30 +897,60 @@ def test_count(self):
895897
ts.iloc[[0, 3, 5]] = nan
896898
assert_series_equal(ts.count(level=1), right - 1)
897899

898-
@pytest.mark.parametrize('dot_fn', [Series.dot, Series.__matmul__])
899-
def test_dot(self, dot_fn):
900-
# __matmul__ test is for GH #10259
900+
def test_dot(self):
901901
a = Series(np.random.randn(4), index=['p', 'q', 'r', 's'])
902902
b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'],
903903
columns=['p', 'q', 'r', 's']).T
904904

905-
result = dot_fn(a, b)
905+
result = a.dot(b)
906906
expected = Series(np.dot(a.values, b.values), index=['1', '2', '3'])
907907
assert_series_equal(result, expected)
908908

909909
# Check index alignment
910910
b2 = b.reindex(index=reversed(b.index))
911-
result = dot_fn(a, b)
911+
result = a.dot(b)
912912
assert_series_equal(result, expected)
913913

914914
# Check ndarray argument
915-
result = dot_fn(a, b.values)
915+
result = a.dot(b.values)
916916
assert np.all(result == expected.values)
917-
assert_almost_equal(dot_fn(a, b['2'].values), expected['2'])
917+
assert_almost_equal(a.dot(b['2'].values), expected['2'])
918918

919919
# Check series argument
920-
assert_almost_equal(dot_fn(a, b['1']), expected['1'])
921-
assert_almost_equal(dot_fn(a, b2['1']), expected['1'])
920+
assert_almost_equal(a.dot(b['1']), expected['1'])
921+
assert_almost_equal(a.dot(b2['1']), expected['1'])
922+
923+
pytest.raises(Exception, a.dot, a.values[:3])
924+
pytest.raises(ValueError, a.dot, b.T)
925+
926+
@pytest.mark.skipif(sys.version_info < (3, 5),
927+
reason='matmul supported for Python>=3.5')
928+
def test_matmul(self):
929+
# matmul test is for GH #10259
930+
a = Series(np.random.randn(4), index=['p', 'q', 'r', 's'])
931+
b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'],
932+
columns=['p', 'q', 'r', 's']).T
933+
934+
# Series @ DataFrame
935+
result = operator.matmul(a, b)
936+
expected = Series(np.dot(a.values, b.values), index=['1', '2', '3'])
937+
assert_series_equal(result, expected)
938+
939+
# DataFrame @ Series
940+
result = operator.matmul(b.T, a)
941+
expected = Series(np.dot(b.T.values, a.T.values),
942+
index=['1', '2', '3'])
943+
assert_series_equal(result, expected)
944+
945+
# Series @ Series
946+
result = operator.matmul(a, a)
947+
expected = np.dot(a.values, a.values)
948+
assert_almost_equal(result, expected)
949+
950+
# np.array @ Series (__rmatmul__)
951+
result = operator.matmul(a.values, a)
952+
expected = np.dot(a.values, a.values)
953+
assert_almost_equal(result, expected)
922954

923955
pytest.raises(Exception, a.dot, a.values[:3])
924956
pytest.raises(ValueError, a.dot, b.T)

0 commit comments

Comments
 (0)