Skip to content

Commit 94662fa

Browse files
bnaulkornilova203
authored andcommitted
ENH: Add matmul to DataFrame, Series (pandas-dev#19035)
1 parent b3e5292 commit 94662fa

File tree

5 files changed

+127
-6
lines changed

5 files changed

+127
-6
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ Other Enhancements
403403
``SQLAlchemy`` dialects supporting multivalue inserts include: ``mysql``, ``postgresql``, ``sqlite`` and any dialect with ``supports_multivalues_insert``. (:issue:`14315`, :issue:`8953`)
404404
- :func:`read_html` now accepts a ``displayed_only`` keyword argument to controls whether or not hidden elements are parsed (``True`` by default) (:issue:`20027`)
405405
- zip compression is supported via ``compression=zip`` in :func:`DataFrame.to_pickle`, :func:`Series.to_pickle`, :func:`DataFrame.to_csv`, :func:`Series.to_csv`, :func:`DataFrame.to_json`, :func:`Series.to_json`. (:issue:`17778`)
406+
- :class:`DataFrame` and :class:`Series` now support matrix multiplication (```@```) operator (:issue:`10259`) for Python>=3.5
406407

407408
.. _whatsnew_0230.api_breaking:
408409

pandas/core/frame.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,8 @@ def __len__(self):
863863

864864
def dot(self, other):
865865
"""
866-
Matrix multiplication with DataFrame or Series objects
866+
Matrix multiplication with DataFrame or Series objects. Can also be
867+
called using `self @ other` in Python >= 3.5.
867868
868869
Parameters
869870
----------
@@ -905,6 +906,14 @@ def dot(self, other):
905906
else: # pragma: no cover
906907
raise TypeError('unsupported type: %s' % type(other))
907908

909+
def __matmul__(self, other):
910+
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
911+
return self.dot(other)
912+
913+
def __rmatmul__(self, other):
914+
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
915+
return self.T.dot(np.transpose(other)).T
916+
908917
# ----------------------------------------------------------------------
909918
# IO methods (to / from other formats)
910919

pandas/core/series.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1997,7 +1997,7 @@ def autocorr(self, lag=1):
19971997
def dot(self, other):
19981998
"""
19991999
Matrix multiplication with DataFrame or inner-product with Series
2000-
objects
2000+
objects. Can also be called using `self @ other` in Python >= 3.5.
20012001
20022002
Parameters
20032003
----------
@@ -2036,6 +2036,14 @@ def dot(self, other):
20362036
else: # pragma: no cover
20372037
raise TypeError('unsupported type: %s' % type(other))
20382038

2039+
def __matmul__(self, other):
2040+
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
2041+
return self.dot(other)
2042+
2043+
def __rmatmul__(self, other):
2044+
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
2045+
return self.dot(other)
2046+
20392047
@Substitution(klass='Series')
20402048
@Appender(base._shared_docs['searchsorted'])
20412049
@deprecate_kwarg(old_arg_name='v', new_arg_name='value')

pandas/tests/frame/test_analytics.py

+59-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import warnings
66
from datetime import timedelta
77
from distutils.version import LooseVersion
8+
import operator
89
import sys
910
import pytest
1011

@@ -13,7 +14,7 @@
1314
from numpy.random import randn
1415
import numpy as np
1516

16-
from pandas.compat import lrange, product
17+
from pandas.compat import lrange, product, PY35
1718
from pandas import (compat, isna, notna, DataFrame, Series,
1819
MultiIndex, date_range, Timestamp, Categorical,
1920
_np_version_under1p15)
@@ -2091,7 +2092,6 @@ def test_clip_with_na_args(self):
20912092
self.frame)
20922093

20932094
# Matrix-like
2094-
20952095
def test_dot(self):
20962096
a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'],
20972097
columns=['p', 'q', 'r', 's'])
@@ -2144,6 +2144,63 @@ def test_dot(self):
21442144
with tm.assert_raises_regex(ValueError, 'aligned'):
21452145
df.dot(df2)
21462146

