Skip to content

Commit 7ec1160

Browse files
Merge branch 'master' into xgb-1.7-1_launch
2 parents 680d76d + 479610d commit 7ec1160

File tree

9 files changed

+175
-3
lines changed

9 files changed

+175
-3
lines changed

src/sagemaker/jumpstart/artifacts.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains functions for obtaining JumpStart ECR and S3 URIs."""
1414
from __future__ import absolute_import
15+
from copy import deepcopy
1516
import os
16-
from typing import Dict, Optional
17+
from typing import Dict, List, Optional
1718
from sagemaker import image_uris
1819
from sagemaker.jumpstart.constants import (
1920
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE,
@@ -363,3 +364,32 @@ def _retrieve_default_environment_variables(
363364
for environment_variable in model_specs.inference_environment_variables:
364365
default_environment_variables[environment_variable.name] = str(environment_variable.default)
365366
return default_environment_variables
367+
368+
369+
def _retrieve_default_training_metric_definitions(
370+
model_id: str,
371+
model_version: str,
372+
region: Optional[str],
373+
) -> Optional[List[Dict[str, str]]]:
374+
"""Retrieves the default training metric definitions for the model.
375+
376+
Args:
377+
model_id (str): JumpStart model ID of the JumpStart model for which to
378+
retrieve the default training metric definitions.
379+
model_version (str): Version of the JumpStart model for which to retrieve the
380+
default training metric definitions.
381+
region (Optional[str]): Region for which to retrieve default training metric
382+
definitions.
383+
384+
Returns:
385+
list: the default training metric definitions to use for the model or None.
386+
"""
387+
388+
if region is None:
389+
region = JUMPSTART_DEFAULT_REGION_NAME
390+
391+
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
392+
region=region, model_id=model_id, version=model_version
393+
)
394+
395+
return deepcopy(model_specs.metrics) if model_specs.metrics else None

src/sagemaker/jumpstart/types.py

+2
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
292292
"training_dependencies",
293293
"training_vulnerabilities",
294294
"deprecated",
295+
"metrics",
295296
]
296297

297298
def __init__(self, spec: Dict[str, Any]):
@@ -328,6 +329,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
328329
self.training_dependencies: List[str] = json_obj["training_dependencies"]
329330
self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"]
330331
self.deprecated: bool = bool(json_obj["deprecated"])
332+
self.metrics: Optional[List[Dict[str, str]]] = json_obj.get("metrics", None)
331333

332334
if self.training_supported:
333335
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(

src/sagemaker/metric_definitions.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Accessors to retrieve metric definition for training jobs."""
14+
15+
from __future__ import absolute_import
16+
17+
import logging
18+
from typing import Dict, Optional, List
19+
20+
from sagemaker.jumpstart import utils as jumpstart_utils
21+
from sagemaker.jumpstart import artifacts
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
def retrieve_default(
27+
region: Optional[str] = None,
28+
model_id: Optional[str] = None,
29+
model_version: Optional[str] = None,
30+
) -> Optional[List[Dict[str, str]]]:
31+
"""Retrieves the default training metric definitions for the model matching the given arguments.
32+
33+
Args:
34+
region (str): The AWS Region for which to retrieve the default default training metric
35+
definitions. Defaults to ``None``.
36+
model_id (str): The model ID of the model for which to
37+
retrieve the default training metric definitions. (Default: None).
38+
model_version (str): The version of the model for which to retrieve the
39+
default training metric definitions. (Default: None).
40+
Returns:
41+
list: The default metric definitions to use for the model or None.
42+
43+
Raises:
44+
ValueError: If the combination of arguments specified is not supported.
45+
"""
46+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
47+
raise ValueError(
48+
"Must specify `model_id` and `model_version` when retrieving default training "
49+
"metric definitions."
50+
)
51+
52+
return artifacts._retrieve_default_training_metric_definitions(model_id, model_version, region)

tests/integ/sagemaker/jumpstart/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
4242

