Skip to content

Commit 37bd4dc

Browse files
authored
PERF: Sparse Series to scipy COO sparse matrix (#42925)
1 parent 31759fa commit 37bd4dc

File tree

5 files changed

+189
-78
lines changed

5 files changed

+189
-78
lines changed

asv_bench/benchmarks/sparse.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,28 @@ def time_sparse_series_from_coo(self):
6767

6868

6969
class ToCoo:
70-
def setup(self):
70+
params = [True, False]
71+
param_names = ["sort_labels"]
72+
73+
def setup(self, sort_labels):
7174
s = Series([np.nan] * 10000)
7275
s[0] = 3.0
7376
s[100] = -1.0
7477
s[999] = 12.1
75-
s.index = MultiIndex.from_product([range(10)] * 4)
76-
self.ss = s.astype("Sparse")
7778

78-
def time_sparse_series_to_coo(self):
79-
self.ss.sparse.to_coo(row_levels=[0, 1], column_levels=[2, 3], sort_labels=True)
79+
s_mult_lvl = s.set_axis(MultiIndex.from_product([range(10)] * 4))
80+
self.ss_mult_lvl = s_mult_lvl.astype("Sparse")
81+
82+
s_two_lvl = s.set_axis(MultiIndex.from_product([range(100)] * 2))
83+
self.ss_two_lvl = s_two_lvl.astype("Sparse")
84+
85+
def time_sparse_series_to_coo(self, sort_labels):
86+
self.ss_mult_lvl.sparse.to_coo(
87+
row_levels=[0, 1], column_levels=[2, 3], sort_labels=sort_labels
88+
)
89+
90+
def time_sparse_series_to_coo_single_level(self, sort_labels):
91+
self.ss_two_lvl.sparse.to_coo(sort_labels=sort_labels)
8092

8193

8294
class Arithmetic:

doc/source/whatsnew/v1.4.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@ Performance improvements
290290
- Performance improvement in some :meth:`GroupBy.apply` operations (:issue:`42992`)
291291
- Performance improvement in :func:`read_stata` (:issue:`43059`)
292292
- Performance improvement in :meth:`to_datetime` with ``uint`` dtypes (:issue:`42606`)
293+
- Performance improvement in :meth:`Series.sparse.to_coo` (:issue:`42880`)
294+
-
293295

294296
.. ---------------------------------------------------------------------------
295297

pandas/core/arrays/sparse/accessor.py

+2
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def to_coo(self, row_levels=(0,), column_levels=(1,), sort_labels=False):
113113
column_levels : tuple/list
114114
sort_labels : bool, default False
115115
Sort the row and column labels before forming the sparse matrix.
116+
When `row_levels` and/or `column_levels` refer to a single level,
117+
set to `True` for a faster execution.
116118
117119
Returns
118120
-------

pandas/core/arrays/sparse/scipy_sparse.py

+126-67
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,32 @@
33
44
Currently only includes to_coo helpers.
55
"""
6-
from pandas.core.indexes.api import (
7-
Index,
8-
MultiIndex,
6+
from __future__ import annotations
7+
8+
from typing import (
9+
TYPE_CHECKING,
10+
Iterable,
11+
)
12+
13+
import numpy as np
14+
15+
from pandas._libs import lib
16+
from pandas._typing import (
17+
IndexLabel,
18+
npt,
919
)
20+
21+
from pandas.core.dtypes.missing import notna
22+
23+
from pandas.core.algorithms import factorize
24+
from pandas.core.indexes.api import MultiIndex
1025
from pandas.core.series import Series
1126

27+
if TYPE_CHECKING:
28+
import scipy.sparse
29+
1230

13-
def _check_is_partition(parts, whole):
31+
def _check_is_partition(parts: Iterable, whole: Iterable):
1432
whole = set(whole)
1533
parts = [set(x) for x in parts]
1634
if set.intersection(*parts) != set():
@@ -19,76 +37,115 @@ def _check_is_partition(parts, whole):
1937
raise ValueError("Is not a partition because union is not the whole.")
2038

2139

22-
def _to_ijv(ss, row_levels=(0,), column_levels=(1,), sort_labels=False):
23-
"""
24-
For arbitrary (MultiIndexed) sparse Series return
25-
(v, i, j, ilabels, jlabels) where (v, (i, j)) is suitable for
26-
passing to scipy.sparse.coo constructor.
40+
def _levels_to_axis(
41+
ss,
42+
levels: tuple[int] | list[int],
43+
valid_ilocs: npt.NDArray[np.intp],
44+
sort_labels: bool = False,
45+
) -> tuple[npt.NDArray[np.intp], list[IndexLabel]]:
2746
"""
28-
# index and column levels must be a partition of the index
29-
_check_is_partition([row_levels, column_levels], range(ss.index.nlevels))
47+
For a MultiIndexed sparse Series `ss`, return `ax_coords` and `ax_labels`,
48+
where `ax_coords` are the coordinates along one of the two axes of the
49+
destination sparse matrix, and `ax_labels` are the labels from `ss`' Index
50+
which correspond to these coordinates.
51+
52+
Parameters
53+
----------
54+
ss : Series
55+
levels : tuple/list
56+
valid_ilocs : numpy.ndarray
57+
Array of integer positions of valid values for the sparse matrix in ss.
58+
sort_labels : bool, default False
59+
Sort the axis labels before forming the sparse matrix. When `levels`
60+
refers to a single level, set to True for a faster execution.
3061
31-
# from the sparse Series: get the labels and data for non-null entries
32-
values = ss.array._valid_sp_values
33-
34-
nonnull_labels = ss.dropna()
35-
36-
def get_indexers(levels):
37-
"""Return sparse coords and dense labels for subset levels"""
38-
# TODO: how to do this better? cleanly slice nonnull_labels given the
39-
# coord
40-
values_ilabels = [tuple(x[i] for i in levels) for x in nonnull_labels.index]
41-
if len(levels) == 1:
42-
values_ilabels = [x[0] for x in values_ilabels]
43-
44-
# # performance issues with groupby ###################################
45-
# TODO: these two lines can replace the code below but
46-
# groupby is too slow (in some cases at least)
47-
# labels_to_i = ss.groupby(level=levels, sort=sort_labels).first()
48-
# labels_to_i[:] = np.arange(labels_to_i.shape[0])
49-
50-
def _get_label_to_i_dict(labels, sort_labels=False):
51-
"""
52-
Return dict of unique labels to number.
53-
Optionally sort by label.
54-
"""
55-
labels = Index(map(tuple, labels)).unique().tolist() # squish
56-
if sort_labels:
57-
labels = sorted(labels)
58-
return {k: i for i, k in enumerate(labels)}
59-
60-
def _get_index_subset_to_coord_dict(index, subset, sort_labels=False):
61-
ilabels = list(zip(*(index._get_level_values(i) for i in subset)))
62-
labels_to_i = _get_label_to_i_dict(ilabels, sort_labels=sort_labels)
63-
labels_to_i = Series(labels_to_i)
64-
if len(subset) > 1:
65-
labels_to_i.index = MultiIndex.from_tuples(labels_to_i.index)
66-
labels_to_i.index.names = [index.names[i] for i in subset]
67-
else:
68-
labels_to_i.index = Index(x[0] for x in labels_to_i.index)
69-
labels_to_i.index.name = index.names[subset[0]]
70-
71-
labels_to_i.name = "value"
72-
return labels_to_i
73-
74-
labels_to_i = _get_index_subset_to_coord_dict(
75-
ss.index, levels, sort_labels=sort_labels
62+
Returns
63+
-------
64+
ax_coords : numpy.ndarray (axis coordinates)
65+
ax_labels : list (axis labels)
66+
"""
67+
# Since the labels are sorted in `Index.levels`, when we wish to sort and
68+
# there is only one level of the MultiIndex for this axis, the desired
69+
# output can be obtained in the following simpler, more efficient way.
70+
if sort_labels and len(levels) == 1:
71+
ax_coords = ss.index.codes[levels[0]][valid_ilocs]
72+
ax_labels = ss.index.levels[levels[0]]
73+
74+
else:
75+
levels_values = lib.fast_zip(
76+
[ss.index.get_level_values(lvl).values for lvl in levels]
7677
)
77-
# #####################################################################
78-
# #####################################################################
78+
codes, ax_labels = factorize(levels_values, sort=sort_labels)
79+
ax_coords = codes[valid_ilocs]
80+
81+
ax_labels = ax_labels.tolist()
82+
return ax_coords, ax_labels
83+
84+
85+
def _to_ijv(
86+
ss,
87+
row_levels: tuple[int] | list[int] = (0,),
88+
column_levels: tuple[int] | list[int] = (1,),
89+
sort_labels: bool = False,
90+
) -> tuple[
91+
np.ndarray,
92+
npt.NDArray[np.intp],
93+
npt.NDArray[np.intp],
94+
list[IndexLabel],
95+
list[IndexLabel],
96+
]:
97+
"""
98+
For an arbitrary MultiIndexed sparse Series return (v, i, j, ilabels,
99+
jlabels) where (v, (i, j)) is suitable for passing to scipy.sparse.coo
100+
constructor, and ilabels and jlabels are the row and column labels
101+
respectively.
79102
80-
i_coord = labels_to_i[values_ilabels].tolist()
81-
i_labels = labels_to_i.index.tolist()
103+
Parameters
104+
----------
105+
ss : Series
106+
row_levels : tuple/list
107+
column_levels : tuple/list
108+
sort_labels : bool, default False
109+
Sort the row and column labels before forming the sparse matrix.
110+
When `row_levels` and/or `column_levels` refer to a single level,
111+
set to `True` for a faster execution.
82112
83-
return i_coord, i_labels
113+
Returns
114+
-------
115+
values : numpy.ndarray
116+
Valid values to populate a sparse matrix, extracted from
117+
ss.
118+
i_coords : numpy.ndarray (row coordinates of the values)
119+
j_coords : numpy.ndarray (column coordinates of the values)
120+
i_labels : list (row labels)
121+
j_labels : list (column labels)
122+
"""
123+
# index and column levels must be a partition of the index
124+
_check_is_partition([row_levels, column_levels], range(ss.index.nlevels))
125+
# From the sparse Series, get the integer indices and data for valid sparse
126+
# entries.
127+
sp_vals = ss.array.sp_values
128+
na_mask = notna(sp_vals)
129+
values = sp_vals[na_mask]
130+
valid_ilocs = ss.array.sp_index.indices[na_mask]
131+
132+
i_coords, i_labels = _levels_to_axis(
133+
ss, row_levels, valid_ilocs, sort_labels=sort_labels
134+
)
84135

85-
i_coord, i_labels = get_indexers(row_levels)
86-
j_coord, j_labels = get_indexers(column_levels)
136+
j_coords, j_labels = _levels_to_axis(
137+
ss, column_levels, valid_ilocs, sort_labels=sort_labels
138+
)
87139

88-
return values, i_coord, j_coord, i_labels, j_labels
140+
return values, i_coords, j_coords, i_labels, j_labels
89141

90142

91-
def sparse_series_to_coo(ss, row_levels=(0,), column_levels=(1,), sort_labels=False):
143+
def sparse_series_to_coo(
144+
ss: Series,
145+
row_levels: Iterable[int] = (0,),
146+
column_levels: Iterable[int] = (1,),
147+
sort_labels: bool = False,
148+
) -> tuple[scipy.sparse.coo_matrix, list[IndexLabel], list[IndexLabel]]:
92149
"""
93150
Convert a sparse Series to a scipy.sparse.coo_matrix using index
94151
levels row_levels, column_levels as the row and column
@@ -97,7 +154,7 @@ def sparse_series_to_coo(ss, row_levels=(0,), column_levels=(1,), sort_labels=Fa
97154
import scipy.sparse
98155

99156
if ss.index.nlevels < 2:
100-
raise ValueError("to_coo requires MultiIndex with nlevels > 2")
157+
raise ValueError("to_coo requires MultiIndex with nlevels >= 2.")
101158
if not ss.index.is_unique:
102159
raise ValueError(
103160
"Duplicate index entries are not allowed in to_coo transformation."
@@ -116,7 +173,9 @@ def sparse_series_to_coo(ss, row_levels=(0,), column_levels=(1,), sort_labels=Fa
116173
return sparse_matrix, rows, columns
117174

118175

119-
def coo_to_sparse_series(A, dense_index: bool = False):
176+
def coo_to_sparse_series(
177+
A: scipy.sparse.coo_matrix, dense_index: bool = False
178+
) -> Series:
120179
"""
121180
Convert a scipy.sparse.coo_matrix to a SparseSeries.
122181

pandas/tests/arrays/sparse/test_array.py

+42-6
Original file line numberDiff line numberDiff line change
@@ -1196,16 +1196,52 @@ def test_from_coo(self):
11961196
tm.assert_series_equal(result, expected)
11971197

11981198
@td.skip_if_no_scipy
1199-
def test_to_coo(self):
1199+
@pytest.mark.parametrize(
1200+
"sort_labels, expected_rows, expected_cols, expected_values_pos",
1201+
[
1202+
(
1203+
False,
1204+
[("b", 2), ("a", 2), ("b", 1), ("a", 1)],
1205+
[("z", 1), ("z", 2), ("x", 2), ("z", 0)],
1206+
{1: (1, 0), 3: (3, 3)},
1207+
),
1208+
(
1209+
True,
1210+
[("a", 1), ("a", 2), ("b", 1), ("b", 2)],
1211+
[("x", 2), ("z", 0), ("z", 1), ("z", 2)],
1212+
{1: (1, 2), 3: (0, 1)},
1213+
),
1214+
],
1215+
)
1216+
def test_to_coo(
1217+
self, sort_labels, expected_rows, expected_cols, expected_values_pos
1218+
):
12001219
import scipy.sparse
12011220

1202-
ser = pd.Series(
1203-
[1, 2, 3],
1204-
index=pd.MultiIndex.from_product([[0], [1, 2, 3]], names=["a", "b"]),
1205-
dtype="Sparse[int]",
1221+
values = SparseArray([0, np.nan, 1, 0, None, 3], fill_value=0)
1222+
index = pd.MultiIndex.from_tuples(
1223+
[
1224+
("b", 2, "z", 1),
1225+
("a", 2, "z", 2),
1226+
("a", 2, "z", 1),
1227+
("a", 2, "x", 2),
1228+
("b", 1, "z", 1),
1229+
("a", 1, "z", 0),
1230+
]
1231+
)
1232+
ss = pd.Series(values, index=index)
1233+
1234+
expected_A = np.zeros((4, 4))
1235+
for value, (row, col) in expected_values_pos.items():
1236+
expected_A[row, col] = value
1237+
1238+
A, rows, cols = ss.sparse.to_coo(
1239+
row_levels=(0, 1), column_levels=(2, 3), sort_labels=sort_labels
12061240
)
1207-
A, _, _ = ser.sparse.to_coo()
12081241
assert isinstance(A, scipy.sparse.coo.coo_matrix)
1242+
np.testing.assert_array_equal(A.toarray(), expected_A)
1243+
assert rows == expected_rows
1244+
assert cols == expected_cols
12091245

12101246
def test_non_sparse_raises(self):
12111247
ser = pd.Series([1, 2, 3])

0 commit comments

Comments
 (0)