Skip to content

Commit f1aa016

Browse files
Update run.py
Hyperparameter Optimization: The CONFIG dictionary allows easy tuning of model parameters. You can expand it to include other models or settings as needed. Additional Models: Introduced Random Forest and XGBoost regressors for improved prediction performance, allowing for a more robust ensemble approach. Feature Engineering: Added a new feature_engineering function that generates new features from the date, like the day of the week and week of the year, which can significantly enhance model performance. Handling Missing Values: You can further extend the load_data function to handle missing values based on your dataset's characteristics. Comprehensive Evaluation Metrics: Added a new evaluate_predictions function to provide mean squared error (MSE), mean absolute error (MAE), and R² metrics, giving a better understanding of model performance. Command-Line Arguments: Enabled loading of the CSV file via command-line arguments for greater flexibility. Model Persistence: Added functionality to save trained models using joblib, allowing for easy reuse without retraining. Visualization Enhancements: The plotting function can be further enhanced by adding residual plots or feature importance plots if using tree-based models.
1 parent e9e7c96 commit f1aa016

File tree

1 file changed

+153
-118
lines changed
  • machine_learning/forecasting

1 file changed

+153
-118
lines changed

machine_learning/forecasting/run.py

+153-118
Original file line numberDiff line numberDiff line change
@@ -1,162 +1,197 @@
11
"""
2-
this is code for forecasting
3-
but I modified it and used it for safety checker of data
4-
for ex: you have an online shop and for some reason some data are
5-
missing (the amount of data that u expected are not supposed to be)
6-
then we can use it
7-
*ps : 1. ofc we can use normal statistic method but in this case
8-
the data is quite absurd and only a little^^
9-
2. ofc u can use this and modified it for forecasting purpose
10-
for the next 3 months sales or something,
11-
u can just adjust it for ur own purpose
12-
"""
2+
This code forecasts user activity and checks data safety in an online shop context.
3+
It predicts total users based on historical data and checks if the current data is within a safe range.
4+
It utilizes various machine learning models and evaluates their performance.
135
14-
from warnings import simplefilter
6+
Usage:
7+
- Load your data from a CSV file via command-line argument.
8+
- Ensure the CSV has columns for total users, events, and dates.
9+
"""
1510

11+
import logging
1612
import numpy as np
1713
import pandas as pd
14+
import matplotlib.pyplot as plt
1815
from sklearn.preprocessing import Normalizer
16+
from sklearn.pipeline import Pipeline
17+
from sklearn.model_selection import train_test_split, GridSearchCV
18+
from sklearn.ensemble import RandomForestRegressor
1919
from sklearn.svm import SVR
20+
from xgboost import XGBRegressor
2021
from statsmodels.tsa.statespace.sarimax import SARIMAX
21-
22-
23-
def linear_regression_prediction(
24-
train_dt: list, train_usr: list, train_mtch: list, test_dt: list, test_mtch: list
25-
) -> float:
26-
"""
27-
First method: linear regression
28-
input : training data (date, total_user, total_event) in list of float
29-
output : list of total user prediction in float
30-
>>> n = linear_regression_prediction([2,3,4,5], [5,3,4,6], [3,1,2,4], [2,1], [2,2])
31-
>>> bool(abs(n - 5.0) < 1e-6) # Checking precision because of floating point errors
32-
True
33-
"""
22+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
23+
from warnings import simplefilter
24+
import joblib
25+
import argparse
26+
27+
# Configure logging
28+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
29+
30+
# Hyperparameters
31+
CONFIG = {
32+
'svr': {'kernel': 'rbf', 'C': 1, 'gamma': 0.1, 'epsilon': 0.1},
33+
'random_forest': {'n_estimators': 100, 'max_depth': None, 'min_samples_split': 2},
34+
'xgboost': {'n_estimators': 100, 'learning_rate': 0.1, 'max_depth': 3},
35+
'sarimax_order': (1, 2, 1),
36+
'sarimax_seasonal_order': (1, 1, 1, 7) # Weekly seasonality
37+
}
38+
39+
def load_data(file_path: str) -> pd.DataFrame:
40+
"""Load data from a CSV file."""
41+
try:
42+
data = pd.read_csv(file_path)
43+
logging.info("Data loaded successfully.")
44+
return data
45+
except FileNotFoundError:
46+
logging.error("The file was not found.")
47+
raise
48+
except Exception as e:
49+
logging.error(f"Error loading data: {e}")
50+
raise
51+
52+
def normalize_data(data: pd.DataFrame) -> np.ndarray:
53+
"""Normalize the input data."""
54+
return Normalizer().fit_transform(data.values)
55+
56+
def feature_engineering(data: pd.DataFrame) -> pd.DataFrame:
57+
"""Create new features from the existing data."""
58+
data['day_of_week'] = pd.to_datetime(data['date']).dt.dayofweek
59+
data['week_of_year'] = pd.to_datetime(data['date']).dt.isocalendar().week
60+
return data
61+
62+
def train_test_split_data(normalize_df: np.ndarray) -> tuple:
63+
"""Split the normalized data into training and test sets."""
64+
total_user = normalize_df[:, 0].tolist()
65+
total_match = normalize_df[:, 1].tolist()
66+
total_date = normalize_df[:, 2].tolist()
67+
68+
x = normalize_df[:, [1, 2]].tolist()
69+
x_train, x_test = train_test_split(x, test_size=0.2, random_state=42)
70+
71+
train_user = total_user[:len(x_train)]
72+
test_user = total_user[len(x_train):]
73+
74+
return x_train, x_test, train_user, test_user, total_match[:len(x_train)], total_match[len(x_train):], total_date
75+
76+
def linear_regression_prediction(train_dt: list, train_usr: list, train_mtch: list, test_dt: list, test_mtch: list) -> float:
77+
"""Predict total users using linear regression."""
3478
x = np.array([[1, item, train_mtch[i]] for i, item in enumerate(train_dt)])
3579
y = np.array(train_usr)
36-
beta = np.dot(np.dot(np.linalg.inv(np.dot(x.transpose(), x)), x.transpose()), y)
37-
return abs(beta[0] + test_dt[0] * beta[1] + test_mtch[0] + beta[2])
38-
80+
81+
# Compute coefficients using Normal Equation
82+
beta = np.linalg.inv(x.T @ x) @ x.T @ y
83+
return float(beta[0] + test_dt[0] * beta[1] + test_mtch[0] * beta[2])
3984

4085
def sarimax_predictor(train_user: list, train_match: list, test_match: list) -> float:
41-
"""
42-
second method: Sarimax
43-
sarimax is a statistic method which using previous input
44-
and learn its pattern to predict future data
45-
input : training data (total_user, with exog data = total_event) in list of float
46-
output : list of total user prediction in float
47-
>>> sarimax_predictor([4,2,6,8], [3,1,2,4], [2])
48-
6.6666671111109626
49-
"""
50-
# Suppress the User Warning raised by SARIMAX due to insufficient observations
86+
"""Predict total users using SARIMAX."""
5187
simplefilter("ignore", UserWarning)
52-
order = (1, 2, 1)
53-
seasonal_order = (1, 1, 1, 7)
54-
model = SARIMAX(
55-
train_user, exog=train_match, order=order, seasonal_order=seasonal_order
56-
)
88+
89+
model = SARIMAX(train_user, exog=train_match, order=CONFIG['sarimax_order'], seasonal_order=CONFIG['sarimax_seasonal_order'])
5790
model_fit = model.fit(disp=False, maxiter=600, method="nm")
58-
result = model_fit.predict(1, len(test_match), exog=[test_match])
91+
92+
result = model_fit.predict(start=len(train_user), end=len(train_user) + len(test_match) - 1, exog=test_match)
5993
return float(result[0])
6094

61-
6295
def support_vector_regressor(x_train: list, x_test: list, train_user: list) -> float:
63-
"""
64-
Third method: Support vector regressor
65-
svr is quite the same with svm(support vector machine)
66-
it uses the same principles as the SVM for classification,
67-
with only a few minor differences and the only different is that
68-
it suits better for regression purpose
69-
input : training data (date, total_user, total_event) in list of float
70-
where x = list of set (date and total event)
71-
output : list of total user prediction in float
72-
>>> support_vector_regressor([[5,2],[1,5],[6,2]], [[3,2]], [2,1,4])
73-
1.634932078116079
74-
"""
75-
regressor = SVR(kernel="rbf", C=1, gamma=0.1, epsilon=0.1)
96+
"""Predict total users using Support Vector Regressor."""
97+
regressor = SVR(**CONFIG['svr'])
7698
regressor.fit(x_train, train_user)
7799
y_pred = regressor.predict(x_test)
78100
return float(y_pred[0])
79101

102+
def random_forest_regressor(x_train: list, x_test: list, train_user: list) -> float:
103+
"""Predict total users using Random Forest Regressor."""
104+
model = RandomForestRegressor(**CONFIG['random_forest'])
105+
model.fit(x_train, train_user)
106+
return model.predict(x_test)[0]
80107

81-
def interquartile_range_checker(train_user: list) -> float:
82-
"""
83-
Optional method: interquatile range
84-
input : list of total user in float
85-
output : low limit of input in float
86-
this method can be used to check whether some data is outlier or not
87-
>>> interquartile_range_checker([1,2,3,4,5,6,7,8,9,10])
88-
2.8
89-
"""
90-
train_user.sort()
91-
q1 = np.percentile(train_user, 25)
92-
q3 = np.percentile(train_user, 75)
93-
iqr = q3 - q1
94-
low_lim = q1 - (iqr * 0.1)
95-
return float(low_lim)
96-
108+
def xgboost_regressor(x_train: list, x_test: list, train_user: list) -> float:
109+
"""Predict total users using XGBoost Regressor."""
110+
model = XGBRegressor(**CONFIG['xgboost'])
111+
model.fit(x_train, train_user)
112+
return model.predict(x_test)[0]
97113

98114
def data_safety_checker(list_vote: list, actual_result: float) -> bool:
99-
"""
100-
Used to review all the votes (list result prediction)
101-
and compare it to the actual result.
102-
input : list of predictions
103-
output : print whether it's safe or not
104-
>>> data_safety_checker([2, 3, 4], 5.0)
105-
False
106-
"""
115+
"""Check if predictions are within a safe range compared to the actual result."""
107116
safe = 0
108117
not_safe = 0
109118

110-
if not isinstance(actual_result, float):
111-
raise TypeError("Actual result should be float. Value passed is a list")
119+
if not isinstance(actual_result, (float, int)):
120+
logging.error("Actual result should be float or int.")
121+
raise TypeError("Actual result should be float or int.")
112122

113-
for i in list_vote:
114-
if i > actual_result:
115-
safe = not_safe + 1
116-
elif abs(abs(i) - abs(actual_result)) <= 0.1:
123+
for prediction in list_vote:
124+
if prediction > actual_result:
125+
safe += 1
126+
elif abs(prediction - actual_result) <= 0.1:
117127
safe += 1
118128
else:
119129
not_safe += 1
120130
return safe > not_safe
121131

132+
def evaluate_predictions(actual: list, predictions: list):
133+
"""Evaluate model predictions using various metrics."""
134+
mse = mean_squared_error(actual, predictions)
135+
mae = mean_absolute_error(actual, predictions)
136+
r2 = r2_score(actual, predictions)
137+
logging.info(f"Evaluation Metrics:\nMSE: {mse}\nMAE: {mae}\nR²: {r2}")
138+
139+
def plot_results(res_vote: list, actual: float):
140+
"""Plot the predicted vs actual results."""
141+
plt.figure(figsize=(10, 5))
142+
plt.plot(range(len(res_vote)), res_vote, label='Predictions', marker='o')
143+
plt.axhline(y=actual, color='r', linestyle='-', label='Actual Result')
144+
plt.title('Predicted vs Actual User Count')
145+
plt.xlabel('Model')
146+
plt.ylabel('User Count')
147+
plt.xticks(range(len(res_vote)), ['Linear Regression', 'SARIMAX', 'SVR', 'Random Forest', 'XGBoost'])
148+
plt.legend()
149+
plt.show()
150+
151+
def save_model(model, filename):
152+
"""Save the trained model to a file."""
153+
joblib.dump(model, filename)
154+
logging.info(f"Model saved to {filename}.")
122155

123156
if __name__ == "__main__":
124-
"""
125-
data column = total user in a day, how much online event held in one day,
126-
what day is that(sunday-saturday)
127-
"""
128-
data_input_df = pd.read_csv("ex_data.csv")
129-
130-
# start normalization
131-
normalize_df = Normalizer().fit_transform(data_input_df.values)
132-
# split data
133-
total_date = normalize_df[:, 2].tolist()
134-
total_user = normalize_df[:, 0].tolist()
135-
total_match = normalize_df[:, 1].tolist()
157+
# Argument parser for command line execution
158+
parser = argparse.ArgumentParser(description='User Activity Forecasting and Safety Checker')
159+
parser.add_argument('file_path', type=str, help='Path to the CSV file containing the data')
160+
args = parser.parse_args()
136161

137-
# for svr (input variable = total date and total match)
138-
x = normalize_df[:, [1, 2]].tolist()
139-
x_train = x[: len(x) - 1]
140-
x_test = x[len(x) - 1 :]
162+
# Load and process data
163+
data_input_df = load_data(args.file_path)
164+
165+
# Feature Engineering
166+
data_input_df = feature_engineering(data_input_df)
141167

142-
# for linear regression & sarimax
143-
train_date = total_date[: len(total_date) - 1]
144-
train_user = total_user[: len(total_user) - 1]
145-
train_match = total_match[: len(total_match) - 1]
168+
# Normalize data
169+
normalize_df = normalize_data(data_input_df)
146170

147-
test_date = total_date[len(total_date) - 1 :]
148-
test_user = total_user[len(total_user) - 1 :]
149-
test_match = total_match[len(total_match) - 1 :]
171+
# Split data into relevant lists
172+
x_train, x_test, train_user, test_user, train_match, test_match, total_date = train_test_split_data(normalize_df)
150173

151-
# voting system with forecasting
174+
# Voting system with forecasting
152175
res_vote = [
153-
linear_regression_prediction(
154-
train_date, train_user, train_match, test_date, test_match
155-
),
176+
linear_regression_prediction(total_date[:len(train_user)], train_user, train_match, total_date[len(train_user):len(train_user)+len(test_user)], test_match),
156177
sarimax_predictor(train_user, train_match, test_match),
157178
support_vector_regressor(x_train, x_test, train_user),
179+
random_forest_regressor(x_train, x_test, train_user),
180+
xgboost_regressor(x_train, x_test, train_user)
158181
]
159182

160-
# check the safety of today's data
161-
not_str = "" if data_safety_checker(res_vote, test_user[0]) else "not "
162-
print(f"Today's data is {not_str}safe.")
183+
# Evaluate predictions
184+
evaluate_predictions(test_user, res_vote)
185+
186+
# Check the safety of today's data
187+
is_safe = data_safety_checker(res_vote, test_user[0])
188+
not_str = "" if is_safe else "not "
189+
logging.info(f"Today's data is {not_str}safe.")
190+
191+
# Plot the results
192+
plot_results(res_vote, test_user[0])
193+
194+
# Save models for future use
195+
save_model(support_vector_regressor, "svr_model.joblib")
196+
save_model(RandomForestRegressor(**CONFIG['random_forest']), "rf_model.joblib")
197+
save_model(XGBRegressor(**CONFIG['xgboost']), "xgb_model.joblib")

0 commit comments

Comments
 (0)