Skip to content

Reduce complexity linear_discriminant_analysis. #2452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 11, 2020
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 62 additions & 65 deletions machine_learning/linear_discriminant_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from math import log
from os import name, system
from random import gauss, seed
from typing import TypeVar, Callable


# Make a training dataset drawn from a gaussian distribution
Expand Down Expand Up @@ -245,6 +246,38 @@ def accuracy(actual_y: list, predicted_y: list) -> float:
return (correct / len(actual_y)) * 100


num = TypeVar("num")


def valid_input(
input_type: Callable[[object], num], # Usually float or int
input_msg: str,
err_msg: str,
condition: Callable[[num], bool] = lambda x: True,
default: str = None,
) -> num:
"""
Ask for user value and validate that it fulfill a condition.

:input_type: user input expected type of value
:input_msg: message to show user in the screen
:err_msg: message to show in the screen in case of error
:condition: function that represents the condition that user input is valid.
:default: Default value in case the user does not type anything
:return: user's input
"""
while True:
try:
user_input = input_type(input(input_msg).strip() or default)
if condition(user_input):
return user_input
else:
print(f"Your entered value is {user_input}. {err_msg}")
continue
except ValueError:
print("Your entered value is not numerical!")


# Main Function
def main():
""" This function starts execution phase """
Expand All @@ -254,87 +287,51 @@ def main():
print("First of all we should specify the number of classes that")
print("we want to generate as training dataset")
# Trying to get number of classes
n_classes = 0
while True:
try:
user_input = int(
input("Enter the number of classes (Data Groupings): ").strip()
)
if user_input > 0:
n_classes = user_input
break
else:
print(
f"Your entered value is {user_input} , Number of classes "
f"should be positive!"
)
continue
except ValueError:
print("Your entered value is not numerical!")
n_classes = valid_input(
input_type=int,
condition=lambda x: x > 0,
input_msg="Enter the number of classes (Data Groupings): ",
err_msg="Number of classes should be positive!",
)

print("-" * 100)

std_dev = 1.0 # Default value for standard deviation of dataset
# Trying to get the value of standard deviation
while True:
try:
user_sd = float(
input(
"Enter the value of standard deviation"
"(Default value is 1.0 for all classes): "
).strip()
or "1.0"
)
if user_sd >= 0.0:
std_dev = user_sd
break
else:
print(
f"Your entered value is {user_sd}, Standard deviation should "
f"not be negative!"
)
continue
except ValueError:
print("Your entered value is not numerical!")
std_dev = valid_input(
input_type=float,
condition=lambda x: x >= 0,
input_msg=(
"Enter the value of standard deviation"
"(Default value is 1.0 for all classes): "
),
err_msg="Standard deviation should not be negative!",
default="1.0",
)

print("-" * 100)

# Trying to get number of instances in classes and theirs means to generate
# dataset
counts = [] # An empty list to store instance counts of classes in dataset
for i in range(n_classes):
while True:
try:
user_count = int(
input(f"Enter The number of instances for class_{i+1}: ")
)
if user_count > 0:
counts.append(user_count)
break
else:
print(
f"Your entered value is {user_count}, Number of "
"instances should be positive!"
)
continue
except ValueError:
print("Your entered value is not numerical!")
user_count = valid_input(
input_type=int,
condition=lambda x: x > 0,
input_msg=(f"Enter The number of instances for class_{i+1}: "),
err_msg="Number of instances should be positive!",
)
counts.append(user_count)
print("-" * 100)

# An empty list to store values of user-entered means of classes
user_means = []
for a in range(n_classes):
while True:
try:
user_mean = float(
input(f"Enter the value of mean for class_{a+1}: ")
)
if isinstance(user_mean, float):
user_means.append(user_mean)
break
print(f"You entered an invalid value: {user_mean}")
except ValueError:
print("Your entered value is not numerical!")
user_mean = valid_input(
input_type=float,
input_msg=(f"Enter the value of mean for class_{a+1}: "),
err_msg="This is an invalid value.",
)
user_means.append(user_mean)
print("-" * 100)

print("Standard deviation: ", std_dev)
Expand Down