4343
TRAINING_DATASET_MODEL_DICT = {
4444
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),
45+
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
4546
}
4647

4748

tests/integ/sagemaker/jumpstart/script_mode_class/test_transfer_learning.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from __future__ import absolute_import
1414
import os
1515

16-
from sagemaker import hyperparameters, image_uris, model_uris, script_uris
16+
from sagemaker import hyperparameters, metric_definitions, image_uris, model_uris, script_uris
1717
from sagemaker.estimator import Estimator
1818
from sagemaker.jumpstart.constants import (
1919
INFERENCE_ENTRY_POINT_SCRIPT_NAME,
@@ -35,7 +35,7 @@
3535

3636
def test_jumpstart_transfer_learning_estimator_class(setup):
3737

38-
model_id, model_version = "huggingface-spc-bert-base-cased", "1.0.0"
38+
model_id, model_version = "huggingface-spc-bert-base-cased", "1.2.3"
3939
training_instance_type = "ml.p3.2xlarge"
4040
inference_instance_type = "ml.p2.xlarge"
4141
instance_count = 1
@@ -66,6 +66,11 @@ def test_jumpstart_transfer_learning_estimator_class(setup):
6666

6767
default_hyperparameters["epochs"] = "1"
6868

69+
default_metric_definitions = metric_definitions.retrieve_default(
70+
model_id=model_id,
71+
model_version=model_version,
72+
)
73+
6974
estimator = Estimator(
7075
image_uri=image_uri,
7176
source_dir=script_uri,
@@ -78,6 +83,7 @@ def test_jumpstart_transfer_learning_estimator_class(setup):
7883
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
7984
instance_count=instance_count,
8085
instance_type=training_instance_type,
86+
metric_definitions=default_metric_definitions,
8187
)
8288

8389
estimator.fit(

tests/unit/sagemaker/jumpstart/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,7 @@
11831183
"training_dependencies": [],
11841184
"training_vulnerabilities": [],
11851185
"deprecated": False,
1186+
"metrics": [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}],
11861187
}
11871188

11881189
BASE_HEADER = {

tests/unit/sagemaker/metric_definitions/__init__.py

Whitespace-only changes.

tests/unit/sagemaker/metric_definitions/jumpstart/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
16+
from mock.mock import patch
17+
import pytest
18+
19+
from sagemaker import metric_definitions
20+
21+
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
22+
23+
24+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
25+
def test_jumpstart_default_metric_definitions(patched_get_model_specs):
26+
27+
patched_get_model_specs.side_effect = get_spec_from_base_spec
28+
29+
model_id = "pytorch-ic-mobilenet-v2"
30+
region = "us-west-2"
31+
32+
definitions = metric_definitions.retrieve_default(
33+
region=region,
34+
model_id=model_id,
35+
model_version="*",
36+
)
37+
assert definitions == [
38+
{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}
39+
]
40+
41+
patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="*")
42+
43+
patched_get_model_specs.reset_mock()
44+
45+
definitions = metric_definitions.retrieve_default(
46+
region=region,
47+
model_id=model_id,
48+
model_version="1.*",
49+
)
50+
assert definitions == [
51+
{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}
52+
]
53+
54+
patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="1.*")
55+
56+
patched_get_model_specs.reset_mock()
57+
58+
with pytest.raises(KeyError):
59+
metric_definitions.retrieve_default(
60+
region=region,
61+
model_id="blah",
62+
model_version="*",
63+
)
64+
65+
with pytest.raises(ValueError):
66+
metric_definitions.retrieve_default(
67+
region="mars-south-1",
68+
model_id=model_id,
69+
model_version="*",
70+
)
71+
72+
with pytest.raises(ValueError):
73+
metric_definitions.retrieve_default(
74+
model_version="*",
75+
)
76+
77+
with pytest.raises(ValueError):
78+
metric_definitions.retrieve_default(
79+
model_id=model_id,
80+
)

0 commit comments

Comments
 (0)