Skip to content

Commit 8308e0c

Browse files
committed
AdaBoost API updates
1 parent 3e646c1 commit 8308e0c

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

q2_sample_classifier/tests/test_estimators.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def test_train_adaboost_decision_tree(self):
398398
parameter_tuning=True, classification=True,
399399
missing_samples='ignore', base_estimator="DecisionTree")
400400
self.assertEqual(type(abe.named_steps.est), AdaBoostClassifier)
401-
self.assertEqual(type(abe.named_steps.est.base_estimator),
401+
self.assertEqual(type(abe.named_steps.est.estimator),
402402
DecisionTreeClassifier)
403403

404404
def test_train_adaboost_extra_trees(self):
@@ -408,7 +408,7 @@ def test_train_adaboost_extra_trees(self):
408408
parameter_tuning=True, classification=True,
409409
missing_samples='ignore', base_estimator="ExtraTrees")
410410
self.assertEqual(type(abe.named_steps.est), AdaBoostClassifier)
411-
self.assertEqual(type(abe.named_steps.est.base_estimator),
411+
self.assertEqual(type(abe.named_steps.est.estimator),
412412
ExtraTreeClassifier)
413413

414414
# test some invalid inputs/edge cases

q2_sample_classifier/utilities.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -827,8 +827,9 @@ def _train_adaboost_base_estimator(table, metadata, column, base_estimator,
827827

828828
return Pipeline(
829829
[('dv', estimator.named_steps.dv),
830-
('est', adaboost_estimator(estimator.named_steps.est,
831-
n_estimators, random_state=random_state))])
830+
('est', adaboost_estimator(estimator=estimator.named_steps.est,
831+
n_estimators=n_estimators,
832+
random_state=random_state))])
832833

833834

834835
def _disable_feature_selection(estimator, optimize_feature_selection):

0 commit comments

Comments
 (0)