2147+
@pytest.mark.skipif(not PY35,
2148+
reason='matmul supported for Python>=3.5')
2149+
def test_matmul(self):
2150+
# matmul test is for GH #10259
2151+
a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'],
2152+
columns=['p', 'q', 'r', 's'])
2153+
b = DataFrame(np.random.randn(4, 2), index=['p', 'q', 'r', 's'],
2154+
columns=['one', 'two'])
2155+
2156+
# DataFrame @ DataFrame
2157+
result = operator.matmul(a, b)
2158+
expected = DataFrame(np.dot(a.values, b.values),
2159+
index=['a', 'b', 'c'],
2160+
columns=['one', 'two'])
2161+
tm.assert_frame_equal(result, expected)
2162+
2163+
# DataFrame @ Series
2164+
result = operator.matmul(a, b.one)
2165+
expected = Series(np.dot(a.values, b.one.values),
2166+
index=['a', 'b', 'c'])
2167+
tm.assert_series_equal(result, expected)
2168+
2169+
# np.array @ DataFrame
2170+
result = operator.matmul(a.values, b)
2171+
expected = np.dot(a.values, b.values)
2172+
tm.assert_almost_equal(result, expected)
2173+
2174+
# nested list @ DataFrame (__rmatmul__)
2175+
result = operator.matmul(a.values.tolist(), b)
2176+
expected = DataFrame(np.dot(a.values, b.values),
2177+
index=['a', 'b', 'c'],
2178+
columns=['one', 'two'])
2179+
tm.assert_almost_equal(result.values, expected.values)
2180+
2181+
# mixed dtype DataFrame @ DataFrame
2182+
a['q'] = a.q.round().astype(int)
2183+
result = operator.matmul(a, b)
2184+
expected = DataFrame(np.dot(a.values, b.values),
2185+
index=['a', 'b', 'c'],
2186+
columns=['one', 'two'])
2187+
tm.assert_frame_equal(result, expected)
2188+
2189+
# different dtypes DataFrame @ DataFrame
2190+
a = a.astype(int)
2191+
result = operator.matmul(a, b)
2192+
expected = DataFrame(np.dot(a.values, b.values),
2193+
index=['a', 'b', 'c'],
2194+
columns=['one', 'two'])
2195+
tm.assert_frame_equal(result, expected)
2196+
2197+
# unaligned
2198+
df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4))
2199+
df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3])
2200+
2201+
with tm.assert_raises_regex(ValueError, 'aligned'):
2202+
operator.matmul(df, df2)
2203+
21472204

21482205
@pytest.fixture
21492206
def df_duplicates():

pandas/tests/series/test_analytics.py

+48-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from itertools import product
55
from distutils.version import LooseVersion
6-
6+
import operator
77
import pytest
88

99
from numpy import nan
@@ -18,7 +18,7 @@
1818
from pandas.core.indexes.timedeltas import Timedelta
1919
import pandas.core.nanops as nanops
2020

21-
from pandas.compat import lrange, range
21+
from pandas.compat import lrange, range, PY35
2222
from pandas import compat
2323
from pandas.util.testing import (assert_series_equal, assert_almost_equal,
2424
assert_frame_equal, assert_index_equal)
@@ -921,6 +921,52 @@ def test_dot(self):
921921
pytest.raises(Exception, a.dot, a.values[:3])
922922
pytest.raises(ValueError, a.dot, b.T)
923923

924+
@pytest.mark.skipif(not PY35,
925+
reason='matmul supported for Python>=3.5')
926+
def test_matmul(self):
927+
# matmul test is for GH #10259
928+
a = Series(np.random.randn(4), index=['p', 'q', 'r', 's'])
929+
b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'],
930+
columns=['p', 'q', 'r', 's']).T
931+
932+
# Series @ DataFrame
933+
result = operator.matmul(a, b)
934+
expected = Series(np.dot(a.values, b.values), index=['1', '2', '3'])
935+
assert_series_equal(result, expected)
936+
937+
# DataFrame @ Series
938+
result = operator.matmul(b.T, a)
939+
expected = Series(np.dot(b.T.values, a.T.values),
940+
index=['1', '2', '3'])
941+
assert_series_equal(result, expected)
942+
943+
# Series @ Series
944+
result = operator.matmul(a, a)
945+
expected = np.dot(a.values, a.values)
946+
assert_almost_equal(result, expected)
947+
948+
# np.array @ Series (__rmatmul__)
949+
result = operator.matmul(a.values, a)
950+
expected = np.dot(a.values, a.values)
951+
assert_almost_equal(result, expected)
952+
953+
# mixed dtype DataFrame @ Series
954+
a['p'] = int(a.p)
955+
result = operator.matmul(b.T, a)
956+
expected = Series(np.dot(b.T.values, a.T.values),
957+
index=['1', '2', '3'])
958+
assert_series_equal(result, expected)
959+
960+
# different dtypes DataFrame @ Series
961+
a = a.astype(int)
962+
result = operator.matmul(b.T, a)
963+
expected = Series(np.dot(b.T.values, a.T.values),
964+
index=['1', '2', '3'])
965+
assert_series_equal(result, expected)
966+
967+
pytest.raises(Exception, a.dot, a.values[:3])
968+
pytest.raises(ValueError, a.dot, b.T)
969+
924970
def test_value_counts_nunique(self):
925971

926972
# basics.rst doc example

0 commit comments

Comments
 (0)