Skip to content

Commit b6a9f0e

Browse files
AndywangnnHaonian Wang
authored andcommitted
feature: Add business details and hyper parameters fields and update test_model_card.py (aws#3639)
Co-authored-by: Haonian Wang <[email protected]>
1 parent 2e24ccd commit b6a9f0e

File tree

5 files changed

+139
-0
lines changed

5 files changed

+139
-0
lines changed

doc/api/governance/model_card.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,9 @@ see `Amazon SageMaker Model Cards <https://docs.aws.amazon.com/sagemaker/latest/
4141

4242
.. autoclass:: TrainingJobDetails
4343
:show-inheritance:
44+
45+
.. autoclass:: BusinessDetails
46+
:show-inheritance:
47+
48+
.. autoclass:: HyperParameter
49+
:show-inheritance:

src/sagemaker/model_card/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
Environment,
1717
ModelOverview,
1818
IntendedUses,
19+
BusinessDetails,
1920
ObjectiveFunction,
2021
TrainingMetric,
22+
HyperParameter,
2123
Metric,
2224
Function,
2325
TrainingJobDetails,

src/sagemaker/model_card/model_card.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
TRAINING_DATASETS_MAX_SIZE,
3535
TRAINING_METRICS_MAX_SIZE,
3636
USER_PROVIDED_TRAINING_METRICS_MAX_SIZE,
37+
HYPER_PARAMETERS_MAX_SIZE,
38+
USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE,
3739
EVALUATION_DATASETS_MAX_SIZE,
3840
)
3941
from sagemaker.model_card.helpers import (
@@ -235,6 +237,27 @@ def __init__(
235237
self.explanations_for_risk_rating = explanations_for_risk_rating
236238

237239

240+
class BusinessDetails(_DefaultToRequestDict, _DefaultFromDict):
241+
"""The business details of a model."""
242+
243+
def __init__(
244+
self,
245+
business_problem: Optional[str] = None,
246+
business_stakeholders: Optional[str] = None,
247+
line_of_business: Optional[str] = None,
248+
):
249+
"""Initialize an Business Details object.
250+
251+
Args:
252+
business_problem (str, optional): The business problem of this model (default: None).
253+
business_stakeholders (str, optional): The business stakeholders for this model (default: None).
254+
line_of_business (str, optional): The line of business for this model (default: None).
255+
""" # noqa E501 # pylint: disable=line-too-long
256+
self.business_problem = business_problem
257+
self.business_stakeholders = business_stakeholders
258+
self.line_of_business = line_of_business
259+
260+
238261
class Function(_DefaultToRequestDict, _DefaultFromDict):
239262
"""Function details."""
240263

@@ -363,6 +386,24 @@ def __init__(
363386
self.notes = notes
364387

365388

389+
class HyperParameter(_DefaultToRequestDict, _DefaultFromDict):
390+
"""Hyper-Parameters data."""
391+
392+
def __init__(
393+
self,
394+
name: str,
395+
value: str,
396+
):
397+
"""Initialize a HyperParameter object.
398+
399+
Args:
400+
name (str): The hyper parameter name.
401+
value (str): The hyper parameter value.
402+
"""
403+
self.name = name
404+
self.value = value
405+
406+
366407
class TrainingJobDetails(_DefaultToRequestDict, _DefaultFromDict):
367408
"""The overview of a training job."""
368409

@@ -371,6 +412,10 @@ class TrainingJobDetails(_DefaultToRequestDict, _DefaultFromDict):
371412
user_provided_training_metrics = _IsList(
372413
TrainingMetric, USER_PROVIDED_TRAINING_METRICS_MAX_SIZE
373414
)
415+
hyper_parameters = _IsList(HyperParameter, HYPER_PARAMETERS_MAX_SIZE)
416+
user_provided_hyper_parameters = _IsList(
417+
HyperParameter, USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE
418+
)
374419
training_environment = _IsModelCardObject(Environment)
375420

376421
def __init__(
@@ -380,6 +425,8 @@ def __init__(
380425
training_environment: Optional[Environment] = None,
381426
training_metrics: Optional[List[TrainingMetric]] = None,
382427
user_provided_training_metrics: Optional[List[TrainingMetric]] = None,
428+
hyper_parameters: Optional[List[HyperParameter]] = None,
429+
user_provided_hyper_parameters: Optional[List[HyperParameter]] = None,
383430
):
384431
"""Initialize a Training Job Details object.
385432
@@ -389,12 +436,16 @@ def __init__(
389436
training_environment (Environment, optional): The SageMaker training image URI. (default: None).
390437
training_metrics (list[TrainingMetric], optional): SageMaker training job results. The maximum `training_metrics` list length is 50 (default: None).
391438
user_provided_training_metrics (list[TrainingMetric], optional): Custom training job results. The maximum `user_provided_training_metrics` list length is 50 (default: None).
439+
hyper_parameters (list[HyperParameter], optional): SageMaker hyper parameter results. The maximum `hyper_parameters` list length is 100 (default: None).
440+
user_provided_hyper_parameters (list[HyperParameter], optional): Custom hyper parameter results. The maximum `user_provided_hyper_parameters` list length is 100 (default: None).
392441
""" # noqa E501 # pylint: disable=line-too-long
393442
self.training_arn = training_arn
394443
self.training_datasets = training_datasets
395444
self.training_environment = training_environment
396445
self.training_metrics = training_metrics
397446
self.user_provided_training_metrics = user_provided_training_metrics
447+
self.hyper_parameters = hyper_parameters
448+
self.user_provided_hyper_parameters = user_provided_hyper_parameters
398449

399450

400451
class TrainingDetails(_DefaultToRequestDict, _DefaultFromDict):
@@ -442,6 +493,10 @@ def _create_training_details(training_job_data: dict, cls: "TrainingDetails", **
442493
]
443494
if "FinalMetricDataList" in training_job_data
444495
else [],
496+
"hyper_parameters": [
497+
HyperParameter(key, value)
498+
for key, value in training_job_data["HyperParameters"].items()
499+
],
445500
}
446501
kwargs.update({"training_job_details": TrainingJobDetails(**job)})
447502
instance = cls(**kwargs)
@@ -568,6 +623,16 @@ def add_metric(self, metric: TrainingMetric):
568623
self.training_job_details = TrainingJobDetails()
569624
self.training_job_details.user_provided_training_metrics.append(metric)
570625

626+
def add_parameter(self, parameter: HyperParameter):
627+
"""Add custom hyper-parameter.
628+
629+
Args:
630+
parameter (HyperParameter): The custom parameter to add.
631+
"""
632+
if not self.training_job_details:
633+
self.training_job_details = TrainingJobDetails()
634+
self.training_job_details.user_provided_hyper_parameters.append(parameter)
635+
571636

572637
class MetricGroup(_DefaultToRequestDict, _DefaultFromDict):
573638
"""Group of metric data"""
@@ -777,6 +842,7 @@ class ModelCard(object):
777842
status = _OneOf(ModelCardStatusEnum)
778843
model_overview = _IsModelCardObject(ModelOverview)
779844
intended_uses = _IsModelCardObject(IntendedUses)
845+
business_details = _IsModelCardObject(BusinessDetails)
780846
training_details = _IsModelCardObject(TrainingDetails)
781847
evaluation_details = _IsList(EvaluationJob)
782848
additional_information = _IsModelCardObject(AdditionalInformation)
@@ -793,6 +859,7 @@ def __init__(
793859
last_modified_by: Optional[dict] = None,
794860
model_overview: Optional[ModelOverview] = None,
795861
intended_uses: Optional[IntendedUses] = None,
862+
business_details: Optional[BusinessDetails] = None,
796863
training_details: Optional[TrainingDetails] = None,
797864
evaluation_details: Optional[List[EvaluationJob]] = None,
798865
additional_information: Optional[AdditionalInformation] = None,
@@ -811,6 +878,7 @@ def __init__(
811878
last_modified_by (dict, optional): The group or individual that last modified the model card (default: None).
812879
model_overview (ModelOverview, optional): An overview of the model (default: None).
813880
intended_uses (IntendedUses, optional): The intended uses of the model (default: None).
881+
business_details (BusinessDetails, optional): The business details of the model (default: None).
814882
training_details (TrainingDetails, optional): The training details of the model (default: None).
815883
evaluation_details (List[EvaluationJob], optional): The evaluation details of the model (default: None).
816884
additional_information (AdditionalInformation, optional): Additional information about the model (default: None).
@@ -826,6 +894,7 @@ def __init__(
826894
self.last_modified_by = last_modified_by
827895
self.model_overview = model_overview
828896
self.intended_uses = intended_uses
897+
self.business_details = business_details
829898
self.training_details = training_details
830899
self.evaluation_details = evaluation_details
831900
self.additional_information = additional_information

src/sagemaker/model_card/schema_constraints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,6 @@ class MetricTypeEnum(str, Enum):
8484
TRAINING_DATASETS_MAX_SIZE = 15
8585
TRAINING_METRICS_MAX_SIZE = 50
8686
USER_PROVIDED_TRAINING_METRICS_MAX_SIZE = 50
87+
HYPER_PARAMETERS_MAX_SIZE = 100
88+
USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE = 100
8789
EVALUATION_DATASETS_MAX_SIZE = 10

tests/unit/test_model_card.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
Environment,
2828
ModelOverview,
2929
IntendedUses,
30+
BusinessDetails,
3031
ObjectiveFunction,
3132
TrainingMetric,
33+
HyperParameter,
3234
Metric,
3335
TrainingDetails,
3436
MetricGroup,
@@ -75,6 +77,11 @@
7577
RISK_RATING = schema_constraints.RiskRatingEnum.LOW
7678
EXPLANATIONS_FOR_RISK_RATING = "ramdomly the first example"
7779

80+
# business details auguments
81+
BUSINESS_PROBLEM = "mock model for business problem testing"
82+
BUSINESS_STAKEHOLDERS = "business stakeholders testing"
83+
LINE_OF_BUSINESS = "how many business models"
84+
7885
# training details arguments
7986
OBJECITVE_FUNCTION_FUNC = schema_constraints.ObjectiveFunctionEnum.MINIMIZE
8087
OBJECTIVE_FUNCTION_FACET = schema_constraints.FacetEnum.LOSS
@@ -89,6 +96,10 @@
8996
USER_METRIC_NAME = "test_metric"
9097
USER_METRIC = TrainingMetric(name=USER_METRIC_NAME, value=1)
9198
USER_PROVIDED_TRAINING_METRICS = [USER_METRIC]
99+
HYPER_PARAMETER = [HyperParameter(name="binary_f_beta", value=0.965)]
100+
USER_PARAMETER_NAME = "test_parameter"
101+
USER_PARAMETER = HyperParameter(name=USER_PARAMETER_NAME, value=1)
102+
USER_PROVIDED_HYPER_PARAMETER = [USER_PARAMETER]
92103

93104
# evaluation job arguments
94105
EVALUATION_JOB_NAME = "evaluation job 1"
@@ -350,6 +361,22 @@
350361
"Timestamp": datetime.datetime(2022, 9, 5, 19, 18, 40),
351362
},
352363
],
364+
"HyperParameters": {
365+
"_kfold": "5",
366+
"_tuning_objective_metric": "validation:accuracy",
367+
"alpha": "0.0037170512924477993",
368+
"colsample_bytree": "0.7476726040667319",
369+
"eta": "0.011391935592233605",
370+
"eval_metric": "accuracy,f1,balanced_accuracy,precision_macro,recall_macro,mlogloss",
371+
"gamma": "1.8903517751689445",
372+
"lambda": "0.5098604662224621",
373+
"max_depth": "3",
374+
"min_child_weight": "5.081388147234708e-06",
375+
"num_class": "28",
376+
"num_round": "165",
377+
"objective": "multi:softprob",
378+
"subsample": "0.8828549481113146",
379+
},
353380
"CreatedBy": {},
354381
}
355382
}
@@ -583,6 +610,17 @@ def fixture_fixture_intended_uses_example():
583610
return test_example
584611

585612

613+
@pytest.fixture(name="business_details_example")
614+
def fixture_fixture_business_details_example():
615+
"""Example business details instance."""
616+
test_example = BusinessDetails(
617+
business_problem=BUSINESS_PROBLEM,
618+
business_stakeholders=BUSINESS_STAKEHOLDERS,
619+
line_of_business=LINE_OF_BUSINESS,
620+
)
621+
return test_example
622+
623+
586624
@pytest.fixture(name="training_details_example")
587625
def fixture_fixture_training_details_example():
588626
"""Example training details instance."""
@@ -601,6 +639,7 @@ def fixture_fixture_training_details_example():
601639
training_datasets=TRAINING_DATASETS,
602640
training_environment=TRAINING_ENVIRONMENT,
603641
training_metrics=TRAINING_METRICS,
642+
hyper_parameters=HYPER_PARAMETER,
604643
),
605644
)
606645
return test_example
@@ -637,6 +676,7 @@ def test_create_model_card(
637676
session,
638677
model_overview_example,
639678
intended_uses_example,
679+
business_details_example,
640680
training_details_example,
641681
evaluation_details_example,
642682
additional_information_example,
@@ -649,6 +689,7 @@ def test_create_model_card(
649689
status=MODEL_CARD_STATUS,
650690
model_overview=model_overview_example,
651691
intended_uses=intended_uses_example,
692+
business_details=business_details_example,
652693
training_details=training_details_example,
653694
evaluation_details=evaluation_details_example,
654695
additional_information=additional_information_example,
@@ -1017,6 +1058,9 @@ def test_training_details_autodiscovery_from_model_overview(
10171058
assert len(training_details.training_job_details.training_metrics) == len(
10181059
SEARCH_TRAINING_JOB_EXAMPLE["Results"][0]["TrainingJob"]["FinalMetricDataList"]
10191060
)
1061+
assert len(training_details.training_job_details.hyper_parameters) == len(
1062+
SEARCH_TRAINING_JOB_EXAMPLE["Results"][0]["TrainingJob"]["HyperParameters"]
1063+
)
10201064
assert training_details.training_job_details.training_environment.container_image == [
10211065
TRAINING_IMAGE
10221066
]
@@ -1046,7 +1090,10 @@ def test_training_details_autodiscovery_from_model_overview_autopilot(
10461090
model_overview=model_overview_example, sagemaker_session=session
10471091
)
10481092

1093+
# MetricDefinitions is empty
10491094
assert len(training_details.training_job_details.training_metrics) == 0
1095+
# HyperParameters have 3 keys
1096+
assert len(training_details.training_job_details.hyper_parameters) == 3
10501097

10511098

10521099
@patch("sagemaker.Session")
@@ -1063,6 +1110,9 @@ def test_training_details_autodiscovery_from_job_name(session):
10631110
assert len(training_details.training_job_details.training_metrics) == len(
10641111
SEARCH_TRAINING_JOB_EXAMPLE["Results"][0]["TrainingJob"]["FinalMetricDataList"]
10651112
)
1113+
assert len(training_details.training_job_details.hyper_parameters) == len(
1114+
SEARCH_TRAINING_JOB_EXAMPLE["Results"][0]["TrainingJob"]["HyperParameters"]
1115+
)
10661116
assert training_details.training_job_details.training_environment.container_image == [
10671117
TRAINING_IMAGE
10681118
]
@@ -1091,6 +1141,16 @@ def test_add_user_provided_training_metrics(training_details_example):
10911141
)
10921142

10931143

1144+
def test_add_user_provided_hyper_parameters(training_details_example):
1145+
assert len(training_details_example.training_job_details.user_provided_hyper_parameters) == 0
1146+
training_details_example.add_parameter(USER_PARAMETER)
1147+
assert len(training_details_example.training_job_details.user_provided_hyper_parameters) == 1
1148+
assert (
1149+
training_details_example.training_job_details.user_provided_hyper_parameters[0].name
1150+
== USER_PARAMETER_NAME
1151+
)
1152+
1153+
10941154
def test_add_evaluation_metrics_manually():
10951155
evaluation_job = EvaluationJob(name=EVALUATION_JOB_NAME)
10961156

0 commit comments

Comments
 (0)