Skip to content
This repository was archived by the owner on Nov 23, 2022. It is now read-only.

Commit 7d7adde

Browse files
committed
Add Median absolute deviation (MAD) ass accuracy metric
* Why MAD? - Simple and Robust - If you want more sophisticated way, check https://arxiv.org/abs/1906.04280 or median-of-means like. * Remove LogCosh loss MAD Reference * [Wiki](https://en.wikipedia.org/wiki/Median_absolute_deviation)
1 parent 2dd4af7 commit 7d7adde

10 files changed

+227
-314
lines changed

mise/ml/mlp_mul_ms.py

Lines changed: 26 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from pytorch_lightning.callbacks import Callback
2929
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
3030
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
31+
32+
from scipy.stats import median_abs_deviation
3133
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
3234
import sklearn.metrics
3335

@@ -236,13 +238,6 @@ def ml_mlp_mul_ms(station_name="종로구"):
236238
batch_size=batch_size)
237239

238240
def objective(trial):
239-
# PyTorch Lightning will try to restore model parameters from previous trials if checkpoint
240-
# filenames match. Therefore, the filenames for each trial must be made unique.
241-
checkpoint_callback = pl.callbacks.ModelCheckpoint(
242-
os.path.join(model_dir, "trial_{}".format(trial.number)), monitor="val_loss",
243-
period=10
244-
)
245-
246241
model = BaseMLPModel(trial=trial,
247242
hparams=hparams,
248243
input_size=sample_size * len(train_features),
@@ -269,15 +264,15 @@ def objective(trial):
269264
logger=True,
270265
checkpoint_callback=False,
271266
callbacks=[PyTorchLightningPruningCallback(
272-
trial, monitor="valid/MSE")])
267+
trial, monitor="valid/MAD")])
273268

274269
trainer.fit(model)
275270

276271
# Don't Log
277272
# hyperparameters = model.hparams
278273
# trainer.logger.log_hyperparams(hyperparameters)
279274

280-
return trainer.callback_metrics.get("valid/MSE")
275+
return trainer.callback_metrics.get("valid/MAD")
281276

282277
if n_trials > 1:
283278
study = optuna.create_study(direction="minimize")
@@ -363,12 +358,12 @@ def objective(trial):
363358
test_dataset.to_csv(model.data_dir / ("df_testset_" + target + ".csv"))
364359

365360
checkpoint_callback = pl.callbacks.ModelCheckpoint(
366-
os.path.join(model_dir, "train_{epoch}_{valid/MSE:.2f}"), monitor="valid/MSE",
361+
os.path.join(model_dir, "train_{epoch}_{valid/MAD:.2f}"), monitor="valid/MAD",
367362
period=10
368363
)
369364

