Skip to content

Commit fb390e3

Browse files
committed
Adding predict function
1 parent 86f00b2 commit fb390e3

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

machine_learning/linear discriminant analysis.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,38 @@ def var_calc(items: list, means: list, total_count: int) -> float:
127127
# one divided by (the number of all instances - number of classes) multiplied by sum of all squared differences
128128
variance = 1 / (total_count - n_classes) * sum(squared_diff)
129129
return variance
130+
131+
132+
# Making predictions
133+
def predict(x_items: list, means: list, variance: float, probabilities: list) -> list:
134+
""" This function predicts new indexes(groups for our data)
135+
:param x_items: a list containing all items(gaussian distribution of all classes)
136+
:param means: a list containing real mean values of each class
137+
:param variance: calculated value of variance by var_calc function
138+
:param probabilities: a list containing all probabilities of classes
139+
:return: a list containing predicted Y values
140+
"""
141+
142+
results = [] # An empty list to store generated discriminant values of all items in dataset for each class
143+
# for loop iterates over number of elements in list
144+
for i in range(len(x_items)):
145+
# for loop iterates over number of inner items of each element
146+
for j in range(len(x_items[i])):
147+
temp = [] # to store all discriminant values of each item as a list
148+
# for loop iterates over number of classes we have in our dataset
149+
for k in range(len(x_items)):
150+
# appending values of discriminants for each class to 'temp' list
151+
temp.append(x_items[i][j] * (means[k] / variance) - (means[k] ** 2 / (2 * variance)) +
152+
log(probabilities[k]))
153+
# appending discriminant values of each item to 'results' list
154+
results.append(temp)
155+
156+
print("Generated Discriminants: \n", results)
157+
158+
predicted_index = [] # An empty list to store predicted indexes
159+
# for loop iterates over elements in 'results'
160+
for l in results:
161+
# after calculating the discriminant value for each class , the class with the largest
162+
# discriminant value is taken as the prediction, than we try to get index of that.
163+
predicted_index.append(l.index(max(l)))
164+
return predicted_index

0 commit comments

Comments
 (0)