Skip to content

Commit 7f0ec36

Browse files
committed
AdaBoost API updates
1 parent 2ec10d4 commit 7f0ec36

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

q2_sample_classifier/tests/test_estimators.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def test_train_adaboost_decision_tree(self):
398398
parameter_tuning=True, classification=True,
399399
missing_samples='ignore', base_estimator="DecisionTree")
400400
self.assertEqual(type(abe.named_steps.est), AdaBoostClassifier)
401-
self.assertEqual(type(abe.named_steps.est.base_estimator),
401+
self.assertEqual(type(abe.named_steps.est.estimator),
402402
DecisionTreeClassifier)
403403

404404
def test_train_adaboost_extra_trees(self):
@@ -408,7 +408,7 @@ def test_train_adaboost_extra_trees(self):
408408
parameter_tuning=True, classification=True,
409409
missing_samples='ignore', base_estimator="ExtraTrees")
410410
self.assertEqual(type(abe.named_steps.est), AdaBoostClassifier)
411-
self.assertEqual(type(abe.named_steps.est.base_estimator),
411+
self.assertEqual(type(abe.named_steps.est.estimator),
412412
ExtraTreeClassifier)
413413

414414
# test some invalid inputs/edge cases

q2_sample_classifier/utilities.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -828,8 +828,9 @@ def _train_adaboost_base_estimator(table, metadata, column, base_estimator,
828828

829829
return Pipeline(
830830
[('dv', estimator.named_steps.dv),
831-
('est', adaboost_estimator(estimator.named_steps.est,
832-
n_estimators, random_state=random_state))])
831+
('est', adaboost_estimator(estimator=estimator.named_steps.est,
832+
n_estimators=n_estimators,
833+
random_state=random_state))])
833834

834835

835836
def _disable_feature_selection(estimator, optimize_feature_selection):

0 commit comments

Comments
 (0)