Skip to content

Commit e4d0874

Browse files
rubanhRuban Hussain
and
Ruban Hussain
authored
fix: load_sagemaker_config should lazy initialize a default S3 resource (#3766)
Co-authored-by: Ruban Hussain <[email protected]>
1 parent 4bc5fb6 commit e4d0874

File tree

4 files changed

+34
-27
lines changed

4 files changed

+34
-27
lines changed

src/sagemaker/config/config.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,10 @@
4343
ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE = "SAGEMAKER_ADMIN_CONFIG_OVERRIDE"
4444
ENV_VARIABLE_USER_CONFIG_OVERRIDE = "SAGEMAKER_USER_CONFIG_OVERRIDE"
4545

46-
_BOTO_SESSION = boto3.DEFAULT_SESSION or boto3.Session()
47-
# The default Boto3 S3 Resource. This is constructed from the default Boto3 session. This will be
48-
# used to fetch SageMakerConfig from S3. Users can override this by passing their own S3 Resource
49-
# as the constructor parameter for SageMakerConfig.
50-
_DEFAULT_S3_RESOURCE = _BOTO_SESSION.resource("s3")
5146
S3_PREFIX = "s3://"
5247

5348

54-
def load_sagemaker_config(
55-
additional_config_paths: List[str] = None, s3_resource=_DEFAULT_S3_RESOURCE
56-
) -> dict:
49+
def load_sagemaker_config(additional_config_paths: List[str] = None, s3_resource=None) -> dict:
5750
"""Loads config files and merges them.
5851
5952
By default, this method first searches for config files in the default locations
@@ -95,8 +88,9 @@ def load_sagemaker_config(
9588
9689
Note: S3 URI follows the format ``s3://<bucket>/<Key prefix>``
9790
s3_resource (boto3.resource("s3")): The Boto3 S3 resource. This is used to fetch
98-
config files from S3. If it is not provided, this method creates a default S3 resource.
99-
See `Boto3 Session documentation <https://boto3.amazonaws.com/v1/documentation/api\
91+
config files from S3. If it is not provided but config files are present in S3,
92+
this method creates a default S3 resource. See `Boto3 Session documentation
93+
<https://boto3.amazonaws.com/v1/documentation/api\
10094
/latest/reference/core/session.html#boto3.session.Session.resource>`__.
10195
This argument is not needed if the config files are present in the local file system.
10296
"""
@@ -161,7 +155,15 @@ def _load_config_from_file(file_path: str) -> dict:
161155
def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict:
162156
"""Placeholder docstring"""
163157
if not s3_resource_for_config:
164-
raise RuntimeError("No S3 client found. Provide a S3 client to load the config file.")
158+
# Constructing a default Boto3 S3 Resource from a default Boto3 session.
159+
boto_session = boto3.DEFAULT_SESSION or boto3.Session()
160+
boto_region_name = boto_session.region_name
161+
if boto_region_name is None:
162+
raise ValueError(
163+
"Must setup local AWS configuration with a region supported by SageMaker."
164+
)
165+
s3_resource_for_config = boto_session.resource("s3", region_name=boto_region_name)
166+
165167
logger.debug("Fetching config file from the S3 URI: %s", s3_uri)
166168
inferred_s3_uri = _get_inferred_s3_uri(s3_uri, s3_resource_for_config)
167169
parsed_url = urlparse(inferred_s3_uri)

src/sagemaker/local/local_session.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -674,22 +674,19 @@ def _initialize(
674674
self.sagemaker_client = LocalSagemakerClient(self)
675675
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
676676
self.local_mode = True
677-
sagemaker_config = kwargs.get("sagemaker_config", None)
678-
if sagemaker_config:
679-
validate_sagemaker_config(sagemaker_config)
680677

681678
if self.s3_endpoint_url is not None:
682679
self.s3_resource = boto_session.resource("s3", endpoint_url=self.s3_endpoint_url)
683680
self.s3_client = boto_session.client("s3", endpoint_url=self.s3_endpoint_url)
684-
self.sagemaker_config = (
685-
sagemaker_config
686-
if sagemaker_config
687-
else load_sagemaker_config(s3_resource=self.s3_resource)
688-
)
681+
682+
sagemaker_config = kwargs.get("sagemaker_config", None)
683+
if sagemaker_config:
684+
validate_sagemaker_config(sagemaker_config)
685+
self.sagemaker_config = sagemaker_config
689686
else:
690-
self.sagemaker_config = (
691-
sagemaker_config if sagemaker_config else load_sagemaker_config()
692-
)
687+
# self.s3_resource might be None. If it is None, load_sagemaker_config will
688+
# create a default S3 resource, but only if it needs to fetch from S3
689+
self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource)
693690

694691
sagemaker_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml")
695692
if os.path.exists(sagemaker_config_file):

src/sagemaker/session.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -266,15 +266,14 @@ def _initialize(
266266
prepend_user_agent(self.sagemaker_metrics_client)
267267

268268
self.local_mode = False
269+
269270
if sagemaker_config:
270271
validate_sagemaker_config(sagemaker_config)
271272
self.sagemaker_config = sagemaker_config
272273
else:
273-
if self.s3_resource is None:
274-
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
275-
else:
276-
s3 = self.s3_resource
277-
self.sagemaker_config = load_sagemaker_config(s3_resource=s3)
274+
# self.s3_resource might be None. If it is None, load_sagemaker_config will
275+
# create a default S3 resource, but only if it needs to fetch from S3
276+
self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource)
278277

279278
@property
280279
def boto_region_name(self):

tests/integ/test_sagemaker_config.py

+9
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,20 @@ def test_config_download_from_s3_and_merge(
172172
sagemaker_session=sagemaker_session,
173173
)
174174

175+
# Set env variable so load_sagemaker_config can construct an S3 resource in the right region
176+
previous_env_value = os.getenv("AWS_DEFAULT_REGION")
177+
os.environ["AWS_DEFAULT_REGION"] = sagemaker_session.boto_session.region_name
178+
175179
# The thing being tested.
176180
sagemaker_config = load_sagemaker_config(
177181
additional_config_paths=[s3_uri_config_1, config_file_2_local_path]
178182
)
179183

184+
# Reset the env variable to what it was before (if it was set before)
185+
os.unsetenv("AWS_DEFAULT_REGION")
186+
if previous_env_value is not None:
187+
os.environ["AWS_DEFAULT_REGION"] = previous_env_value
188+
180189
assert sagemaker_config == expected_merged_config
181190

182191

0 commit comments

Comments
 (0)