Skip to content

Commit 80b0464

Browse files
committed
chore: always include training script if available
1 parent c09435f commit 80b0464

File tree

3 files changed

+5
-42
lines changed

3 files changed

+5
-42
lines changed

src/sagemaker/jumpstart/artifacts.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ def _retrieve_script_uri(
239239
region: Optional[str] = False,
240240
tolerate_vulnerable_model: bool = False,
241241
tolerate_deprecated_model: bool = False,
242-
include_training_script: bool = False,
243242
):
244243
"""Retrieves the script S3 URI associated with the model matching the given arguments.
245244
@@ -260,8 +259,6 @@ def _retrieve_script_uri(
260259
tolerate_deprecated_model (bool): True if deprecated versions of model
261260
specifications should be tolerated (exception not raised). If False, raises
262261
an exception if the version of the model is deprecated.
263-
include_training_script (bool): True if training script should be packaged along with
264-
inference script. (Default: False.)
265262
Returns:
266263
str: the model script URI for the corresponding model.
267264
@@ -284,17 +281,11 @@ def _retrieve_script_uri(
284281
)
285282

286283
if script_scope == JumpStartScriptScope.INFERENCE:
287-
if not include_training_script:
288-
model_script_key = model_specs.hosting_script_key
289-
else:
290-
model_script_key = getattr(model_specs, "training_prepacked_script_key", None)
291-
if model_script_key is None:
292-
raise ValueError(
293-
f"Cannot include training script for {model_id} with version {model_version}."
294-
)
284+
model_script_key = (
285+
getattr(model_specs, "training_prepacked_script_key") or model_specs.hosting_script_key
286+
)
287+
295288
elif script_scope == JumpStartScriptScope.TRAINING:
296-
if include_training_script:
297-
raise ValueError("Can only include training script for inference jobs.")
298289
model_script_key = model_specs.training_script_key
299290

300291
bucket = os.environ.get(

src/sagemaker/script_uris.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def retrieve(
2929
script_scope=None,
3030
tolerate_vulnerable_model: bool = False,
3131
tolerate_deprecated_model: bool = False,
32-
include_training_script: bool = False,
3332
) -> str:
3433
"""Retrieves the script S3 URI associated with the model matching the given arguments.
3534
@@ -48,8 +47,6 @@ def retrieve(
4847
tolerate_deprecated_model (bool): ``True`` if deprecated models should be tolerated
4948
without raising an exception. ``False`` if these models should raise an exception.
5049
(Default: False).
51-
include_training_script (bool): True if training script should be packaged along with
52-
inference script. (Default: False.)
5350
Returns:
5451
str: The model script URI for the corresponding model.
5552
@@ -70,5 +67,4 @@ def retrieve(
7067
region,
7168
tolerate_vulnerable_model,
7269
tolerate_deprecated_model,
73-
include_training_script,
7470
)

tests/unit/sagemaker/script_uris/jumpstart/test_combined_script_artifact.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
from mock.mock import patch
1616

1717
from sagemaker import script_uris
18-
import pytest
1918

20-
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec, get_special_model_spec
19+
from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec
2120

2221

2322
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -32,31 +31,8 @@ def test_jumpstart_combined_artifacts(patched_get_model_specs):
3231
script_scope="inference",
3332
model_id=model_id_combined_script_artifact,
3433
model_version="*",
35-
include_training_script=True,
3634
)
3735
assert (
3836
uri == "s3://jumpstart-cache-prod-us-west-2/some/key/to/"
3937
"training_prepacked_script_key.tar.gz"
4038
)
41-
42-
with pytest.raises(ValueError):
43-
script_uris.retrieve(
44-
region="us-west-2",
45-
script_scope="training",
46-
model_id=model_id_combined_script_artifact,
47-
model_version="*",
48-
include_training_script=True,
49-
)
50-
51-
patched_get_model_specs.side_effect = get_prototype_model_spec
52-
53-
model_id_combined_script_artifact_unsupported = "xgboost-classification-model"
54-
55-
with pytest.raises(ValueError):
56-
script_uris.retrieve(
57-
region="us-west-2",
58-
script_scope="inference",
59-
model_id=model_id_combined_script_artifact_unsupported,
60-
model_version="*",
61-
include_training_script=True,
62-
)

0 commit comments

Comments
 (0)