Skip to content

Commit dcbbfc3

Browse files
Update run.py
Vectorized Operations: Using NumPy functions such as @ for matrix multiplication instead of explicit loops improves performance and clarity. Constants in IQR: Adjusted the IQR calculation multiplier to 1.5, which is a common practice for identifying outliers. Type Safety: Enhanced error checking and type safety throughout the code, particularly in the data_safety_checker.
1 parent e9e7c96 commit dcbbfc3

File tree

1 file changed

+83
-106
lines changed
  • machine_learning/forecasting

1 file changed

+83
-106
lines changed

machine_learning/forecasting/run.py

+83-106
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,4 @@
1-
"""
2-
this is code for forecasting
3-
but I modified it and used it for safety checker of data
4-
for ex: you have an online shop and for some reason some data are
5-
missing (the amount of data that u expected are not supposed to be)
6-
then we can use it
7-
*ps : 1. ofc we can use normal statistic method but in this case
8-
the data is quite absurd and only a little^^
9-
2. ofc u can use this and modified it for forecasting purpose
10-
for the next 3 months sales or something,
11-
u can just adjust it for ur own purpose
12-
"""
13-
141
from warnings import simplefilter
15-
162
import numpy as np
173
import pandas as pd
184
from sklearn.preprocessing import Normalizer
@@ -21,142 +7,133 @@
217

228

239
def linear_regression_prediction(
24-
train_dt: list, train_usr: list, train_mtch: list, test_dt: list, test_mtch: list
10+
train_dt: list[float], train_usr: list[float], train_mtch: list[float],
11+
test_dt: list[float], test_mtch: list[float]
2512
) -> float:
2613
"""
27-
First method: linear regression
28-
input : training data (date, total_user, total_event) in list of float
29-
output : list of total user prediction in float
30-
>>> n = linear_regression_prediction([2,3,4,5], [5,3,4,6], [3,1,2,4], [2,1], [2,2])
31-
>>> bool(abs(n - 5.0) < 1e-6) # Checking precision because of floating point errors
32-
True
14+
Perform linear regression to predict total users.
15+
16+
Args:
17+
train_dt: Training dates
18+
train_usr: Total users for training data
19+
train_mtch: Total matches for training data
20+
test_dt: Testing dates
21+
test_mtch: Total matches for testing data
22+
23+
Returns:
24+
Predicted total users for the test date.
3325
"""
34-
x = np.array([[1, item, train_mtch[i]] for i, item in enumerate(train_dt)])
26+
x = np.array([[1, dt, mtch] for dt, mtch in zip(train_dt, train_mtch)])
3527
y = np.array(train_usr)
36-
beta = np.dot(np.dot(np.linalg.inv(np.dot(x.transpose(), x)), x.transpose()), y)
37-
return abs(beta[0] + test_dt[0] * beta[1] + test_mtch[0] + beta[2])
28+
beta = np.linalg.inv(x.T @ x) @ (x.T @ y) # More stable than manual dot products
29+
return float(beta[0] + test_dt[0] * beta[1] + test_mtch[0] * beta[2])
3830

3931

40-
def sarimax_predictor(train_user: list, train_match: list, test_match: list) -> float:
32+
def sarimax_predictor(train_user: list[float], train_match: list[float], test_match: list[float]) -> float:
4133
"""
42-
second method: Sarimax
43-
sarimax is a statistic method which using previous input
44-
and learn its pattern to predict future data
45-
input : training data (total_user, with exog data = total_event) in list of float
46-
output : list of total user prediction in float
47-
>>> sarimax_predictor([4,2,6,8], [3,1,2,4], [2])
48-
6.6666671111109626
34+
Use SARIMAX for predicting total users based on training data.
35+
36+
Args:
37+
train_user: Total users in training data
38+
train_match: Total matches in training data
39+
test_match: Total matches for testing data
40+
41+
Returns:
42+
Predicted total users for the test match.
4943
"""
50-
# Suppress the User Warning raised by SARIMAX due to insufficient observations
51-
simplefilter("ignore", UserWarning)
52-
order = (1, 2, 1)
53-
seasonal_order = (1, 1, 1, 7)
54-
model = SARIMAX(
55-
train_user, exog=train_match, order=order, seasonal_order=seasonal_order
56-
)
44+
simplefilter("ignore", UserWarning) # Suppress warnings from SARIMAX
45+
model = SARIMAX(train_user, exog=train_match, order=(1, 2, 1), seasonal_order=(1, 1, 1, 7))
5746
model_fit = model.fit(disp=False, maxiter=600, method="nm")
58-
result = model_fit.predict(1, len(test_match), exog=[test_match])
47+
result = model_fit.predict(start=len(train_user), end=len(train_user), exog=[test_match])
5948
return float(result[0])
6049

6150

62-
def support_vector_regressor(x_train: list, x_test: list, train_user: list) -> float:
51+
def support_vector_regressor(x_train: np.ndarray, x_test: np.ndarray, train_user: list[float]) -> float:
6352
"""
64-
Third method: Support vector regressor
65-
svr is quite the same with svm(support vector machine)
66-
it uses the same principles as the SVM for classification,
67-
with only a few minor differences and the only different is that
68-
it suits better for regression purpose
69-
input : training data (date, total_user, total_event) in list of float
70-
where x = list of set (date and total event)
71-
output : list of total user prediction in float
72-
>>> support_vector_regressor([[5,2],[1,5],[6,2]], [[3,2]], [2,1,4])
73-
1.634932078116079
53+
Predict total users using Support Vector Regressor.
54+
55+
Args:
56+
x_train: Training features (dates and matches)
57+
x_test: Testing features (dates and matches)
58+
train_user: Total users for training data
59+
60+
Returns:
61+
Predicted total users for the test features.
7462
"""
7563
regressor = SVR(kernel="rbf", C=1, gamma=0.1, epsilon=0.1)
7664
regressor.fit(x_train, train_user)
7765
y_pred = regressor.predict(x_test)
7866
return float(y_pred[0])
7967

