Skip to content

Commit 788a458

Browse files
ngoixamueller
authored andcommitted
[MRG+2] LOF algorithm (Anomaly Detection) (scikit-learn#5279)
* LOF algorithm add tests and example fix DepreciationWarning by reshape(1,-1) one-sample data LOF with inheritance lof and lof2 return same score fix bugs fix bugs optimized and cosmit rm lof2 cosmit rm MixinLOF + fit_predict fix travis - optimize pairwise_distance like in KNeighborsMixin.kneighbors add comparison example + doc LOF -> LocalOutlierFactor cosmit change LOF API: -fit(X).predict() and fit(X).decision_function() do prediction on X without considering samples as their own neighbors (ie without considering X as a new dataset as does fit(X).predict(X)) -rm fit_predict() method -add a contamination parameter st predict returns a binary value like other anomaly detection algos cosmit doc + debug example correction doc pass on doc + examples pep8 + fix warnings first attempt at fixing API issues minor changes takes into account tguillemot advice -remove pairwise_distance calculation as to heavy in memory -add benchmarks cosmit minor changes + deals with duplicates fix depreciation warnings * factorize the two for loops * take into account @albertthomas88 review and cosmit * fix doc * alex review + rebase * make predict private add outlier_factor_ attribute and update tests * make fit_predict take y argument * fix benchmarks file * update examples * make decision_function public (rm X=None default) * fix travis * take into account tguillemot review + remove useless k_distance function * fix broken links :meth:`kneighbors` * cosmit * whatsnew * amueller review + remove _local_outlier_factor method * add n_neighbors_ parameter the effective nb neighbors we use * make decision_function private and negative_outlier_factor attribute
1 parent 73d3f03 commit 788a458

File tree

12 files changed

+710
-33
lines changed

12 files changed

+710
-33
lines changed

benchmarks/bench_lof.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""
2+
============================
3+
LocalOutlierFactor benchmark
4+
============================
5+
6+
A test of LocalOutlierFactor on classical anomaly detection datasets.
7+
8+
"""
9+
10+
from time import time
11+
import numpy as np
12+
import matplotlib.pyplot as plt
13+
from sklearn.neighbors import LocalOutlierFactor
14+
from sklearn.metrics import roc_curve, auc
15+
from sklearn.datasets import fetch_kddcup99, fetch_covtype, fetch_mldata
16+
from sklearn.preprocessing import LabelBinarizer
17+
from sklearn.utils import shuffle as sh
18+
19+
print(__doc__)
20+
21+
np.random.seed(2)
22+
23+
# datasets available: ['http', 'smtp', 'SA', 'SF', 'shuttle', 'forestcover']
24+
datasets = ['shuttle']
25+
26+
novelty_detection = True # if False, training set polluted by outliers
27+
28+
for dataset_name in datasets:
29+
# loading and vectorization
30+
print('loading data')
31+
if dataset_name in ['http', 'smtp', 'SA', 'SF']:
32+
dataset = fetch_kddcup99(subset=dataset_name, shuffle=True,
33+
percent10=False)
34+
X = dataset.data
35+
y = dataset.target
36+
37+
if dataset_name == 'shuttle':
38+
dataset = fetch_mldata('shuttle')
39+
X = dataset.data
40+
y = dataset.target
41+
X, y = sh(X, y)
42+
# we remove data with label 4
43+
# normal data are then those of class 1
44+
s = (y != 4)
45+
X = X[s, :]
46+
y = y[s]
47+
y = (y != 1).astype(int)
48+
49+
if dataset_name == 'forestcover':
50+
dataset = fetch_covtype(shuffle=True)
51+
X = dataset.data
52+
y = dataset.target
53+
# normal data are those with attribute 2
54+
# abnormal those with attribute 4
55+
s = (y == 2) + (y == 4)
56+
X = X[s, :]
57+
y = y[s]
58+
y = (y != 2).astype(int)
59+
60+
print('vectorizing data')
61+
62+
if dataset_name == 'SF':
63+
lb = LabelBinarizer()
64+
lb.fit(X[:, 1])
65+
x1 = lb.transform(X[:, 1])
66+
X = np.c_[X[:, :1], x1, X[:, 2:]]
67+
y = (y != 'normal.').astype(int)
68+
69+
if dataset_name == 'SA':
70+
lb = LabelBinarizer()
71+
lb.fit(X[:, 1])
72+
x1 = lb.transform(X[:, 1])
73+
lb.fit(X[:, 2])
74+
x2 = lb.transform(X[:, 2])
75+
lb.fit(X[:, 3])
76+
x3 = lb.transform(X[:, 3])
77+
X = np.c_[X[:, :1], x1, x2, x3, X[:, 4:]]
78+
y = (y != 'normal.').astype(int)
79+
80+
if dataset_name == 'http' or dataset_name == 'smtp':
81+
y = (y != 'normal.').astype(int)
82+
83+
n_samples, n_features = np.shape(X)
84+
n_samples_train = n_samples // 2
85+
n_samples_test = n_samples - n_samples_train
86+
87+
X = X.astype(float)
88+
X_train = X[:n_samples_train, :]
89+
X_test = X[n_samples_train:, :]
90+
y_train = y[:n_samples_train]
91+
y_test = y[n_samples_train:]
92+
93+
if novelty_detection:
94+
X_train = X_train[y_train == 0]
95+
y_train = y_train[y_train == 0]
96+
97+
print('LocalOutlierFactor processing...')
98+
model = LocalOutlierFactor(n_neighbors=20)
99+
tstart = time()
100+
model.fit(X_train)
101+
fit_time = time() - tstart
102+
tstart = time()
103+
104+
scoring = -model.decision_function(X_test) # the lower, the more normal
105+
predict_time = time() - tstart
106+
fpr, tpr, thresholds = roc_curve(y_test, scoring)
107+
AUC = auc(fpr, tpr)
108+
plt.plot(fpr, tpr, lw=1,
109+
label=('ROC for %s (area = %0.3f, train-time: %0.2fs,'
110+
'test-time: %0.2fs)' % (dataset_name, AUC, fit_time,
111+
predict_time)))
112+
113+
plt.xlim([-0.05, 1.05])
114+
plt.ylim([-0.05, 1.05])
115+
plt.xlabel('False Positive Rate')
116+
plt.ylabel('True Positive Rate')
117+
plt.title('Receiver operating characteristic')
118+
plt.legend(loc="lower right")
119+
plt.show()

doc/modules/classes.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,8 @@ See the :ref:`metrics` section of the user guide for further details.
10501050
neighbors.LSHForest
10511051
neighbors.DistanceMetric
10521052
neighbors.KernelDensity
1053-
1053+
neighbors.LocalOutlierFactor
1054+
10541055
.. autosummary::
10551056
:toctree: generated/
10561057
:template: function.rst

doc/modules/outlier_detection.rst

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -165,18 +165,76 @@ This strategy is illustrated below.
165165

166166
* See :ref:`sphx_glr_auto_examples_covariance_plot_outlier_detection.py` for a
167167
comparison of :class:`ensemble.IsolationForest` with
168+
:class:`neighbors.LocalOutlierFactor`,
168169
:class:`svm.OneClassSVM` (tuned to perform like an outlier detection
169170
method) and a covariance-based outlier detection with
170-
:class:`covariance.MinCovDet`.
171+
:class:`covariance.EllipticEnvelope`.
171172

172173
.. topic:: References:
173174

174175
.. [LTZ2008] Liu, Fei Tony, Ting, Kai Ming and Zhou, Zhi-Hua. "Isolation forest."
175176
Data Mining, 2008. ICDM'08. Eighth IEEE International Conference on.
176177
177178
178-
One-class SVM versus Elliptic Envelope versus Isolation Forest
179-
--------------------------------------------------------------
179+
Local Outlier Factor
180+
--------------------
181+
Another efficient way to perform outlier detection on moderately high dimensional
182+
datasets is to use the Local Outlier Factor (LOF) algorithm.
183+
184+
The :class:`neighbors.LocalOutlierFactor` (LOF) algorithm computes a score
185+
(called local outlier factor) reflecting the degree of abnormality of the
186+
observations.
187+
It measures the local density deviation of a given data point with respect to
188+
its neighbors. The idea is to detect the samples that have a substantially
189+
lower density than their neighbors.
190+
191+
In practice the local density is obtained from the k-nearest neighbors.
192+
The LOF score of an observation is equal to the ratio of the
193+
average local density of his k-nearest neighbors, and its own local density:
194+
a normal instance is expected to have a local density similar to that of its
195+
neighbors, while abnormal data are expected to have much smaller local density.
196+
197+
The number k of neighbors considered, (alias parameter n_neighbors) is typically
198+
chosen 1) greater than the minimum number of objects a cluster has to contain,
199+
so that other objects can be local outliers relative to this cluster, and 2)
200+
smaller than the maximum number of close by objects that can potentially be
201+
local outliers.
202+
In practice, such informations are generally not available, and taking
203+
n_neighbors=20 appears to work well in general.
204+
When the proportion of outliers is high (i.e. greater than 10 \%, as in the
205+
example below), n_neighbors should be greater (n_neighbors=35 in the example
206+
below).
207+
208+
The strength of the LOF algorithm is that it takes both local and global
209+
properties of datasets into consideration: it can perform well even in datasets
210+
where abnormal samples have different underlying densities.
211+
The question is not, how isolated the sample is, but how isolated it is
212+
with respect to the surrounding neighborhood.
213+
214+
This strategy is illustrated below.
215+
216+
.. figure:: ../auto_examples/neighbors/images/sphx_glr_plot_lof_001.png
217+
:target: ../auto_examples/neighbors/plot_lof.html
218+
:align: center
219+
:scale: 75%
220+
221+
.. topic:: Examples:
222+
223+
* See :ref:`sphx_glr_auto_example_neighbors_plot_lof.py` for
224+
an illustration of the use of :class:`neighbors.LocalOutlierFactor`.
225+
226+
* See :ref:`sphx_glr_auto_example_covariance_plot_outlier_detection.py` for a
227+
comparison with other anomaly detection methods.
228+
229+
.. topic:: References:
230+
231+
.. [BKNS2000] Breunig, Kriegel, Ng, and Sander (2000)
232+
`LOF: identifying density-based local outliers.
233+
<http://www.dbs.ifi.lmu.de/Publikationen/Papers/LOF.pdf>`_
234+
Proc. ACM SIGMOD
235+
236+
One-class SVM versus Elliptic Envelope versus Isolation Forest versus LOF
237+
-------------------------------------------------------------------------
180238

181239
Strictly-speaking, the One-class SVM is not an outlier-detection method,
182240
but a novelty-detection method: its training set should not be
@@ -188,7 +246,8 @@ results in these situations.
188246
The examples below illustrate how the performance of the
189247
:class:`covariance.EllipticEnvelope` degrades as the data is less and
190248
less unimodal. The :class:`svm.OneClassSVM` works better on data with
191-
multiple modes and :class:`ensemble.IsolationForest` performs well in every cases.
249+
multiple modes and :class:`ensemble.IsolationForest` and
250+
:class:`neighbors.LocalOutlierFactor` perform well in every cases.
192251

193252
.. |outlier1| image:: ../auto_examples/covariance/images/sphx_glr_plot_outlier_detection_001.png
194253
:target: ../auto_examples/covariance/plot_outlier_detection.html
@@ -202,7 +261,7 @@ multiple modes and :class:`ensemble.IsolationForest` performs well in every case
202261
:target: ../auto_examples/covariance/plot_outlier_detection.html
203262
:scale: 50%
204263

205-
.. list-table:: **Comparing One-class SVM approach, and elliptic envelope**
264+
.. list-table:: **Comparing One-class SVM, Isolation Forest, LOF, and Elliptic Envelope**
206265
:widths: 40 60
207266

208267
*
@@ -213,31 +272,33 @@ multiple modes and :class:`ensemble.IsolationForest` performs well in every case
213272
opposite, the decision rule based on fitting an
214273
:class:`covariance.EllipticEnvelope` learns an ellipse, which
215274
fits well the inlier distribution. The :class:`ensemble.IsolationForest`
216-
performs as well.
217-
- |outlier1|
275+
and :class:`neighbors.LocalOutlierFactor` perform as well.
276+
- |outlier1|
218277

219278
*
220279
- As the inlier distribution becomes bimodal, the
221280
:class:`covariance.EllipticEnvelope` does not fit well the
222-
inliers. However, we can see that both :class:`ensemble.IsolationForest`
223-
and :class:`svm.OneClassSVM` have difficulties to detect the two modes,
281+
inliers. However, we can see that :class:`ensemble.IsolationForest`,
282+
:class:`svm.OneClassSVM` and :class:`neighbors.LocalOutlierFactor`
283+
have difficulties to detect the two modes,
224284
and that the :class:`svm.OneClassSVM`
225-
tends to overfit: because it has not model of inliers, it
285+
tends to overfit: because it has no model of inliers, it
226286
interprets a region where, by chance some outliers are
227287
clustered, as inliers.
228288
- |outlier2|
229289

230290
*
231291
- If the inlier distribution is strongly non Gaussian, the
232292
:class:`svm.OneClassSVM` is able to recover a reasonable
233-
approximation as well as :class:`ensemble.IsolationForest`,
293+
approximation as well as :class:`ensemble.IsolationForest`
294+
and :class:`neighbors.LocalOutlierFactor`,
234295
whereas the :class:`covariance.EllipticEnvelope` completely fails.
235296
- |outlier3|
236297

237298
.. topic:: Examples:
238299

239300
* See :ref:`sphx_glr_auto_examples_covariance_plot_outlier_detection.py` for a
240301
comparison of the :class:`svm.OneClassSVM` (tuned to perform like
241-
an outlier detection method), the :class:`ensemble.IsolationForest`
242-
and a covariance-based outlier
243-
detection with :class:`covariance.MinCovDet`.
302+
an outlier detection method), the :class:`ensemble.IsolationForest`,
303+
the :class:`neighbors.LocalOutlierFactor`
304+
and a covariance-based outlier detection :class:`covariance.EllipticEnvelope`.

doc/whats_new.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ Changelog
1616
New features
1717
............
1818

19+
- Added the :class:`neighbors.LocalOutlierFactor` class for anomaly detection based
20+
on nearest neighbors. By `Nicolas Goix`_ and `Alexandre Gramfort`_.
21+
1922
Enhancements
2023
............
2124

@@ -4740,7 +4743,7 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
47404743

47414744
.. _Eric Martin: http://www.ericmart.in
47424745

4743-
.. _Nicolas Goix: https://webperso.telecom-paristech.fr/front/frontoffice.php?SP_ID=241
4746+
.. _Nicolas Goix: https://perso.telecom-paristech.fr/~goix/
47444747

47454748
.. _Cory Lorenz: https://github.com/clorenz7
47464749

examples/covariance/plot_outlier_detection.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
hence more adapted to large-dimensional settings, even if it performs
1919
quite well in the examples below.
2020
21+
- using the Local Outlier Factor to measure the local deviation of a given
22+
data point with respect to its neighbors by comparing their local density.
23+
2124
The ground truth about inliers and outliers is given by the points colors
2225
while the orange-filled area indicates which points are reported as inliers
2326
by each method.
@@ -27,7 +30,6 @@
2730
threshold on the decision_function to separate out the corresponding
2831
fraction.
2932
"""
30-
print(__doc__)
3133

3234
import numpy as np
3335
from scipy import stats
@@ -37,6 +39,9 @@
3739
from sklearn import svm
3840
from sklearn.covariance import EllipticEnvelope
3941
from sklearn.ensemble import IsolationForest
42+
from sklearn.neighbors import LocalOutlierFactor
43+
44+
print(__doc__)
4045

4146
rng = np.random.RandomState(42)
4247

@@ -52,10 +57,13 @@
5257
"Robust covariance": EllipticEnvelope(contamination=outliers_fraction),
5358
"Isolation Forest": IsolationForest(max_samples=n_samples,
5459
contamination=outliers_fraction,
55-
random_state=rng)}
60+
random_state=rng),
61+
"Local Outlier Factor": LocalOutlierFactor(
62+
n_neighbors=35,
63+
contamination=outliers_fraction)}
5664

