|
| 1 | +""" |
| 2 | +this is code for forecasting |
| 3 | +but i modified it and used it for safety checker of data |
| 4 | +for ex: you have a online shop and for some reason some data are |
| 5 | +missing (the amount of data that u expected are not supposed to be) |
| 6 | + then we can use it |
| 7 | +*ps : 1. ofc we can use normal statistic method but in this case |
| 8 | + the data is quite absurd and only a little^^ |
| 9 | + 2. ofc u can use this and modified it for forecasting purpose |
| 10 | + for the next 3 months sales or something, |
| 11 | + u can just adjust it for ur own purpose |
| 12 | +""" |
| 13 | + |
| 14 | +import numpy as np |
| 15 | +import pandas as pd |
| 16 | +from sklearn.preprocessing import Normalizer |
| 17 | +from sklearn.svm import SVR |
| 18 | +from statsmodels.tsa.statespace.sarimax import SARIMAX |
| 19 | + |
| 20 | + |
| 21 | +def linear_regression_prediction( |
| 22 | + train_dt: list, train_usr: list, train_mtch: list, test_dt: list, test_mtch: list |
| 23 | +) -> float: |
| 24 | + """ |
| 25 | + First method: linear regression |
| 26 | + input : training data (date, total_user, total_event) in list of float |
| 27 | + output : list of total user prediction in float |
| 28 | + >>> linear_regression_prediction([2,3,4,5], [5,3,4,6], [3,1,2,4], [2,1], [2,2]) |
| 29 | + 5.000000000000003 |
| 30 | + """ |
| 31 | + x = [[1, item, train_mtch[i]] for i, item in enumerate(train_dt)] |
| 32 | + x = np.array(x) |
| 33 | + y = np.array(train_usr) |
| 34 | + beta = np.dot(np.dot(np.linalg.inv(np.dot(x.transpose(), x)), x.transpose()), y) |
| 35 | + return abs(beta[0] + test_dt[0] * beta[1] + test_mtch[0] + beta[2]) |
| 36 | + |
| 37 | + |
| 38 | +def sarimax_predictor(train_user: list, train_match: list, test_match: list) -> float: |
| 39 | + """ |
| 40 | + second method: Sarimax |
| 41 | + sarimax is a statistic method which using previous input |
| 42 | + and learn its pattern to predict future data |
| 43 | + input : training data (total_user, with exog data = total_event) in list of float |
| 44 | + output : list of total user prediction in float |
| 45 | + >>> sarimax_predictor([4,2,6,8], [3,1,2,4], [2]) |
| 46 | + 6.6666671111109626 |
| 47 | + """ |
| 48 | + order = (1, 2, 1) |
| 49 | + seasonal_order = (1, 1, 0, 7) |
| 50 | + model = SARIMAX( |
| 51 | + train_user, exog=train_match, order=order, seasonal_order=seasonal_order |
| 52 | + ) |
| 53 | + model_fit = model.fit(disp=False, maxiter=600, method="nm") |
| 54 | + result = model_fit.predict(1, len(test_match), exog=[test_match]) |
| 55 | + return result[0] |
| 56 | + |
| 57 | + |
| 58 | +def support_vector_regressor(x_train: list, x_test: list, train_user: list) -> float: |
| 59 | + """ |
| 60 | + Third method: Support vector regressor |
| 61 | + svr is quite the same with svm(support vector machine) |
| 62 | + it uses the same principles as the SVM for classification, |
| 63 | + with only a few minor differences and the only different is that |
| 64 | + it suits better for regression purpose |
| 65 | + input : training data (date, total_user, total_event) in list of float |
| 66 | + where x = list of set (date and total event) |
| 67 | + output : list of total user prediction in float |
| 68 | + >>> support_vector_regressor([[5,2],[1,5],[6,2]], [[3,2]], [2,1,4]) |
| 69 | + 1.634932078116079 |
| 70 | + """ |
| 71 | + regressor = SVR(kernel="rbf", C=1, gamma=0.1, epsilon=0.1) |
| 72 | + regressor.fit(x_train, train_user) |
| 73 | + y_pred = regressor.predict(x_test) |
| 74 | + return y_pred[0] |
| 75 | + |
| 76 | + |
| 77 | +def interquartile_range_checker(train_user: list) -> float: |
| 78 | + """ |
| 79 | + Optional method: interquatile range |
| 80 | + input : list of total user in float |
| 81 | + output : low limit of input in float |
| 82 | + this method can be used to check whether some data is outlier or not |
| 83 | + >>> interquartile_range_checker([1,2,3,4,5,6,7,8,9,10]) |
| 84 | + 2.8 |
| 85 | + """ |
| 86 | + train_user.sort() |
| 87 | + q1 = np.percentile(train_user, 25) |
| 88 | + q3 = np.percentile(train_user, 75) |
| 89 | + iqr = q3 - q1 |
| 90 | + low_lim = q1 - (iqr * 0.1) |
| 91 | + return low_lim |
| 92 | + |
| 93 | + |
| 94 | +def data_safety_checker(list_vote: list, actual_result: float) -> None: |
| 95 | + """ |
| 96 | + Used to review all the votes (list result prediction) |
| 97 | + and compare it to the actual result. |
| 98 | + input : list of predictions |
| 99 | + output : print whether it's safe or not |
| 100 | + >>> data_safety_checker([2,3,4],5.0) |
| 101 | + Today's data is not safe. |
| 102 | + """ |
| 103 | + safe = 0 |
| 104 | + not_safe = 0 |
| 105 | + for i in list_vote: |
| 106 | + if i > actual_result: |
| 107 | + safe = not_safe + 1 |
| 108 | + else: |
| 109 | + if abs(abs(i) - abs(actual_result)) <= 0.1: |
| 110 | + safe = safe + 1 |
| 111 | + else: |
| 112 | + not_safe = not_safe + 1 |
| 113 | + print(f"Today's data is {'not ' if safe <= not_safe else ''}safe.") |
| 114 | + |
| 115 | + |
| 116 | +# data_input_df = pd.read_csv("ex_data.csv", header=None) |
| 117 | +data_input = [[18231, 0.0, 1], [22621, 1.0, 2], [15675, 0.0, 3], [23583, 1.0, 4]] |
| 118 | +data_input_df = pd.DataFrame(data_input, columns=["total_user", "total_even", "days"]) |
| 119 | + |
| 120 | +""" |
| 121 | +data column = total user in a day, how much online event held in one day, |
| 122 | +what day is that(sunday-saturday) |
| 123 | +""" |
| 124 | + |
| 125 | +# start normalization |
| 126 | +normalize_df = Normalizer().fit_transform(data_input_df.values) |
| 127 | +# split data |
| 128 | +total_date = normalize_df[:, 2].tolist() |
| 129 | +total_user = normalize_df[:, 0].tolist() |
| 130 | +total_match = normalize_df[:, 1].tolist() |
| 131 | + |
| 132 | +# for svr (input variable = total date and total match) |
| 133 | +x = normalize_df[:, [1, 2]].tolist() |
| 134 | +x_train = x[: len(x) - 1] |
| 135 | +x_test = x[len(x) - 1 :] |
| 136 | + |
| 137 | +# for linear reression & sarimax |
| 138 | +trn_date = total_date[: len(total_date) - 1] |
| 139 | +trn_user = total_user[: len(total_user) - 1] |
| 140 | +trn_match = total_match[: len(total_match) - 1] |
| 141 | + |
| 142 | +tst_date = total_date[len(total_date) - 1 :] |
| 143 | +tst_user = total_user[len(total_user) - 1 :] |
| 144 | +tst_match = total_match[len(total_match) - 1 :] |
| 145 | + |
| 146 | + |
| 147 | +# voting system with forecasting |
| 148 | +res_vote = [] |
| 149 | +res_vote.append( |
| 150 | + linear_regression_prediction(trn_date, trn_user, trn_match, tst_date, tst_match) |
| 151 | +) |
| 152 | +res_vote.append(sarimax_predictor(trn_user, trn_match, tst_match)) |
| 153 | +res_vote.append(support_vector_regressor(x_train, x_test, trn_user)) |
| 154 | + |
| 155 | +# check the safety of todays'data^^ |
| 156 | +data_safety_checker(res_vote, tst_user) |
0 commit comments