Skip to content

Commit ad7cd9e

Browse files
authored
Merge branch 'master' into processing-job-codeartifact-support
2 parents 01c1c40 + ddd06bb commit ad7cd9e

File tree

5 files changed

+223
-8
lines changed

5 files changed

+223
-8
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module stores JumpStart Model factory methods."""
1414
from __future__ import absolute_import
15+
import json
1516

1617

1718
from typing import Any, Dict, List, Optional, Union
@@ -206,9 +207,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
206207
def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
207208
"""Sets model data based on default or override, returns full kwargs."""
208209

209-
model_data = kwargs.model_data
210-
211-
kwargs.model_data = model_data or model_uris.retrieve(
210+
model_data: Union[str, dict] = kwargs.model_data or model_uris.retrieve(
212211
model_scope=JumpStartScriptScope.INFERENCE,
213212
model_id=kwargs.model_id,
214213
model_version=kwargs.model_version,
@@ -218,6 +217,25 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
218217
sagemaker_session=kwargs.sagemaker_session,
219218
)
220219

220+
if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"):
221+
old_model_data_str = model_data
222+
model_data = {
223+
"S3DataSource": {
224+
"S3Uri": model_data,
225+
"S3DataType": "S3Prefix",
226+
"CompressionType": "None",
227+
}
228+
}
229+
if kwargs.model_data:
230+
JUMPSTART_LOGGER.info(
231+
"S3 prefix model_data detected for JumpStartModel: '%s'. "
232+
"Converting to S3DataSource dictionary: '%s'.",
233+
old_model_data_str,
234+
json.dumps(model_data),
235+
)
236+
237+
kwargs.model_data = model_data
238+
221239
return kwargs
222240

223241

@@ -496,7 +514,7 @@ def get_init_kwargs(
496514
instance_type: Optional[str] = None,
497515
region: Optional[str] = None,
498516
image_uri: Optional[Union[str, PipelineVariable]] = None,
499-
model_data: Optional[Union[str, PipelineVariable]] = None,
517+
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
500518
role: Optional[str] = None,
501519
predictor_cls: Optional[callable] = None,
502520
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,

src/sagemaker/jumpstart/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
region: Optional[str] = None,
5454
instance_type: Optional[str] = None,
5555
image_uri: Optional[Union[str, PipelineVariable]] = None,
56-
model_data: Optional[Union[str, PipelineVariable]] = None,
56+
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
5757
role: Optional[str] = None,
5858
predictor_cls: Optional[callable] = None,
5959
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
@@ -95,8 +95,8 @@ def __init__(
9595
instance_type (Optional[str]): The EC2 instance type to use when provisioning a hosting
9696
endpoint. (Default: None).
9797
image_uri (Optional[Union[str, PipelineVariable]]): A Docker image URI. (Default: None).
98-
model_data (Optional[Union[str, PipelineVariable]]): The S3 location of a SageMaker
99-
model data ``.tar.gz`` file. (Default: None).
98+
model_data (Optional[Union[str, PipelineVariable, dict]]): Location
99+
of SageMaker model data. (Default: None).
100100
role (Optional[str]): An AWS IAM role (either name or full ARN). The Amazon
101101
SageMaker training jobs and APIs that create Amazon SageMaker
102102
endpoints use this role to access training data and model

src/sagemaker/jumpstart/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def __init__(
752752
region: Optional[str] = None,
753753
instance_type: Optional[str] = None,
754754
image_uri: Optional[Union[str, Any]] = None,
755-
model_data: Optional[Union[str, Any]] = None,
755+
model_data: Optional[Union[str, Any, dict]] = None,
756756
role: Optional[str] = None,
757757
predictor_cls: Optional[callable] = None,
758758
env: Optional[Dict[str, Union[str, Any]]] = None,

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1708,6 +1708,93 @@
17081708
"default_accept_type": "application/json",
17091709
},
17101710
},
1711+
"model_data_s3_prefix_model": {
1712+
"model_id": "huggingface-text2text-flan-t5-xxl-fp16",
1713+
"url": "https://huggingface.co/google/flan-t5-xxl",
1714+
"version": "1.0.1",
1715+
"min_sdk_version": "2.130.0",
1716+
"training_supported": False,
1717+
"incremental_training_supported": False,
1718+
"hosting_ecr_specs": {
1719+
"framework": "pytorch",
1720+
"framework_version": "1.12.0",
1721+
"py_version": "py38",
1722+
"huggingface_transformers_version": "4.17.0",
1723+
},
1724+
"hosting_artifact_key": "huggingface-infer/",
1725+
"hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.3/sourcedir.tar.gz",
1726+
"hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/",
1727+
"hosting_prepacked_artifact_version": "1.0.1",
1728+
"inference_vulnerable": False,
1729+
"inference_dependencies": [
1730+
"accelerate==0.16.0",
1731+
"bitsandbytes==0.37.0",
1732+
"filelock==3.9.0",
1733+
"huggingface_hub==0.12.0",
1734+
"regex==2022.7.9",
1735+
"tokenizers==0.13.2",
1736+
"transformers==4.26.0",
1737+
],
1738+
"inference_vulnerabilities": [],
1739+
"training_vulnerable": False,
1740+
"training_dependencies": [],
1741+
"training_vulnerabilities": [],
1742+
"deprecated": False,
1743+
"inference_environment_variables": [
1744+
{
1745+
"name": "SAGEMAKER_PROGRAM",
1746+
"type": "text",
1747+
"default": "inference.py",
1748+
"scope": "container",
1749+
},
1750+
{
1751+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
1752+
"type": "text",
1753+
"default": "/opt/ml/model/code",
1754+
"scope": "container",
1755+
},
1756+
{
1757+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
1758+
"type": "text",
1759+
"default": "20",
1760+
"scope": "container",
1761+
},
1762+
{
1763+
"name": "MODEL_CACHE_ROOT",
1764+
"type": "text",
1765+
"default": "/opt/ml/model",
1766+
"scope": "container",
1767+
},
1768+
{"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"},
1769+
{
1770+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
1771+
"type": "text",
1772+
"default": "1",
1773+
"scope": "container",
1774+
},
1775+
{
1776+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
1777+
"type": "text",
1778+
"default": "3600",
1779+
"scope": "container",
1780+
},
1781+
],
1782+
"metrics": [],
1783+
"default_inference_instance_type": "ml.g5.12xlarge",
1784+
"supported_inference_instance_types": [
1785+
"ml.g5.12xlarge",
1786+
"ml.g5.24xlarge",
1787+
"ml.p3.8xlarge",
1788+
"ml.p3.16xlarge",
1789+
"ml.g4dn.12xlarge",
1790+
],
1791+
"predictor_specs": {
1792+
"supported_content_types": ["application/x-text"],
1793+
"supported_accept_types": ["application/json;verbose", "application/json"],
1794+
"default_content_type": "application/x-text",
1795+
"default_accept_type": "application/json",
1796+
},
1797+
},
17111798
"no-supported-instance-types-model": {
17121799
"model_id": "pytorch-ic-mobilenet-v2",
17131800
"url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/",

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,116 @@ def test_jumpstart_model_package_arn_unsupported_region(
678678
"us-east-2. Please try one of the following regions: us-west-2, us-east-1."
679679
)
680680

681+
@mock.patch("sagemaker.utils.sagemaker_timestamp")
682+
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
683+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
684+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
685+
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
686+
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
687+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
688+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER.info")
689+
def test_model_data_s3_prefix_override(
690+
self,
691+
mock_js_info_logger: mock.Mock,
692+
mock_model_deploy: mock.Mock,
693+
mock_model_init: mock.Mock,
694+
mock_get_model_specs: mock.Mock,
695+
mock_session: mock.Mock,
696+
mock_is_valid_model_id: mock.Mock,
697+
mock_sagemaker_timestamp: mock.Mock,
698+
):
699+
mock_model_deploy.return_value = default_predictor
700+
701+
mock_sagemaker_timestamp.return_value = "7777"
702+
703+
mock_is_valid_model_id.return_value = True
704+
model_id, _ = "js-trainable-model", "*"
705+
706+
mock_get_model_specs.side_effect = get_special_model_spec
707+
708+
mock_session.return_value = sagemaker_session
709+
710+
JumpStartModel(model_id=model_id, model_data="s3://some-bucket/path/to/prefix/")
711+
712+
mock_model_init.assert_called_once_with(
713+
image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/"
714+
"autogluon-inference:0.4.3-gpu-py38",
715+
model_data={
716+
"S3DataSource": {
717+
"S3Uri": "s3://some-bucket/path/to/prefix/",
718+
"S3DataType": "S3Prefix",
719+
"CompressionType": "None",
720+
}
721+
},
722+
source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-"
723+
"tarballs/autogluon/inference/classification/v1.0.0/sourcedir.tar.gz",
724+
entry_point="inference.py",
725+
env={
726+
"SAGEMAKER_PROGRAM": "inference.py",
727+
"ENDPOINT_SERVER_TIMEOUT": "3600",
728+
"MODEL_CACHE_ROOT": "/opt/ml/model",
729+
"SAGEMAKER_ENV": "1",
730+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
731+
},
732+
predictor_cls=Predictor,
733+
role=execution_role,
734+
sagemaker_session=sagemaker_session,
735+
enable_network_isolation=False,
736+
name="blahblahblah-7777",
737+
)
738+
739+
mock_js_info_logger.assert_called_with(
740+
"S3 prefix model_data detected for JumpStartModel: '%s'. "
741+
"Converting to S3DataSource dictionary: '%s'.",
742+
"s3://some-bucket/path/to/prefix/",
743+
'{"S3DataSource": {"S3Uri": "s3://some-bucket/path/to/prefix/", '
744+
'"S3DataType": "S3Prefix", "CompressionType": "None"}}',
745+
)
746+
747+
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
748+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
749+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
750+
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
751+
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
752+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
753+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER.info")
754+
def test_model_data_s3_prefix_model(
755+
self,
756+
mock_js_info_logger: mock.Mock,
757+
mock_model_deploy: mock.Mock,
758+
mock_model_init: mock.Mock,
759+
mock_get_model_specs: mock.Mock,
760+
mock_session: mock.Mock,
761+
mock_is_valid_model_id: mock.Mock,
762+
):
763+
mock_model_deploy.return_value = default_predictor
764+
765+
mock_is_valid_model_id.return_value = True
766+
model_id, _ = "model_data_s3_prefix_model", "*"
767+
768+
mock_get_model_specs.side_effect = get_special_model_spec
769+
770+
mock_session.return_value = sagemaker_session
771+
772+
JumpStartModel(model_id=model_id, instance_type="ml.p2.xlarge")
773+
774+
mock_model_init.assert_called_once_with(
775+
image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12.0-gpu-py38",
776+
model_data={
777+
"S3DataSource": {
778+
"S3Uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/prepack/v1.0.1/",
779+
"S3DataType": "S3Prefix",
780+
"CompressionType": "None",
781+
}
782+
},
783+
predictor_cls=Predictor,
784+
role=execution_role,
785+
sagemaker_session=sagemaker_session,
786+
enable_network_isolation=False,
787+
)
788+
789+
mock_js_info_logger.assert_not_called()
790+
681791

682792
def test_jumpstart_model_requires_model_id():
683793
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)