Skip to content

Commit 3ebd239

Browse files
ensuring boolean columns get cast to float for statsmodels
1 parent 18e95bc commit 3ebd239

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

ISLP/models/model_spec.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def fit(self, X, y=None):
107107
cats = self.encoder_.categories_[0]
108108
column_names = [str(n) for n in cats]
109109

110-
111110
if isinstance(X, pd.DataFrame): # expecting a column, we take .iloc[:,0]
112111
X = X.iloc[:,0]
113112

@@ -635,18 +634,23 @@ def build_model(column_info,
635634
if isinstance(X, (pd.Series, pd.DataFrame)):
636635
df = pd.concat(dfs, axis=1)
637636
df.index = X.index
638-
return df
639637
else:
640-
return np.column_stack(dfs)
638+
return np.column_stack(dfs).astype(float)
641639
else: # return a 0 design
642640
zero = np.zeros(X.shape[0])
643641
if isinstance(X, (pd.Series, pd.DataFrame)):
644642
df = pd.DataFrame({'zero': zero})
645643
df.index = X.index
646-
return df
647644
else:
648645
return zero
649646

647+
# if we reach here, we will be returning a DataFrame
648+
649+
for col in df.columns:
650+
if df[col].dtype == bool:
651+
df[col] = df[col].astype(float)
652+
return df
653+
650654
def derived_feature(variables, encoder=None, name=None, use_transform=True):
651655
"""
652656
Create a Feature, optionally

tests/models/test_boolean_columns.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pandas as pd
2+
import statsmodels.api as sm
3+
import numpy as np
4+
from itertools import combinations
5+
6+
from ISLP.models import ModelSpec as MS
7+
8+
rng = np.random.default_rng(0)
9+
10+
df = pd.DataFrame({'A':rng.standard_normal(10),
11+
'B':np.array([1,2,3,2,1,1,1,3,2,1], int),
12+
'C':np.array([True,False,False,True,True]*2, bool),
13+
'D':rng.standard_normal(10)})
14+
Y = rng.standard_normal(10)
15+
16+
def test_all():
17+
18+
for i in range(1, 5):
19+
for comb in combinations(['A','B','C','D'], i):
20+
21+
X = MS(comb).fit_transform(df)
22+
sm.OLS(Y, X).fit()
23+

0 commit comments

Comments
 (0)