diff --git a/src/sagemaker/djl_inference/model.py b/src/sagemaker/djl_inference/model.py index d57420a31d..89f15a54ab 100644 --- a/src/sagemaker/djl_inference/model.py +++ b/src/sagemaker/djl_inference/model.py @@ -135,15 +135,17 @@ def _read_existing_serving_properties(directory: str): return properties -def _get_model_config_properties_from_s3(model_s3_uri: str): +def _get_model_config_properties_from_s3(model_s3_uri: str, sagemaker_session: Session): """Placeholder docstring""" - s3_files = s3.S3Downloader.list(model_s3_uri) + s3_files = s3.S3Downloader.list(model_s3_uri, sagemaker_session=sagemaker_session) model_config = None for config in defaults.VALID_MODEL_CONFIG_FILES: config_file = os.path.join(model_s3_uri, config) if config_file in s3_files: - model_config = json.loads(s3.S3Downloader.read_file(config_file)) + model_config = json.loads( + s3.S3Downloader.read_file(config_file, sagemaker_session=sagemaker_session) + ) break if not model_config: raise ValueError( @@ -198,7 +200,8 @@ def __new__( "containing folder" ) if model_id.startswith("s3://"): - model_config = _get_model_config_properties_from_s3(model_id) + sagemaker_session = kwargs.get("sagemaker_session") + model_config = _get_model_config_properties_from_s3(model_id, sagemaker_session) else: model_config = _get_model_config_properties_from_hf(model_id) if model_config.get("_class_name") == "StableDiffusionPipeline": diff --git a/tests/unit/test_djl_inference.py b/tests/unit/test_djl_inference.py index 4c4ec27f49..c4f03ae502 100644 --- a/tests/unit/test_djl_inference.py +++ b/tests/unit/test_djl_inference.py @@ -174,6 +174,9 @@ def test_create_model_automatic_engine_selection(mock_s3_list, mock_read_file, s sagemaker_session=sagemaker_session, number_of_partitions=2, ) + mock_s3_list.assert_any_call( + VALID_UNCOMPRESSED_MODEL_DATA, sagemaker_session=sagemaker_session + ) if model_type == defaults.STABLE_DIFFUSION_MODEL_TYPE: assert ds_model.engine == DJLServingEngineEntryPointDefaults.STABLE_DIFFUSION else: