Skip to content

Commit 49e09c3

Browse files
keerthanvasistAditi2424
authored andcommitted
feature: allow choosing js payload by alias in private method
1 parent 0075fb3 commit 49e09c3

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

src/sagemaker/jumpstart/payload_utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def _construct_payload(
6262
tolerate_deprecated_model: bool = False,
6363
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
6464
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
65+
alias: Optional[str] = None,
6566
) -> Optional[JumpStartSerializablePayload]:
6667
"""Returns example payload from prompt.
6768
@@ -102,7 +103,9 @@ def _construct_payload(
102103
if payloads is None or len(payloads) == 0:
103104
return None
104105

105-
payload_to_use: JumpStartSerializablePayload = list(payloads.values())[0]
106+
payload_to_use: JumpStartSerializablePayload = (
107+
payloads[alias] if alias else list(payloads.values())[0]
108+
)
106109

107110
prompt_key: Optional[str] = payload_to_use.prompt_key
108111
if prompt_key is None:

tests/unit/sagemaker/jumpstart/test_payload_utils.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,36 @@ def test_construct_payload(self, patched_get_model_specs):
3232
region = "us-west-2"
3333

3434
constructed_payload_body = _construct_payload(
35-
prompt="kobebryant",
36-
model_id=model_id,
37-
model_version="*",
38-
region=region,
35+
prompt="kobebryant", model_id=model_id, model_version="*", region=region
36+
).body
37+
38+
self.assertEqual(
39+
{
40+
"hello": {"prompt": "kobebryant"},
41+
"seed": 43,
42+
},
43+
constructed_payload_body,
44+
)
45+
46+
# Unsupported model
47+
self.assertIsNone(
48+
_construct_payload(
49+
prompt="blah",
50+
model_id="default_payloads",
51+
model_version="*",
52+
region=region,
53+
)
54+
)
55+
56+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
57+
def test_construct_payload_with_specific_alias(self, patched_get_model_specs):
58+
patched_get_model_specs.side_effect = get_special_model_spec
59+
60+
model_id = "prompt-key"
61+
region = "us-west-2"
62+
63+
constructed_payload_body = _construct_payload(
64+
prompt="kobebryant", model_id=model_id, model_version="*", region=region, alias="Dog"
3965
).body
4066

4167
self.assertEqual(

0 commit comments

Comments
 (0)