370365
early_stop_callback = EarlyStopping(
371-
monitor='valid/MSE',
366+
monitor='valid/MAD',
372367
min_delta=0.001,
373368
patience=30,
374369
verbose=True,
@@ -417,7 +412,7 @@ def __init__(self, *args, **kwargs):
417412
self.features_nonperiodic = kwargs.get('features_nonperiodic',
418413
["temp", "wind_spd", "wind_cdir", "wind_sdir",
419414
"pres", "humid", "prep"])
420-
self.metrics = kwargs.get('metrics', ['MAE', 'MSE', 'R2'])
415+
self.metrics = kwargs.get('metrics', ['MAE', 'MSE', 'R2', 'MAD'])
421416
self.num_workers = kwargs.get('num_workers', 1)
422417
self.output_dir = kwargs.get(
423418
'output_dir', Path('/mnt/data/MLPMS2Multivariate/'))
@@ -516,15 +511,17 @@ def training_step(self, batch, batch_idx):
516511
y_hat = _y_hat.detach().cpu().clone().numpy()
517512
y_raw = _y_raw.detach().cpu().clone().numpy()
518513

519-
_mae = mean_absolute_error(y_hat, y)
520-
_mse = mean_squared_error(y_hat, y)
521-
_r2 = r2_score(y_hat, y)
514+
_mae = mean_absolute_error(y, y_hat)
515+
_mse = mean_squared_error(y, y_hat)
516+
_r2 = r2_score(y, y_hat)
517+
_mad = median_abs_deviation(y - y_hat)
522518

523519
return {
524520
'loss': _loss,
525521
'metric': {
526522
'MSE': _mse,
527523
'MAE': _mae,
524+
'MAD': _mad,
528525
'R2': _r2
529526
}
530527
}
@@ -546,6 +543,7 @@ def training_epoch_end(self, outputs):
546543
# self.log('train/loss', tensorboard_logs['train/loss'].item(), prog_bar=True)
547544
self.log('train/MSE', tensorboard_logs['train/MSE'].item(), on_epoch=True, logger=self.logger)
548545
self.log('train/MAE', tensorboard_logs['train/MAE'].item(), on_epoch=True, logger=self.logger)
546+
self.log('train/MAD', tensorboard_logs['train/MAD'].item(), on_epoch=True, logger=self.logger)
549547
self.log('train/avg_loss', _log['loss'], on_epoch=True, logger=self.logger)
550548

551549
def validation_step(self, batch, batch_idx):
@@ -557,15 +555,17 @@ def validation_step(self, batch, batch_idx):
557555
y_hat = _y_hat.detach().cpu().clone().numpy()
558556
y_raw = _y_raw.detach().cpu().clone().numpy()
559557

560-
_mae = mean_absolute_error(y_hat, y)
561-
_mse = mean_squared_error(y_hat, y)
562-
_r2 = r2_score(y_hat, y)
558+
_mae = mean_absolute_error(y, y_hat)
559+
_mse = mean_squared_error(y, y_hat)
560+
_r2 = r2_score(y, y_hat)
561+
_mad = median_abs_deviation(y - y_hat)
563562

564563
return {
565564
'loss': _loss,
566565
'metric': {
567566
'MSE': _mse,
568567
'MAE': _mae,
568+
'MAD': _mad,
569569
'R2': _r2
570570
}
571571
}
@@ -586,6 +586,7 @@ def validation_epoch_end(self, outputs):
586586

587587
self.log('valid/MSE', tensorboard_logs['valid/MSE'].item(), on_epoch=True, logger=self.logger)
588588
self.log('valid/MAE', tensorboard_logs['valid/MAE'].item(), on_epoch=True, logger=self.logger)
589+
self.log('valid/MAD', tensorboard_logs['valid/MAD'].item(), on_epoch=True, logger=self.logger)
589590
self.log('valid/loss', _log['loss'], on_epoch=True, logger=self.logger)
590591

591592
def test_step(self, batch, batch_idx):
@@ -598,11 +599,12 @@ def test_step(self, batch, batch_idx):
598599
y_hat = _y_hat.detach().cpu().clone().numpy()
599600
y_hat2 = relu_mul(
600601
np.array(self.test_dataset.inverse_transform(y_hat, dates)))
601-
_loss = self.loss(torch.as_tensor(y_hat2).to(device), _y_raw)
602+
_loss = self.loss(_y_raw, torch.as_tensor(y_hat2).to(device))
602603

603-
_mae = mean_absolute_error(y_hat2, y_raw)
604-
_mse = mean_squared_error(y_hat2, y_raw)
605-
_r2 = r2_score(y_hat2, y_raw)
604+
_mae = mean_absolute_error(y_raw, y_hat2)
605+
_mse = mean_squared_error(y_raw, y_hat2)
606+
_r2 = r2_score(y_raw, y_hat2)
607+
_mad = median_abs_deviation(y_raw - y_hat2)
606608

607609
return {
608610
'loss': _loss,
@@ -612,6 +614,7 @@ def test_step(self, batch, batch_idx):
612614
'metric': {
613615
'MSE': _mse,
614616
'MAE': _mae,
617+
'MAD': _mad,
615618
'R2': _r2
616619
}
617620
}
@@ -660,6 +663,7 @@ def test_epoch_end(self, outputs):
660663

661664
self.log('test/MSE', tensorboard_logs['test/MSE'].item(), on_epoch=True, logger=self.logger)
662665
self.log('test/MAE', tensorboard_logs['test/MAE'].item(), on_epoch=True, logger=self.logger)
666+
self.log('test/MAD', tensorboard_logs['test/MAD'].item(), on_epoch=True, logger=self.logger)
663667
self.log('test/loss', avg_loss, on_epoch=True, logger=self.logger)
664668

665669
self.df_obs = df_obs
@@ -1083,31 +1087,6 @@ def _mccr(x):
10831087
return torch.mean(_mccr(input - target))
10841088

10851089

1086-
class LogCoshLoss(nn.Module):
1087-
__constants__ = ['reduction']
1088-
1089-
def __init__(self):
1090-
super().__init__()
1091-
1092-
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
1093-
"""
1094-
Implement numerically stable log-cosh which is used in Keras
1095-
1096-
log(cosh(x)) = logaddexp(x, -x) - log(2)
1097-
= abs(x) + log1p(exp(-2 * abs(x))) - log(2)
1098-
1099-
Reference:
1100-
* https://stackoverflow.com/a/57786270
1101-
"""
1102-
# not to compute log(0), add 1e-24 (small value)
1103-
def _log_cosh(x):
1104-
return torch.abs(x) + \
1105-
torch.log1p(torch.exp(-2 * torch.abs(x))) + \
1106-
torch.log(torch.full_like(x, 2, dtype=x.dtype))
1107-
1108-
return torch.mean(_log_cosh(input - target))
1109-
1110-
11111090
def relu_mul(x):
11121091
"""[fastest method](https://stackoverflow.com/a/32109519/743078)
11131092
"""

mise/ml/mlp_mul_ms_mccr.py

Lines changed: 26 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from pytorch_lightning.callbacks import Callback
2929
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
3030
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
31+
32+
from scipy.stats import median_abs_deviation
3133
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
3234
import sklearn.metrics
3335

@@ -262,15 +264,15 @@ def objective(trial):
262264
logger=True,
263265
checkpoint_callback=False,
264266
callbacks=[PyTorchLightningPruningCallback(
265-
trial, monitor="valid/MSE")])
267+
trial, monitor="valid/MAD")])
266268

