@@ -1307,6 +1307,36 @@ def run_explainability(
1307
1307
the Trial Component will be unassociated.
1308
1308
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
1309
1309
""" # noqa E501 # pylint: disable=c0301
1310
+ analysis_config = _AnalysisConfigGenerator .explainability (
1311
+ data_config ,
1312
+ model_config ,
1313
+ model_scores ,
1314
+ explainability_config
1315
+ )
1316
+ if job_name is None :
1317
+ if self .job_name_prefix :
1318
+ job_name = utils .name_from_base (self .job_name_prefix )
1319
+ else :
1320
+ job_name = utils .name_from_base ("Clarify-Explainability" )
1321
+ return self ._run (
1322
+ data_config ,
1323
+ analysis_config ,
1324
+ wait ,
1325
+ logs ,
1326
+ job_name ,
1327
+ kms_key ,
1328
+ experiment_config ,
1329
+ )
1330
+
1331
+
1332
+ class _AnalysisConfigGenerator :
1333
+ @staticmethod
1334
+ def explainability (
1335
+ data_config ,
1336
+ model_config ,
1337
+ model_scores ,
1338
+ explainability_config
1339
+ ):
1310
1340
analysis_config = data_config .get_config ()
1311
1341
predictor_config = model_config .get_predictor_config ()
1312
1342
if isinstance (model_scores , ModelPredictedLabelConfig ):
@@ -1329,34 +1359,21 @@ def run_explainability(
1329
1359
if not len (explainability_methods .keys ()) == len (explainability_config ):
1330
1360
raise ValueError ("Duplicate explainability configs are provided" )
1331
1361
if (
1332
- "shap" not in explainability_methods
1333
- and explainability_methods ["pdp" ].get ("features" , None ) is None
1362
+ "shap" not in explainability_methods
1363
+ and explainability_methods ["pdp" ].get ("features" , None ) is None
1334
1364
):
1335
1365
raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1336
1366
else :
1337
1367
if (
1338
- isinstance (explainability_config , PDPConfig )
1339
- and explainability_config .get_explainability_config ()["pdp" ].get ("features" , None )
1340
- is None
1368
+ isinstance (explainability_config , PDPConfig )
1369
+ and explainability_config .get_explainability_config ()["pdp" ].get ("features" , None )
1370
+ is None
1341
1371
):
1342
1372
raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1343
1373
explainability_methods = explainability_config .get_explainability_config ()
1344
1374
analysis_config ["methods" ] = explainability_methods
1345
1375
analysis_config ["predictor" ] = predictor_config
1346
- if job_name is None :
1347
- if self .job_name_prefix :
1348
- job_name = utils .name_from_base (self .job_name_prefix )
1349
- else :
1350
- job_name = utils .name_from_base ("Clarify-Explainability" )
1351
- return self ._run (
1352
- data_config ,
1353
- analysis_config ,
1354
- wait ,
1355
- logs ,
1356
- job_name ,
1357
- kms_key ,
1358
- experiment_config ,
1359
- )
1376
+ return analysis_config
1360
1377
1361
1378
1362
1379
def _upload_analysis_config (analysis_config_file , s3_output_path , sagemaker_session , kms_key ):
0 commit comments