Skip to content

Commit 379f54b

Browse files
KaiTomDLT
authored andcommitted
[MRG+2] DOC correct docstring for sample_gaussian (scikit-learn#6957)
* DOC correct docstring for sample_gaussian Docstring did not match function behaviour. Caused some trouble trying to implement a compatible version for a different distribution. * MAINT Use _sample_gaussian internally This avoids triggering the deprecation warning. sample_gaussian is just a thin wrapper. * Fix _sample_gaussian test path _sample_gaussian is not publicly exported in mixture, so use gmm module.
1 parent 0ea8e8b commit 379f54b

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

sklearn/mixture/gmm.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def sample_gaussian(mean, covar, covariance_type='diag', n_samples=1,
8484
mean : array_like, shape (n_features,)
8585
Mean of the distribution.
8686
87-
covar : array_like, optional
87+
covar : array_like
8888
Covariance of the distribution. The shape depends on `covariance_type`:
8989
scalar if 'spherical',
9090
(n_features) if 'diag',
@@ -99,9 +99,17 @@ def sample_gaussian(mean, covar, covariance_type='diag', n_samples=1,
9999
100100
Returns
101101
-------
102-
X : array, shape (n_features, n_samples)
103-
Randomly generated sample
102+
X : array
103+
Randomly generated sample. The shape depends on `n_samples`:
104+
(n_features,) if `1`
105+
(n_features, n_samples) otherwise
104106
"""
107+
_sample_gaussian(mean, covar, covariance_type='diag', n_samples=1,
108+
random_state=None)
109+
110+
111+
def _sample_gaussian(mean, covar, covariance_type='diag', n_samples=1,
112+
random_state=None):
105113
rng = check_random_state(random_state)
106114
n_dim = len(mean)
107115
rand = rng.randn(n_dim, n_samples)
@@ -423,7 +431,7 @@ def sample(self, n_samples=1, random_state=None):
423431
cv = self.covars_[comp][0]
424432
else:
425433
cv = self.covars_[comp]
426-
X[comp_in_X] = sample_gaussian(
434+
X[comp_in_X] = _sample_gaussian(
427435
self.means_[comp], cv, self.covariance_type,
428436
num_comp_in_X, random_state=random_state).T
429437
return X

sklearn/mixture/tests/test_gmm.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ def test_sample_gaussian():
3333
mu = rng.randint(10) * rng.rand(n_features)
3434
cv = (rng.rand(n_features) + 1.0) ** 2
3535

36-
samples = mixture.sample_gaussian(
36+
samples = mixture.gmm._sample_gaussian(
3737
mu, cv, covariance_type='diag', n_samples=n_samples)
3838

3939
assert_true(np.allclose(samples.mean(axis), mu, atol=1.3))
4040
assert_true(np.allclose(samples.var(axis), cv, atol=1.5))
4141

4242
# the same for spherical covariances
4343
cv = (rng.rand() + 1.0) ** 2
44-
samples = mixture.sample_gaussian(
44+
samples = mixture.gmm._sample_gaussian(
4545
mu, cv, covariance_type='spherical', n_samples=n_samples)
4646

4747
assert_true(np.allclose(samples.mean(axis), mu, atol=1.5))
@@ -51,16 +51,15 @@ def test_sample_gaussian():
5151
# and for full covariances
5252
A = rng.randn(n_features, n_features)
5353
cv = np.dot(A.T, A) + np.eye(n_features)
54-
samples = mixture.sample_gaussian(
54+
samples = mixture.gmm._sample_gaussian(
5555
mu, cv, covariance_type='full', n_samples=n_samples)
5656
assert_true(np.allclose(samples.mean(axis), mu, atol=1.3))
5757
assert_true(np.allclose(np.cov(samples), cv, atol=2.5))
5858

5959
# Numerical stability check: in SciPy 0.12.0 at least, eigh may return
6060
# tiny negative values in its second return value.
61-
from sklearn.mixture import sample_gaussian
62-
x = sample_gaussian([0, 0], [[4, 3], [1, .1]],
63-
covariance_type='full', random_state=42)
61+
x = mixture.gmm._sample_gaussian(
62+
[0, 0], [[4, 3], [1, .1]], covariance_type='full', random_state=42)
6463
assert_true(np.isfinite(x).all())
6564

6665

0 commit comments

Comments
 (0)