Skip to content

Commit fa753be

Browse files
committed
Add matmul to DataFrame, Series
1 parent 699a48b commit fa753be

File tree

5 files changed

+33
-20
lines changed

5 files changed

+33
-20
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ Other Enhancements
344344
- :meth:`DataFrame.to_sql` now performs a multivalue insert if the underlying connection supports itk rather than inserting row by row.
345345
``SQLAlchemy`` dialects supporting multivalue inserts include: ``mysql``, ``postgresql``, ``sqlite`` and any dialect with ``supports_multivalues_insert``. (:issue:`14315`, :issue:`8953`)
346346
- :func:`read_html` now accepts a ``displayed_only`` keyword argument to controls whether or not hidden elements are parsed (``True`` by default) (:issue:`20027`)
347+
- :class:`DataFrame` and :class:`Series` now support matrix multiplication (```@```) operator (:issue:`10259`) for Python>=3.5
347348

348349
.. _whatsnew_0230.api_breaking:
349350

pandas/core/frame.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,8 @@ def __len__(self):
863863

864864
def dot(self, other):
865865
"""
866-
Matrix multiplication with DataFrame or Series objects
866+
Matrix multiplication with DataFrame or Series objects. Can also be
867+
called using `self @ other` in Python >= 3.5.
867868
868869
Parameters
869870
----------
@@ -905,6 +906,10 @@ def dot(self, other):
905906
else: # pragma: no cover
906907
raise TypeError('unsupported type: %s' % type(other))
907908

909+
def __matmul__(self, other):
910+
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
911+
return self.dot(other)
912+
908913
# ----------------------------------------------------------------------
909914
# IO methods (to / from other formats)
910915

pandas/core/series.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1949,7 +1949,7 @@ def autocorr(self, lag=1):
19491949
def dot(self, other):
19501950
"""
19511951
Matrix multiplication with DataFrame or inner-product with Series
1952-
objects
1952+
objects. Can also be called using `self @ other` in Python >= 3.5.
19531953
19541954
Parameters
19551955
----------
@@ -1988,6 +1988,10 @@ def dot(self, other):
19881988
else: # pragma: no cover
19891989
raise TypeError('unsupported type: %s' % type(other))
19901990

1991+
def __matmul__(self, other):
1992+
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
1993+
return self.dot(other)
1994+
19911995
@Substitution(klass='Series')
19921996
@Appender(base._shared_docs['searchsorted'])
19931997
@deprecate_kwarg(old_arg_name='v', new_arg_name='value')

pandas/tests/frame/test_analytics.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -2089,41 +2089,42 @@ def test_clip_with_na_args(self):
20892089
self.frame)
20902090

20912091
# Matrix-like
2092-
2093-
def test_dot(self):
2092+
@pytest.mark.parametrize('dot_fn', [DataFrame.dot, DataFrame.__matmul__])
2093+
def test_dot(self, dot_fn):
2094+
# __matmul__ test is for GH #10259
20942095
a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'],
20952096
columns=['p', 'q', 'r', 's'])
20962097
b = DataFrame(np.random.randn(4, 2), index=['p', 'q', 'r', 's'],
20972098
columns=['one', 'two'])
20982099

2099-
result = a.dot(b)
2100+
result = dot_fn(a, b)
21002101
expected = DataFrame(np.dot(a.values, b.values),
21012102
index=['a', 'b', 'c'],
21022103
columns=['one', 'two'])
21032104
# Check alignment
21042105
b1 = b.reindex(index=reversed(b.index))
2105-
result = a.dot(b)
2106+
result = dot_fn(a, b)
21062107
tm.assert_frame_equal(result, expected)
21072108

21082109
# Check series argument
2109-
result = a.dot(b['one'])
2110+
result = dot_fn(a, b['one'])
21102111
tm.assert_series_equal(result, expected['one'], check_names=False)
21112112
assert result.name is None
21122113

2113-
result = a.dot(b1['one'])
2114+
result = dot_fn(a, b1['one'])
21142115
tm.assert_series_equal(result, expected['one'], check_names=False)
21152116
assert result.name is None
21162117

21172118
# can pass correct-length arrays
21182119
row = a.iloc[0].values
21192120

2120-
result = a.dot(row)
2121-
exp = a.dot(a.iloc[0])
2121+
result = dot_fn(a, row)
2122+
exp = dot_fn(a, a.iloc[0])
21222123
tm.assert_series_equal(result, exp)
21232124

21242125
with tm.assert_raises_regex(ValueError,
21252126
'Dot product shape mismatch'):
2126-
a.dot(row[:-1])
2127+
dot_fn(a, row[:-1])
21272128

21282129
a = np.random.rand(1, 5)
21292130
b = np.random.rand(5, 1)
@@ -2133,14 +2134,14 @@ def test_dot(self):
21332134
B = DataFrame(b) # noqa
21342135

21352136
# it works
2136-
result = A.dot(b)
2137+
result = dot_fn(A, b)
21372138

21382139
# unaligned
21392140
df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4))
21402141
df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3])
21412142

21422143
with tm.assert_raises_regex(ValueError, 'aligned'):
2143-
df.dot(df2)
2144+
dot_fn(df, df2)
21442145

21452146

21462147
@pytest.fixture

pandas/tests/series/test_analytics.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -895,28 +895,30 @@ def test_count(self):
895895
ts.iloc[[0, 3, 5]] = nan
896896
assert_series_equal(ts.count(level=1), right - 1)
897897

898-
def test_dot(self):
898+
@pytest.mark.parametrize('dot_fn', [Series.dot, Series.__matmul__])
899+
def test_dot(self, dot_fn):
900+
# __matmul__ test is for GH #10259
899901
a = Series(np.random.randn(4), index=['p', 'q', 'r', 's'])
900902
b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'],
901903
columns=['p', 'q', 'r', 's']).T
902904

903-
result = a.dot(b)
905+
result = dot_fn(a, b)
904906
expected = Series(np.dot(a.values, b.values), index=['1', '2', '3'])
905907
assert_series_equal(result, expected)
906908

907909
# Check index alignment
908910
b2 = b.reindex(index=reversed(b.index))
909-
result = a.dot(b)
911+
result = dot_fn(a, b)
910912
assert_series_equal(result, expected)
911913

912914
# Check ndarray argument
913-
result = a.dot(b.values)
915+
result = dot_fn(a, b.values)
914916
assert np.all(result == expected.values)
915-
assert_almost_equal(a.dot(b['2'].values), expected['2'])
917+
assert_almost_equal(dot_fn(a, b['2'].values), expected['2'])
916918

917919
# Check series argument
918-
assert_almost_equal(a.dot(b['1']), expected['1'])
919-
assert_almost_equal(a.dot(b2['1']), expected['1'])
920+
assert_almost_equal(dot_fn(a, b['1']), expected['1'])
921+
assert_almost_equal(dot_fn(a, b2['1']), expected['1'])
920922

921923
pytest.raises(Exception, a.dot, a.values[:3])
922924
pytest.raises(ValueError, a.dot, b.T)

0 commit comments

Comments
 (0)