Skip to content

Commit 5519e13

Browse files
committed
Add matmul to DataFrame, Series
1 parent ee9c7e9 commit 5519e13

File tree

5 files changed

+33
-17
lines changed

5 files changed

+33
-17
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ Other Enhancements
142142
- ``Categorical.rename_categories``, ``CategoricalIndex.rename_categories`` and :attr:`Series.cat.rename_categories`
143143
can now take a callable as their argument (:issue:`18862`)
144144
- :class:`Interval` and :class:`IntervalIndex` have gained a ``length`` attribute (:issue:`18789`)
145+
- :class:`DataFrame` and :class:`Series` now support matrix multiplication (```@```) operator (:issue:`10259`)
145146

146147
.. _whatsnew_0230.api_breaking:
147148

pandas/core/frame.py

+6
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,12 @@ def dot(self, other):
866866
else: # pragma: no cover
867867
raise TypeError('unsupported type: %s' % type(other))
868868

869+
def __matmul__(self, other):
870+
try:
871+
return self.dot(other)
872+
except TypeError:
873+
return NotImplemented
874+
869875
# ----------------------------------------------------------------------
870876
# IO methods (to / from other formats)
871877

pandas/core/series.py

+6
Original file line numberDiff line numberDiff line change
@@ -1625,6 +1625,12 @@ def dot(self, other):
16251625
else: # pragma: no cover
16261626
raise TypeError('unsupported type: %s' % type(other))
16271627

1628+
def __matmul__(self, other):
1629+
try:
1630+
return self.dot(other)
1631+
except TypeError:
1632+
return NotImplemented
1633+
16281634
@Substitution(klass='Series')
16291635
@Appender(base._shared_docs['searchsorted'])
16301636
@deprecate_kwarg(old_arg_name='v', new_arg_name='value')

pandas/tests/frame/test_analytics.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -2014,41 +2014,42 @@ def test_clip_with_na_args(self):
20142014
self.frame)
20152015

20162016
# Matrix-like
2017-
2017+
@pytest.mark.parametrize('dot_fn', [DataFrame.dot, DataFrame.__matmul__])
20182018
def test_dot(self):
2019+
# __matmul__ test is for GH #10259
20192020
a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'],
20202021
columns=['p', 'q', 'r', 's'])
20212022
b = DataFrame(np.random.randn(4, 2), index=['p', 'q', 'r', 's'],
20222023
columns=['one', 'two'])
20232024

2024-
result = a.dot(b)
2025+
result = dot_fn(a, b)
20252026
expected = DataFrame(np.dot(a.values, b.values),
20262027
index=['a', 'b', 'c'],
20272028
columns=['one', 'two'])
20282029
# Check alignment
20292030
b1 = b.reindex(index=reversed(b.index))
2030-
result = a.dot(b)
2031+
result = dot_fn(a, b)
20312032
tm.assert_frame_equal(result, expected)
20322033

20332034
# Check series argument
2034-
result = a.dot(b['one'])
2035+
result = dot_fn(a, b['one'])
20352036
tm.assert_series_equal(result, expected['one'], check_names=False)
20362037
assert result.name is None
20372038

2038-
result = a.dot(b1['one'])
2039+
result = dot_fn(a, b1['one'])
20392040
tm.assert_series_equal(result, expected['one'], check_names=False)
20402041
assert result.name is None
20412042

20422043
# can pass correct-length arrays
20432044
row = a.iloc[0].values
20442045

2045-
result = a.dot(row)
2046-
exp = a.dot(a.iloc[0])
2046+
result = dot_fn(a, row)
2047+
exp = dot_fn(a, a.iloc[0])
20472048
tm.assert_series_equal(result, exp)
20482049

20492050
with tm.assert_raises_regex(ValueError,
20502051
'Dot product shape mismatch'):
2051-
a.dot(row[:-1])
2052+
dot_fn(a, row[:-1])
20522053

20532054
a = np.random.rand(1, 5)
20542055
b = np.random.rand(5, 1)
@@ -2058,14 +2059,14 @@ def test_dot(self):
20582059
B = DataFrame(b) # noqa
20592060

20602061
# it works
2061-
result = A.dot(b)
2062+
result = dot_fn(A, b)
20622063

20632064
# unaligned
20642065
df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4))
20652066
df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3])
20662067

20672068
with tm.assert_raises_regex(ValueError, 'aligned'):
2068-
df.dot(df2)
2069+
dot_fn(df, df2)
20692070

20702071

20712072
@pytest.fixture

pandas/tests/series/test_analytics.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -811,28 +811,30 @@ def test_count(self):
811811
ts.iloc[[0, 3, 5]] = nan
812812
assert_series_equal(ts.count(level=1), right - 1)
813813

814-
def test_dot(self):
814+
@pytest.mark.parametrize('dot_fn', [Series.dot, Series.__matmul__])
815+
def test_dot(self, dot_fn):
816+
# __matmul__ test is for GH #10259
815817
a = Series(np.random.randn(4), index=['p', 'q', 'r', 's'])
816818
b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'],
817819
columns=['p', 'q', 'r', 's']).T
818820

819-
result = a.dot(b)
821+
result = dot_fn(a, b)
820822
expected = Series(np.dot(a.values, b.values), index=['1', '2', '3'])
821823
assert_series_equal(result, expected)
822824

823825
# Check index alignment
824826
b2 = b.reindex(index=reversed(b.index))
825-
result = a.dot(b)
827+
result = dot_fn(a, b)
826828
assert_series_equal(result, expected)
827829

828830
# Check ndarray argument
829-
result = a.dot(b.values)
831+
result = dot_fn(a, b.values)
830832
assert np.all(result == expected.values)
831-
assert_almost_equal(a.dot(b['2'].values), expected['2'])
833+
assert_almost_equal(dot_fn(a, b['2'].values), expected['2'])
832834

833835
# Check series argument
834-
assert_almost_equal(a.dot(b['1']), expected['1'])
835-
assert_almost_equal(a.dot(b2['1']), expected['1'])
836+
assert_almost_equal(dot_fn(a, b['1']), expected['1'])
837+
assert_almost_equal(dot_fn(a, b2['1']), expected['1'])
836838

837839
pytest.raises(Exception, a.dot, a.values[:3])
838840
pytest.raises(ValueError, a.dot, b.T)

0 commit comments

Comments
 (0)