Skip to content

Commit 7eaa84b

Browse files
committed
fix logic for s3 output config
1 parent a6fc9b3 commit 7eaa84b

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,35 +94,37 @@ def _get_s3_client(self) -> BaseClient:
9494
"""Returns an S3 client used for creating a HubContentDocument."""
9595
return boto3.client("s3", region_name=self.region)
9696

97-
def _fetch_hub_bucket_name(self) -> str:
97+
def _fetch_hub_storage_location(self) -> S3ObjectLocation:
9898
"""Retrieves hub bucket name from Hub config if exists"""
9999
try:
100100
hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name)
101101
hub_output_location = hub_response["S3StorageConfig"].get("S3OutputPath")
102+
print("aaaaa", hub_output_location)
102103
if hub_output_location:
103104
location = create_s3_object_reference_from_uri(hub_output_location)
104-
return location.bucket
105+
return location
105106
default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session)
107+
curr_timestamp = datetime.now().timestamp()
106108
JUMPSTART_LOGGER.warning(
107109
"There is not a Hub bucket associated with %s. Using %s",
108110
self.hub_name,
109111
default_bucket_name,
110112
)
111-
return default_bucket_name
113+
return S3ObjectLocation(bucket=default_bucket_name, key=f"{self.hub_name}-{curr_timestamp}")
112114
except exceptions.ClientError:
113115
hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session)
116+
curr_timestamp = datetime.now().timestamp()
114117
JUMPSTART_LOGGER.warning(
115118
"There is not a Hub bucket associated with %s. Using %s",
116119
self.hub_name,
117120
hub_bucket_name,
118121
)
119-
return hub_bucket_name
122+
return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}")
120123

121124
def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None:
122125
"""Generates an ``S3ObjectLocation`` given a Hub name."""
123-
hub_bucket_name = bucket_name or self._fetch_hub_bucket_name()
124126
curr_timestamp = datetime.now().timestamp()
125-
return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}")
127+
return S3ObjectLocation(bucket=bucket_name, key=f"{self.hub_name}-{curr_timestamp}") if bucket_name else self._fetch_hub_storage_location()
126128

127129
def create(
128130
self,

0 commit comments

Comments
 (0)