Skip to content

Commit 770d6a7

Browse files
evakravibenieric
authored andcommitted
chore: address PR comments
1 parent 37f8f6e commit 770d6a7

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module stores JumpStart Model factory methods."""
1414
from __future__ import absolute_import
15+
import json
1516

1617

1718
from typing import Any, Dict, List, Optional, Union
@@ -217,19 +218,21 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
217218
)
218219

219220
if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"):
220-
if kwargs.model_data:
221-
JUMPSTART_LOGGER.info(
222-
"S3 prefix model_data detected for JumpStartModel: '%s'. "
223-
"Converting to S3DataSource dictionary.",
224-
model_data,
225-
)
221+
old_model_data_str = model_data
226222
model_data = {
227223
"S3DataSource": {
228224
"S3Uri": model_data,
229225
"S3DataType": "S3Prefix",
230226
"CompressionType": "None",
231227
}
232228
}
229+
if kwargs.model_data:
230+
JUMPSTART_LOGGER.info(
231+
"S3 prefix model_data detected for JumpStartModel: '%s'. "
232+
"Converting to S3DataSource dictionary: '%s'.",
233+
old_model_data_str,
234+
json.dumps(model_data),
235+
)
233236

234237
kwargs.model_data = model_data
235238

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1721,7 +1721,7 @@
17211721
"py_version": "py38",
17221722
"huggingface_transformers_version": "4.17.0",
17231723
},
1724-
"hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz",
1724+
"hosting_artifact_key": "huggingface-infer/",
17251725
"hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.3/sourcedir.tar.gz",
17261726
"hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/",
17271727
"hosting_prepacked_artifact_version": "1.0.1",

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,8 +737,11 @@ def test_model_data_s3_prefix_override(
737737
)
738738

739739
mock_js_info_logger.assert_called_with(
740-
"S3 prefix model_data detected for JumpStartModel: '%s'. Converting to S3DataSource dictionary.",
740+
"S3 prefix model_data detected for JumpStartModel: '%s'. "
741+
"Converting to S3DataSource dictionary: '%s'.",
741742
"s3://some-bucket/path/to/prefix/",
743+
'{"S3DataSource": {"S3Uri": "s3://some-bucket/path/to/prefix/", '
744+
'"S3DataType": "S3Prefix", "CompressionType": "None"}}',
742745
)
743746

744747
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")

0 commit comments

Comments
 (0)