Skip to content

Commit d5110ed

Browse files
committed
Add matmul to DataFrame, Series
1 parent bbfbe48 commit d5110ed

File tree

5 files changed

+33
-17
lines changed

5 files changed

+33
-17
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ Other Enhancements
202202
- ``Resampler`` objects now have a functioning :attr:`~pandas.core.resample.Resampler.pipe` method.
203203
Previously, calls to ``pipe`` were diverted to the ``mean`` method (:issue:`17905`).
204204
- :func:`~pandas.api.types.is_scalar` now returns ``True`` for ``DateOffset`` objects (:issue:`18943`).
205+
- :class:`DataFrame` and :class:`Series` now support matrix multiplication (```@```) operator (:issue:`10259`)
205206

206207
.. _whatsnew_0230.api_breaking:
207208

pandas/core/frame.py

+6
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,12 @@ def dot(self, other):
874874
else: # pragma: no cover
875875
raise TypeError('unsupported type: %s' % type(other))
876876

877+
def __matmul__(self, other):
878+
try:
879+
return self.dot(other)
880+
except TypeError:
881+
return NotImplemented
882+
877883
# ----------------------------------------------------------------------
878884
# IO methods (to / from other formats)
879885

pandas/core/series.py

+6
Original file line numberDiff line numberDiff line change
@@ -1596,6 +1596,12 @@ def dot(self, other):
15961596
else: # pragma: no cover
15971597
raise TypeError('unsupported type: %s' % type(other))
15981598

1599+
def __matmul__(self, other):
1600+
try:
1601+
return self.dot(other)
1602+
except TypeError:
1603+
return NotImplemented
1604+
15991605
@Substitution(klass='Series')
16001606
@Appender(base._shared_docs['searchsorted'])
16011607
@deprecate_kwarg(old_arg_name='v', new_arg_name='value')

pandas/tests/frame/test_analytics.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -2075,41 +2075,42 @@ def test_clip_with_na_args(self):
20752075
self.frame)
20762076

20772077
# Matrix-like
2078-
2078+
@pytest.mark.parametrize('dot_fn', [DataFrame.dot, DataFrame.__matmul__])
20792079
def test_dot(self):
2080+
# __matmul__ test is for GH #10259
20802081
a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'],
20812082
columns=['p', 'q', 'r', 's'])
20822083
b = DataFrame(np.random.randn(4, 2), index=['p', 'q', 'r', 's'],
20832084
columns=['one', 'two'])
20842085

2085-
result = a.dot(b)
2086+
result = dot_fn(a, b)
20862087
expected = DataFrame(np.dot(a.values, b.values),
20872088
index=['a', 'b', 'c'],
20882089
columns=['one', 'two'])
20892090
# Check alignment
20902091
b1 = b.reindex(index=reversed(b.index))
2091-
result = a.dot(b)
2092+
result = dot_fn(a, b)
20922093
tm.assert_frame_equal(result, expected)
20932094

20942095
# Check series argument
2095-
result = a.dot(b['one'])
2096+
result = dot_fn(a, b['one'])
20962097
tm.assert_series_equal(result, expected['one'], check_names=False)
20972098
assert result.name is None
20982099

2099-
result = a.dot(b1['one'])
2100+
result = dot_fn(a, b1['one'])
21002101
tm.assert_series_equal(result, expected['one'], check_names=False)
21012102
assert result.name is None
21022103

21032104
# can pass correct-length arrays
21042105
row = a.iloc[0].values
21052106

2106-
result = a.dot(row)
2107-
exp = a.dot(a.iloc[0])
2107+
result = dot_fn(a, row)
2108+
exp = dot_fn(a, a.iloc[0])
21082109
tm.assert_series_equal(result, exp)
21092110

21102111
with tm.assert_raises_regex(ValueError,
21112112
'Dot product shape mismatch'):
2112-
a.dot(row[:-1])
2113+
dot_fn(a, row[:-1])
21132114

21142115
a = np.random.rand(1, 5)
21152116
b = np.random.rand(5, 1)
@@ -2119,14 +2120,14 @@ def test_dot(self):
21192120
B = DataFrame(b) # noqa
21202121

21212122
# it works
2122-
result = A.dot(b)
2123+
result = dot_fn(A, b)
21232124

21242125
# unaligned
21252126
df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4))
21262127
df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3])
21272128

21282129
with tm.assert_raises_regex(ValueError, 'aligned'):
2129-
df.dot(df2)
2130+
dot_fn(df, df2)
21302131

21312132

21322133
@pytest.fixture

pandas/tests/series/test_analytics.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -895,28 +895,30 @@ def test_count(self):
895895
ts.iloc[[0, 3, 5]] = nan
896896
assert_series_equal(ts.count(level=1), right - 1)
897897

898-
def test_dot(self):
898+
@pytest.mark.parametrize('dot_fn', [Series.dot, Series.__matmul__])
899+
def test_dot(self, dot_fn):
900+
# __matmul__ test is for GH #10259
899901
a = Series(np.random.randn(4), index=['p', 'q', 'r', 's'])
900902
b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'],
901903
columns=['p', 'q', 'r', 's']).T
902904

903-
result = a.dot(b)
905+
result = dot_fn(a, b)
904906
expected = Series(np.dot(a.values, b.values), index=['1', '2', '3'])
905907
assert_series_equal(result, expected)
906908

907909
# Check index alignment
908910
b2 = b.reindex(index=reversed(b.index))
909-
result = a.dot(b)
911+
result = dot_fn(a, b)
910912
assert_series_equal(result, expected)
911913

912914
# Check ndarray argument
913-
result = a.dot(b.values)
915+
result = dot_fn(a, b.values)
914916
assert np.all(result == expected.values)
915-
assert_almost_equal(a.dot(b['2'].values), expected['2'])
917+
assert_almost_equal(dot_fn(a, b['2'].values), expected['2'])
916918

917919
# Check series argument
918-
assert_almost_equal(a.dot(b['1']), expected['1'])
919-
assert_almost_equal(a.dot(b2['1']), expected['1'])
920+
assert_almost_equal(dot_fn(a, b['1']), expected['1'])
921+
assert_almost_equal(dot_fn(a, b2['1']), expected['1'])
920922

921923
pytest.raises(Exception, a.dot, a.values[:3])
922924
pytest.raises(ValueError, a.dot, b.T)

0 commit comments

Comments
 (0)