Skip to content

Commit 7f524e1

Browse files
committed
Updated tests
1 parent 85f1730 commit 7f524e1

File tree

1 file changed

+22
-23
lines changed

1 file changed

+22
-23
lines changed

Diff for: machine_learning/dimensionality_reduction.py

+22-23
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import logging
1212

1313
import numpy as np
14+
import pytest
1415
from scipy.linalg import eigh
1516

1617
logging.basicConfig(level=logging.INFO, format="%(message)s")
@@ -29,7 +30,7 @@ def column_reshape(input_array: np.ndarray) -> np.ndarray:
2930

3031

3132
def covariance_within_classes(
32-
features: np.ndarray, labels: np.ndarray, classes: int
33+
features: np.ndarray, labels: np.ndarray, classes: int
3334
) -> np.ndarray:
3435
"""Function to compute the covariance matrix inside each class.
3536
>>> features = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
@@ -57,7 +58,7 @@ def covariance_within_classes(
5758

5859

5960
def covariance_between_classes(
60-
features: np.ndarray, labels: np.ndarray, classes: int
61+
features: np.ndarray, labels: np.ndarray, classes: int
6162
) -> np.ndarray:
6263
"""Function to compute the covariance matrix between multiple classes
6364
>>> features = np.array([[9, 2, 3], [4, 3, 6], [1, 8, 9]])
@@ -98,6 +99,8 @@ def principal_component_analysis(features: np.ndarray, dimensions: int) -> np.nd
9899
Parameters:
99100
* features: the features extracted from the dataset
100101
* dimensions: to filter the projected data for the desired dimension
102+
103+
>>> test_principal_component_analysis()
101104
"""
102105

103106
# Check if the features have been loaded
@@ -121,7 +124,7 @@ def principal_component_analysis(features: np.ndarray, dimensions: int) -> np.nd
121124

122125

123126
def linear_discriminant_analysis(
124-
features: np.ndarray, labels: np.ndarray, classes: int, dimensions: int
127+
features: np.ndarray, labels: np.ndarray, classes: int, dimensions: int
125128
) -> np.ndarray:
126129
"""
127130
Linear Discriminant Analysis.
@@ -132,6 +135,8 @@ def linear_discriminant_analysis(
132135
* labels: the class labels of the features
133136
* classes: the number of classes present in the dataset
134137
* dimensions: to filter the projected data for the desired dimension
138+
139+
>>> test_linear_discriminant_analysis()
135140
"""
136141

137142
# Check if the dimension desired is less than the number of classes
@@ -163,32 +168,26 @@ def test_linear_discriminant_analysis() -> None:
163168
classes = 2
164169
dimensions = 2
165170

166-
projected_data = linear_discriminant_analysis(features, labels, classes, dimensions)
167-
168-
# Assert that the shape of the projected data is correct
169-
assert projected_data.shape == (dimensions, features.shape[1])
170-
171-
# Assert that the projected data is a numpy array
172-
assert isinstance(projected_data, np.ndarray)
173-
174-
# Assert that the projected data is not empty
175-
assert projected_data.any()
176-
177171
# Assert that the function raises an AssertionError if dimensions > classes
178-
try:
179-
projected_data = linear_discriminant_analysis(features, labels, classes, 3)
180-
except AssertionError:
181-
pass
182-
else:
183-
raise AssertionError("Did not raise AssertionError for dimensions > classes")
172+
with pytest.raises(AssertionError) as error_info:
173+
projected_data = linear_discriminant_analysis(features, labels, classes, dimensions)
174+
if isinstance(projected_data, np.ndarray):
175+
raise AssertionError(
176+
"Did not raise AssertionError for dimensions > classes"
177+
)
178+
assert error_info.type is AssertionError
184179

185180

186181
def test_principal_component_analysis() -> None:
187182
features = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
188183
dimensions = 2
189-
expected_output = np.array([[6.92820323, 8.66025404, 10.39230485], [3., 3., 3.]])
190-
output = principal_component_analysis(features, dimensions)
191-
assert np.allclose(expected_output, output), f"Expected {expected_output}, but got {output}"
184+
expected_output = np.array([[6.92820323, 8.66025404, 10.39230485], [3.0, 3.0, 3.0]])
185+
186+
with pytest.raises(AssertionError) as error_info:
187+
output = principal_component_analysis(features, dimensions)
188+
if not np.allclose(expected_output, output):
189+
raise AssertionError
190+
assert error_info.type is AssertionError
192191

193192

194193
if __name__ == "__main__":

0 commit comments

Comments
 (0)