Skip to content

Commit 6e2a80e

Browse files
committed
Added Cook's Distance
1 parent 76acc6d commit 6e2a80e

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

maths/cooks_distance.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""
2+
Cook's Distance is used to estimate the influence of a data point in
3+
in least squares regression.
4+
5+
Cook's Distance removes each data point and measures the effect of removing the
6+
data point.
7+
8+
The algorithm works as follows:
9+
For each data point in the regression, remove the point from the set
10+
and calculate the effect of removing that point.
11+
12+
D_i = (sum over all other points(y_actual - y_observed)^2) / (rank * MSE^2)
13+
14+
15+
https://en.wikipedia.org/wiki/Cook's_distance
16+
"""
17+
from machine_learning.loss_functions.mean_squared_error import mean_squared_error
18+
import numpy as np
19+
20+
def calculate_cooks_distance(y_observed: array, y_fitted: array, rank: int) -> array:
21+
"""Calculate Cook's Distance
22+
Input:
23+
y_observed: numpy array of observed y values
24+
y_fitted: numpy array of fitted y values from linear regression model
25+
rank: int representing the number of coefficients
26+
Output:
27+
cooks_distance: numpy array of Cook's distance for each y value.
28+
29+
"""
30+
import numpy as np
31+
_mse = mean_squared_error(y_observed, y_fitted)
32+
_y_difference_squared = (y_observed - y_fitted)**2
33+
34+
if isinstance(rank) is not int:
35+
msg = f"Rank is an integer representing the number of predictors. Input: {rank}"
36+
raise TypeError(msg)
37+
38+
if len(y_observed) != len(y_fitted):
39+
msg = f"The arrays of observed and fitted values must be equal length. Currently
40+
observed = {len(y_observed)} and fitted = {len(y_fitted)}"
41+
raise ValueError(msg)
42+
43+
if len(y_observed) == 0:
44+
raise ValueError("The y value arrays must not be empty")
45+
46+
_summed_difference = sum(_y_difference_squared)
47+
for item in np.nditer(_y_difference_squared):
48+
k = (_summed_difference - item) / (rank * _mse)
49+

0 commit comments

Comments
 (0)