|
43 | 43 | ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE = "SAGEMAKER_ADMIN_CONFIG_OVERRIDE"
|
44 | 44 | ENV_VARIABLE_USER_CONFIG_OVERRIDE = "SAGEMAKER_USER_CONFIG_OVERRIDE"
|
45 | 45 |
|
46 |
| -_BOTO_SESSION = boto3.DEFAULT_SESSION or boto3.Session() |
47 |
| -# The default Boto3 S3 Resource. This is constructed from the default Boto3 session. This will be |
48 |
| -# used to fetch SageMakerConfig from S3. Users can override this by passing their own S3 Resource |
49 |
| -# as the constructor parameter for SageMakerConfig. |
50 |
| -_DEFAULT_S3_RESOURCE = _BOTO_SESSION.resource("s3") |
51 | 46 | S3_PREFIX = "s3://"
|
52 | 47 |
|
53 | 48 |
|
54 |
| -def load_sagemaker_config( |
55 |
| - additional_config_paths: List[str] = None, s3_resource=_DEFAULT_S3_RESOURCE |
56 |
| -) -> dict: |
| 49 | +def load_sagemaker_config(additional_config_paths: List[str] = None, s3_resource=None) -> dict: |
57 | 50 | """Loads config files and merges them.
|
58 | 51 |
|
59 | 52 | By default, this method first searches for config files in the default locations
|
@@ -95,8 +88,9 @@ def load_sagemaker_config(
|
95 | 88 |
|
96 | 89 | Note: S3 URI follows the format ``s3://<bucket>/<Key prefix>``
|
97 | 90 | s3_resource (boto3.resource("s3")): The Boto3 S3 resource. This is used to fetch
|
98 |
| - config files from S3. If it is not provided, this method creates a default S3 resource. |
99 |
| - See `Boto3 Session documentation <https://boto3.amazonaws.com/v1/documentation/api\ |
| 91 | + config files from S3. If it is not provided but config files are present in S3, |
| 92 | + this method creates a default S3 resource. See `Boto3 Session documentation |
| 93 | + <https://boto3.amazonaws.com/v1/documentation/api\ |
100 | 94 | /latest/reference/core/session.html#boto3.session.Session.resource>`__.
|
101 | 95 | This argument is not needed if the config files are present in the local file system.
|
102 | 96 | """
|
@@ -161,7 +155,15 @@ def _load_config_from_file(file_path: str) -> dict:
|
161 | 155 | def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict:
|
162 | 156 | """Placeholder docstring"""
|
163 | 157 | if not s3_resource_for_config:
|
164 |
| - raise RuntimeError("No S3 client found. Provide a S3 client to load the config file.") |
| 158 | + # Constructing a default Boto3 S3 Resource from a default Boto3 session. |
| 159 | + boto_session = boto3.DEFAULT_SESSION or boto3.Session() |
| 160 | + boto_region_name = boto_session.region_name |
| 161 | + if boto_region_name is None: |
| 162 | + raise ValueError( |
| 163 | + "Must setup local AWS configuration with a region supported by SageMaker." |
| 164 | + ) |
| 165 | + s3_resource_for_config = boto_session.resource("s3", region_name=boto_region_name) |
| 166 | + |
165 | 167 | logger.debug("Fetching config file from the S3 URI: %s", s3_uri)
|
166 | 168 | inferred_s3_uri = _get_inferred_s3_uri(s3_uri, s3_resource_for_config)
|
167 | 169 | parsed_url = urlparse(inferred_s3_uri)
|
|
0 commit comments