Skip to content

Commit 041aa1d

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5509e7d commit 041aa1d

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

machine_learning/dimensionality_reduction.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import scipy
66

7-
logging.basicConfig(level=logging.INFO, format='%(message)s')
7+
logging.basicConfig(level=logging.INFO, format="%(message)s")
88

99

1010
def column_reshape(input_array: np.ndarray) -> np.ndarray:
@@ -13,7 +13,9 @@ def column_reshape(input_array: np.ndarray) -> np.ndarray:
1313
return input_array.reshape((input_array.size, 1))
1414

1515

16-
def covariance_within_classes(features: np.ndarray, labels: np.ndarray, classes: int) -> np.ndarray:
16+
def covariance_within_classes(
17+
features: np.ndarray, labels: np.ndarray, classes: int
18+
) -> np.ndarray:
1719
"""Function to compute the covariance matrix inside each class"""
1820

1921
covariance_sum = np.nan
@@ -32,7 +34,9 @@ def covariance_within_classes(features: np.ndarray, labels: np.ndarray, classes:
3234
return covariance_sum / features.shape[1]
3335

3436

35-
def covariance_between_classes(features: np.ndarray, labels: np.ndarray, classes: int) -> np.ndarray:
37+
def covariance_between_classes(
38+
features: np.ndarray, labels: np.ndarray, classes: int
39+
) -> np.ndarray:
3640
"""Function to compute the covariance matrix between multiple classes"""
3741

3842
general_data_mean = features.mean(1)
@@ -43,12 +47,16 @@ def covariance_between_classes(features: np.ndarray, labels: np.ndarray, classes
4347
data_mean = data.mean(1)
4448
if i > 0:
4549
# If covariance_sum is not None
46-
covariance_sum += device_data * np.dot(column_reshape(data_mean) - column_reshape(general_data_mean),
47-
(column_reshape(data_mean) - column_reshape(general_data_mean)).T)
50+
covariance_sum += device_data * np.dot(
51+
column_reshape(data_mean) - column_reshape(general_data_mean),
52+
(column_reshape(data_mean) - column_reshape(general_data_mean)).T,
53+
)
4854
else:
4955
# If covariance_sum is np.nan (i.e. first loop)
50-
covariance_sum = device_data * np.dot(column_reshape(data_mean) - column_reshape(general_data_mean),
51-
(column_reshape(data_mean) - column_reshape(general_data_mean)).T)
56+
covariance_sum = device_data * np.dot(
57+
column_reshape(data_mean) - column_reshape(general_data_mean),
58+
(column_reshape(data_mean) - column_reshape(general_data_mean)).T,
59+
)
5260

5361
return covariance_sum / features.shape[1]
5462

@@ -76,12 +84,14 @@ def PCA(features: np.ndarray, dimensions: int) -> np.ndarray:
7684

7785
return projected_data
7886
else:
79-
logging.basicConfig(level=logging.ERROR, format='%(message)s', force=True)
87+
logging.basicConfig(level=logging.ERROR, format="%(message)s", force=True)
8088
logging.error("Dataset empty")
8189
raise AssertionError
8290

8391

84-
def LDA(features: np.ndarray, labels: np.ndarray, classes: int, dimensions: int) -> np.ndarray:
92+
def LDA(
93+
features: np.ndarray, labels: np.ndarray, classes: int, dimensions: int
94+
) -> np.ndarray:
8595
"""Linear Discriminant Analysis. \n
8696
For more details, see here: https://en.wikipedia.org/wiki/Linear_discriminant_analysis \n
8797
Parameters: \n
@@ -97,7 +107,8 @@ def LDA(features: np.ndarray, labels: np.ndarray, classes: int, dimensions: int)
97107
if features.any:
98108
_, eigenvectors = scipy.linalg.eigh(
99109
covariance_between_classes(features, labels, classes),
100-
covariance_within_classes(features, labels, classes))
110+
covariance_within_classes(features, labels, classes),
111+
)
101112
filtered_eigenvectors = eigenvectors[:, ::-1][:, :dimensions]
102113
svd_matrix, _, _ = np.linalg.svd(filtered_eigenvectors)
103114
filtered_svd_matrix = svd_matrix[:, 0:dimensions]
@@ -106,6 +117,6 @@ def LDA(features: np.ndarray, labels: np.ndarray, classes: int, dimensions: int)
106117

107118
return projected_data
108119
else:
109-
logging.basicConfig(level=logging.ERROR, format='%(message)s', force=True)
120+
logging.basicConfig(level=logging.ERROR, format="%(message)s", force=True)
110121
logging.error("Dataset empty")
111122
raise AssertionError

0 commit comments

Comments
 (0)