Skip to content
This repository was archived by the owner on Nov 23, 2022. It is now read-only.

Commit 85bceb2

Browse files
committed
Increase weight_decay and decrease model in MLP model not to overfit
* set default hparams without sigma if you don't use MCCR loss
1 parent fa1b1b6 commit 85bceb2

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

mise/ml/mlp_mul_ms.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ def ml_mlp_mul_ms(station_name="종로구"):
231231

232232
# num_layer == number of hidden layer
233233
hparams = Namespace(
234-
sigma=1.0,
235234
num_layers=1,
236235
layer_size=128,
237236
learning_rate=learning_rate,
@@ -321,7 +320,6 @@ def objective(trial):
321320
fig_slice.write_image(str(output_dir / "slice.svg"))
322321

323322
# set hparams with optmized value
324-
hparams.sigma = trial.params['sigma']
325323
hparams.num_layers = trial.params['num_layers']
326324
hparams.layer_size = trial.params['layer_size']
327325

@@ -439,12 +437,10 @@ def __init__(self, *args, **kwargs):
439437
# num_layer == number of hidden layer
440438
self.layer_sizes = [self.input_size, self.output_size]
441439
if self.trial:
442-
self.hparams.sigma = self.trial.suggest_float(
443-
"sigma", 0.5, 1.5, step=0.05)
444440
self.hparams.num_layers = self.trial.suggest_int(
445441
"num_layers", 2, 8)
446442
self.hparams.layer_size = self.trial.suggest_int(
447-
"layer_size", 8, 1024)
443+
"layer_size", 8, 512)
448444

449445
for l in range(self.hparams.num_layers):
450446
# insert another layer_size to end of list of layer_size
@@ -500,7 +496,7 @@ def forward(self, x, x1d):
500496
def configure_optimizers(self):
501497
return torch.optim.Adam(self.parameters(),
502498
lr=self.hparams.learning_rate,
503-
weight_decay=0.001)
499+
weight_decay=0.01)
504500

505501
def training_step(self, batch, batch_idx):
506502
x, x1d, _y, _y_raw, dates = batch

mise/ml/mlp_mul_transformer_mccr.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ def __init__(self, **kwargs):
497497
super().__init__()
498498

499499
self.hparams = kwargs.get('hparams', Namespace(
500+
sigma=1.0,
500501
nhead=16,
501502
head_dim=128,
502503
d_feedforward=256,

0 commit comments

Comments
 (0)