Skip to content

Commit 1fb8997

Browse files
authored
Merge branch 'master' into doc_clarify_1.0.9
2 parents f77d109 + 116bcce commit 1fb8997

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

src/sagemaker/estimator.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -2319,9 +2319,13 @@ def _model_source_dir(self):
23192319
str: Either a local or an S3 path pointing to the ``source_dir`` to be
23202320
used for code by the model to be deployed
23212321
"""
2322-
return (
2323-
self.source_dir if self.sagemaker_session.local_mode else self.uploaded_code.s3_prefix
2324-
)
2322+
if self.sagemaker_session.local_mode:
2323+
return self.source_dir
2324+
2325+
if self.uploaded_code is not None:
2326+
return self.uploaded_code.s3_prefix
2327+
2328+
return None
23252329

23262330
def _model_entry_point(self):
23272331
"""Get the appropriate value to pass as ``entry_point`` to a model constructor.
@@ -2333,7 +2337,10 @@ def _model_entry_point(self):
23332337
if self.sagemaker_session.local_mode or (self._model_source_dir() is None):
23342338
return self.entry_point
23352339

2336-
return self.uploaded_code.script_name
2340+
if self.uploaded_code is not None:
2341+
return self.uploaded_code.script_name
2342+
2343+
return None
23372344

23382345
def hyperparameters(self):
23392346
"""Return the hyperparameters as a dictionary to use for training.

src/sagemaker/workflow/_utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,9 @@ def arguments(self) -> RequestType:
312312
model = self.estimator.create_model(**self.kwargs)
313313
self.image_uri = model.image_uri
314314

315+
if self.model_data is None:
316+
self.model_data = model.model_data
317+
315318
# reset placeholder
316319
self.estimator.output_path = output_path
317320

src/sagemaker/workflow/step_collections.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def __init__(
146146
else:
147147
sagemaker_session = model_entity.sagemaker_session
148148
role = model_entity.role
149-
if hasattr(model_entity, "entry_point"):
149+
if hasattr(model_entity, "entry_point") and model_entity.entry_point is not None:
150150
repack_model = True
151151
entry_point = model_entity.entry_point
152152
source_dir = model_entity.source_dir
@@ -169,6 +169,7 @@ def __init__(
169169
model_entity.model_data = (
170170
repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
171171
)
172+
172173
# remove kwargs consumed by model repacking step
173174
kwargs.pop("output_kms_key", None)
174175

0 commit comments

Comments
 (0)