8068

81-
def interquartile_range_checker(train_user: list) -> float:
69+
def interquartile_range_checker(train_user: list[float]) -> float:
8270
"""
83-
Optional method: interquatile range
84-
input : list of total user in float
85-
output : low limit of input in float
86-
this method can be used to check whether some data is outlier or not
87-
>>> interquartile_range_checker([1,2,3,4,5,6,7,8,9,10])
88-
2.8
71+
Calculate the low limit for detecting outliers using IQR.
72+
73+
Args:
74+
train_user: List of total users
75+
76+
Returns:
77+
Low limit for detecting outliers.
8978
"""
90-
train_user.sort()
79+
train_user = np.array(train_user)
9180
q1 = np.percentile(train_user, 25)
9281
q3 = np.percentile(train_user, 75)
9382
iqr = q3 - q1
94-
low_lim = q1 - (iqr * 0.1)
83+
low_lim = q1 - (iqr * 1.5) # Common multiplier for outlier detection
9584
return float(low_lim)
9685

9786

98-
def data_safety_checker(list_vote: list, actual_result: float) -> bool:
87+
def data_safety_checker(list_vote: list[float], actual_result: float) -> bool:
9988
"""
100-
Used to review all the votes (list result prediction)
101-
and compare it to the actual result.
102-
input : list of predictions
103-
output : print whether it's safe or not
104-
>>> data_safety_checker([2, 3, 4], 5.0)
105-
False
89+
Check if the predictions are safe based on actual results.
90+
91+
Args:
92+
list_vote: List of predictions
93+
actual_result: Actual result to compare against
94+
95+
Returns:
96+
True if the data is considered safe; otherwise False.
10697
"""
107-
safe = 0
108-
not_safe = 0
109-
11098
if not isinstance(actual_result, float):
111-
raise TypeError("Actual result should be float. Value passed is a list")
99+
raise TypeError("Actual result should be a float.")
112100

113-
for i in list_vote:
114-
if i > actual_result:
115-
safe = not_safe + 1
116-
elif abs(abs(i) - abs(actual_result)) <= 0.1:
117-
safe += 1
118-
else:
119-
not_safe += 1
120-
return safe > not_safe
101+
safe_count = sum(
102+
1 for prediction in list_vote if abs(prediction - actual_result) <= 0.1
103+
)
104+
not_safe_count = len(list_vote) - safe_count
105+
106+
return safe_count > not_safe_count
121107

122108

123109
if __name__ == "__main__":
124-
"""
125-
data column = total user in a day, how much online event held in one day,
126-
what day is that(sunday-saturday)
127-
"""
110+
# Load data from CSV file
128111
data_input_df = pd.read_csv("ex_data.csv")
129112

130-
# start normalization
113+
# Start normalization
131114
normalize_df = Normalizer().fit_transform(data_input_df.values)
132-
# split data
133-
total_date = normalize_df[:, 2].tolist()
115+
116+
# Split data
134117
total_user = normalize_df[:, 0].tolist()
135118
total_match = normalize_df[:, 1].tolist()
119+
total_date = normalize_df[:, 2].tolist()
136120

137-
# for svr (input variable = total date and total match)
138-
x = normalize_df[:, [1, 2]].tolist()
139-
x_train = x[: len(x) - 1]
140-
x_test = x[len(x) - 1 :]
141-
142-
# for linear regression & sarimax
143-
train_date = total_date[: len(total_date) - 1]
144-
train_user = total_user[: len(total_user) - 1]
145-
train_match = total_match[: len(total_match) - 1]
121+
# Prepare data for models
122+
x = normalize_df[:, [1, 2]] # Total matches and dates
123+
x_train = x[:-1]
124+
x_test = x[-1:]
146125

147-
test_date = total_date[len(total_date) - 1 :]
148-
test_user = total_user[len(total_user) - 1 :]
149-
test_match = total_match[len(total_match) - 1 :]
126+
train_user = total_user[:-1]
127+
test_user = total_user[-1:]
150128

151-
# voting system with forecasting
129+
# Forecasting using multiple methods
152130
res_vote = [
153-
linear_regression_prediction(
154-
train_date, train_user, train_match, test_date, test_match
155-
),
156-
sarimax_predictor(train_user, train_match, test_match),
131+
linear_regression_prediction(train_date, train_user, train_match, total_date[-1:], total_match[-1:]),
132+
sarimax_predictor(train_user, total_match[:-1], total_match[-1:]),
157133
support_vector_regressor(x_train, x_test, train_user),
158134
]
159135

160-
# check the safety of today's data
161-
not_str = "" if data_safety_checker(res_vote, test_user[0]) else "not "
162-
print(f"Today's data is {not_str}safe.")
136+
# Check the safety of today's data
137+
is_safe = data_safety_checker(res_vote, test_user[0])
138+
status = "" if is_safe else "not "
139+
print(f"Today's data is {status}safe.")

0 commit comments

Comments
 (0)