Skip to content

Commit 978375f

Browse files
author
Pablo Osorio Lopez
committed
Reduce complexity linear_discriminant_analysis.
1 parent 9b73884 commit 978375f

File tree

1 file changed

+62
-65
lines changed

1 file changed

+62
-65
lines changed

machine_learning/linear_discriminant_analysis.py

Lines changed: 62 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from math import log
4545
from os import name, system
4646
from random import gauss, seed
47+
from typing import TypeVar, Callable
4748

4849

4950
# Make a training dataset drawn from a gaussian distribution
@@ -245,6 +246,38 @@ def accuracy(actual_y: list, predicted_y: list) -> float:
245246
return (correct / len(actual_y)) * 100
246247

247248

249+
num = TypeVar("num")
250+
251+
252+
def valid_input(
253+
input_type: Callable[[object], num], # Usually float or int
254+
input_msg: str,
255+
err_msg: str,
256+
condition: Callable[[num], bool] = lambda x: True,
257+
default: str = None,
258+
) -> num:
259+
"""
260+
Ask for user value and validate that it fulfill a condition.
261+
262+
:input_type: user input expected type of value
263+
:input_msg: message to show user in the screen
264+
:err_msg: message to show in the screen in case of error
265+
:condition: function that represents the condition that user input is valid.
266+
:default: Default value in case the user does not type anything
267+
:return: user's input
268+
"""
269+
while True:
270+
try:
271+
user_input = input_type(input(input_msg).strip() or default)
272+
if condition(user_input):
273+
return user_input
274+
else:
275+
print(f"Your entered value is {user_input}. {err_msg}")
276+
continue
277+
except ValueError:
278+
print("Your entered value is not numerical!")
279+
280+
248281
# Main Function
249282
def main():
250283
""" This function starts execution phase """
@@ -254,87 +287,51 @@ def main():
254287
print("First of all we should specify the number of classes that")
255288
print("we want to generate as training dataset")
256289
# Trying to get number of classes
257-
n_classes = 0
258-
while True:
259-
try:
260-
user_input = int(
261-
input("Enter the number of classes (Data Groupings): ").strip()
262-
)
263-
if user_input > 0:
264-
n_classes = user_input
265-
break
266-
else:
267-
print(
268-
f"Your entered value is {user_input} , Number of classes "
269-
f"should be positive!"
270-
)
271-
continue
272-
except ValueError:
273-
print("Your entered value is not numerical!")
290+
n_classes = valid_input(
291+
input_type=int,
292+
condition=lambda x: x > 0,
293+
input_msg="Enter the number of classes (Data Groupings): ",
294+
err_msg="Number of classes should be positive!",
295+
)
274296

275297
print("-" * 100)
276298

277-
std_dev = 1.0 # Default value for standard deviation of dataset
278299
# Trying to get the value of standard deviation
279-
while True:
280-
try:
281-
user_sd = float(
282-
input(
283-
"Enter the value of standard deviation"
284-
"(Default value is 1.0 for all classes): "
285-
).strip()
286-
or "1.0"
287-
)
288-
if user_sd >= 0.0:
289-
std_dev = user_sd
290-
break
291-
else:
292-
print(
293-
f"Your entered value is {user_sd}, Standard deviation should "
294-
f"not be negative!"
295-
)
296-
continue
297-
except ValueError:
298-
print("Your entered value is not numerical!")
300+
std_dev = valid_input(
301+
input_type=float,
302+
condition=lambda x: x >= 0,
303+
input_msg=(
304+
"Enter the value of standard deviation"
305+
"(Default value is 1.0 for all classes): "
306+
),
307+
err_msg="Standard deviation should not be negative!",
308+
default="1.0",
309+
)
299310

300311
print("-" * 100)
301312

302313
# Trying to get number of instances in classes and theirs means to generate
303314
# dataset
304315
counts = [] # An empty list to store instance counts of classes in dataset
305316
for i in range(n_classes):
306-
while True:
307-
try:
308-
user_count = int(
309-
input(f"Enter The number of instances for class_{i+1}: ")
310-
)
311-
if user_count > 0:
312-
counts.append(user_count)
313-
break
314-
else:
315-
print(
316-
f"Your entered value is {user_count}, Number of "
317-
"instances should be positive!"
318-
)
319-
continue
320-
except ValueError:
321-
print("Your entered value is not numerical!")
317+
user_count = valid_input(
318+
input_type=int,
319+
condition=lambda x: x > 0,
320+
input_msg=(f"Enter The number of instances for class_{i+1}: "),
321+
err_msg="Number of instances should be positive!",
322+
)
323+
counts.append(user_count)
322324
print("-" * 100)
323325

324326
# An empty list to store values of user-entered means of classes
325327
user_means = []
326328
for a in range(n_classes):
327-
while True:
328-
try:
329-
user_mean = float(
330-
input(f"Enter the value of mean for class_{a+1}: ")
331-
)
332-
if isinstance(user_mean, float):
333-
user_means.append(user_mean)
334-
break
335-
print(f"You entered an invalid value: {user_mean}")
336-
except ValueError:
337-
print("Your entered value is not numerical!")
329+
user_mean = valid_input(
330+
input_type=float,
331+
input_msg=(f"Enter the value of mean for class_{a+1}: "),
332+
err_msg="This is an invalid value.",
333+
)
334+
user_means.append(user_mean)
338335
print("-" * 100)
339336

340337
print("Standard deviation: ", std_dev)

0 commit comments

Comments
 (0)