Skip to content

Commit b95aec9

Browse files
github-actionsgithub-actions
github-actions
authored and
github-actions
committed
fixup! Format Python code with psf/black push
1 parent 16a1831 commit b95aec9

File tree

1 file changed

+61
-21
lines changed

1 file changed

+61
-21
lines changed

Diff for: machine_learning/linear_discriminant_analysis.py

+61-21
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
# importing modules
4343
from random import gauss
4444
from math import log
45-
from os import system, name # to use < clear > or < cls > commands in terminal or cmd
45+
from os import system, name # to use < clear > or < cls > commands in terminal or cmd
4646

4747

4848
# Making training dataset drawn from a gaussian distribution
@@ -128,23 +128,28 @@ def predict(x_items: list, means: list, variance: float, probabilities: list) ->
128128
:return: a list containing predicted Y values
129129
"""
130130

131-
results = [] # An empty list to store generated discriminant values of all items in dataset for each class
131+
results = (
132+
[]
133+
) # An empty list to store generated discriminant values of all items in dataset for each class
132134
# for loop iterates over number of elements in list
133135
for i in range(len(x_items)):
134136
# for loop iterates over number of inner items of each element
135137
for j in range(len(x_items[i])):
136-
temp = [] # to store all discriminant values of each item as a list
138+
temp = [] # to store all discriminant values of each item as a list
137139
# for loop iterates over number of classes we have in our dataset
138140
for k in range(len(x_items)):
139141
# appending values of discriminants for each class to 'temp' list
140-
temp.append(x_items[i][j] * (means[k] / variance) - (means[k] ** 2 / (2 * variance)) +
141-
log(probabilities[k]))
142+
temp.append(
143+
x_items[i][j] * (means[k] / variance)
144+
- (means[k] ** 2 / (2 * variance))
145+
+ log(probabilities[k])
146+
)
142147
# appending discriminant values of each item to 'results' list
143148
results.append(temp)
144149

145150
print("Generated Discriminants: \n", results)
146151

147-
predicted_index = [] # An empty list to store predicted indexes
152+
predicted_index = [] # An empty list to store predicted indexes
148153
# for loop iterates over elements in 'results'
149154
for l in results:
150155
# after calculating the discriminant value for each class , the class with the largest
@@ -160,7 +165,7 @@ def accuracy(actual_y: list, predicted_y: list) -> float:
160165
:param predicted_y: a list containing predicted Y values generated by 'predict' function
161166
:return: percentage of accuracy
162167
"""
163-
correct = 0 # initial value for number of correct predictions
168+
correct = 0 # initial value for number of correct predictions
164169
# for loop iterates over one element of each list at a time (zip mode)
165170
for i, j in zip(actual_y, predicted_y):
166171
# if actual Y value equals to predicted Y value
@@ -181,19 +186,27 @@ def main():
181186

182187
print(" Linear Discriminant Analysis ".center(100, "*"))
183188
print("*" * 100, "\n")
184-
print("First of all we should specify the number of classes that \n"
185-
"we want to generate as training dataset")
189+
print(
190+
"First of all we should specify the number of classes that \n"
191+
"we want to generate as training dataset"
192+
)
186193

187194
# Trying to get number of classes
188195
n_classes = 0
189196
while True:
190197
try:
191-
user_input = int(input("Enter the number of classes (Data Groupings): "))
198+
user_input = int(
199+
input("Enter the number of classes (Data Groupings): ")
200+
)
192201
if user_input > 0:
193202
n_classes = user_input
194203
break
195204
else:
196-
print("Your entered value is {} , Number of classes should be positive!".format(user_input))
205+
print(
206+
"Your entered value is {} , Number of classes should be positive!".format(
207+
user_input
208+
)
209+
)
197210
continue
198211
except ValueError:
199212
print("Your entered value is not numerical!")
@@ -204,13 +217,22 @@ def main():
204217
# Trying to get the value of standard deviation
205218
while True:
206219
try:
207-
user_sd = float(input("Enter the value of standard deviation"
208-
"(Default value is 1.0 for all classes): ") or "1.0")
220+
user_sd = float(
221+
input(
222+
"Enter the value of standard deviation"
223+
"(Default value is 1.0 for all classes): "
224+
)
225+
or "1.0"
226+
)
209227
if user_sd >= 0.0:
210228
std_dev = user_sd
211229
break
212230
else:
213-
print("Your entered value is {}, Standard deviation should not be negative!".format(user_sd))
231+
print(
232+
"Your entered value is {}, Standard deviation should not be negative!".format(
233+
user_sd
234+
)
235+
)
214236
continue
215237
except ValueError:
216238
print("Your entered value is not numerical!")
@@ -222,28 +244,44 @@ def main():
222244
for i in range(n_classes):
223245
while True:
224246
try:
225-
user_count = int(input("Enter The number of instances for class_{}: ".format(i + 1)))
247+
user_count = int(
248+
input(
249+
"Enter The number of instances for class_{}: ".format(i + 1)
250+
)
251+
)
226252
if user_count > 0:
227253
counts.append(user_count)
228254
break
229255
else:
230-
print("Your entered value is {}, Number of instances should be positive!".format(user_count))
256+
print(
257+
"Your entered value is {}, Number of instances should be positive!".format(
258+
user_count
259+
)
260+
)
231261
continue
232262
except ValueError:
233263
print("Your entered value is not numerical!")
234264

235265
print("-" * 100)
236266

237-
user_means = [] # An empty list to store values of user-entered means of classes
267+
user_means = (
268+
[]
269+
) # An empty list to store values of user-entered means of classes
238270
for a in range(n_classes):
239271
while True:
240272
try:
241-
user_mean = float(input("Enter the value of mean for class_{}: ".format(a + 1)))
273+
user_mean = float(
274+
input("Enter the value of mean for class_{}: ".format(a + 1))
275+
)
242276
if isinstance(user_mean, float):
243277
user_means.append(user_mean)
244278
break
245279
else:
246-
print("Your entered value is {}, And this is not valid!".format(user_mean))
280+
print(
281+
"Your entered value is {}, And this is not valid!".format(
282+
user_mean
283+
)
284+
)
247285

248286
except ValueError:
249287
print("Your entered value is not numerical!")
@@ -293,7 +331,9 @@ def main():
293331
print("-" * 100)
294332

295333
# Calculating the value of probabilities for each class
296-
probabilities = [] # An empty list to store values of probabilities for each class
334+
probabilities = (
335+
[]
336+
) # An empty list to store values of probabilities for each class
297337
# # for loop iterates over number of classes(data groupings)
298338
for l in range(n_classes):
299339
# appending return values of 'prob_calc' function to 'probabilities' list
@@ -334,5 +374,5 @@ def main():
334374
continue
335375

336376

337-
if __name__ == '__main__':
377+
if __name__ == "__main__":
338378
main()

0 commit comments

Comments
 (0)