5765
# Compare given classifiers under given settings
58-
xx, yy = np.meshgrid(np.linspace(-7, 7, 500), np.linspace(-7, 7, 500))
66+
xx, yy = np.meshgrid(np.linspace(-7, 7, 100), np.linspace(-7, 7, 100))
5967
n_inliers = int((1. - outliers_fraction) * n_samples)
6068
n_outliers = int(outliers_fraction * n_samples)
6169
ground_truth = np.ones(n_samples, dtype=int)
@@ -72,19 +80,27 @@
7280
X = np.r_[X, np.random.uniform(low=-6, high=6, size=(n_outliers, 2))]
7381

7482
# Fit the model
75-
plt.figure(figsize=(10.8, 3.6))
83+
plt.figure(figsize=(9, 7))
7684
for i, (clf_name, clf) in enumerate(classifiers.items()):
7785
# fit the data and tag outliers
78-
clf.fit(X)
79-
scores_pred = clf.decision_function(X)
86+
if clf_name == "Local Outlier Factor":
87+
y_pred = clf.fit_predict(X)
88+
scores_pred = clf.negative_outlier_factor_
89+
else:
90+
clf.fit(X)
91+
scores_pred = clf.decision_function(X)
92+
y_pred = clf.predict(X)
8093
threshold = stats.scoreatpercentile(scores_pred,
8194
100 * outliers_fraction)
82-
y_pred = clf.predict(X)
8395
n_errors = (y_pred != ground_truth).sum()
8496
# plot the levels lines and the points
85-
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
97+
if clf_name == "Local Outlier Factor":
98+
# decision_function is private for LOF
99+
Z = clf._decision_function(np.c_[xx.ravel(), yy.ravel()])
100+
else:
101+
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
86102
Z = Z.reshape(xx.shape)
87-
subplot = plt.subplot(1, 3, i + 1)
103+
subplot = plt.subplot(2, 2, i + 1)
88104
subplot.contourf(xx, yy, Z, levels=np.linspace(Z.min(), threshold, 7),
89105
cmap=plt.cm.Blues_r)
90106
a = subplot.contour(xx, yy, Z, levels=[threshold],
@@ -97,11 +113,12 @@
97113
subplot.legend(
98114
[a.collections[0], b, c],
99115
['learned decision function', 'true inliers', 'true outliers'],
100-
prop=matplotlib.font_manager.FontProperties(size=11),
116+
prop=matplotlib.font_manager.FontProperties(size=10),
101117
loc='lower right')
102-
subplot.set_title("%d. %s (errors: %d)" % (i + 1, clf_name, n_errors))
118+
subplot.set_xlabel("%d. %s (errors: %d)" % (i + 1, clf_name, n_errors))
103119
subplot.set_xlim((-7, 7))
104120
subplot.set_ylim((-7, 7))
105-
plt.subplots_adjust(0.04, 0.1, 0.96, 0.92, 0.1, 0.26)
121+
plt.subplots_adjust(0.04, 0.1, 0.96, 0.94, 0.1, 0.26)
122+
plt.suptitle("Outlier detection")
106123

107124
plt.show()

0 commit comments

Comments
 (0)