File tree Expand file tree Collapse file tree 2 files changed +15
-4
lines changed Expand file tree Collapse file tree 2 files changed +15
-4
lines changed Original file line number Diff line number Diff line change 1
1
import arviz as az
2
2
import matplotlib .pyplot as plt
3
3
import numpy as np
4
- import pandas as pd
5
4
6
5
from numpy .random import RandomState
7
6
from scipy .interpolate import griddata
@@ -147,13 +146,13 @@ def plot_dependence(
147
146
148
147
rng = RandomState (seed = random_seed )
149
148
150
- if isinstance (X , pd . DataFrame ):
149
+ if hasattr (X , "columns" ) and hasattr ( X , "values" ):
151
150
X_names = list (X .columns )
152
151
X = X .values
153
152
else :
154
153
X_names = []
155
154
156
- if isinstance (Y , pd . DataFrame ):
155
+ if hasattr (Y , "name" ):
157
156
Y_label = f"Predicted { Y .name } "
158
157
else :
159
158
Y_label = "Predicted Y"
Original file line number Diff line number Diff line change 2
2
import pytest
3
3
4
4
from numpy .random import RandomState
5
- from numpy .testing import assert_almost_equal
5
+ from numpy .testing import assert_almost_equal , assert_array_equal
6
6
7
7
import pymc as pm
8
8
@@ -103,6 +103,18 @@ def test_predict(self):
103
103
def test_pdp (self , kwargs ):
104
104
pm .bart .utils .plot_dependence (self .idata , X = self .X , Y = self .Y , ** kwargs )
105
105
106
+ def test_pdp_pandas_labels (self ):
107
+ pd = pytest .importorskip ("pandas" )
108
+
109
+ X_names = ["norm1" , "norm2" , "binom" ]
110
+ X_pd = pd .DataFrame (self .X , columns = X_names )
111
+ Y_pd = pd .Series (self .Y , name = "response" )
112
+ axes = pm .bart .utils .plot_dependence (self .idata , X = X_pd , Y = Y_pd )
113
+
114
+ figure = axes [0 ].figure
115
+ assert figure .texts [0 ].get_text () == "Predicted response"
116
+ assert_array_equal ([ax .get_xlabel () for ax in axes ], X_names )
117
+
106
118
107
119
@pytest .mark .parametrize (
108
120
"size, expected" ,
You can’t perform that action at this time.
0 commit comments