Skip to content

Commit d4749cd

Browse files
aloctavodiaJunpeng Lao
authored and
Junpeng Lao
committed
improve error messages and round results (#2524)
* improve error messages and round results * change NotImplementedError to ValueError and data points to observations * fix small negative number issue without rounding
1 parent 004ce61 commit d4749cd

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

pymc3/stats.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def bpic(trace, model=None):
342342

343343

344344
def compare(traces, models, ic='WAIC', method='stacking', b_samples=1000,
345-
alpha=1, seed=None):
345+
alpha=1, seed=None, round_to=2):
346346
"""Compare models based on the widely available information criterion (WAIC)
347347
or leave-one-out (LOO) cross-validation.
348348
Read more theory here - in a paper by some of the leading authorities on
@@ -378,6 +378,8 @@ def compare(traces, models, ic='WAIC', method='stacking', b_samples=1000,
378378
If int or RandomState, use it for seeding Bayesian bootstrap. Only
379379
useful when method = 'BB-pseudo-BMA'. Default None the global
380380
np.random state is used.
381+
round_to : int
382+
Number of decimals used to round results (default 2).
381383
382384
Returns
383385
-------
@@ -421,11 +423,11 @@ def compare(traces, models, ic='WAIC', method='stacking', b_samples=1000,
421423

422424
if len(set([len(m.observed_RVs) for m in models])) != 1:
423425
raise ValueError(
424-
'The Observed RVs should be the same across all models')
426+
'The number of observed RVs should be the same across all models')
425427

426428
if method not in ['stacking', 'BB-pseudo-BMA', 'pseudo-BMA']:
427-
raise NotImplementedError(
428-
'The method to compute weights {} is not supported.'.format(method))
429+
raise ValueError('The method {}, to compute weights,'
430+
'is not supported.'.format(method))
429431

430432
warns = np.zeros(len(models))
431433

@@ -449,7 +451,7 @@ def add_warns(*args):
449451
Km = K - 1
450452

451453
def w_fuller(w):
452-
return np.concatenate((w, 1. - np.sum(w, keepdims=True)))
454+
return np.concatenate((w, [max(1. - np.sum(w), 0.)]))
453455

454456
def log_score(w):
455457
w_full = w_fuller(w)
@@ -514,7 +516,12 @@ def gradient(w):
514516
d_se = np.sqrt(len(diff) * np.var(diff))
515517
se = ses[i]
516518
weight = weights[i]
517-
df_comp.at[idx] = (res[0], res[2], d_ic, weight, se, d_se,
519+
df_comp.at[idx] = (round(res[0], round_to),
520+
round(res[2], round_to),
521+
round(d_ic, round_to),
522+
round(weight, round_to),
523+
round(se, round_to),
524+
round(d_se, round_to),
518525
warns[idx])
519526

520527
return df_comp.sort_values(by=ic)
@@ -526,10 +533,15 @@ def _ic_matrix(ics):
526533
"""
527534
N = len(ics[0][1][3])
528535
K = len(ics)
529-
530536
ic_i = np.zeros((N, K))
537+
531538
for i in range(K):
532-
ic_i[:, i] = ics[i][1][3]
539+
ic = ics[i][1][3]
540+
if len(ic) != N:
541+
raise ValueError('The number of observations should be the same '
542+
'across all models')
543+
else:
544+
ic_i[:, i] = ic
533545

534546
return N, K, ic_i
535547

0 commit comments

Comments
 (0)