Skip to content

Commit cb9889a

Browse files
committed
[fix]: Pass sagemaker session to downstream s3 calls
1 parent ca107d4 commit cb9889a

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

src/sagemaker/djl_inference/model.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,17 @@ def _read_existing_serving_properties(directory: str):
135135
return properties
136136

137137

138-
def _get_model_config_properties_from_s3(model_s3_uri: str):
138+
def _get_model_config_properties_from_s3(model_s3_uri: str, sagemaker_session: Session):
139139
"""Placeholder docstring"""
140140

141-
s3_files = s3.S3Downloader.list(model_s3_uri)
141+
s3_files = s3.S3Downloader.list(model_s3_uri, sagemaker_session=sagemaker_session)
142142
model_config = None
143143
for config in defaults.VALID_MODEL_CONFIG_FILES:
144144
config_file = os.path.join(model_s3_uri, config)
145145
if config_file in s3_files:
146-
model_config = json.loads(s3.S3Downloader.read_file(config_file))
146+
model_config = json.loads(
147+
s3.S3Downloader.read_file(config_file, sagemaker_session=sagemaker_session)
148+
)
147149
break
148150
if not model_config:
149151
raise ValueError(
@@ -198,7 +200,8 @@ def __new__(
198200
"containing folder"
199201
)
200202
if model_id.startswith("s3://"):
201-
model_config = _get_model_config_properties_from_s3(model_id)
203+
sagemaker_session = kwargs.get("sagemaker_session")
204+
model_config = _get_model_config_properties_from_s3(model_id, sagemaker_session)
202205
else:
203206
model_config = _get_model_config_properties_from_hf(model_id)
204207
if model_config.get("_class_name") == "StableDiffusionPipeline":

tests/unit/test_djl_inference.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ def test_create_model_automatic_engine_selection(mock_s3_list, mock_read_file, s
174174
sagemaker_session=sagemaker_session,
175175
number_of_partitions=2,
176176
)
177+
mock_s3_list.assert_any_call(
178+
VALID_UNCOMPRESSED_MODEL_DATA, sagemaker_session=sagemaker_session
179+
)
177180
if model_type == defaults.STABLE_DIFFUSION_MODEL_TYPE:
178181
assert ds_model.engine == DJLServingEngineEntryPointDefaults.STABLE_DIFFUSION
179182
else:

0 commit comments

Comments
 (0)