267269
trainer.fit(model)
268270

269271
# Don't Log
270272
# hyperparameters = model.hparams
271273
# trainer.logger.log_hyperparams(hyperparameters)
272274

273-
return trainer.callback_metrics.get("valid/MSE")
275+
return trainer.callback_metrics.get("valid/MAD")
274276

275277
if n_trials > 1:
276278
study = optuna.create_study(direction="minimize")
@@ -356,12 +358,12 @@ def objective(trial):
356358
test_dataset.to_csv(model.data_dir / ("df_testset_" + target + ".csv"))
357359

358360
checkpoint_callback = pl.callbacks.ModelCheckpoint(
359-
os.path.join(model_dir, "train_{epoch}_{valid/MSE:.2f}"), monitor="valid/MSE",
361+
os.path.join(model_dir, "train_{epoch}_{valid/MAD:.2f}"), monitor="valid/MAD",
360362
period=10
361363
)
362364

363365
early_stop_callback = EarlyStopping(
364-
monitor='valid/MSE',
366+
monitor='valid/MAD',
365367
min_delta=0.001,
366368
patience=30,
367369
verbose=True,
@@ -412,7 +414,7 @@ def __init__(self, *args, **kwargs):
412414
self.features_nonperiodic = kwargs.get('features_nonperiodic',
413415
["temp", "wind_spd", "wind_cdir", "wind_sdir",
414416
"pres", "humid", "prep"])
415-
self.metrics = kwargs.get('metrics', ['MAE', 'MSE', 'R2'])
417+
self.metrics = kwargs.get('metrics', ['MAE', 'MSE', 'R2', 'MAD'])
416418
self.num_workers = kwargs.get('num_workers', 1)
417419
self.output_dir = kwargs.get(
418420
'output_dir', Path('/mnt/data/MLPMS2Multivariate/'))
@@ -515,15 +517,17 @@ def training_step(self, batch, batch_idx):
515517
y_hat = _y_hat.detach().cpu().clone().numpy()
516518
y_raw = _y_raw.detach().cpu().clone().numpy()
517519

518-
_mae = mean_absolute_error(y_hat, y)
519-
_mse = mean_squared_error(y_hat, y)
520-
_r2 = r2_score(y_hat, y)
520+
_mae = mean_absolute_error(y, y_hat)
521+
_mse = mean_squared_error(y, y_hat)
522+
_r2 = r2_score(y, y_hat)
523+
_mad = median_abs_deviation(y - y_hat)
521524

522525
return {
523526
'loss': _loss,
524527
'metric': {
525528
'MSE': _mse,
526529
'MAE': _mae,
530+
'MAD': _mad,
527531
'R2': _r2
528532
}
529533
}
@@ -545,6 +549,7 @@ def training_epoch_end(self, outputs):
545549
# self.log('train/loss', tensorboard_logs['train/loss'].item(), prog_bar=True)
546550
self.log('train/MSE', tensorboard_logs['train/MSE'].item(), on_epoch=True, logger=self.logger)
547551
self.log('train/MAE', tensorboard_logs['train/MAE'].item(), on_epoch=True, logger=self.logger)
552+
self.log('train/MAD', tensorboard_logs['train/MAD'].item(), on_epoch=True, logger=self.logger)
548553
self.log('train/avg_loss', _log['loss'], on_epoch=True, logger=self.logger)
549554

550555
def validation_step(self, batch, batch_idx):
@@ -556,15 +561,17 @@ def validation_step(self, batch, batch_idx):
556561
y_hat = _y_hat.detach().cpu().clone().numpy()
557562
y_raw = _y_raw.detach().cpu().clone().numpy()
558563

559-
_mae = mean_absolute_error(y_hat, y)
560-
_mse = mean_squared_error(y_hat, y)
561-
_r2 = r2_score(y_hat, y)
564+
_mae = mean_absolute_error(y, y_hat)
565+
_mse = mean_squared_error(y, y_hat)
566+
_r2 = r2_score(y, y_hat)
567+
_mad = median_abs_deviation(y - y_hat)
562568

563569
return {
564570
'loss': _loss,
565571
'metric': {
566572
'MSE': _mse,
567573
'MAE': _mae,
574+
'MAD': _mad,
568575
'R2': _r2
569576
}
570577
}
@@ -585,6 +592,7 @@ def validation_epoch_end(self, outputs):
585592

586593
self.log('valid/MSE', tensorboard_logs['valid/MSE'].item(), on_epoch=True, logger=self.logger)
587594
self.log('valid/MAE', tensorboard_logs['valid/MAE'].item(), on_epoch=True, logger=self.logger)
595+
self.log('valid/MAD', tensorboard_logs['valid/MAD'].item(), on_epoch=True, logger=self.logger)
588596
self.log('valid/loss', _log['loss'], on_epoch=True, logger=self.logger)
589597

590598
def test_step(self, batch, batch_idx):
@@ -597,11 +605,12 @@ def test_step(self, batch, batch_idx):
597605
y_hat = _y_hat.detach().cpu().clone().numpy()
598606
y_hat2 = relu_mul(
599607
np.array(self.test_dataset.inverse_transform(y_hat, dates)))
600-
_loss = self.loss(torch.as_tensor(y_hat2).to(device), _y_raw)
608+
_loss = self.loss(_y_raw, torch.as_tensor(y_hat2).to(device))
601609

602-
_mae = mean_absolute_error(y_hat2, y_raw)
603-
_mse = mean_squared_error(y_hat2, y_raw)
604-
_r2 = r2_score(y_hat2, y_raw)
610+
_mae = mean_absolute_error(y_raw, y_hat2)
611+
_mse = mean_squared_error(y_raw, y_hat2)
612+
_r2 = r2_score(y_raw, y_hat2)
613+
_mad = median_abs_deviation(y_raw - y_hat2)
605614

606615
return {
607616
'loss': _loss,
@@ -611,6 +620,7 @@ def test_step(self, batch, batch_idx):
611620
'metric': {
612621
'MSE': _mse,
613622
'MAE': _mae,
623+
'MAD': _mad,
614624
'R2': _r2
615625
}
616626
}
@@ -659,6 +669,7 @@ def test_epoch_end(self, outputs):
659669

660670
self.log('test/MSE', tensorboard_logs['test/MSE'].item(), on_epoch=True, logger=self.logger)
661671
self.log('test/MAE', tensorboard_logs['test/MAE'].item(), on_epoch=True, logger=self.logger)
672+
self.log('test/MAD', tensorboard_logs['test/MAD'].item(), on_epoch=True, logger=self.logger)
662673
self.log('test/loss', avg_loss, on_epoch=True, logger=self.logger)
663674

664675
self.df_obs = df_obs
@@ -1079,31 +1090,6 @@ def forward(self, _input: torch.Tensor, _target: torch.Tensor) -> torch.Tensor:
10791090
self.sigma2 * (1-torch.exp(-(_input - _target)**2 / self.sigma2)))
10801091

10811092

1082-
class LogCoshLoss(nn.Module):
1083-
__constants__ = ['reduction']
1084-
1085-
def __init__(self):
1086-
super().__init__()
1087-
1088-
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
1089-
"""
1090-
Implement numerically stable log-cosh which is used in Keras
1091-
1092-
log(cosh(x)) = logaddexp(x, -x) - log(2)
1093-
= abs(x) + log1p(exp(-2 * abs(x))) - log(2)
1094-
1095-
Reference:
1096-
* https://stackoverflow.com/a/57786270
1097-
"""
1098-
# not to compute log(0), add 1e-24 (small value)
1099-
def _log_cosh(x):
1100-
return torch.abs(x) + \
1101-
torch.log1p(torch.exp(-2 * torch.abs(x))) + \
1102-
torch.log(torch.full_like(x, 2, dtype=x.dtype))
1103-
1104-
return torch.mean(_log_cosh(input - target))
1105-
1106-
11071093
def relu_mul(x):
11081094
"""[fastest method](https://stackoverflow.com/a/32109519/743078)
11091095
"""

0 commit comments

Comments
 (0)