Skip to content

ENH: Add matmul to DataFrame, Series #19035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.23.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ Other Enhancements
``SQLAlchemy`` dialects supporting multivalue inserts include: ``mysql``, ``postgresql``, ``sqlite`` and any dialect with ``supports_multivalues_insert``. (:issue:`14315`, :issue:`8953`)
- :func:`read_html` now accepts a ``displayed_only`` keyword argument to controls whether or not hidden elements are parsed (``True`` by default) (:issue:`20027`)
- zip compression is supported via ``compression=zip`` in :func:`DataFrame.to_pickle`, :func:`Series.to_pickle`, :func:`DataFrame.to_csv`, :func:`Series.to_csv`, :func:`DataFrame.to_json`, :func:`Series.to_json`. (:issue:`17778`)
- :class:`DataFrame` and :class:`Series` now support matrix multiplication (```@```) operator (:issue:`10259`) for Python>=3.5

.. _whatsnew_0230.api_breaking:

Expand Down
11 changes: 10 additions & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,8 @@ def __len__(self):

def dot(self, other):
"""
Matrix multiplication with DataFrame or Series objects
Matrix multiplication with DataFrame or Series objects. Can also be
called using `self @ other` in Python >= 3.5.

Parameters
----------
Expand Down Expand Up @@ -905,6 +906,14 @@ def dot(self, other):
else: # pragma: no cover
raise TypeError('unsupported type: %s' % type(other))

def __matmul__(self, other):
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
return self.dot(other)

def __rmatmul__(self, other):
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
return self.T.dot(np.transpose(other)).T

# ----------------------------------------------------------------------
# IO methods (to / from other formats)

Expand Down
10 changes: 9 additions & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1994,7 +1994,7 @@ def autocorr(self, lag=1):
def dot(self, other):
"""
Matrix multiplication with DataFrame or inner-product with Series
objects
objects. Can also be called using `self @ other` in Python >= 3.5.

Parameters
----------
Expand Down Expand Up @@ -2033,6 +2033,14 @@ def dot(self, other):
else: # pragma: no cover
raise TypeError('unsupported type: %s' % type(other))

def __matmul__(self, other):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can u update the .dot doc string that ‘@‘ operator is supported in >=3.5

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you do this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean update the .dot doc string? I believe that is done

""" Matrix multiplication using binary `@` operator in Python>=3.5 """
return self.dot(other)

def __rmatmul__(self, other):
""" Matrix multiplication using binary `@` operator in Python>=3.5 """
return self.dot(other)

@Substitution(klass='Series')
@Appender(base._shared_docs['searchsorted'])
@deprecate_kwarg(old_arg_name='v', new_arg_name='value')
Expand Down
61 changes: 59 additions & 2 deletions pandas/tests/frame/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
from datetime import timedelta
from distutils.version import LooseVersion
import operator
import sys
import pytest

Expand All @@ -13,7 +14,7 @@
from numpy.random import randn
import numpy as np

from pandas.compat import lrange, product
from pandas.compat import lrange, product, PY35
from pandas import (compat, isna, notna, DataFrame, Series,
MultiIndex, date_range, Timestamp, Categorical,
_np_version_under1p15)
Expand Down Expand Up @@ -2091,7 +2092,6 @@ def test_clip_with_na_args(self):
self.frame)

# Matrix-like

def test_dot(self):
a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'],
columns=['p', 'q', 'r', 's'])
Expand Down Expand Up @@ -2144,6 +2144,63 @@ def test_dot(self):
with tm.assert_raises_regex(ValueError, 'aligned'):
df.dot(df2)

@pytest.mark.skipif(not PY35,
reason='matmul supported for Python>=3.5')
def test_matmul(self):
# matmul test is for GH #10259
a = DataFrame(np.random.randn(3, 4), index=['a', 'b', 'c'],
columns=['p', 'q', 'r', 's'])
b = DataFrame(np.random.randn(4, 2), index=['p', 'q', 'r', 's'],
columns=['one', 'two'])

# DataFrame @ DataFrame
result = operator.matmul(a, b)
expected = DataFrame(np.dot(a.values, b.values),
index=['a', 'b', 'c'],
columns=['one', 'two'])
tm.assert_frame_equal(result, expected)

# DataFrame @ Series
result = operator.matmul(a, b.one)
expected = Series(np.dot(a.values, b.one.values),
index=['a', 'b', 'c'])
tm.assert_series_equal(result, expected)

# np.array @ DataFrame
result = operator.matmul(a.values, b)
expected = np.dot(a.values, b.values)
tm.assert_almost_equal(result, expected)

# nested list @ DataFrame (__rmatmul__)
result = operator.matmul(a.values.tolist(), b)
expected = DataFrame(np.dot(a.values, b.values),
index=['a', 'b', 'c'],
columns=['one', 'two'])
tm.assert_almost_equal(result.values, expected.values)

# mixed dtype DataFrame @ DataFrame
a['q'] = a.q.round().astype(int)
result = operator.matmul(a, b)
expected = DataFrame(np.dot(a.values, b.values),
index=['a', 'b', 'c'],
columns=['one', 'two'])
tm.assert_frame_equal(result, expected)

# different dtypes DataFrame @ DataFrame
a = a.astype(int)
result = operator.matmul(a, b)
expected = DataFrame(np.dot(a.values, b.values),
index=['a', 'b', 'c'],
columns=['one', 'two'])
tm.assert_frame_equal(result, expected)

# unaligned
df = DataFrame(randn(3, 4), index=[1, 2, 3], columns=lrange(4))
df2 = DataFrame(randn(5, 3), index=lrange(5), columns=[1, 2, 3])

with tm.assert_raises_regex(ValueError, 'aligned'):
operator.matmul(df, df2)


@pytest.fixture
def df_duplicates():
Expand Down
50 changes: 48 additions & 2 deletions pandas/tests/series/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from itertools import product
from distutils.version import LooseVersion

import operator
import pytest

from numpy import nan
Expand All @@ -18,7 +18,7 @@
from pandas.core.indexes.timedeltas import Timedelta
import pandas.core.nanops as nanops

from pandas.compat import lrange, range
from pandas.compat import lrange, range, PY35
from pandas import compat
from pandas.util.testing import (assert_series_equal, assert_almost_equal,
assert_frame_equal, assert_index_equal)
Expand Down Expand Up @@ -921,6 +921,52 @@ def test_dot(self):
pytest.raises(Exception, a.dot, a.values[:3])
pytest.raises(ValueError, a.dot, b.T)

@pytest.mark.skipif(not PY35,
reason='matmul supported for Python>=3.5')
def test_matmul(self):
# matmul test is for GH #10259
a = Series(np.random.randn(4), index=['p', 'q', 'r', 's'])
b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'],
columns=['p', 'q', 'r', 's']).T

# Series @ DataFrame
result = operator.matmul(a, b)
expected = Series(np.dot(a.values, b.values), index=['1', '2', '3'])
assert_series_equal(result, expected)

# DataFrame @ Series
result = operator.matmul(b.T, a)
expected = Series(np.dot(b.T.values, a.T.values),
index=['1', '2', '3'])
assert_series_equal(result, expected)

# Series @ Series
result = operator.matmul(a, a)
expected = np.dot(a.values, a.values)
assert_almost_equal(result, expected)

# np.array @ Series (__rmatmul__)
result = operator.matmul(a.values, a)
expected = np.dot(a.values, a.values)
assert_almost_equal(result, expected)

# mixed dtype DataFrame @ Series
a['p'] = int(a.p)
result = operator.matmul(b.T, a)
expected = Series(np.dot(b.T.values, a.T.values),
index=['1', '2', '3'])
assert_series_equal(result, expected)

# different dtypes DataFrame @ Series
a = a.astype(int)
result = operator.matmul(b.T, a)
expected = Series(np.dot(b.T.values, a.T.values),
index=['1', '2', '3'])
assert_series_equal(result, expected)

pytest.raises(Exception, a.dot, a.values[:3])
pytest.raises(ValueError, a.dot, b.T)

def test_value_counts_nunique(self):

# basics.rst doc example
Expand Down