Skip to content

Commit 77618f0

Browse files
committed
Merge pull request #3470 from dieterv77/SeriesDot
ENH: Bring Series.dot up to par with DataFrame.dot
2 parents 6e7c4d6 + 8a39682 commit 77618f0

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

pandas/core/series.py

+41
Original file line numberDiff line numberDiff line change
@@ -1944,6 +1944,47 @@ def clip_lower(self, threshold):
19441944
"""
19451945
return pa.where(self < threshold, threshold, self)
19461946

1947+
def dot(self, other):
1948+
"""
1949+
Matrix multiplication with DataFrame or inner-product with Series objects
1950+
1951+
Parameters
1952+
----------
1953+
other : Series or DataFrame
1954+
1955+
Returns
1956+
-------
1957+
dot_product : scalar or Series
1958+
"""
1959+
from pandas.core.frame import DataFrame
1960+
if isinstance(other, (Series, DataFrame)):
1961+
common = self.index.union(other.index)
1962+
if (len(common) > len(self.index) or
1963+
len(common) > len(other.index)):
1964+
raise ValueError('matrices are not aligned')
1965+
1966+
left = self.reindex(index=common, copy=False)
1967+
right = other.reindex(index=common, copy=False)
1968+
lvals = left.values
1969+
rvals = right.values
1970+
else:
1971+
left = self
1972+
lvals = self.values
1973+
rvals = np.asarray(other)
1974+
if lvals.shape[0] != rvals.shape[0]:
1975+
raise Exception('Dot product shape mismatch, %s vs %s' %
1976+
(lvals.shape, rvals.shape))
1977+
1978+
if isinstance(other, DataFrame):
1979+
return self._constructor(np.dot(lvals, rvals),
1980+
index=other.columns)
1981+
elif isinstance(other, Series):
1982+
return np.dot(lvals, rvals)
1983+
elif isinstance(rvals, np.ndarray):
1984+
return np.dot(lvals, rvals)
1985+
else: # pragma: no cover
1986+
raise TypeError('unsupported type: %s' % type(other))
1987+
19471988
#------------------------------------------------------------------------------
19481989
# Combination
19491990

pandas/tests/test_series.py

+27
Original file line numberDiff line numberDiff line change
@@ -2486,6 +2486,33 @@ def test_count(self):
24862486

24872487
self.assertEqual(self.ts.count(), np.isfinite(self.ts).sum())
24882488

2489+
def test_dot(self):
2490+
a = Series(np.random.randn(4), index=['p', 'q', 'r', 's'])
2491+
b = DataFrame(np.random.randn(3, 4), index=['1', '2', '3'],
2492+
columns=['p', 'q', 'r', 's']).T
2493+
2494+
result = a.dot(b)
2495+
expected = Series(np.dot(a.values, b.values),
2496+
index=['1', '2', '3'])
2497+
assert_series_equal(result, expected)
2498+
2499+
#Check index alignment
2500+
b2 = b.reindex(index=reversed(b.index))
2501+
result = a.dot(b)
2502+
assert_series_equal(result, expected)
2503+
2504+
# Check ndarray argument
2505+
result = a.dot(b.values)
2506+
self.assertTrue(np.all(result == expected.values))
2507+
self.assertEquals(a.dot(b['2'].values), expected['2'])
2508+
2509+
#Check series argument
2510+
self.assertEquals(a.dot(b['1']), expected['1'])
2511+
self.assertEquals(a.dot(b2['1']), expected['1'])
2512+
2513+
self.assertRaises(Exception, a.dot, a.values[:3])
2514+
self.assertRaises(ValueError, a.dot, b.T)
2515+
24892516
def test_value_counts_nunique(self):
24902517
s = Series(['a', 'b', 'b', 'b', 'b', 'a', 'c', 'd', 'd', 'a'])
24912518
hist = s.value_counts()

0 commit comments

Comments
 (0)