Skip to content

Commit bbb715d

Browse files
authored
feature: add warnings for xgboost specific rules in debugger rules (#3255)
1 parent c8d70a3 commit bbb715d

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

src/sagemaker/estimator.py

+12
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,8 @@ def _prepare_rules(self):
792792
if self.rules is not None:
793793
for rule in self.rules:
794794
if isinstance(rule, Rule):
795+
# Add check for xgboost rules
796+
self._check_debugger_rule(rule)
795797
self.debugger_rules.append(rule)
796798
elif isinstance(rule, ProfilerRule):
797799
self.profiler_rules.append(rule)
@@ -801,6 +803,16 @@ def _prepare_rules(self):
801803
+ "and sagemaker.debugger.ProfilerRule"
802804
)
803805

806+
def _check_debugger_rule(self, rule):
807+
"""Add warning for incorrectly used xgboost rules."""
808+
_xgboost_specific_rules = ["FeatureImportanceOverweight", "TreeDepth"]
809+
if rule.name in _xgboost_specific_rules:
810+
logger.warning(
811+
"TreeDepth and FeatureImportanceOverweight rules are valid "
812+
"only for the XGBoost algorithm. Please make sure this estimator "
813+
"is used for XGBoost algorithm. "
814+
)
815+
804816
def _prepare_debugger_for_training(self):
805817
"""Prepare debugger rules and debugger configs for training."""
806818
if self.debugger_rules and self.debugger_hook_config is None:

0 commit comments

Comments
 (0)