Skip to content

Commit 7db1e25

Browse files
committed
Add matmul to DataFrame, Series
1 parent cdfce2b commit 7db1e25

File tree

5 files changed

+33
-20
lines changed

5 files changed

+33
-20
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ Other Enhancements
346346
``SQLAlchemy`` dialects supporting multivalue inserts include: ``mysql``, ``postgresql``, ``sqlite`` and any dialect with ``supports_multivalues_insert``. (:issue:`14315`, :issue:`8953`)
347347
- :func:`read_html` now accepts a ``displayed_only`` keyword argument to controls whether or not hidden elements are parsed (``True`` by default) (:issue:`20027`)
348348
- zip compression is supported via ``compression=zip`` in :func:`DataFrame.to_pickle`, :func:`Series.to_pickle`, :func:`DataFrame.to_csv`, :func:`Series.to_csv`, :func:`DataFrame.to_json`, :func:`Series.to_json`. (:issue:`17778`)
349+
- :class:`DataFrame` and :class:`Series` now support matrix multiplication (```@```) operator (:issue:`10259`) for Python>=3.5
349350

350351
.. _whatsnew_0230.api_breaking:
351352

pandas/core/frame.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,8 @@ def __len__(self):
863863

864864
def dot(self, other):
865865
"""
866-
Matrix multiplication with DataFrame or Series objects
866+
Matrix multiplication with DataFrame or Series objects. Can also be
867+
called using `self @ other` in Python >= 3.5.
867868
868869
Parameters
869870
----------
@@ -905,6 +906,10 @@ def dot(self, other):
905906
else: # pragma: no cover
906907
raise TypeError('unsupported type: %s' % type(other))
907908

909+
def __matmul__(self, other):
910+
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
911+
return self.dot(other)
912+
908913
# ----------------------------------------------------------------------
909914
# IO methods (to / from other formats)
910915

pandas/core/series.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1992,7 +1992,7 @@ def autocorr(self, lag=1):
19921992
def dot(self, other):
19931993
"""
19941994
Matrix multiplication with DataFrame or inner-product with Series
1995-
objects
1995+
objects. Can also be called using `self @ other` in Python >= 3.5.
19961996
19971997
Parameters
19981998
----------
@@ -2031,6 +2031,10 @@ def dot(self, other):
20312031
else: # pragma: no cover
20322032
raise TypeError('unsupported type: %s' % type(other))
20332033

2034+
def __matmul__(self, other):
2035+
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
2036+
return self.dot(other)
2037+
20342038
@Substitution(klass='Series')
20352039
@Appender(base._shared_docs['searchsorted'])
20362040
@deprecate_kwarg(old_arg_name='v', new_arg_name='value')

pandas/tests/frame/test_analytics.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -2091,41 +2091,42 @@ def test_clip_with_na_args(self):
20912091
self.frame)
20922092

20932093
# Matrix-like
2094-
2095-
def test_dot(self):
2094+
@pytest.mark.parametrize('dot_fn', [DataFrame.dot, DataFrame.__matmul__])
2095+
def test_dot(self, dot_fn):
2096+
# __matmul__ test is for GH #10259
20962097
a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'],
20972098
columns=['p', 'q', 'r', 's'])
20982099
b = DataFrame(np.random.randn(4, 2), index=['p', 'q', 'r', 's'],
20992100
columns=['one', 'two'])
21002101

2101-
result = a.dot(b)
2102+
result = dot_fn(a, b)
21022103
expected = DataFrame(np.dot(a.values, b.values),
21032104
index=['a', 'b', 'c'],
21042105
columns=['one', 'two'])
21052106
# Check alignment
21062107
b1 = b.reindex(index=reversed(b.index))
2107-
result = a.dot(b)
2108+
result = dot_fn(a, b)
21082109
tm.assert_frame_equal(result, expected)
21092110

21102111
# Check series argument
2111-
result = a.dot(b['one'])
2112+
result = dot_fn(a, b['one'])
21122113
tm.assert_series_equal(result, expected['one'], check_names=False)
21132114
assert result.name is None
21142115

2115-
result = a.dot(b1['one'])
2116+
result = dot_fn(a, b1['one'])
21162117
tm.assert_series_equal(result, expected['one'], check_names=False)
21172118
assert result.name is None
21182119

21192120
# can pass correct-length arrays
21202121
row = a.iloc[0].values
21212122

2122-
result = a.dot(row)
2123-
exp = a.dot(a.iloc[0])
2123+
result = dot_fn(a, row)
2124+
exp = dot_fn(a, a.iloc[0])
21242125
tm.assert_series_equal(result, exp)
21252126

21262127
with tm.assert_raises_regex(ValueError,
21272128
'Dot product shape mismatch'):
2128-
a.dot(row[:-1])
2129+
dot_fn(a, row[:-1])
21292130

21302131
a = np.random.rand(1, 5)
21312132
b = np.random.rand(5, 1)
@@ -2135,14 +2136,14 @@ def test_dot(self):
21352136
B = DataFrame(b) # noqa
21362137

21372138
# it works
2138-
result = A.dot(b)
2139+
result = dot_fn(A, b)
21392140

21402141
# unaligned
21412142
df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4))
21422143
df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3])
21432144

21442145
with tm.assert_raises_regex(ValueError, 'aligned'):
2145-
df.dot(df2)
2146+
dot_fn(df, df2)
21462147

21472148

21482149
@pytest.fixture

pandas/tests/series/test_analytics.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -895,28 +895,30 @@ def test_count(self):
895895
ts.iloc[[0, 3, 5]] = nan
896896
assert_series_equal(ts.count(level=1), right - 1)
897897

898-
def test_dot(self):
898+
@pytest.mark.parametrize('dot_fn', [Series.dot, Series.__matmul__])
899+
def test_dot(self, dot_fn):
900+
# __matmul__ test is for GH #10259
899901
a = Series(np.random.randn(4), index=['p', 'q', 'r', 's'])
900902
b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'],
901903
columns=['p', 'q', 'r', 's']).T
902904

903-
result = a.dot(b)
905+
result = dot_fn(a, b)
904906
expected = Series(np.dot(a.values, b.values), index=['1', '2', '3'])
905907
assert_series_equal(result, expected)
906908

907909
# Check index alignment
908910
b2 = b.reindex(index=reversed(b.index))
909-
result = a.dot(b)
911+
result = dot_fn(a, b)
910912
assert_series_equal(result, expected)
911913

912914
# Check ndarray argument
913-
result = a.dot(b.values)
915+
result = dot_fn(a, b.values)
914916
assert np.all(result == expected.values)
915-
assert_almost_equal(a.dot(b['2'].values), expected['2'])
917+
assert_almost_equal(dot_fn(a, b['2'].values), expected['2'])
916918

917919
# Check series argument
918-
assert_almost_equal(a.dot(b['1']), expected['1'])
919-
assert_almost_equal(a.dot(b2['1']), expected['1'])
920+
assert_almost_equal(dot_fn(a, b['1']), expected['1'])
921+
assert_almost_equal(dot_fn(a, b2['1']), expected['1'])
920922

921923
pytest.raises(Exception, a.dot, a.values[:3])
922924
pytest.raises(ValueError, a.dot, b.T)

0 commit comments

Comments
 (0)