Skip to content

Commit 2d81b90

Browse files
AmritK10stokhos
authored andcommitted
Infinite loop was fixed. (TheAlgorithms#1105)
* Infinite loop was fixed. Removed issue of unused variables. * Update logistic_regression.py * Update logistic_regression.py * correct spacing according to PEP8
1 parent e54013a commit 2d81b90

File tree

1 file changed

+7
-20
lines changed

1 file changed

+7
-20
lines changed

Diff for: machine_learning/logistic_regression.py

+7-20
Original file line numberDiff line numberDiff line change
@@ -40,34 +40,20 @@ def logistic_reg(
4040
alpha,
4141
X,
4242
y,
43-
num_steps,
4443
max_iterations=70000,
4544
):
46-
converged = False
47-
iterations = 0
4845
theta = np.zeros(X.shape[1])
4946

50-
while not converged:
47+
for iterations in range(max_iterations):
5148
z = np.dot(X, theta)
5249
h = sigmoid_function(z)
5350
gradient = np.dot(X.T, h - y) / y.size
54-
theta = theta - alpha * gradient
51+
theta = theta - alpha * gradient # updating the weights
5552
z = np.dot(X, theta)
5653
h = sigmoid_function(z)
5754
J = cost_function(h, y)
58-
iterations += 1 # update iterations
59-
weights = np.zeros(X.shape[1])
60-
for step in range(num_steps):
61-
scores = np.dot(X, weights)
62-
predictions = sigmoid_function(scores)
63-
if step % 10000 == 0:
64-
print(log_likelihood(X,y,weights)) # Print log-likelihood every so often
65-
return weights
66-
67-
if iterations == max_iterations:
68-
print('Maximum iterations exceeded!')
69-
print('Minimal cost function J=', J)
70-
converged = True
55+
if iterations % 100 == 0:
56+
print(f'loss: {J} \t') # printing the loss after every 100 iterations
7157
return theta
7258

7359
# In[68]:
@@ -78,8 +64,8 @@ def logistic_reg(
7864
y = (iris.target != 0) * 1
7965

8066
alpha = 0.1
81-
theta = logistic_reg(alpha,X,y,max_iterations=70000,num_steps=30000)
82-
print(theta)
67+
theta = logistic_reg(alpha,X,y,max_iterations=70000)
68+
print("theta: ",theta) # printing the theta i.e our weights vector
8369

8470

8571
def predict_prob(X):
@@ -105,3 +91,4 @@ def predict_prob(X):
10591
)
10692

10793
plt.legend()
94+
plt.show()

0 commit comments

Comments
 (0)