@@ -164,24 +164,13 @@ def linear_discriminant_analysis(
164
164
def test_linear_discriminant_analysis () -> None :
165
165
# Create dummy dataset with 2 classes and 3 features
166
166
features = np .array ([[1 , 2 , 3 , 4 , 5 ], [2 , 3 , 4 , 5 , 6 ], [3 , 4 , 5 , 6 , 7 ]])
167
- labels = np .array ([0 , 2 , 0 , 1 , 1 ])
168
- classes = 3
167
+ labels = np .array ([0 , 0 , 0 , 1 , 1 ])
168
+ classes = 2
169
169
dimensions = 2
170
170
171
- projected_data = linear_discriminant_analysis (features , labels , classes , dimensions )
172
-
173
- # Assert that the shape of the projected data is correct
174
- assert projected_data .shape == (dimensions , features .shape [1 ])
175
-
176
- # Assert that the projected data is a numpy array
177
- assert isinstance (projected_data , np .ndarray )
178
-
179
- # Assert that the projected data is not empty
180
- assert projected_data .any ()
181
-
182
171
# Assert that the function raises an AssertionError if dimensions > classes
183
172
with pytest .raises (AssertionError ) as error_info :
184
- projected_data = linear_discriminant_analysis (features , labels , classes , 3 )
173
+ projected_data = linear_discriminant_analysis (features , labels , classes , dimensions )
185
174
if isinstance (projected_data , np .ndarray ):
186
175
raise AssertionError (
187
176
"Did not raise AssertionError for dimensions > classes"
0 commit comments