Skip to content

Commit 1e7e303

Browse files
authored
Merge pull request #364 from pavlin-policar/sparse-1
ScPreprocess: Add sparse support
2 parents 80585b6 + 38d0fd5 commit 1e7e303

File tree

2 files changed

+262
-115
lines changed

2 files changed

+262
-115
lines changed

orangecontrib/single_cell/preprocess/scpreprocess.py

Lines changed: 80 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1-
from typing import Tuple, Optional
1+
from typing import Tuple
22
import warnings
3+
from typing import Union
4+
35
import numpy as np
46
import scipy.sparse as sp
57
from scipy.stats import zscore, percentileofscore
68

79
from Orange.data import Domain, Table
810
from Orange.preprocess.preprocess import Preprocess
911
from Orange.util import Enum
12+
import Orange.statistics.util as ut
13+
14+
15+
AnyArray = Union[np.ndarray, sp.csr_matrix, sp.csc_matrix]
1016

1117

1218
class LogarithmicScale(Preprocess):
@@ -17,14 +23,23 @@ class LogarithmicScale(Preprocess):
1723
def __init__(self, base=BinaryLog):
1824
self.base = base
1925

20-
def __call__(self, data):
26+
def __call__(self, data: Table) -> Table:
2127
new_data = data.copy()
28+
2229
if self.base == LogarithmicScale.BinaryLog:
23-
new_data.X = np.log2(1 + data.X)
24-
elif self.base == LogarithmicScale.NaturalLog:
25-
new_data.X = np.log(1 + data.X)
30+
def func(x, *args, **kwargs):
31+
return np.log2(x + 1, *args, **kwargs)
2632
elif self.base == LogarithmicScale.CommonLog:
27-
new_data.X = np.log10(1 + data.X)
33+
def func(x, *args, **kwargs):
34+
return np.log10(x + 1, *args, **kwargs)
35+
elif self.base == LogarithmicScale.NaturalLog:
36+
func = np.log1p
37+
38+
if sp.issparse(new_data.X):
39+
func(new_data.X.data, out=new_data.X.data)
40+
else:
41+
func(new_data.X, out=new_data.X)
42+
2843
return new_data
2944

3045

@@ -37,18 +52,17 @@ def __init__(self, condition=GreaterOrEqual, threshold=1):
3752
self.condition = condition
3853
self.threshold = threshold
3954

40-
def __call__(self, data):
55+
def __call__(self, data: Table) -> Table:
4156
new_data = data.copy()
4257
if self.condition == Binarize.GreaterOrEqual:
43-
new_data.X = np.where(data.X >= self.threshold, 1, 0)
58+
new_data.X = new_data.X >= self.threshold
4459
elif self.condition == Binarize.Greater:
45-
new_data.X = np.where(data.X > self.threshold, 1, 0)
60+
new_data.X = new_data.X > self.threshold
4661
return new_data
4762

4863

4964
class Normalize(Preprocess):
50-
Method = Enum("Normalize", ("CPM", "Median"),
51-
qualname="Normalize.Method")
65+
Method = Enum("Normalize", ("CPM", "Median"), qualname="Normalize.Method")
5266
CPM, Median = Method
5367

5468
def __init__(self, method=CPM):
@@ -62,46 +76,67 @@ def normalize(self, *args):
6276

6377

6478
class NormalizeSamples(Normalize):
65-
def __call__(self, data):
79+
def __call__(self, data: Table) -> Table:
6680
new_data = data.copy()
6781
new_data.X = self.normalize(data.X)
6882
return new_data
6983

70-
def normalize(self, table):
71-
row_sums = np.nansum(table, axis=1)
72-
row_sums[row_sums == 0] = 1
73-
table = table / row_sums[:, None]
74-
factor = np.nanmedian(row_sums) \
75-
if self.method == NormalizeSamples.Median else 10 ** 6
76-
return table * factor
84+
def normalize(self, table: AnyArray) -> AnyArray:
85+
row_sums = ut.nansum(table, axis=1)
86+
row_sums[row_sums == 0] = 1 # avoid division by zero errors
87+
88+
if self.method == NormalizeSamples.Median:
89+
factor = np.nanmedian(row_sums)
90+
else:
91+
factor = 1e6
92+
93+
if sp.issparse(table):
94+
table = sp.diags(1 / row_sums) @ table
95+
else:
96+
table = table / row_sums[:, None]
97+
98+
table *= factor
99+
100+
return table
77101

78102

79103
class NormalizeGroups(Normalize):
80104
def __init__(self, group_var, method=Normalize.CPM):
81105
super().__init__(method)
82106
self.group_var = group_var
83107

