diff --git a/doc/source/whatsnew/v0.23.0.txt b/doc/source/whatsnew/v0.23.0.txt index af5371b06192f..ce63cb2473bc4 100644 --- a/doc/source/whatsnew/v0.23.0.txt +++ b/doc/source/whatsnew/v0.23.0.txt @@ -403,6 +403,7 @@ Other Enhancements ``SQLAlchemy`` dialects supporting multivalue inserts include: ``mysql``, ``postgresql``, ``sqlite`` and any dialect with ``supports_multivalues_insert``. (:issue:`14315`, :issue:`8953`) - :func:`read_html` now accepts a ``displayed_only`` keyword argument to controls whether or not hidden elements are parsed (``True`` by default) (:issue:`20027`) - zip compression is supported via ``compression=zip`` in :func:`DataFrame.to_pickle`, :func:`Series.to_pickle`, :func:`DataFrame.to_csv`, :func:`Series.to_csv`, :func:`DataFrame.to_json`, :func:`Series.to_json`. (:issue:`17778`) +- :class:`DataFrame` and :class:`Series` now support matrix multiplication (```@```) operator (:issue:`10259`) for Python>=3.5 .. _whatsnew_0230.api_breaking: diff --git a/pandas/core/frame.py b/pandas/core/frame.py index ace975385ce32..9626079660771 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -863,7 +863,8 @@ def __len__(self): def dot(self, other): """ - Matrix multiplication with DataFrame or Series objects + Matrix multiplication with DataFrame or Series objects. Can also be + called using `self @ other` in Python >= 3.5. Parameters ---------- @@ -905,6 +906,14 @@ def dot(self, other): else: # pragma: no cover raise TypeError('unsupported type: %s' % type(other)) + def __matmul__(self, other): + """ Matrix multiplication using binary `@` operator in Python>=3.5 """ + return self.dot(other) + + def __rmatmul__(self, other): + """ Matrix multiplication using binary `@` operator in Python>=3.5 """ + return self.T.dot(np.transpose(other)).T + # ---------------------------------------------------------------------- # IO methods (to / from other formats) diff --git a/pandas/core/series.py b/pandas/core/series.py index 1b07f24e148e3..f3630dc43fbd1 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -1994,7 +1994,7 @@ def autocorr(self, lag=1): def dot(self, other): """ Matrix multiplication with DataFrame or inner-product with Series - objects + objects. Can also be called using `self @ other` in Python >= 3.5. Parameters ---------- @@ -2033,6 +2033,14 @@ def dot(self, other): else: # pragma: no cover raise TypeError('unsupported type: %s' % type(other)) + def __matmul__(self, other): + """ Matrix multiplication using binary `@` operator in Python>=3.5 """ + return self.dot(other) + + def __rmatmul__(self, other): + """ Matrix multiplication using binary `@` operator in Python>=3.5 """ + return self.dot(other) + @Substitution(klass='Series') @Appender(base._shared_docs['searchsorted']) @deprecate_kwarg(old_arg_name='v', new_arg_name='value') diff --git a/pandas/tests/frame/test_analytics.py b/pandas/tests/frame/test_analytics.py index 8efa140237614..7949636fcafbb 100644 --- a/pandas/tests/frame/test_analytics.py +++ b/pandas/tests/frame/test_analytics.py @@ -5,6 +5,7 @@ import warnings from datetime import timedelta from distutils.version import LooseVersion +import operator import sys import pytest @@ -13,7 +14,7 @@ from numpy.random import randn import numpy as np -from pandas.compat import lrange, product +from pandas.compat import lrange, product, PY35 from pandas import (compat, isna, notna, DataFrame, Series, MultiIndex, date_range, Timestamp, Categorical, _np_version_under1p15) @@ -2091,7 +2092,6 @@ def test_clip_with_na_args(self): self.frame) # Matrix-like - def test_dot(self): a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'], columns=['p', 'q', 'r', 's']) @@ -2144,6 +2144,63 @@ def test_dot(self): with tm.assert_raises_regex(ValueError, 'aligned'): df.dot(df2) + @pytest.mark.skipif(not PY35, + reason='matmul supported for Python>=3.5') + def test_matmul(self): + # matmul test is for GH #10259 + a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'], + columns=['p', 'q', 'r', 's']) + b = DataFrame(np.random.randn(4, 2), index=['p', 'q', 'r', 's'], + columns=['one', 'two']) + + # DataFrame @ DataFrame + result = operator.matmul(a, b) + expected = DataFrame(np.dot(a.values, b.values), + index=['a', 'b', 'c'], + columns=['one', 'two']) + tm.assert_frame_equal(result, expected) + + # DataFrame @ Series + result = operator.matmul(a, b.one) + expected = Series(np.dot(a.values, b.one.values), + index=['a', 'b', 'c']) + tm.assert_series_equal(result, expected) + + # np.array @ DataFrame + result = operator.matmul(a.values, b) + expected = np.dot(a.values, b.values) + tm.assert_almost_equal(result, expected) + + # nested list @ DataFrame (__rmatmul__) + result = operator.matmul(a.values.tolist(), b) + expected = DataFrame(np.dot(a.values, b.values), + index=['a', 'b', 'c'], + columns=['one', 'two']) + tm.assert_almost_equal(result.values, expected.values) + + # mixed dtype DataFrame @ DataFrame + a['q'] = a.q.round().astype(int) + result = operator.matmul(a, b) + expected = DataFrame(np.dot(a.values, b.values), + index=['a', 'b', 'c'], + columns=['one', 'two']) + tm.assert_frame_equal(result, expected) + + # different dtypes DataFrame @ DataFrame + a = a.astype(int) + result = operator.matmul(a, b) + expected = DataFrame(np.dot(a.values, b.values), + index=['a', 'b', 'c'], + columns=['one', 'two']) + tm.assert_frame_equal(result, expected) + + # unaligned + df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4)) + df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3]) + + with tm.assert_raises_regex(ValueError, 'aligned'): + operator.matmul(df, df2) + @pytest.fixture def df_duplicates(): diff --git a/pandas/tests/series/test_analytics.py b/pandas/tests/series/test_analytics.py index 0e6e44e839464..f93aaf2115601 100644 --- a/pandas/tests/series/test_analytics.py +++ b/pandas/tests/series/test_analytics.py @@ -3,7 +3,7 @@ from itertools import product from distutils.version import LooseVersion - +import operator import pytest from numpy import nan @@ -18,7 +18,7 @@ from pandas.core.indexes.timedeltas import Timedelta import pandas.core.nanops as nanops -from pandas.compat import lrange, range +from pandas.compat import lrange, range, PY35 from pandas import compat from pandas.util.testing import (assert_series_equal, assert_almost_equal, assert_frame_equal, assert_index_equal) @@ -921,6 +921,52 @@ def test_dot(self): pytest.raises(Exception, a.dot, a.values[:3]) pytest.raises(ValueError, a.dot, b.T) + @pytest.mark.skipif(not PY35, + reason='matmul supported for Python>=3.5') + def test_matmul(self): + # matmul test is for GH #10259 + a = Series(np.random.randn(4), index=['p', 'q', 'r', 's']) + b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'], + columns=['p', 'q', 'r', 's']).T + + # Series @ DataFrame + result = operator.matmul(a, b) + expected = Series(np.dot(a.values, b.values), index=['1', '2', '3']) + assert_series_equal(result, expected) + + # DataFrame @ Series + result = operator.matmul(b.T, a) + expected = Series(np.dot(b.T.values, a.T.values), + index=['1', '2', '3']) + assert_series_equal(result, expected) + + # Series @ Series + result = operator.matmul(a, a) + expected = np.dot(a.values, a.values) + assert_almost_equal(result, expected) + + # np.array @ Series (__rmatmul__) + result = operator.matmul(a.values, a) + expected = np.dot(a.values, a.values) + assert_almost_equal(result, expected) + + # mixed dtype DataFrame @ Series + a['p'] = int(a.p) + result = operator.matmul(b.T, a) + expected = Series(np.dot(b.T.values, a.T.values), + index=['1', '2', '3']) + assert_series_equal(result, expected) + + # different dtypes DataFrame @ Series + a = a.astype(int) + result = operator.matmul(b.T, a) + expected = Series(np.dot(b.T.values, a.T.values), + index=['1', '2', '3']) + assert_series_equal(result, expected) + + pytest.raises(Exception, a.dot, a.values[:3]) + pytest.raises(ValueError, a.dot, b.T) + def test_value_counts_nunique(self): # basics.rst doc example