Skip to content

Commit 3280e1a

Browse files
authored
Merge branch 'master' into feat/py-typed
2 parents 453e53e + 2be822c commit 3280e1a

File tree

16 files changed

+184
-37
lines changed

16 files changed

+184
-37
lines changed

doc/overview.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1958,15 +1958,15 @@ Make sure to have a Compose Version compatible with your Docker Engine installat
19581958
Local mode configuration
19591959
========================
19601960

1961-
The local mode uses a YAML configuration file located at ``~/.sagemaker/config.yaml`` to define the default values that are automatically passed to the ``config`` attribute of ``LocalSession``. This is an example of the configuration, for the full schema, see `sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA <https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/config/config_schema.py>`_.
1961+
The local mode uses a YAML configuration file located at ``${user_config_directory}/sagemaker/config.yaml`` to define the default values that are automatically passed to the ``config`` attribute of ``LocalSession``. This is an example of the configuration, for the full schema, see `sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA <https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/config/config_schema.py>`_.
19621962

19631963
.. code:: yaml
19641964
19651965
local:
19661966
local_code: true # Using everything locally
19671967
region_name: "us-west-2" # Name of the region
19681968
container_config: # Additional docker container config
1969-
shm_size: "128M
1969+
shm_size: "128M"
19701970
19711971
If you want to keep everything local, and not use Amazon S3 either, you can enable "local code" in one of two ways:
19721972

