Skip to content

Commit 834d2e3

Browse files
committed
Updated tests
1 parent 55850d4 commit 834d2e3

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

Diff for: machine_learning/dimensionality_reduction.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
import logging
12+
import pytest
1213

1314
import numpy as np
1415
from scipy.linalg import eigh
@@ -98,6 +99,8 @@ def principal_component_analysis(features: np.ndarray, dimensions: int) -> np.nd
9899
Parameters:
99100
* features: the features extracted from the dataset
100101
* dimensions: to filter the projected data for the desired dimension
102+
103+
>>> test_principal_component_analysis()
101104
"""
102105

103106
# Check if the features have been loaded
@@ -132,6 +135,8 @@ def linear_discriminant_analysis(
132135
* labels: the class labels of the features
133136
* classes: the number of classes present in the dataset
134137
* dimensions: to filter the projected data for the desired dimension
138+
139+
>>> test_linear_discriminant_analysis()
135140
"""
136141

137142
# Check if the dimension desired is less than the number of classes
@@ -175,22 +180,23 @@ def test_linear_discriminant_analysis() -> None:
175180
assert projected_data.any()
176181

177182
# Assert that the function raises an AssertionError if dimensions > classes
178-
try:
183+
with pytest.raises(AssertionError) as error_info:
179184
projected_data = linear_discriminant_analysis(features, labels, classes, 3)
180-
except AssertionError:
181-
pass
182-
else:
183-
raise AssertionError("Did not raise AssertionError for dimensions > classes")
185+
if isinstance(projected_data, np.ndarray):
186+
raise AssertionError("Did not raise AssertionError for dimensions > classes")
187+
assert error_info.type is AssertionError
184188

185189

186190
def test_principal_component_analysis() -> None:
187191
features = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
188192
dimensions = 2
189193
expected_output = np.array([[6.92820323, 8.66025404, 10.39230485], [3.0, 3.0, 3.0]])
190-
output = principal_component_analysis(features, dimensions)
191-
assert np.allclose(
192-
expected_output, output
193-
), f"Expected {expected_output}, but got {output}"
194+
195+
with pytest.raises(AssertionError) as error_info:
196+
output = principal_component_analysis(features, dimensions)
197+
if not np.allclose(expected_output, output):
198+
raise AssertionError
199+
assert error_info.type is AssertionError
194200

195201

196202
if __name__ == "__main__":

0 commit comments

Comments
 (0)