@@ -309,7 +309,9 @@ def plot_dependence(
309
309
return axes
310
310
311
311
312
- def plot_variable_importance (idata , X , labels = None , figsize = None , samples = 100 , random_seed = None ):
312
+ def plot_variable_importance (
313
+ idata , X , labels = None , sort_vars = True , figsize = None , samples = 100 , random_seed = None
314
+ ):
313
315
"""
314
316
Estimates variable importance from the BART-posterior.
315
317
@@ -319,9 +321,11 @@ def plot_variable_importance(idata, X, labels=None, figsize=None, samples=100, r
319
321
InferenceData containing a collection of BART_trees in sample_stats group
320
322
X : array-like
321
323
The covariate matrix.
322
- labels: list
324
+ labels : list
323
325
List of the names of the covariates. If X is a DataFrame the names of the covariables will
324
326
be taken from it and this argument will be ignored.
327
+ sort_vars : bool
328
+ Whether to sort the variables according to their variable importance. Defaults to True.
325
329
figsize : tuple
326
330
Figure size. If None it will be defined automatically.
327
331
samples : int
@@ -337,23 +341,29 @@ def plot_variable_importance(idata, X, labels=None, figsize=None, samples=100, r
337
341
_ , axes = plt .subplots (2 , 1 , figsize = figsize )
338
342
339
343
if hasattr (X , "columns" ) and hasattr (X , "values" ):
340
- labels = list ( X .columns )
344
+ labels = X .columns
341
345
X = X .values
342
346
343
347
VI = idata .sample_stats ["variable_inclusion" ].mean (("chain" , "draw" )).values
344
348
if labels is None :
345
- labels = range (len (VI ))
349
+ labels = np .arange (len (VI ))
350
+ else :
351
+ labels = np .array (labels )
346
352
347
353
ticks = np .arange (len (VI ), dtype = int )
348
354
idxs = np .argsort (VI )
349
355
subsets = [idxs [:- i ] for i in range (1 , len (idxs ))]
350
356
subsets .append (None )
351
357
352
- axes [0 ].plot (VI / VI .sum (), "o-" )
358
+ if sort_vars :
359
+ indices = idxs [::- 1 ]
360
+ else :
361
+ indices = np .arange (len (VI ))
362
+ axes [0 ].plot ((VI / VI .sum ())[indices ], "o-" )
353
363
axes [0 ].set_xticks (ticks )
354
- axes [0 ].set_xticklabels (labels )
355
- axes [0 ].set_xlabel ("variable index " )
356
- axes [0 ].set_ylabel ("relative importance" )
364
+ axes [0 ].set_xticklabels (labels [ indices ] )
365
+ axes [0 ].set_xlabel ("covariables " )
366
+ axes [0 ].set_ylabel ("importance" )
357
367
358
368
predicted_all = predict (idata , rng , X = X , size = samples , excluded = None )
359
369
@@ -363,16 +373,18 @@ def plot_variable_importance(idata, X, labels=None, figsize=None, samples=100, r
363
373
predicted_subset = predict (idata , rng , X = X , size = samples , excluded = subset )
364
374
pearson = np .zeros (samples )
365
375
for j in range (samples ):
366
- pearson [j ] = pearsonr (predicted_all [j ].flatten (), predicted_subset [j ].flatten ())[0 ]
376
+ pearson [j ] = (
377
+ pearsonr (predicted_all [j ].flatten (), predicted_subset [j ].flatten ())[0 ]
378
+ ) ** 2
367
379
EV_mean [idx ] = np .mean (pearson )
368
380
EV_hdi [idx ] = az .hdi (pearson )
369
381
370
382
axes [1 ].errorbar (ticks , EV_mean , np .array ((EV_mean - EV_hdi [:, 0 ], EV_hdi [:, 1 ] - EV_mean )))
371
383
372
384
axes [1 ].set_xticks (ticks )
373
385
axes [1 ].set_xticklabels (ticks + 1 )
374
- axes [1 ].set_xlabel ("number of components " )
375
- axes [1 ].set_ylabel ("correlation" )
386
+ axes [1 ].set_xlabel ("number of covariables " )
387
+ axes [1 ].set_ylabel ("R²" , rotation = 0 , labelpad = 12 )
376
388
axes [1 ].set_ylim (0 , 1 )
377
389
378
390
axes [0 ].set_xlim (- 0.5 , len (VI ) - 0.5 )
0 commit comments