src/sagemaker/estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,8 @@ def __init__(
387387
source_dir (str or PipelineVariable): The absolute, relative, or S3 URI Path to
388388
a directory with any other training source code dependencies aside from the entry
389389
point file (default: None). If ``source_dir`` is an S3 URI, it must
390-
point to a tar.gz file. The structure within this directory is preserved
391-
when training on Amazon SageMaker. If 'git_config' is provided,
390+
point to a file with name ``sourcedir.tar.gz``. The structure within this directory
391+
is preserved when training on Amazon SageMaker. If 'git_config' is provided,
392392
'source_dir' should be a relative location to a directory in the Git
393393
repo.
394394
With the following GitHub repo directory structure:
@@ -3421,8 +3421,8 @@ def __init__(
34213421
source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI)
34223422
to a directory with any other training source code dependencies aside from
34233423
the entry point file (default: None). If ``source_dir`` is an S3 URI, it must
3424-
point to a tar.gz file. Structure within this directory are preserved
3425-
when training on Amazon SageMaker. If 'git_config' is provided,
3424+
point to a file with name ``sourcedir.tar.gz``. Structure within this directory
3425+
are preserved when training on Amazon SageMaker. If 'git_config' is provided,
34263426
'source_dir' should be a relative location to a directory in the Git
34273427
repo.
34283428

src/sagemaker/fw_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def validate_source_code_input_against_pipeline_variables(
252252
logger.warning(
253253
"The source_dir is a pipeline variable: %s. During pipeline execution, "
254254
"the interpreted value of source_dir has to be an S3 URI and "
255-
"must point to a tar.gz file",
255+
"must point to a file with name ``sourcedir.tar.gz``",
256256
type(source_dir),
257257
)
258258

src/sagemaker/huggingface/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def __init__(
8484
source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a
8585
directory with any other training source code dependencies aside from the entry
8686
point file (default: None). If ``source_dir`` is an S3 URI, it must
87-
point to a tar.gz file. Structure within this directory are preserved
88-
when training on Amazon SageMaker.
87+
point to a file with name ``sourcedir.tar.gz``. Structure within this directory are
88+
preserved when training on Amazon SageMaker.
8989
hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters
9090
that will be used for training (default: None). The hyperparameters are made
9191
accessible as a dict[str, str] to the training code on

src/sagemaker/jumpstart/cache.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def _model_id_retrieval_function(
262262
return JumpStartVersionedModelId(model_id, sm_compatible_model_version)
263263

264264
versions_incompatible_with_sagemaker = [
265-
Version(header.version)
265+
header.version
266266
for header in manifest.values() # type: ignore
267267
if header.model_id == model_id
268268
]
@@ -540,9 +540,7 @@ def _select_version(
540540
"""
541541

542542
if version_str == "*":
543-
if len(available_versions) == 0:
544-
return None
545-
return str(max(available_versions))
543+
return utils.get_latest_version(available_versions)
546544

547545
if model_type == JumpStartModelType.PROPRIETARY:
548546
if "*" in version_str:

src/sagemaker/jumpstart/estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,8 @@ def __init__(
350350
source_dir (Optional[Union[str, PipelineVariable]]): The absolute, relative, or
351351
S3 URI Path to a directory with any other training source code dependencies
352352
aside from the entry point file. If ``source_dir`` is an S3 URI, it must
353-
point to a tar.gz file. Structure within this directory is preserved
354-
when training on Amazon SageMaker. If 'git_config' is provided,
353+
point to a file with name ``sourcedir.tar.gz``. Structure within this directory
354+
is preserved when training on Amazon SageMaker. If 'git_config' is provided,
355355
'source_dir' should be a relative location to a directory in the Git
356356
repo.
357357
(Default: None).
@@ -947,8 +947,8 @@ def deploy(
947947
source_dir (Optional[str]): The absolute, relative, or S3 URI Path to a directory
948948
with any other training source code dependencies aside from the entry
949949
point file (Default: None). If ``source_dir`` is an S3 URI, it must
950-
point to a tar.gz file. Structure within this directory is preserved
951-
when training on Amazon SageMaker. If 'git_config' is provided,
950+
point to a file with name ``sourcedir.tar.gz``. Structure within this directory is
951+
preserved when training on Amazon SageMaker. If 'git_config' is provided,
952952
'source_dir' should be a relative location to a directory in the Git repo.
953953
If the directory points to S3, no code is uploaded and the S3 location
954954
is used instead. (Default: None).

src/sagemaker/jumpstart/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ def __init__(
178178
source_dir (Optional[str]): The absolute, relative, or S3 URI Path to a directory
179179
with any other training source code dependencies aside from the entry
180180
point file (Default: None). If ``source_dir`` is an S3 URI, it must
181-
point to a tar.gz file. Structure within this directory is preserved
182-
when training on Amazon SageMaker. If 'git_config' is provided,
181+
point to a file with name ``sourcedir.tar.gz``. Structure within this directory is
182+
preserved when training on Amazon SageMaker. If 'git_config' is provided,
183183
'source_dir' should be a relative location to a directory in the Git repo.
184184
If the directory points to S3, no code is uploaded and the S3 location
185185
is used instead. (Default: None).

src/sagemaker/jumpstart/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from urllib.parse import urlparse
2222
import boto3
2323
from botocore.exceptions import ClientError
24-
from packaging.version import Version
24+
from packaging.version import Version, InvalidVersion
2525
import botocore
2626
from sagemaker_core.shapes import ModelAccessConfig
2727
import sagemaker
@@ -1630,3 +1630,11 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
16301630
return get_jumpstart_gated_content_bucket(region=region)
16311631
return get_jumpstart_content_bucket(region=region)
16321632
return neo_bucket
1633+
1634+
1635+
def get_latest_version(versions: List[str]) -> Optional[str]:
1636+
"""Returns the latest version using sem-ver when possible."""
1637+
try:
1638+
return None if not versions else max(versions, key=Version)
1639+
except InvalidVersion:
1640+
return max(versions)

src/sagemaker/model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ def __init__(
215215
source_dir (str): The absolute, relative, or S3 URI Path to a directory
216216
with any other training source code dependencies aside from the entry
217217
point file (default: None). If ``source_dir`` is an S3 URI, it must
218-
point to a tar.gz file. Structure within this directory is preserved
219-
when training on Amazon SageMaker. If 'git_config' is provided,
218+
point to a file with name ``sourcedir.tar.gz``. Structure within this directory
219+
is preserved when training on Amazon SageMaker. If 'git_config' is provided,
220220
'source_dir' should be a relative location to a directory in the Git repo.
221221
If the directory points to S3, no code is uploaded and the S3 location
222222
is used instead.
@@ -1996,11 +1996,11 @@ def __init__(
19961996
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
19971997
with any other training source code dependencies aside from the entry
19981998
point file (default: None). If ``source_dir`` is an S3 URI, it must
1999-
point to a tar.gz file. Structure within this directory are preserved
2000-
when training on Amazon SageMaker. If 'git_config' is provided,
2001-
'source_dir' should be a relative location to a directory in the Git repo.
2002-
If the directory points to S3, no code will be uploaded and the S3 location
2003-
will be used instead.
1999+
point to a file with name ``sourcedir.tar.gz``. Structure within this
2000+
directory are preserved when training on Amazon SageMaker. If 'git_config'
2001+
is provided, 'source_dir' should be a relative location to a directory in the
2002+
Git repo. If the directory points to S3, no code will be uploaded and the S3
2003+
location will be used instead.
20042004
20052005
.. admonition:: Example
20062006

src/sagemaker/mxnet/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def __init__(
8484
source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to
8585
a directory with any other training source code dependencies aside from the entry
8686
point file (default: None). If ``source_dir`` is an S3 URI, it must
87-
point to a tar.gz file. Structure within this directory are preserved
88-
when training on Amazon SageMaker.
87+
point to a file with name ``sourcedir.tar.gz``. Structure within this directory
88+
are preserved when training on Amazon SageMaker.
8989
hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters
9090
that will be used for training (default: None). The hyperparameters are made
9191
accessible as a dict[str, str] to the training code on

src/sagemaker/pytorch/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def __init__(
182182
unless ``image_uri`` is provided.
183183
source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to
184184
a directory with any other training source code dependencies aside from the entry
185-
point file (default: None). If ``source_dir`` is an S3 URI, it must
186-
point to a tar.gz file. Structure within this directory are preserved
185+
point file (default: None). If ``source_dir`` is an S3 URI, it must point to a
186+
file with name ``sourcedir.tar.gz``. Structure within this directory are preserved
187187
when training on Amazon SageMaker. Must be a local path when using training_recipe.
188188
hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters
189189
that will be used for training (default: None). The hyperparameters are made

src/sagemaker/rl/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ def __init__(
120120
source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI)
121121
to a directory with any other training source code dependencies aside from
122122
the entry point file (default: None). If ``source_dir`` is an S3 URI, it must
123-
point to a tar.gz file. Structure within this directory are preserved
124-
when training on Amazon SageMaker.
123+
point to a file with name ``sourcedir.tar.gz``. Structure within this directory
124+
are preserved when training on Amazon SageMaker.
125125
hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters
126126
that will be used for training (default: None). The hyperparameters are made
127127
accessible as a dict[str, str] to the training code on

src/sagemaker/sklearn/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def __init__(
8383
source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to
8484
a directory with any other training source code dependencies aside from the entry
8585
point file (default: None). If ``source_dir`` is an S3 URI, it must
86-
point to a tar.gz file. Structure within this directory are preserved
87-
when training on Amazon SageMaker.
86+
point to a file with name ``sourcedir.tar.gz``. Structure within this directory
87+
are preserved when training on Amazon SageMaker.
8888
hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters
8989
that will be used for training (default: None). The hyperparameters are made
9090
accessible as a dict[str, str] to the training code on

src/sagemaker/xgboost/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def __init__(
7878
source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to
7979
a directory with any other training source code dependencies aside from the entry
8080
point file (default: None). If ``source_dir`` is an S3 URI, it must
81-
point to a tar.gz file. Structure within this directory are preserved
82-
when training on Amazon SageMaker.
81+
point to a file with name ``sourcedir.tar.gz``. Structure within this directory
82+
are preserved when training on Amazon SageMaker.
8383
hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters
8484
that will be used for training (default: None).
8585
The hyperparameters are made accessible as a dict[str, str] to the training code

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from mock.mock import MagicMock
2323
import pytest
2424
from mock import patch
25+
from packaging.version import Version
2526

27+
28+
from sagemaker.jumpstart import utils
2629
from sagemaker.jumpstart.cache import (
2730
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
2831
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
@@ -33,6 +36,7 @@
3336
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
3437
)
3538
from sagemaker.jumpstart.types import (
39+
JumpStartCachedContentValue,
3640
JumpStartModelHeader,
3741
JumpStartModelSpecs,
3842
JumpStartVersionedModelId,
@@ -1119,3 +1123,124 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
11191123
),
11201124
]
11211125
)
1126+
1127+
1128+
@patch.object(JumpStartModelsCache, "_retrieval_function")
1129+
def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
1130+
retrieval_function: Mock,
1131+
):
1132+
sm_version = Version(utils.get_sagemaker_version())
1133+
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
1134+
print(str(new_sm_version))
1135+
versions = ["1.0.0", "2.9.1", "2.16.0"]
1136+
manifest = [
1137+
{
1138+
"model_id": "test-model",
1139+
"version": version,
1140+
"min_version": "2.49.0",
1141+
"spec_key": "spec_key",
1142+
}
1143+
for version in versions
1144+
]
1145+
1146+
manifest.append(
1147+
{
1148+
"model_id": "test-model",
1149+
"version": "3.0.0",
1150+
"min_version": str(new_sm_version),
1151+
"spec_key": "spec_key",
1152+
}
1153+
)
1154+
1155+
manifest_dict = {}
1156+
for header in manifest:
1157+
header_obj = JumpStartModelHeader(header)
1158+
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
1159+
header_obj
1160+
)
1161+
retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict)
1162+
key = JumpStartVersionedModelId("test-model", "*")
1163+
1164+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
1165+
result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None)
1166+
1167+
assert_key = JumpStartVersionedModelId("test-model", "2.16.0")
1168+
1169+
assert result == assert_key
1170+
1171+
1172+
@patch.object(JumpStartModelsCache, "_retrieval_function")
1173+
def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
1174+
retrieval_function: Mock,
1175+
):
1176+
sm_version = Version(utils.get_sagemaker_version())
1177+
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
1178+
print(str(new_sm_version))
1179+
versions = ["1.0.0", "2.9.1", "2.16.0"]
1180+
manifest = [
1181+
{
1182+
"model_id": "test-model",
1183+
"version": version,
1184+
"min_version": "2.49.0",
1185+
"spec_key": "spec_key",
1186+
}
1187+
for version in versions
1188+
]
1189+
1190+
manifest.append(
1191+
{
1192+
"model_id": "test-model",
1193+
"version": "3.0.0",
1194+
"min_version": str(new_sm_version),
1195+
"spec_key": "spec_key",
1196+
}
1197+
)
1198+
1199+
manifest_dict = {}
1200+
for header in manifest:
1201+
header_obj = JumpStartModelHeader(header)
1202+
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
1203+
header_obj
1204+
)
1205+
retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict)
1206+
key = JumpStartVersionedModelId("test-model", "*")
1207+
1208+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
1209+
result = cache._get_proprietary_manifest_key_from_model_id(key=key, value=None)
1210+
1211+
assert_key = JumpStartVersionedModelId("test-model", "2.16.0")
1212+
1213+
assert result == assert_key
1214+
1215+
1216+
@patch.object(JumpStartModelsCache, "_retrieval_function")
1217+
def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(retrieval_function: Mock):
1218+
sm_version = Version(utils.get_sagemaker_version())
1219+
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
1220+
print(str(new_sm_version))
1221+
versions = ["abc", "2.9.1", "2.16.0"]
1222+
manifest = [
1223+
{
1224+
"model_id": "test-model",
1225+
"version": version,
1226+
"min_version": "2.49.0",
1227+
"spec_key": "spec_key",
1228+
}
1229+
for version in versions
1230+
]
1231+
1232+
manifest_dict = {}
1233+
for header in manifest:
1234+
header_obj = JumpStartModelHeader(header)
1235+
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
1236+
header_obj
1237+
)
1238+
retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict)
1239+
key = JumpStartVersionedModelId("test-model", "*")
1240+
1241+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
1242+
result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None)
1243+
1244+
assert_key = JumpStartVersionedModelId("test-model", "abc")
1245+
1246+
assert result == assert_key

0 commit comments

Comments
 (0)