34
34
TRAINING_DATASETS_MAX_SIZE ,
35
35
TRAINING_METRICS_MAX_SIZE ,
36
36
USER_PROVIDED_TRAINING_METRICS_MAX_SIZE ,
37
+ HYPER_PARAMETERS_MAX_SIZE ,
38
+ USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE ,
37
39
EVALUATION_DATASETS_MAX_SIZE ,
38
40
)
39
41
from sagemaker .model_card .helpers import (
@@ -235,6 +237,27 @@ def __init__(
235
237
self .explanations_for_risk_rating = explanations_for_risk_rating
236
238
237
239
240
+ class BusinessDetails (_DefaultToRequestDict , _DefaultFromDict ):
241
+ """The business details of a model."""
242
+
243
+ def __init__ (
244
+ self ,
245
+ business_problem : Optional [str ] = None ,
246
+ business_stakeholders : Optional [str ] = None ,
247
+ line_of_business : Optional [str ] = None ,
248
+ ):
249
+ """Initialize an Business Details object.
250
+
251
+ Args:
252
+ business_problem (str, optional): The business problem of this model (default: None).
253
+ business_stakeholders (str, optional): The business stakeholders for this model (default: None).
254
+ line_of_business (str, optional): The line of business for this model (default: None).
255
+ """ # noqa E501 # pylint: disable=line-too-long
256
+ self .business_problem = business_problem
257
+ self .business_stakeholders = business_stakeholders
258
+ self .line_of_business = line_of_business
259
+
260
+
238
261
class Function (_DefaultToRequestDict , _DefaultFromDict ):
239
262
"""Function details."""
240
263
@@ -363,6 +386,24 @@ def __init__(
363
386
self .notes = notes
364
387
365
388
389
+ class HyperParameter (_DefaultToRequestDict , _DefaultFromDict ):
390
+ """Hyper-Parameters data."""
391
+
392
+ def __init__ (
393
+ self ,
394
+ name : str ,
395
+ value : str ,
396
+ ):
397
+ """Initialize a HyperParameter object.
398
+
399
+ Args:
400
+ name (str): The hyper parameter name.
401
+ value (str): The hyper parameter value.
402
+ """
403
+ self .name = name
404
+ self .value = value
405
+
406
+
366
407
class TrainingJobDetails (_DefaultToRequestDict , _DefaultFromDict ):
367
408
"""The overview of a training job."""
368
409
@@ -371,6 +412,10 @@ class TrainingJobDetails(_DefaultToRequestDict, _DefaultFromDict):
371
412
user_provided_training_metrics = _IsList (
372
413
TrainingMetric , USER_PROVIDED_TRAINING_METRICS_MAX_SIZE
373
414
)
415
+ hyper_parameters = _IsList (HyperParameter , HYPER_PARAMETERS_MAX_SIZE )
416
+ user_provided_hyper_parameters = _IsList (
417
+ HyperParameter , USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE
418
+ )
374
419
training_environment = _IsModelCardObject (Environment )
375
420
376
421
def __init__ (
@@ -380,6 +425,8 @@ def __init__(
380
425
training_environment : Optional [Environment ] = None ,
381
426
training_metrics : Optional [List [TrainingMetric ]] = None ,
382
427
user_provided_training_metrics : Optional [List [TrainingMetric ]] = None ,
428
+ hyper_parameters : Optional [List [HyperParameter ]] = None ,
429
+ user_provided_hyper_parameters : Optional [List [HyperParameter ]] = None ,
383
430
):
384
431
"""Initialize a Training Job Details object.
385
432
@@ -389,12 +436,16 @@ def __init__(
389
436
training_environment (Environment, optional): The SageMaker training image URI. (default: None).
390
437
training_metrics (list[TrainingMetric], optional): SageMaker training job results. The maximum `training_metrics` list length is 50 (default: None).
391
438
user_provided_training_metrics (list[TrainingMetric], optional): Custom training job results. The maximum `user_provided_training_metrics` list length is 50 (default: None).
439
+ hyper_parameters (list[HyperParameter], optional): SageMaker hyper parameter results. The maximum `hyper_parameters` list length is 100 (default: None).
440
+ user_provided_hyper_parameters (list[HyperParameter], optional): Custom hyper parameter results. The maximum `user_provided_hyper_parameters` list length is 100 (default: None).
392
441
""" # noqa E501 # pylint: disable=line-too-long
393
442
self .training_arn = training_arn
394
443
self .training_datasets = training_datasets
395
444
self .training_environment = training_environment
396
445
self .training_metrics = training_metrics
397
446
self .user_provided_training_metrics = user_provided_training_metrics
447
+ self .hyper_parameters = hyper_parameters
448
+ self .user_provided_hyper_parameters = user_provided_hyper_parameters
398
449
399
450
400
451
class TrainingDetails (_DefaultToRequestDict , _DefaultFromDict ):
@@ -442,6 +493,10 @@ def _create_training_details(training_job_data: dict, cls: "TrainingDetails", **
442
493
]
443
494
if "FinalMetricDataList" in training_job_data
444
495
else [],
496
+ "hyper_parameters" : [
497
+ HyperParameter (key , value )
498
+ for key , value in training_job_data ["HyperParameters" ].items ()
499
+ ],
445
500
}
446
501
kwargs .update ({"training_job_details" : TrainingJobDetails (** job )})
447
502
instance = cls (** kwargs )
@@ -568,6 +623,16 @@ def add_metric(self, metric: TrainingMetric):
568
623
self .training_job_details = TrainingJobDetails ()
569
624
self .training_job_details .user_provided_training_metrics .append (metric )
570
625
626
+ def add_parameter (self , parameter : HyperParameter ):
627
+ """Add custom hyper-parameter.
628
+
629
+ Args:
630
+ parameter (HyperParameter): The custom parameter to add.
631
+ """
632
+ if not self .training_job_details :
633
+ self .training_job_details = TrainingJobDetails ()
634
+ self .training_job_details .user_provided_hyper_parameters .append (parameter )
635
+
571
636
572
637
class MetricGroup (_DefaultToRequestDict , _DefaultFromDict ):
573
638
"""Group of metric data"""
@@ -777,6 +842,7 @@ class ModelCard(object):
777
842
status = _OneOf (ModelCardStatusEnum )
778
843
model_overview = _IsModelCardObject (ModelOverview )
779
844
intended_uses = _IsModelCardObject (IntendedUses )
845
+ business_details = _IsModelCardObject (BusinessDetails )
780
846
training_details = _IsModelCardObject (TrainingDetails )
781
847
evaluation_details = _IsList (EvaluationJob )
782
848
additional_information = _IsModelCardObject (AdditionalInformation )
@@ -793,6 +859,7 @@ def __init__(
793
859
last_modified_by : Optional [dict ] = None ,
794
860
model_overview : Optional [ModelOverview ] = None ,
795
861
intended_uses : Optional [IntendedUses ] = None ,
862
+ business_details : Optional [BusinessDetails ] = None ,
796
863
training_details : Optional [TrainingDetails ] = None ,
797
864
evaluation_details : Optional [List [EvaluationJob ]] = None ,
798
865
additional_information : Optional [AdditionalInformation ] = None ,
@@ -811,6 +878,7 @@ def __init__(
811
878
last_modified_by (dict, optional): The group or individual that last modified the model card (default: None).
812
879
model_overview (ModelOverview, optional): An overview of the model (default: None).
813
880
intended_uses (IntendedUses, optional): The intended uses of the model (default: None).
881
+ business_details (BusinessDetails, optional): The business details of the model (default: None).
814
882
training_details (TrainingDetails, optional): The training details of the model (default: None).
815
883
evaluation_details (List[EvaluationJob], optional): The evaluation details of the model (default: None).
816
884
additional_information (AdditionalInformation, optional): Additional information about the model (default: None).
@@ -826,6 +894,7 @@ def __init__(
826
894
self .last_modified_by = last_modified_by
827
895
self .model_overview = model_overview
828
896
self .intended_uses = intended_uses
897
+ self .business_details = business_details
829
898
self .training_details = training_details
830
899
self .evaluation_details = evaluation_details
831
900
self .additional_information = additional_information
0 commit comments