Skip to content

Commit 36a0cb8

Browse files
committed
Add matmul to DataFrame, Series
1 parent 718d067 commit 36a0cb8

File tree

5 files changed

+44
-19
lines changed

5 files changed

+44
-19
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ Other Enhancements
324324
- ``IntervalIndex.astype`` now supports conversions between subtypes when passed an ``IntervalDtype`` (:issue:`19197`)
325325
- :class:`IntervalIndex` and its associated constructor methods (``from_arrays``, ``from_breaks``, ``from_tuples``) have gained a ``dtype`` parameter (:issue:`19262`)
326326
- Added :func:`SeriesGroupBy.is_monotonic_increasing` and :func:`SeriesGroupBy.is_monotonic_decreasing` (:issue:`17015`)
327+
- :class:`DataFrame` and :class:`Series` now support matrix multiplication (```@```) operator (:issue:`10259`) for Python>=3.5
327328

328329
.. _whatsnew_0230.api_breaking:
329330

pandas/core/frame.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,8 @@ def __len__(self):
830830

831831
def dot(self, other):
832832
"""
833-
Matrix multiplication with DataFrame or Series objects
833+
Matrix multiplication with DataFrame or Series objects. Can also be
834+
called using `self @ other` in Python >= 3.5.
834835
835836
Parameters
836837
----------
@@ -872,6 +873,13 @@ def dot(self, other):
872873
else: # pragma: no cover
873874
raise TypeError('unsupported type: %s' % type(other))
874875

876+
def __matmul__(self, other):
877+
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
878+
try:
879+
return self.dot(other)
880+
except TypeError:
881+
return NotImplemented
882+
875883
# ----------------------------------------------------------------------
876884
# IO methods (to / from other formats)
877885

pandas/core/series.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1559,7 +1559,7 @@ def autocorr(self, lag=1):
15591559
def dot(self, other):
15601560
"""
15611561
Matrix multiplication with DataFrame or inner-product with Series
1562-
objects
1562+
objects. Can also be called using `self @ other` in Python >= 3.5.
15631563
15641564
Parameters
15651565
----------
@@ -1598,6 +1598,13 @@ def dot(self, other):
15981598
else: # pragma: no cover
15991599
raise TypeError('unsupported type: %s' % type(other))
16001600

1601+
def __matmul__(self, other):
1602+
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
1603+
try:
1604+
return self.dot(other)
1605+
except TypeError:
1606+
return NotImplemented
1607+
16011608
@Substitution(klass='Series')
16021609
@Appender(base._shared_docs['searchsorted'])
16031610
@deprecate_kwarg(old_arg_name='v', new_arg_name='value')

pandas/tests/frame/test_analytics.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -2075,41 +2075,42 @@ def test_clip_with_na_args(self):
20752075
self.frame)
20762076

20772077
# Matrix-like
2078-
2078+
@pytest.mark.parametrize('dot_fn', [DataFrame.dot, DataFrame.__matmul__])
20792079
def test_dot(self):
2080+
# __matmul__ test is for GH #10259
20802081
a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'],
20812082
columns=['p', 'q', 'r', 's'])
20822083
b = DataFrame(np.random.randn(4, 2), index=['p', 'q', 'r', 's'],
20832084
columns=['one', 'two'])
20842085

2085-
result = a.dot(b)
2086+
result = dot_fn(a, b)
20862087
expected = DataFrame(np.dot(a.values, b.values),
20872088
index=['a', 'b', 'c'],
20882089
columns=['one', 'two'])
20892090
# Check alignment
20902091
b1 = b.reindex(index=reversed(b.index))
2091-
result = a.dot(b)
2092+
result = dot_fn(a, b)
20922093
tm.assert_frame_equal(result, expected)
20932094

20942095
# Check series argument
2095-
result = a.dot(b['one'])
2096+
result = dot_fn(a, b['one'])
20962097
tm.assert_series_equal(result, expected['one'], check_names=False)
20972098
assert result.name is None
20982099

2099-
result = a.dot(b1['one'])
2100+
result = dot_fn(a, b1['one'])
21002101
tm.assert_series_equal(result, expected['one'], check_names=False)
21012102
assert result.name is None
21022103

21032104
# can pass correct-length arrays
21042105
row = a.iloc[0].values
21052106

2106-
result = a.dot(row)
2107-
exp = a.dot(a.iloc[0])
2107+
result = dot_fn(a, row)
2108+
exp = dot_fn(a, a.iloc[0])
21082109
tm.assert_series_equal(result, exp)
21092110

21102111
with tm.assert_raises_regex(ValueError,
21112112
'Dot product shape mismatch'):
2112-
a.dot(row[:-1])
2113+
dot_fn(a, row[:-1])
21132114

21142115
a = np.random.rand(1, 5)
21152116
b = np.random.rand(5, 1)
@@ -2119,14 +2120,19 @@ def test_dot(self):
21192120
B = DataFrame(b) # noqa
21202121

21212122
# it works
2122-
result = A.dot(b)
2123+
result = dot_fn(A, b)
21232124

21242125
# unaligned
21252126
df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4))
21262127
df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3])
21272128

21282129
with tm.assert_raises_regex(ValueError, 'aligned'):
2129-
df.dot(df2)
2130+
dot_fn(df, df2)
2131+
2132+
# invalid type
2133+
with tm.assert_raises_regex(TypeError, 'unsupported'):
2134+
x = 1
2135+
dot_fn(df, 1)
21302136

21312137

21322138
@pytest.fixture

pandas/tests/series/test_analytics.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -895,31 +895,34 @@ def test_count(self):
895895
ts.iloc[[0, 3, 5]] = nan
896896
assert_series_equal(ts.count(level=1), right - 1)
897897

898-
def test_dot(self):
898+
@pytest.mark.parametrize('dot_fn', [Series.dot, Series.__matmul__])
899+
def test_dot(self, dot_fn):
900+
# __matmul__ test is for GH #10259
899901
a = Series(np.random.randn(4), index=['p', 'q', 'r', 's'])
900902
b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'],
901903
columns=['p', 'q', 'r', 's']).T
902904

903-
result = a.dot(b)
905+
result = dot_fn(a, b)
904906
expected = Series(np.dot(a.values, b.values), index=['1', '2', '3'])
905907
assert_series_equal(result, expected)
906908

907909
# Check index alignment
908910
b2 = b.reindex(index=reversed(b.index))
909-
result = a.dot(b)
911+
result = dot_fn(a, b)
910912
assert_series_equal(result, expected)
911913

912914
# Check ndarray argument
913-
result = a.dot(b.values)
915+
result = dot_fn(a, b.values)
914916
assert np.all(result == expected.values)
915-
assert_almost_equal(a.dot(b['2'].values), expected['2'])
917+
assert_almost_equal(dot_fn(a, b['2'].values), expected['2'])
916918

917919
# Check series argument
918-
assert_almost_equal(a.dot(b['1']), expected['1'])
919-
assert_almost_equal(a.dot(b2['1']), expected['1'])
920+
assert_almost_equal(dot_fn(a, b['1']), expected['1'])
921+
assert_almost_equal(dot_fn(a, b2['1']), expected['1'])
920922

921923
pytest.raises(Exception, a.dot, a.values[:3])
922924
pytest.raises(ValueError, a.dot, b.T)
925+
pytest.raises(TypeError, a.dot, 1)
923926

924927
def test_value_counts_nunique(self):
925928

0 commit comments

Comments
 (0)