diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index b6fd68f472..d63d5711f0 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -779,6 +779,8 @@ def _prepare_rules(self): if self.rules is not None: for rule in self.rules: if isinstance(rule, Rule): + # Add check for xgboost rules + self._check_debugger_rule(rule) self.debugger_rules.append(rule) elif isinstance(rule, ProfilerRule): self.profiler_rules.append(rule) @@ -788,6 +790,16 @@ def _prepare_rules(self): + "and sagemaker.debugger.ProfilerRule" ) + def _check_debugger_rule(self, rule): + """Add warning for incorrectly used xgboost rules.""" + _xgboost_specific_rules = ["FeatureImportanceOverweight", "TreeDepth"] + if rule.name in _xgboost_specific_rules: + logger.warning( + "TreeDepth and FeatureImportanceOverweight rules are valid " + "only for the XGBoost algorithm. Please make sure this estimator " + "is used for XGBoost algorithm. " + ) + def _prepare_debugger_for_training(self): """Prepare debugger rules and debugger configs for training.""" if self.debugger_rules and self.debugger_hook_config is None: