Skip to content

Commit 793fa04

Browse files
authored
Merge branch 'aws:master' into master
2 parents c8d69c8 + a538a1c commit 793fa04

File tree

13 files changed

+206
-89
lines changed

13 files changed

+206
-89
lines changed

src/sagemaker/fw_utils.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,37 +10,37 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Utility methods used by framework classes"""
13+
"""Utility methods used by framework classes."""
1414
from __future__ import absolute_import
1515

1616
import json
1717
import logging
1818
import os
1919
import re
20-
import time
2120
import shutil
2221
import tempfile
22+
import time
2323
from collections import namedtuple
24-
from typing import List, Optional, Union, Dict
24+
from typing import Dict, List, Optional, Union
25+
2526
from packaging import version
2627

2728
import sagemaker.image_uris
29+
import sagemaker.utils
30+
from sagemaker.deprecations import deprecation_warn_base, renamed_kwargs, renamed_warning
2831
from sagemaker.instance_group import InstanceGroup
2932
from sagemaker.s3_utils import s3_path_join
3033
from sagemaker.session_settings import SessionSettings
31-
import sagemaker.utils
3234
from sagemaker.workflow import is_pipeline_variable
33-
34-
from sagemaker.deprecations import renamed_warning, renamed_kwargs
3535
from sagemaker.workflow.entities import PipelineVariable
36-
from sagemaker.deprecations import deprecation_warn_base
3736

3837
logger = logging.getLogger(__name__)
3938

4039
_TAR_SOURCE_FILENAME = "source.tar.gz"
4140

4241
UploadedCode = namedtuple("UploadedCode", ["s3_prefix", "script_name"])
4342
"""sagemaker.fw_utils.UploadedCode: An object containing the S3 prefix and script name.
43+
4444
This is for the source code used for the entry point with an ``Estimator``. It can be
4545
instantiated with positional or keyword arguments.
4646
"""
@@ -211,7 +211,7 @@ def validate_source_code_input_against_pipeline_variables(
211211
git_config: Optional[Dict[str, str]] = None,
212212
enable_network_isolation: Union[bool, PipelineVariable] = False,
213213
):
214-
"""Validate source code input against pipeline variables
214+
"""Validate source code input against pipeline variables.
215215
216216
Args:
217217
entry_point (str or PipelineVariable): The path to the local Python source file that
@@ -481,7 +481,7 @@ def tar_and_upload_dir(
481481

482482

483483
def _list_files_to_compress(script, directory):
484-
"""Placeholder docstring"""
484+
"""Placeholder docstring."""
485485
if directory is None:
486486
return [script]
487487

@@ -585,7 +585,6 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image):
585585
The location returned is a potential concatenation of 2 parts
586586
1. code_location_key_prefix if it exists
587587
2. model_name or a name derived from the image
588-
589588
Args:
590589
code_location_key_prefix (str): the s3 key prefix from code_location
591590
model_name (str): the name of the model
@@ -620,8 +619,6 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
620619
"enabled": True
621620
}
622621
}
623-
624-
625622
"""
626623
if training_instance_type == "local" or distribution is None:
627624
return
@@ -646,7 +643,7 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
646643
def profiler_config_deprecation_warning(
647644
profiler_config, image_uri, framework_name, framework_version
648645
):
649-
"""Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >= 2.0"""
646+
"""Deprecation message if framework profiling is specified TF >= 2.12 and PT >= 2.0."""
650647
if profiler_config is None or profiler_config.framework_profile_params is None:
651648
return
652649

@@ -692,6 +689,7 @@ def validate_smdistributed(
692689
framework_name (str): A string representing the name of framework selected.
693690
framework_version (str): A string representing the framework version selected.
694691
py_version (str): A string representing the python version selected.
692+
Ex: `py38, py39, py310, py311`
695693
distribution (dict): A dictionary with information to enable distributed training.
696694
(Defaults to None if distributed training is not enabled.) For example:
697695
@@ -763,7 +761,8 @@ def _validate_smdataparallel_args(
763761
instance_type (str): A string representing the type of training instance selected. Ex: `ml.p3.16xlarge`
764762
framework_name (str): A string representing the name of framework selected. Ex: `tensorflow`
765763
framework_version (str): A string representing the framework version selected. Ex: `2.3.1`
766-
py_version (str): A string representing the python version selected. Ex: `py3`
764+
py_version (str): A string representing the python version selected.
765+
Ex: `py38, py39, py310, py311`
767766
distribution (dict): A dictionary with information to enable distributed training.
768767
(Defaults to None if distributed training is not enabled.) Ex:
769768
@@ -847,6 +846,7 @@ def validate_distribution(
847846
framework_name (str): A string representing the name of framework selected.
848847
framework_version (str): A string representing the framework version selected.
849848
py_version (str): A string representing the python version selected.
849+
Ex: `py38, py39, py310, py311`
850850
image_uri (str): A string representing a Docker image URI.
851851
kwargs(dict): Additional kwargs passed to this function
852852
@@ -953,7 +953,7 @@ def validate_distribution(
953953

954954

955955
def validate_distribution_for_instance_type(instance_type, distribution):
956-
"""Check if the provided distribution strategy is supported for the instance_type
956+
"""Check if the provided distribution strategy is supported for the instance_type.
957957
958958
Args:
959959
instance_type (str): A string representing the type of training instance selected.
@@ -1010,6 +1010,7 @@ def validate_torch_distributed_distribution(
10101010
}
10111011
framework_version (str): A string representing the framework version selected.
10121012
py_version (str): A string representing the python version selected.
1013+
Ex: `py38, py39, py310, py311`
10131014
image_uri (str): A string representing a Docker image URI.
10141015
entry_point (str or PipelineVariable): The absolute or relative path to the local Python
10151016
source file that should be executed as the entry point to
@@ -1072,7 +1073,7 @@ def validate_torch_distributed_distribution(
10721073

10731074

10741075
def _is_gpu_instance(instance_type):
1075-
"""Returns bool indicating whether instance_type supports GPU
1076+
"""Returns bool indicating whether instance_type supports GPU.
10761077
10771078
Args:
10781079
instance_type (str): Name of the instance_type to check against.
@@ -1091,7 +1092,7 @@ def _is_gpu_instance(instance_type):
10911092

10921093

10931094
def _is_trainium_instance(instance_type):
1094-
"""Returns bool indicating whether instance_type is a Trainium instance
1095+
"""Returns bool indicating whether instance_type is a Trainium instance.
10951096
10961097
Args:
10971098
instance_type (str): Name of the instance_type to check against.
@@ -1107,7 +1108,7 @@ def _is_trainium_instance(instance_type):
11071108

11081109

11091110
def python_deprecation_warning(framework, latest_supported_version):
1110-
"""Placeholder docstring"""
1111+
"""Placeholder docstring."""
11111112
return PYTHON_2_DEPRECATION_WARNING.format(
11121113
framework=framework, latest_supported_version=latest_supported_version
11131114
)
@@ -1121,7 +1122,6 @@ def _region_supports_debugger(region_name):
11211122
11221123
Returns:
11231124
bool: Whether or not the region supports Amazon SageMaker Debugger.
1124-
11251125
"""
11261126
return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS
11271127

@@ -1134,7 +1134,6 @@ def _region_supports_profiler(region_name):
11341134
11351135
Returns:
11361136
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
1137-
11381137
"""
11391138
return region_name.lower() not in PROFILER_UNSUPPORTED_REGIONS
11401139

@@ -1162,7 +1161,8 @@ def validate_version_or_image_args(framework_version, py_version, image_uri):
11621161
11631162
Args:
11641163
framework_version (str): The version of the framework.
1165-
py_version (str): The version of Python.
1164+
py_version (str): A string representing the python version selected.
1165+
Ex: `py38, py39, py310, py311`
11661166
image_uri (str): The URI of the image.
11671167
11681168
Raises:
@@ -1194,9 +1194,8 @@ def create_image_uri(
11941194
instance_type (str): SageMaker instance type. Used to determine device
11951195
type (cpu/gpu/family-specific optimized).
11961196
framework_version (str): The version of the framework.
1197-
py_version (str): Optional. Python version. If specified, should be one
1198-
of 'py2' or 'py3'. If not specified, image uri will not include a
1199-
python component.
1197+
py_version (str): Optional. Python version Ex: `py38, py39, py310, py311`.
1198+
If not specified, image uri will not include a python component.
12001199
account (str): AWS account that contains the image. (default:
12011200
'520713654638')
12021201
accelerator_type (str): SageMaker Elastic Inference accelerator type.

src/sagemaker/huggingface/estimator.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,13 @@
1515

1616
import logging
1717
import re
18-
from typing import Optional, Union, Dict
18+
from typing import Dict, Optional, Union
1919

20-
from sagemaker.estimator import Framework, EstimatorBase
21-
from sagemaker.fw_utils import (
22-
framework_name_from_image,
23-
validate_distribution,
24-
)
20+
from sagemaker.estimator import EstimatorBase, Framework
21+
from sagemaker.fw_utils import framework_name_from_image, validate_distribution
2522
from sagemaker.huggingface.model import HuggingFaceModel
26-
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
27-
2823
from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig
24+
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2925
from sagemaker.workflow.entities import PipelineVariable
3026

3127
logger = logging.getLogger("sagemaker")
@@ -66,7 +62,7 @@ def __init__(
6662
Args:
6763
py_version (str): Python version you want to use for executing your model training
6864
code. Defaults to ``None``. Required unless ``image_uri`` is provided. If
69-
using PyTorch, the current supported version is ``py36``. If using TensorFlow,
65+
using PyTorch, the current supported version is ``py39``. If using TensorFlow,
7066
the current supported version is ``py37``.
7167
entry_point (str or PipelineVariable): Path (absolute or relative) to the Python source
7268
file which should be executed as the entry point to training.

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,6 +1479,7 @@
14791479
"ap-southeast-3": "907027046896",
14801480
"ap-southeast-4": "457447274322",
14811481
"ap-southeast-5": "550225433462",
1482+
"ap-southeast-7": "590183813437",
14821483
"ca-central-1": "763104351884",
14831484
"ca-west-1": "204538143572",
14841485
"cn-north-1": "727897471807",
@@ -1494,6 +1495,7 @@
14941495
"il-central-1": "780543022126",
14951496
"me-central-1": "914824155844",
14961497
"me-south-1": "217643126080",
1498+
"mx-central-1": "637423239942",
14971499
"sa-east-1": "763104351884",
14981500
"us-east-1": "763104351884",
14991501
"us-east-2": "763104351884",

src/sagemaker/jumpstart/accessors.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def get_model_specs(
288288
)
289289
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
290290

291+
# Users only input model id, not contentType, so first try to describe with ModelReference, then with Model
291292
if hub_arn:
292293
try:
293294
hub_model_arn = construct_hub_model_reference_arn_from_inputs(
@@ -308,11 +309,22 @@ def get_model_specs(
308309
hub_model_arn = construct_hub_model_arn_from_inputs(
309310
hub_arn=hub_arn, model_name=model_id, version=version
310311
)
311-
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
312-
hub_model_arn=hub_model_arn
313-
)
314-
model_specs.set_hub_content_type(HubContentType.MODEL)
315-
return model_specs
312+
313+
# Failed to describe ModelReference, try with Model
314+
try:
315+
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
316+
hub_model_arn=hub_model_arn
317+
)
318+
model_specs.set_hub_content_type(HubContentType.MODEL)
319+
320+
return model_specs
321+
except Exception as ex:
322+
# Failed with both, throw a custom error message
323+
raise RuntimeError(
324+
f"Cannot get details for {model_id} in Hub {hub_arn}. \
325+
{model_id} does not exist as a Model or ModelReference: \n"
326+
+ str(ex)
327+
)
316328

317329
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
318330
model_id=model_id, version_str=version, model_type=model_type

src/sagemaker/jumpstart/hub/hub.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -272,18 +272,21 @@ def delete_model_reference(self, model_name: str) -> None:
272272
def describe_model(
273273
self, model_name: str, hub_name: Optional[str] = None, model_version: Optional[str] = None
274274
) -> DescribeHubContentResponse:
275-
"""Describe model in the SageMaker Hub."""
275+
"""Describe Model or ModelReference in a Hub."""
276+
hub_name = hub_name or self.hub_name
277+
278+
# Users only input model id, not contentType, so first try to describe with ModelReference, then with Model
276279
try:
277280
model_version = get_hub_model_version(
278281
hub_model_name=model_name,
279282
hub_model_type=HubContentType.MODEL_REFERENCE.value,
280-
hub_name=self.hub_name if not hub_name else hub_name,
283+
hub_name=hub_name,
281284
sagemaker_session=self._sagemaker_session,
282285
hub_model_version=model_version,
283286
)
284287

285288
hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
286-
hub_name=self.hub_name if not hub_name else hub_name,
289+
hub_name=hub_name,
287290
hub_content_name=model_name,
288291
hub_content_version=model_version,
289292
hub_content_type=HubContentType.MODEL_REFERENCE.value,
@@ -294,19 +297,32 @@ def describe_model(
294297
"Received exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: "
295298
+ str(ex)
296299
)
297-
model_version = get_hub_model_version(
298-
hub_model_name=model_name,
299-
hub_model_type=HubContentType.MODEL.value,
300-
hub_name=self.hub_name if not hub_name else hub_name,
301-
sagemaker_session=self._sagemaker_session,
302-
hub_model_version=model_version,
303-
)
304300

305-
hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
306-
hub_name=self.hub_name if not hub_name else hub_name,
307-
hub_content_name=model_name,
308-
hub_content_version=model_version,
309-
hub_content_type=HubContentType.MODEL.value,
310-
)
301+
# Failed to describe ModelReference, try with Model
302+
try:
303+
model_version = get_hub_model_version(
304+
hub_model_name=model_name,
305+
hub_model_type=HubContentType.MODEL.value,
306+
hub_name=hub_name,
307+
sagemaker_session=self._sagemaker_session,
308+
hub_model_version=model_version,
309+
)
310+
311+
hub_content_description: Dict[str, Any] = (
312+
self._sagemaker_session.describe_hub_content(
313+
hub_name=hub_name,
314+
hub_content_name=model_name,
315+
hub_content_version=model_version,
316+
hub_content_type=HubContentType.MODEL.value,
317+
)
318+
)
319+
320+
except Exception as ex:
321+
# Failed with both, throw a custom error message
322+
raise RuntimeError(
323+
f"Cannot get details for {model_name} in Hub {hub_name}. \
324+
{model_name} does not exist as a Model or ModelReference in {hub_name}: \n"
325+
+ str(ex)
326+
)
311327

312328
return DescribeHubContentResponse(hub_content_description)

src/sagemaker/jumpstart/types.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,9 +1363,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
13631363
self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {}))
13641364
self.predictor_specs: Optional[JumpStartPredictorSpecs] = (
13651365
JumpStartPredictorSpecs(
1366-
json_obj["predictor_specs"], is_hub_content=self._is_hub_content
1366+
json_obj.get("predictor_specs"),
1367+
is_hub_content=self._is_hub_content,
13671368
)
1368-
if "predictor_specs" in json_obj
1369+
if json_obj.get("predictor_specs")
13691370
else None
13701371
)
13711372
self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (
@@ -1501,6 +1502,9 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields):
15011502
"incremental_training_supported",
15021503
]
15031504

1505+
# Map of HubContent fields that map to custom names in MetadataBaseFields
1506+
CUSTOM_FIELD_MAP = {"sage_maker_sdk_predictor_specifications": "predictor_specs"}
1507+
15041508
__slots__ = slots + JumpStartMetadataBaseFields.__slots__
15051509

15061510
def __init__(
@@ -1532,6 +1536,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
15321536
if field in self.__slots__:
15331537
setattr(self, field, json_obj[field])
15341538

1539+
# Handle custom fields
1540+
for custom_field, field in self.CUSTOM_FIELD_MAP.items():
1541+
if custom_field in json_obj:
1542+
setattr(self, field, json_obj.get(custom_field))
1543+
15351544

15361545
class JumpStartMetadataConfig(JumpStartDataHolderType):
15371546
"""Data class of JumpStart metadata config."""

0 commit comments

Comments
 (0)