Skip to content

Commit 313cd18

Browse files
ErnevSharmashaernev
and
shaernev
authored
fix: removing kwargs as this is breaking predictor_cls param for mode… (#4816)
* fix: removing kwargs as this is breaking predictor_cls param for model deploy * fix: include all known fields directly in create_sagemaker_model function signature * chore: fixing model deploy tests to incorporate new params * chore: fix formatting --------- Co-authored-by: shaernev <[email protected]>
1 parent 5416993 commit 313cd18

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

src/sagemaker/model.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,7 @@ def deploy(
13771377
managed_instance_scaling: Optional[str] = None,
13781378
inference_component_name=None,
13791379
routing_config: Optional[Dict[str, Any]] = None,
1380+
model_reference_arn: Optional[str] = None,
13801381
**kwargs,
13811382
):
13821383
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -1483,6 +1484,8 @@ def deploy(
14831484
{
14841485
"RoutingStrategy": sagemaker.enums.RoutingStrategy.RANDOM
14851486
}
1487+
model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type
1488+
content (default: None).
14861489
Raises:
14871490
ValueError: If arguments combination check failed in these circumstances:
14881491
- If no role is specified or
@@ -1697,7 +1700,8 @@ def deploy(
16971700
accelerator_type=accelerator_type,
16981701
tags=tags,
16991702
serverless_inference_config=serverless_inference_config,
1700-
**kwargs,
1703+
accept_eula=accept_eula,
1704+
model_reference_arn=model_reference_arn,
17011705
)
17021706
serverless_inference_config_dict = (
17031707
serverless_inference_config._to_request_dict() if is_serverless else None

tests/unit/sagemaker/jumpstart/model/test_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
771771
and reach out to JumpStart team."""
772772

773773
init_args_to_skip: Set[str] = set(["model_reference_arn"])
774-
deploy_args_to_skip: Set[str] = set(["kwargs"])
774+
deploy_args_to_skip: Set[str] = set(["kwargs", "model_reference_arn"])
775775

776776
parent_class_init = Model.__init__
777777
parent_class_init_args = set(signature(parent_class_init).parameters.keys())

tests/unit/sagemaker/model/test_deploy.py

+6
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ def test_deploy_accelerator_type(
179179
accelerator_type=ACCELERATOR_TYPE,
180180
tags=None,
181181
serverless_inference_config=None,
182+
accept_eula=None,
183+
model_reference_arn=None,
182184
)
183185
production_variant.assert_called_with(
184186
MODEL_NAME,
@@ -299,6 +301,8 @@ def test_deploy_tags(create_sagemaker_model, production_variant, name_from_base,
299301
accelerator_type=None,
300302
tags=tags,
301303
serverless_inference_config=None,
304+
accept_eula=None,
305+
model_reference_arn=None,
302306
)
303307
sagemaker_session.endpoint_from_production_variants.assert_called_with(
304308
name=ENDPOINT_NAME,
@@ -502,6 +506,8 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model,
502506
accelerator_type=None,
503507
tags=None,
504508
serverless_inference_config=serverless_inference_config,
509+
accept_eula=None,
510+
model_reference_arn=None,
505511
)
506512
production_variant.assert_called_with(
507513
MODEL_NAME,

0 commit comments

Comments
 (0)