84-
def __call__(self, data):
108+
def __call__(self, data: Table) -> Table:
85109
group_col = data.get_column_view(self.group_var)[0]
86110
group_col = group_col.astype("int64")
87111
new_data = data.copy()
88112
new_data.X = self.normalize(data.X, group_col)
89113
return new_data
90114

91-
def normalize(self, table, group_col):
92-
group_sums = np.bincount(group_col, np.nansum(table, axis=1))
115+
def normalize(self, table: AnyArray, group_col: np.ndarray) -> AnyArray:
116+
group_sums = np.bincount(group_col, ut.nansum(table, axis=1))
93117
group_sums[group_sums == 0] = 1
94118
group_sums_row = np.zeros_like(group_col)
95119
medians = []
96-
row_sums = np.nansum(table, axis=1)
120+
row_sums = ut.nansum(table, axis=1)
97121
for value, group_sum in zip(np.unique(group_col), group_sums):
98122
mask = group_col == value
99123
group_sums_row[mask] = group_sum
100124
if self.method == NormalizeGroups.Median:
101125
medians.append(np.nanmedian(row_sums[mask]))
102-
factor = np.min(medians) \
103-
if self.method == NormalizeGroups.Median else 10 ** 6
104-
return table / group_sums_row[:, None] * factor
126+
127+
if self.method == NormalizeGroups.Median:
128+
factor = np.min(medians)
129+
else:
130+
factor = 1e6
131+
132+
if sp.issparse(table):
133+
table = sp.diags(1 / group_sums_row) @ table
134+
else:
135+
table = table / group_sums_row[:, None]
136+
137+
table *= factor
138+
139+
return table
105140

106141

107142
class Standardize(Preprocess):
@@ -129,10 +164,10 @@ def __init__(self, method=Dispersion, n_genes=1000, n_groups=20):
129164
self.n_genes = n_genes
130165
self.n_groups = n_groups if n_groups and n_groups > 1 else 1
131166

132-
def __call__(self, data):
167+
def __call__(self, data: Table) -> Table:
133168
n_groups = min(self.n_groups, len(data.domain.attributes))
134-
mean = np.nanmean(data.X, axis=0)
135-
variance = np.nanvar(data.X, axis=0)
169+
mean = ut.nanmean(data.X, axis=0)
170+
variance = ut.nanvar(data.X, axis=0)
136171
percentiles = [percentileofscore(mean, m) for m in mean]
137172
_, bins = np.histogram(percentiles, n_groups)
138173
bin_indices = np.digitize(percentiles, bins, True)
@@ -190,23 +225,24 @@ def __call__(self, data: Table) -> Table:
190225
warnings.warn(f"{sum(selected)} genes selected", DropoutWarning)
191226
return self.filter_columns(data, selected)
192227

193-
def detection(self, table: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
194-
mask = table > self.threshold
228+
def detection(self, table: AnyArray) -> Tuple[np.ndarray, np.ndarray]:
229+
with np.errstate(invalid="ignore"): # comparison can include nans
230+
mask = table > self.threshold
231+
195232
if sp.issparse(table):
196-
zero_rate = 1 - np.squeeze(np.array(mask.mean(axis=0)))
197-
A = table.multiply(mask)
198-
A.data = np.log2(A.data)
199-
mean_expr = np.zeros_like(zero_rate) * np.nan
200-
detected = zero_rate < 1
201-
detected_mean = np.squeeze(np.array(A[:, detected].mean(axis=0)))
202-
mean_expr[detected] = detected_mean / (1 - zero_rate[detected])
233+
A = table.copy()
234+
np.log2(A.data, out=A.data)
203235
else:
204-
zero_rate = 1 - np.mean(mask, axis=0)
205-
mean_expr = np.zeros_like(zero_rate) * np.nan
206-
detected = zero_rate < 1
207-
mean_expr[detected] = np.nanmean(
208-
np.where(table[:, detected] > self.threshold,
209-
np.log2(table[:, detected]), np.nan), axis=0)
236+
A = np.ma.log2(table) # avoid log2(0)
237+
A.mask = False
238+
239+
detection_rate = ut.nanmean(mask, axis=0)
240+
zero_rate = 1 - detection_rate
241+
detected = detection_rate > 0
242+
detected_mean = ut.nanmean(A[:, detected], axis=0)
243+
244+
mean_expr = np.full_like(zero_rate, fill_value=np.nan)
245+
mean_expr[detected] = detected_mean / detection_rate[detected]
210246

211247
low_detection = np.array(np.sum(mask, axis=0)).squeeze()
212248
zero_rate[low_detection < self.at_least] = np.nan

0 commit comments

Comments
 (0)