Skip to content

Commit 70c2f46

Browse files
committed
Switch to using math.lgamma
This reduces the amount of `xp.asarray` that we need to convert scalars to arrays for the array API
1 parent 5e4356a commit 70c2f46

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

sklearn/decomposition/_pca.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,18 @@
33
# Authors: The scikit-learn developers
44
# SPDX-License-Identifier: BSD-3-Clause
55

6-
from math import log, sqrt
6+
from math import lgamma, log, sqrt
77
from numbers import Integral, Real
88

99
import numpy as np
1010
from scipy import linalg
1111
from scipy.sparse import issparse
1212
from scipy.sparse.linalg import svds
13-
from scipy.special import gammaln
1413

1514
from ..base import _fit_context
1615
from ..utils import check_random_state
1716
from ..utils._arpack import _init_arpack_v0
18-
from ..utils._array_api import _convert_to_numpy, device, get_namespace
17+
from ..utils._array_api import _convert_to_numpy, get_namespace
1918
from ..utils._param_validation import Interval, RealNotInt, StrOptions
2019
from ..utils.extmath import fast_logdet, randomized_svd, stable_cumsum, svd_flip
2120
from ..utils.sparsefuncs import _implicit_column_offset, mean_variance_axis
@@ -71,8 +70,7 @@ def _assess_dimension(spectrum, rank, n_samples):
7170
pu = -rank * log(2.0)
7271
for i in range(1, rank + 1):
7372
pu += (
74-
gammaln((n_features - i + 1) / 2.0)
75-
- log(xp.pi) * (n_features - i + 1) / 2.0
73+
lgamma((n_features - i + 1) / 2.0) - log(xp.pi) * (n_features - i + 1) / 2.0
7674
)
7775

7876
pl = xp.sum(xp.log(spectrum[:rank]))
@@ -93,7 +91,6 @@ def _assess_dimension(spectrum, rank, n_samples):
9391
(spectrum[i] - spectrum[j]) * (1.0 / spectrum_[j] - 1.0 / spectrum_[i])
9492
) + log(n_samples)
9593

96-
pu = xp.asarray(pu, device=device(spectrum), dtype=spectrum.dtype)
9794
ll = pu + pl + pv + pp - pa / 2.0 - rank * log(n_samples) / 2.0
9895

9996
return ll

0 commit comments

Comments
 (0)