Skip to content

Commit 7d710c8

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 8a38a5f commit 7d710c8

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

machine_learning/gradient_descent_momentum.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Implementation of gradient descent algorithm using momentum for minimizing cost of a linear hypothesis
33
function.
44
"""
5+
56
import numpy as np
67

78
# List of input, output pairs
@@ -21,16 +22,19 @@
2122
# Initialize velocity (for momentum)
2223
velocity = [0] * len(parameter_vector)
2324

25+
2426
def _error(example_no, data_set="train"):
2527
"""
2628
Calculate the error (difference between predicted and actual output) for a given example.
2729
Args:
2830
example_no (int): Index of the example in the dataset.
29-
data_set (str): The dataset to use, either "train" or "test".
31+
data_set (str): The dataset to use, either "train" or "test".
3032
Returns:
3133
float: The difference between the predicted output and the actual output.
3234
"""
33-
return calculate_hypothesis_value(example_no, data_set) - output(example_no, data_set)
35+
return calculate_hypothesis_value(example_no, data_set) - output(
36+
example_no, data_set
37+
)
3438

3539

3640
def _hypothesis_value(data_input_tuple):
@@ -125,8 +129,13 @@ def run_gradient_descent_with_momentum():
125129
cost_derivative = get_cost_derivative(i - 1)
126130
velocity[i] = MOMENTUM * velocity[i] + cost_derivative
127131
temp_parameter_vector[i] = parameter_vector[i] - LEARNING_RATE * velocity[i]
128-
129-
if np.allclose(parameter_vector, temp_parameter_vector, atol=absolute_error_limit, rtol=relative_error_limit):
132+
133+
if np.allclose(
134+
parameter_vector,
135+
temp_parameter_vector,
136+
atol=absolute_error_limit,
137+
rtol=relative_error_limit,
138+
):
130139
break
131140
parameter_vector = temp_parameter_vector
132141
print(f"Number of iterations: {iteration}")
@@ -140,7 +149,10 @@ def test_gradient_descent():
140149
print(f"Actual output value: {output(i, 'test')}")
141150
print(f"Hypothesis output: {calculate_hypothesis_value(i, 'test')}")
142151

152+
143153
if __name__ == "__main__":
144154
run_gradient_descent_with_momentum()
145-
print("\nTesting gradient descent with momentum for a linear hypothesis function.\n")
155+
print(
156+
"\nTesting gradient descent with momentum for a linear hypothesis function.\n"
157+
)
146158
test_gradient_descent()

0 commit comments

Comments
 (0)