From 7db1e25e1f258aa96476169f56b11a7f8818dcae Mon Sep 17 00:00:00 2001 From: Brett Naul Date: Sat, 6 Jan 2018 15:09:51 -0800 Subject: [PATCH 1/4] Add matmul to DataFrame, Series --- doc/source/whatsnew/v0.23.0.txt | 1 + pandas/core/frame.py | 7 ++++++- pandas/core/series.py | 6 +++++- pandas/tests/frame/test_analytics.py | 23 ++++++++++++----------- pandas/tests/series/test_analytics.py | 16 +++++++++------- 5 files changed, 33 insertions(+), 20 deletions(-) diff --git a/doc/source/whatsnew/v0.23.0.txt b/doc/source/whatsnew/v0.23.0.txt index d7c92ed822ffc..312faefdc1a22 100644 --- a/doc/source/whatsnew/v0.23.0.txt +++ b/doc/source/whatsnew/v0.23.0.txt @@ -346,6 +346,7 @@ Other Enhancements ``SQLAlchemy`` dialects supporting multivalue inserts include: ``mysql``, ``postgresql``, ``sqlite`` and any dialect with ``supports_multivalues_insert``. (:issue:`14315`, :issue:`8953`) - :func:`read_html` now accepts a ``displayed_only`` keyword argument to controls whether or not hidden elements are parsed (``True`` by default) (:issue:`20027`) - zip compression is supported via ``compression=zip`` in :func:`DataFrame.to_pickle`, :func:`Series.to_pickle`, :func:`DataFrame.to_csv`, :func:`Series.to_csv`, :func:`DataFrame.to_json`, :func:`Series.to_json`. (:issue:`17778`) +- :class:`DataFrame` and :class:`Series` now support matrix multiplication (```@```) operator (:issue:`10259`) for Python>=3.5 .. _whatsnew_0230.api_breaking: diff --git a/pandas/core/frame.py b/pandas/core/frame.py index d2617305d220a..a49990bea7203 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -863,7 +863,8 @@ def __len__(self): def dot(self, other): """ - Matrix multiplication with DataFrame or Series objects + Matrix multiplication with DataFrame or Series objects. Can also be + called using `self @ other` in Python >= 3.5. Parameters ---------- @@ -905,6 +906,10 @@ def dot(self, other): else: # pragma: no cover raise TypeError('unsupported type: %s' % type(other)) + def __matmul__(self, other): + """ Matrix multiplication using binary `@` operator in Python>=3.5 """ + return self.dot(other) + # ---------------------------------------------------------------------- # IO methods (to / from other formats) diff --git a/pandas/core/series.py b/pandas/core/series.py index 48e6453e36491..6d791c3336f34 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -1992,7 +1992,7 @@ def autocorr(self, lag=1): def dot(self, other): """ Matrix multiplication with DataFrame or inner-product with Series - objects + objects. Can also be called using `self @ other` in Python >= 3.5. Parameters ---------- @@ -2031,6 +2031,10 @@ def dot(self, other): else: # pragma: no cover raise TypeError('unsupported type: %s' % type(other)) + def __matmul__(self, other): + """ Matrix multiplication using binary `@` operator in Python>=3.5 """ + return self.dot(other) + @Substitution(klass='Series') @Appender(base._shared_docs['searchsorted']) @deprecate_kwarg(old_arg_name='v', new_arg_name='value') diff --git a/pandas/tests/frame/test_analytics.py b/pandas/tests/frame/test_analytics.py index 8efa140237614..eea5a729cf46e 100644 --- a/pandas/tests/frame/test_analytics.py +++ b/pandas/tests/frame/test_analytics.py @@ -2091,41 +2091,42 @@ def test_clip_with_na_args(self): self.frame) # Matrix-like - - def test_dot(self): + @pytest.mark.parametrize('dot_fn', [DataFrame.dot, DataFrame.__matmul__]) + def test_dot(self, dot_fn): + # __matmul__ test is for GH #10259 a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'], columns=['p', 'q', 'r', 's']) b = DataFrame(np.random.randn(4, 2), index=['p', 'q', 'r', 's'], columns=['one', 'two']) - result = a.dot(b) + result = dot_fn(a, b) expected = DataFrame(np.dot(a.values, b.values), index=['a', 'b', 'c'], columns=['one', 'two']) # Check alignment b1 = b.reindex(index=reversed(b.index)) - result = a.dot(b) + result = dot_fn(a, b) tm.assert_frame_equal(result, expected) # Check series argument - result = a.dot(b['one']) + result = dot_fn(a, b['one']) tm.assert_series_equal(result, expected['one'], check_names=False) assert result.name is None - result = a.dot(b1['one']) + result = dot_fn(a, b1['one']) tm.assert_series_equal(result, expected['one'], check_names=False) assert result.name is None # can pass correct-length arrays row = a.iloc[0].values - result = a.dot(row) - exp = a.dot(a.iloc[0]) + result = dot_fn(a, row) + exp = dot_fn(a, a.iloc[0]) tm.assert_series_equal(result, exp) with tm.assert_raises_regex(ValueError, 'Dot product shape mismatch'): - a.dot(row[:-1]) + dot_fn(a, row[:-1]) a = np.random.rand(1, 5) b = np.random.rand(5, 1) @@ -2135,14 +2136,14 @@ def test_dot(self): B = DataFrame(b) # noqa # it works - result = A.dot(b) + result = dot_fn(A, b) # unaligned df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4)) df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3]) with tm.assert_raises_regex(ValueError, 'aligned'): - df.dot(df2) + dot_fn(df, df2) @pytest.fixture diff --git a/pandas/tests/series/test_analytics.py b/pandas/tests/series/test_analytics.py index 0e6e44e839464..d351be800062f 100644 --- a/pandas/tests/series/test_analytics.py +++ b/pandas/tests/series/test_analytics.py @@ -895,28 +895,30 @@ def test_count(self): ts.iloc[[0, 3, 5]] = nan assert_series_equal(ts.count(level=1), right - 1) - def test_dot(self): + @pytest.mark.parametrize('dot_fn', [Series.dot, Series.__matmul__]) + def test_dot(self, dot_fn): + # __matmul__ test is for GH #10259 a = Series(np.random.randn(4), index=['p', 'q', 'r', 's']) b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'], columns=['p', 'q', 'r', 's']).T - result = a.dot(b) + result = dot_fn(a, b) expected = Series(np.dot(a.values, b.values), index=['1', '2', '3']) assert_series_equal(result, expected) # Check index alignment b2 = b.reindex(index=reversed(b.index)) - result = a.dot(b) + result = dot_fn(a, b) assert_series_equal(result, expected) # Check ndarray argument - result = a.dot(b.values) + result = dot_fn(a, b.values) assert np.all(result == expected.values) - assert_almost_equal(a.dot(b['2'].values), expected['2']) + assert_almost_equal(dot_fn(a, b['2'].values), expected['2']) # Check series argument - assert_almost_equal(a.dot(b['1']), expected['1']) - assert_almost_equal(a.dot(b2['1']), expected['1']) + assert_almost_equal(dot_fn(a, b['1']), expected['1']) + assert_almost_equal(dot_fn(a, b2['1']), expected['1']) pytest.raises(Exception, a.dot, a.values[:3]) pytest.raises(ValueError, a.dot, b.T) From d6f156b8da7d856539b2de08a7d161a5810cff56 Mon Sep 17 00:00:00 2001 From: Brett Naul Date: Sat, 17 Mar 2018 12:37:06 -0700 Subject: [PATCH 2/4] Add rmatmul to DataFrame, Series --- pandas/core/frame.py | 4 ++ pandas/core/series.py | 4 ++ pandas/tests/frame/test_analytics.py | 64 ++++++++++++++++++++++----- pandas/tests/series/test_analytics.py | 50 +++++++++++++++++---- 4 files changed, 101 insertions(+), 21 deletions(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index a49990bea7203..c53a739bb8e34 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -910,6 +910,10 @@ def __matmul__(self, other): """ Matrix multiplication using binary `@` operator in Python>=3.5 """ return self.dot(other) + def __rmatmul__(self, other): + """ Matrix multiplication using binary `@` operator in Python>=3.5 """ + return self.T.dot(np.transpose(other)).T + # ---------------------------------------------------------------------- # IO methods (to / from other formats) diff --git a/pandas/core/series.py b/pandas/core/series.py index 6d791c3336f34..c9e74a7b067cc 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -2035,6 +2035,10 @@ def __matmul__(self, other): """ Matrix multiplication using binary `@` operator in Python>=3.5 """ return self.dot(other) + def __rmatmul__(self, other): + """ Matrix multiplication using binary `@` operator in Python>=3.5 """ + return self.dot(other) + @Substitution(klass='Series') @Appender(base._shared_docs['searchsorted']) @deprecate_kwarg(old_arg_name='v', new_arg_name='value') diff --git a/pandas/tests/frame/test_analytics.py b/pandas/tests/frame/test_analytics.py index eea5a729cf46e..eddf135d1fa24 100644 --- a/pandas/tests/frame/test_analytics.py +++ b/pandas/tests/frame/test_analytics.py @@ -5,6 +5,7 @@ import warnings from datetime import timedelta from distutils.version import LooseVersion +import operator import sys import pytest @@ -2091,42 +2092,40 @@ def test_clip_with_na_args(self): self.frame) # Matrix-like - @pytest.mark.parametrize('dot_fn', [DataFrame.dot, DataFrame.__matmul__]) - def test_dot(self, dot_fn): - # __matmul__ test is for GH #10259 + def test_dot(self): a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'], columns=['p', 'q', 'r', 's']) b = DataFrame(np.random.randn(4, 2), index=['p', 'q', 'r', 's'], columns=['one', 'two']) - result = dot_fn(a, b) + result = a.dot(b) expected = DataFrame(np.dot(a.values, b.values), index=['a', 'b', 'c'], columns=['one', 'two']) # Check alignment b1 = b.reindex(index=reversed(b.index)) - result = dot_fn(a, b) + result = a.dot(b) tm.assert_frame_equal(result, expected) # Check series argument - result = dot_fn(a, b['one']) + result = a.dot(b['one']) tm.assert_series_equal(result, expected['one'], check_names=False) assert result.name is None - result = dot_fn(a, b1['one']) + result = a.dot(b1['one']) tm.assert_series_equal(result, expected['one'], check_names=False) assert result.name is None # can pass correct-length arrays row = a.iloc[0].values - result = dot_fn(a, row) - exp = dot_fn(a, a.iloc[0]) + result = a.dot(row) + exp = a.dot(a.iloc[0]) tm.assert_series_equal(result, exp) with tm.assert_raises_regex(ValueError, 'Dot product shape mismatch'): - dot_fn(a, row[:-1]) + a.dot(row[:-1]) a = np.random.rand(1, 5) b = np.random.rand(5, 1) @@ -2136,14 +2135,55 @@ def test_dot(self, dot_fn): B = DataFrame(b) # noqa # it works - result = dot_fn(A, b) + result = A.dot(b) # unaligned df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4)) df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3]) with tm.assert_raises_regex(ValueError, 'aligned'): - dot_fn(df, df2) + df.dot(df2) + + @pytest.mark.skipif(sys.version_info < (3, 5), + reason='matmul supported for Python>=3.5') + def test_matmul(self): + # matmul test is for GH #10259 + a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'], + columns=['p', 'q', 'r', 's']) + b = DataFrame(np.random.randn(4, 2), index=['p', 'q', 'r', 's'], + columns=['one', 'two']) + + # DataFrame @ DataFrame + result = operator.matmul(a, b) + expected = DataFrame(np.dot(a.values, b.values), + index=['a', 'b', 'c'], + columns=['one', 'two']) + tm.assert_frame_equal(result, expected) + + # DataFrame @ Series + result = operator.matmul(a, b.one) + expected = Series(np.dot(a.values, b.one.values), + index=['a', 'b', 'c']) + tm.assert_series_equal(result, expected) + + # np.array @ DataFrame + result = operator.matmul(a.values, b) + expected = np.dot(a.values, b.values) + tm.assert_almost_equal(result, expected) + + # nested list @ DataFrame (__rmatmul__) + result = operator.matmul(a.values.tolist(), b) + expected = DataFrame(np.dot(a.values, b.values), + index=['a', 'b', 'c'], + columns=['one', 'two']) + tm.assert_almost_equal(result.values, expected.values) + + # unaligned + df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4)) + df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3]) + + with tm.assert_raises_regex(ValueError, 'aligned'): + operator.matmul(df, df2) @pytest.fixture diff --git a/pandas/tests/series/test_analytics.py b/pandas/tests/series/test_analytics.py index d351be800062f..26f4540fa0bb4 100644 --- a/pandas/tests/series/test_analytics.py +++ b/pandas/tests/series/test_analytics.py @@ -3,6 +3,8 @@ from itertools import product from distutils.version import LooseVersion +import operator +import sys import pytest @@ -895,30 +897,60 @@ def test_count(self): ts.iloc[[0, 3, 5]] = nan assert_series_equal(ts.count(level=1), right - 1) - @pytest.mark.parametrize('dot_fn', [Series.dot, Series.__matmul__]) - def test_dot(self, dot_fn): - # __matmul__ test is for GH #10259 + def test_dot(self): a = Series(np.random.randn(4), index=['p', 'q', 'r', 's']) b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'], columns=['p', 'q', 'r', 's']).T - result = dot_fn(a, b) + result = a.dot(b) expected = Series(np.dot(a.values, b.values), index=['1', '2', '3']) assert_series_equal(result, expected) # Check index alignment b2 = b.reindex(index=reversed(b.index)) - result = dot_fn(a, b) + result = a.dot(b) assert_series_equal(result, expected) # Check ndarray argument - result = dot_fn(a, b.values) + result = a.dot(b.values) assert np.all(result == expected.values) - assert_almost_equal(dot_fn(a, b['2'].values), expected['2']) + assert_almost_equal(a.dot(b['2'].values), expected['2']) # Check series argument - assert_almost_equal(dot_fn(a, b['1']), expected['1']) - assert_almost_equal(dot_fn(a, b2['1']), expected['1']) + assert_almost_equal(a.dot(b['1']), expected['1']) + assert_almost_equal(a.dot(b2['1']), expected['1']) + + pytest.raises(Exception, a.dot, a.values[:3]) + pytest.raises(ValueError, a.dot, b.T) + + @pytest.mark.skipif(sys.version_info < (3, 5), + reason='matmul supported for Python>=3.5') + def test_matmul(self): + # matmul test is for GH #10259 + a = Series(np.random.randn(4), index=['p', 'q', 'r', 's']) + b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'], + columns=['p', 'q', 'r', 's']).T + + # Series @ DataFrame + result = operator.matmul(a, b) + expected = Series(np.dot(a.values, b.values), index=['1', '2', '3']) + assert_series_equal(result, expected) + + # DataFrame @ Series + result = operator.matmul(b.T, a) + expected = Series(np.dot(b.T.values, a.T.values), + index=['1', '2', '3']) + assert_series_equal(result, expected) + + # Series @ Series + result = operator.matmul(a, a) + expected = np.dot(a.values, a.values) + assert_almost_equal(result, expected) + + # np.array @ Series (__rmatmul__) + result = operator.matmul(a.values, a) + expected = np.dot(a.values, a.values) + assert_almost_equal(result, expected) pytest.raises(Exception, a.dot, a.values[:3]) pytest.raises(ValueError, a.dot, b.T) From 396a45621eeba6e70d5acc7d635cf12197a05005 Mon Sep 17 00:00:00 2001 From: Brett Naul Date: Wed, 28 Mar 2018 10:07:08 -0700 Subject: [PATCH 3/4] Mixed dtype matmul tests --- pandas/tests/frame/test_analytics.py | 16 ++++++++++++++++ pandas/tests/series/test_analytics.py | 16 +++++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/pandas/tests/frame/test_analytics.py b/pandas/tests/frame/test_analytics.py index eddf135d1fa24..4ccd7fb84f5bf 100644 --- a/pandas/tests/frame/test_analytics.py +++ b/pandas/tests/frame/test_analytics.py @@ -2178,6 +2178,22 @@ def test_matmul(self): columns=['one', 'two']) tm.assert_almost_equal(result.values, expected.values) + # mixed dtype DataFrame @ DataFrame + a['q'] = a.q.round().astype(int) + result = operator.matmul(a, b) + expected = DataFrame(np.dot(a.values, b.values), + index=['a', 'b', 'c'], + columns=['one', 'two']) + tm.assert_frame_equal(result, expected) + + # different dtypes DataFrame @ DataFrame + a = a.astype(int) + result = operator.matmul(a, b) + expected = DataFrame(np.dot(a.values, b.values), + index=['a', 'b', 'c'], + columns=['one', 'two']) + tm.assert_frame_equal(result, expected) + # unaligned df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4)) df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3]) diff --git a/pandas/tests/series/test_analytics.py b/pandas/tests/series/test_analytics.py index 26f4540fa0bb4..f938ad357716c 100644 --- a/pandas/tests/series/test_analytics.py +++ b/pandas/tests/series/test_analytics.py @@ -947,11 +947,25 @@ def test_matmul(self): expected = np.dot(a.values, a.values) assert_almost_equal(result, expected) - # np.array @ Series (__rmatmul__) + # np.array @ Series (__rmatmul__) result = operator.matmul(a.values, a) expected = np.dot(a.values, a.values) assert_almost_equal(result, expected) + # mixed dtype DataFrame @ Series + a['p'] = int(a.p) + result = operator.matmul(b.T, a) + expected = Series(np.dot(b.T.values, a.T.values), + index=['1', '2', '3']) + assert_series_equal(result, expected) + + # different dtypes DataFrame @ Series + a = a.astype(int) + result = operator.matmul(b.T, a) + expected = Series(np.dot(b.T.values, a.T.values), + index=['1', '2', '3']) + assert_series_equal(result, expected) + pytest.raises(Exception, a.dot, a.values[:3]) pytest.raises(ValueError, a.dot, b.T) From c036ed0da8f147790529f123106f46e4c23cdd36 Mon Sep 17 00:00:00 2001 From: Jeff Reback Date: Fri, 30 Mar 2018 17:54:28 -0400 Subject: [PATCH 4/4] use compat for version check --- pandas/tests/frame/test_analytics.py | 4 ++-- pandas/tests/series/test_analytics.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pandas/tests/frame/test_analytics.py b/pandas/tests/frame/test_analytics.py index 4ccd7fb84f5bf..7949636fcafbb 100644 --- a/pandas/tests/frame/test_analytics.py +++ b/pandas/tests/frame/test_analytics.py @@ -14,7 +14,7 @@ from numpy.random import randn import numpy as np -from pandas.compat import lrange, product +from pandas.compat import lrange, product, PY35 from pandas import (compat, isna, notna, DataFrame, Series, MultiIndex, date_range, Timestamp, Categorical, _np_version_under1p15) @@ -2144,7 +2144,7 @@ def test_dot(self): with tm.assert_raises_regex(ValueError, 'aligned'): df.dot(df2) - @pytest.mark.skipif(sys.version_info < (3, 5), + @pytest.mark.skipif(not PY35, reason='matmul supported for Python>=3.5') def test_matmul(self): # matmul test is for GH #10259 diff --git a/pandas/tests/series/test_analytics.py b/pandas/tests/series/test_analytics.py index f938ad357716c..f93aaf2115601 100644 --- a/pandas/tests/series/test_analytics.py +++ b/pandas/tests/series/test_analytics.py @@ -4,8 +4,6 @@ from itertools import product from distutils.version import LooseVersion import operator -import sys - import pytest from numpy import nan @@ -20,7 +18,7 @@ from pandas.core.indexes.timedeltas import Timedelta import pandas.core.nanops as nanops -from pandas.compat import lrange, range +from pandas.compat import lrange, range, PY35 from pandas import compat from pandas.util.testing import (assert_series_equal, assert_almost_equal, assert_frame_equal, assert_index_equal) @@ -923,7 +921,7 @@ def test_dot(self): pytest.raises(Exception, a.dot, a.values[:3]) pytest.raises(ValueError, a.dot, b.T) - @pytest.mark.skipif(sys.version_info < (3, 5), + @pytest.mark.skipif(not PY35, reason='matmul supported for Python>=3.5') def test_matmul(self): # matmul test is for GH #10259