Skip to content

Commit b47b1d5

Browse files
committed
chore: resolve git comments
1 parent 5677dcb commit b47b1d5

File tree

6 files changed

+25
-9
lines changed

6 files changed

+25
-9
lines changed

doc/overview.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -773,8 +773,8 @@ Deployment may take about 5 minutes.
773773
   instance_type=instance_type,
774774
)
775775
776-
Because the model and script URIs are owned by JumpStart, the endpoint,
777-
endpoint config and model resources will be prefixed with
776+
Because the model and script URIs are distributed by SageMaker JumpStart,
777+
the endpoint, endpoint config and model resources will be prefixed with
778778
``sagemaker-jumpstart``. Refer to the model ``Tags`` to inspect the
779779
JumpStart artifacts involved in the model creation.
780780

src/sagemaker/estimator.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -570,10 +570,11 @@ def prepare_workflow_for_training(self, job_name=None):
570570
def _ensure_base_job_name(self):
571571
"""Set ``self.base_job_name`` if it is not set already."""
572572
# honor supplied base_job_name or generate it
573-
if self.base_job_name is None:
574-
self.base_job_name = get_jumpstart_base_name_if_jumpstart_model(
575-
self.source_dir, self.model_uri
576-
) or base_name_from_image(self.training_image_uri())
573+
self.base_job_name = (
574+
self.base_job_name
575+
or get_jumpstart_base_name_if_jumpstart_model(self.source_dir, self.model_uri)
576+
or base_name_from_image(self.training_image_uri())
577+
)
577578

578579
def _get_or_create_name(self, name=None):
579580
"""Generate a name based on the base job name or training image if needed.

src/sagemaker/jumpstart/cache.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _get_manifest_key_from_model_id_semantic_version(
229229
)
230230

231231
else:
232-
possible_model_ids = [header.model_id for header in manifest.values()]
232+
possible_model_ids = [header.model_id for header in manifest.values()] # type: ignore
233233
closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0]
234234
error_msg += f"Did you mean to use model ID '{closest_model_id}'?"
235235

src/sagemaker/jumpstart/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def add_single_jumpstart_tag(
233233

234234

235235
def get_jumpstart_base_name_if_jumpstart_model(
236-
*uris,
236+
*uris: Optional[str],
237237
) -> Optional[str]:
238238
"""Return default JumpStart base name if a URI belongs to JumpStart.
239239

src/sagemaker/model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,10 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag
532532
)
533533

534534
def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
535-
"""Create a base name from the image URI if there is no model name provided."""
535+
"""Create a base name from the image URI if there is no model name provided.
536+
537+
If a JumpStart script or model uri is used, select the JumpStart base name.
538+
"""
536539
if self.name is None:
537540
self._base_name = (
538541
self._base_name

tests/unit/sagemaker/jumpstart/test_utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -883,3 +883,15 @@ def test_get_jumpstart_base_name_if_jumpstart_model():
883883

884884
uris = ["s3://not-jumpstart-bucket/some-key" for _ in range(random.randint(0, 10))]
885885
assert utils.get_jumpstart_base_name_if_jumpstart_model(*uris) is None
886+
887+
uris = ["s3://not-jumpstart-bucket/some-key" for _ in range(random.randint(1, 10))] + [
888+
random_jumpstart_s3_uri("random_key")
889+
]
890+
assert JUMPSTART_RESOURCE_BASE_NAME == utils.get_jumpstart_base_name_if_jumpstart_model(*uris)
891+
892+
uris = (
893+
["s3://not-jumpstart-bucket/some-key" for _ in range(random.randint(1, 10))]
894+
+ [random_jumpstart_s3_uri("random_key")]
895+
+ ["s3://not-jumpstart-bucket/some-key-2" for _ in range(random.randint(1, 10))]
896+
)
897+
assert JUMPSTART_RESOURCE_BASE_NAME == utils.get_jumpstart_base_name_if_jumpstart_model(*uris)

0 commit comments

Comments
 (0)