Skip to content

Commit 0ea8e8b

Browse files
asanakoyNelleV
authored andcommitted
[MRG + 1] fix bug with negative values in cosine_distances (scikit-learn#7732)
* fix bug with negative values in cosine_distances clip distances to [0, 2] set distances between vectors and themselves to 0 * add test * add test on big random matrix * use np.diag_indices_from instead of slicing
1 parent 9d535ad commit 0ea8e8b

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

sklearn/metrics/pairwise.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,11 @@ def cosine_distances(X, Y=None):
570570
S = cosine_similarity(X, Y)
571571
S *= -1
572572
S += 1
573+
np.clip(S, 0, 2, out=S)
574+
if X is Y or Y is None:
575+
# Ensure that distances between vectors and themselves are set to 0.0.
576+
# This may not be the case due to floating point rounding errors.
577+
S[np.diag_indices_from(S)] = 0.0
573578
return S
574579

575580

sklearn/metrics/tests/test_pairwise.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,36 @@ def test_euclidean_distances():
407407
assert_greater(np.max(np.abs(wrong_D - D1)), .01)
408408

409409

410+
def test_cosine_distances():
411+
# Check the pairwise Cosine distances computation
412+
rng = np.random.RandomState(1337)
413+
x = np.abs(rng.rand(910))
414+
XA = np.vstack([x, x])
415+
D = cosine_distances(XA)
416+
assert_array_almost_equal(D, [[0., 0.], [0., 0.]])
417+
# check that all elements are in [0, 2]
418+
assert_true(np.all(D >= 0.))
419+
assert_true(np.all(D <= 2.))
420+
# check that diagonal elements are equal to 0
421+
assert_array_equal(D[np.diag_indices_from(D)], [0., 0.])
422+
423+
XB = np.vstack([x, -x])
424+
D2 = cosine_distances(XB)
425+
# check that all elements are in [0, 2]
426+
assert_true(np.all(D2 >= 0.))
427+
assert_true(np.all(D2 <= 2.))
428+
# check that diagonal elements are equal to 0 and non diagonal to 2
429+
assert_array_equal(D2, [[0., 2.], [2., 0.]])
430+
431+
# check large random matrix
432+
X = np.abs(rng.rand(1000, 5000))
433+
D = cosine_distances(X)
434+
# check that diagonal elements are equal to 0
435+
assert_array_almost_equal(D[np.diag_indices_from(D)], [0.] * D.shape[0])
436+
assert_true(np.all(D >= 0.))
437+
assert_true(np.all(D <= 2.))
438+
439+
410440
# Paired distances
411441

412442
def test_paired_euclidean_distances():

0 commit comments

Comments
 (0)