Skip to content

fix: load_sagemaker_config should lazy initialize a default S3 resource #3766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions src/sagemaker/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,10 @@
ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE = "SAGEMAKER_ADMIN_CONFIG_OVERRIDE"
ENV_VARIABLE_USER_CONFIG_OVERRIDE = "SAGEMAKER_USER_CONFIG_OVERRIDE"

_BOTO_SESSION = boto3.DEFAULT_SESSION or boto3.Session()
# The default Boto3 S3 Resource. This is constructed from the default Boto3 session. This will be
# used to fetch SageMakerConfig from S3. Users can override this by passing their own S3 Resource
# as the constructor parameter for SageMakerConfig.
_DEFAULT_S3_RESOURCE = _BOTO_SESSION.resource("s3")
S3_PREFIX = "s3://"


def load_sagemaker_config(
additional_config_paths: List[str] = None, s3_resource=_DEFAULT_S3_RESOURCE
) -> dict:
def load_sagemaker_config(additional_config_paths: List[str] = None, s3_resource=None) -> dict:
"""Loads config files and merges them.

By default, this method first searches for config files in the default locations
Expand Down Expand Up @@ -95,8 +88,9 @@ def load_sagemaker_config(

Note: S3 URI follows the format ``s3://<bucket>/<Key prefix>``
s3_resource (boto3.resource("s3")): The Boto3 S3 resource. This is used to fetch
config files from S3. If it is not provided, this method creates a default S3 resource.
See `Boto3 Session documentation <https://boto3.amazonaws.com/v1/documentation/api\
config files from S3. If it is not provided but config files are present in S3,
this method creates a default S3 resource. See `Boto3 Session documentation
<https://boto3.amazonaws.com/v1/documentation/api\
/latest/reference/core/session.html#boto3.session.Session.resource>`__.
This argument is not needed if the config files are present in the local file system.
"""
Expand Down Expand Up @@ -161,7 +155,15 @@ def _load_config_from_file(file_path: str) -> dict:
def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict:
"""Placeholder docstring"""
if not s3_resource_for_config:
raise RuntimeError("No S3 client found. Provide a S3 client to load the config file.")
# Constructing a default Boto3 S3 Resource from a default Boto3 session.
boto_session = boto3.DEFAULT_SESSION or boto3.Session()
boto_region_name = boto_session.region_name
if boto_region_name is None:
raise ValueError(
"Must setup local AWS configuration with a region supported by SageMaker."
)
s3_resource_for_config = boto_session.resource("s3", region_name=boto_region_name)

logger.debug("Fetching config file from the S3 URI: %s", s3_uri)
inferred_s3_uri = _get_inferred_s3_uri(s3_uri, s3_resource_for_config)
parsed_url = urlparse(inferred_s3_uri)
Expand Down
19 changes: 8 additions & 11 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,22 +674,19 @@ def _initialize(
self.sagemaker_client = LocalSagemakerClient(self)
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
self.local_mode = True
sagemaker_config = kwargs.get("sagemaker_config", None)
if sagemaker_config:
validate_sagemaker_config(sagemaker_config)

if self.s3_endpoint_url is not None:
self.s3_resource = boto_session.resource("s3", endpoint_url=self.s3_endpoint_url)
self.s3_client = boto_session.client("s3", endpoint_url=self.s3_endpoint_url)
self.sagemaker_config = (
sagemaker_config
if sagemaker_config
else load_sagemaker_config(s3_resource=self.s3_resource)
)

sagemaker_config = kwargs.get("sagemaker_config", None)
if sagemaker_config:
validate_sagemaker_config(sagemaker_config)
self.sagemaker_config = sagemaker_config
else:
self.sagemaker_config = (
sagemaker_config if sagemaker_config else load_sagemaker_config()
)
# self.s3_resource might be None. If it is None, load_sagemaker_config will
# create a default S3 resource, but only if it needs to fetch from S3
self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource)

sagemaker_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml")
if os.path.exists(sagemaker_config_file):
Expand Down
9 changes: 4 additions & 5 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,14 @@ def _initialize(
prepend_user_agent(self.sagemaker_metrics_client)

self.local_mode = False

if sagemaker_config:
validate_sagemaker_config(sagemaker_config)
self.sagemaker_config = sagemaker_config
else:
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_resource
self.sagemaker_config = load_sagemaker_config(s3_resource=s3)
# self.s3_resource might be None. If it is None, load_sagemaker_config will
# create a default S3 resource, but only if it needs to fetch from S3
self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource)

@property
def boto_region_name(self):
Expand Down
9 changes: 9 additions & 0 deletions tests/integ/test_sagemaker_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,20 @@ def test_config_download_from_s3_and_merge(
sagemaker_session=sagemaker_session,
)

# Set env variable so load_sagemaker_config can construct an S3 resource in the right region
previous_env_value = os.getenv("AWS_DEFAULT_REGION")
os.environ["AWS_DEFAULT_REGION"] = sagemaker_session.boto_session.region_name

# The thing being tested.
sagemaker_config = load_sagemaker_config(
additional_config_paths=[s3_uri_config_1, config_file_2_local_path]
)

# Reset the env variable to what it was before (if it was set before)
os.unsetenv("AWS_DEFAULT_REGION")
if previous_env_value is not None:
os.environ["AWS_DEFAULT_REGION"] = previous_env_value

assert sagemaker_config == expected_merged_config


Expand Down