4
4
import numpy as np
5
5
import scipy
6
6
7
- logging .basicConfig (level = logging .INFO , format = ' %(message)s' )
7
+ logging .basicConfig (level = logging .INFO , format = " %(message)s" )
8
8
9
9
10
10
def column_reshape (input_array : np .ndarray ) -> np .ndarray :
@@ -13,7 +13,9 @@ def column_reshape(input_array: np.ndarray) -> np.ndarray:
13
13
return input_array .reshape ((input_array .size , 1 ))
14
14
15
15
16
- def covariance_within_classes (features : np .ndarray , labels : np .ndarray , classes : int ) -> np .ndarray :
16
+ def covariance_within_classes (
17
+ features : np .ndarray , labels : np .ndarray , classes : int
18
+ ) -> np .ndarray :
17
19
"""Function to compute the covariance matrix inside each class"""
18
20
19
21
covariance_sum = np .nan
@@ -32,7 +34,9 @@ def covariance_within_classes(features: np.ndarray, labels: np.ndarray, classes:
32
34
return covariance_sum / features .shape [1 ]
33
35
34
36
35
- def covariance_between_classes (features : np .ndarray , labels : np .ndarray , classes : int ) -> np .ndarray :
37
+ def covariance_between_classes (
38
+ features : np .ndarray , labels : np .ndarray , classes : int
39
+ ) -> np .ndarray :
36
40
"""Function to compute the covariance matrix between multiple classes"""
37
41
38
42
general_data_mean = features .mean (1 )
@@ -43,12 +47,16 @@ def covariance_between_classes(features: np.ndarray, labels: np.ndarray, classes
43
47
data_mean = data .mean (1 )
44
48
if i > 0 :
45
49
# If covariance_sum is not None
46
- covariance_sum += device_data * np .dot (column_reshape (data_mean ) - column_reshape (general_data_mean ),
47
- (column_reshape (data_mean ) - column_reshape (general_data_mean )).T )
50
+ covariance_sum += device_data * np .dot (
51
+ column_reshape (data_mean ) - column_reshape (general_data_mean ),
52
+ (column_reshape (data_mean ) - column_reshape (general_data_mean )).T ,
53
+ )
48
54
else :
49
55
# If covariance_sum is np.nan (i.e. first loop)
50
- covariance_sum = device_data * np .dot (column_reshape (data_mean ) - column_reshape (general_data_mean ),
51
- (column_reshape (data_mean ) - column_reshape (general_data_mean )).T )
56
+ covariance_sum = device_data * np .dot (
57
+ column_reshape (data_mean ) - column_reshape (general_data_mean ),
58
+ (column_reshape (data_mean ) - column_reshape (general_data_mean )).T ,
59
+ )
52
60
53
61
return covariance_sum / features .shape [1 ]
54
62
@@ -76,12 +84,14 @@ def PCA(features: np.ndarray, dimensions: int) -> np.ndarray:
76
84
77
85
return projected_data
78
86
else :
79
- logging .basicConfig (level = logging .ERROR , format = ' %(message)s' , force = True )
87
+ logging .basicConfig (level = logging .ERROR , format = " %(message)s" , force = True )
80
88
logging .error ("Dataset empty" )
81
89
raise AssertionError
82
90
83
91
84
- def LDA (features : np .ndarray , labels : np .ndarray , classes : int , dimensions : int ) -> np .ndarray :
92
+ def LDA (
93
+ features : np .ndarray , labels : np .ndarray , classes : int , dimensions : int
94
+ ) -> np .ndarray :
85
95
"""Linear Discriminant Analysis. \n
86
96
For more details, see here: https://en.wikipedia.org/wiki/Linear_discriminant_analysis \n
87
97
Parameters: \n
@@ -97,7 +107,8 @@ def LDA(features: np.ndarray, labels: np.ndarray, classes: int, dimensions: int)
97
107
if features .any :
98
108
_ , eigenvectors = scipy .linalg .eigh (
99
109
covariance_between_classes (features , labels , classes ),
100
- covariance_within_classes (features , labels , classes ))
110
+ covariance_within_classes (features , labels , classes ),
111
+ )
101
112
filtered_eigenvectors = eigenvectors [:, ::- 1 ][:, :dimensions ]
102
113
svd_matrix , _ , _ = np .linalg .svd (filtered_eigenvectors )
103
114
filtered_svd_matrix = svd_matrix [:, 0 :dimensions ]
@@ -106,6 +117,6 @@ def LDA(features: np.ndarray, labels: np.ndarray, classes: int, dimensions: int)
106
117
107
118
return projected_data
108
119
else :
109
- logging .basicConfig (level = logging .ERROR , format = ' %(message)s' , force = True )
120
+ logging .basicConfig (level = logging .ERROR , format = " %(message)s" , force = True )
110
121
logging .error ("Dataset empty" )
111
122
raise AssertionError
0 commit comments