|
| 1 | +""" |
| 2 | +Cook's Distance is used to estimate the influence of a data point in |
| 3 | +in least squares regression. |
| 4 | +
|
| 5 | +Cook's Distance removes each data point and measures the effect of removing the |
| 6 | +data point. |
| 7 | +
|
| 8 | +The algorithm works as follows: |
| 9 | +For each data point in the regression, remove the point from the set |
| 10 | + and calculate the effect of removing that point. |
| 11 | +
|
| 12 | +D_i = (sum over all other points(y_actual - y_observed)^2) / (rank * MSE^2) |
| 13 | +
|
| 14 | +
|
| 15 | +https://en.wikipedia.org/wiki/Cook's_distance |
| 16 | +""" |
| 17 | +from machine_learning.loss_functions.mean_squared_error import mean_squared_error |
| 18 | +import numpy as np |
| 19 | + |
| 20 | +def calculate_cooks_distance(y_observed: array, y_fitted: array, rank: int) -> array: |
| 21 | + """Calculate Cook's Distance |
| 22 | + Input: |
| 23 | + y_observed: numpy array of observed y values |
| 24 | + y_fitted: numpy array of fitted y values from linear regression model |
| 25 | + rank: int representing the number of coefficients |
| 26 | + Output: |
| 27 | + cooks_distance: numpy array of Cook's distance for each y value. |
| 28 | + |
| 29 | + """ |
| 30 | + import numpy as np |
| 31 | + _mse = mean_squared_error(y_observed, y_fitted) |
| 32 | + _y_difference_squared = (y_observed - y_fitted)**2 |
| 33 | + |
| 34 | + if isinstance(rank) is not int: |
| 35 | + msg = f"Rank is an integer representing the number of predictors. Input: {rank}" |
| 36 | + raise TypeError(msg) |
| 37 | + |
| 38 | + if len(y_observed) != len(y_fitted): |
| 39 | + msg = f"The arrays of observed and fitted values must be equal length. Currently |
| 40 | + observed = {len(y_observed)} and fitted = {len(y_fitted)}" |
| 41 | + raise ValueError(msg) |
| 42 | + |
| 43 | + if len(y_observed) == 0: |
| 44 | + raise ValueError("The y value arrays must not be empty") |
| 45 | + |
| 46 | + _summed_difference = sum(_y_difference_squared) |
| 47 | + for item in np.nditer(_y_difference_squared): |
| 48 | + k = (_summed_difference - item) / (rank * _mse) |
| 49 | + |
0 commit comments