Skip to content

Commit 9800681

Browse files
authored
Merge branch 'master' into processing-job-codeartifact-support
2 parents 7d78c28 + cbbbb32 commit 9800681

File tree

11 files changed

+148
-4
lines changed

11 files changed

+148
-4
lines changed

src/sagemaker/enums.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,15 @@ class EndpointType(Enum):
2828
INFERENCE_COMPONENT_BASED = (
2929
"InferenceComponentBased" # Amazon SageMaker Inference Component Based Endpoint
3030
)
31+
32+
33+
class RoutingStrategy(Enum):
34+
"""Strategy for routing https traffics."""
35+
36+
RANDOM = "RANDOM"
37+
"""The endpoint routes each request to a randomly chosen instance.
38+
"""
39+
LEAST_OUTSTANDING_REQUESTS = "LEAST_OUTSTANDING_REQUESTS"
40+
"""The endpoint routes requests to the specific instances that have
41+
more capacity to process them.
42+
"""

src/sagemaker/huggingface/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def deploy(
334334
endpoint_type=kwargs.get("endpoint_type", None),
335335
resources=kwargs.get("resources", None),
336336
managed_instance_scaling=kwargs.get("managed_instance_scaling", None),
337+
routing_config=kwargs.get("routing_config", None),
337338
)
338339

339340
def register(

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1315,7 +1315,8 @@
13151315
"1.13": "1.13.1",
13161316
"2.0": "2.0.1",
13171317
"2.1": "2.1.0",
1318-
"2.2": "2.2.0"
1318+
"2.2": "2.2.0",
1319+
"2.3": "2.3.0"
13191320
},
13201321
"versions": {
13211322
"0.4.0": {
@@ -2288,6 +2289,47 @@
22882289
"us-west-2": "763104351884"
22892290
},
22902291
"repository": "pytorch-training"
2292+
},
2293+
"2.3.0": {
2294+
"py_versions": [
2295+
"py311"
2296+
],
2297+
"registries": {
2298+
"af-south-1": "626614931356",
2299+
"ap-east-1": "871362719292",
2300+
"ap-northeast-1": "763104351884",
2301+
"ap-northeast-2": "763104351884",
2302+
"ap-northeast-3": "364406365360",
2303+
"ap-south-1": "763104351884",
2304+
"ap-south-2": "772153158452",
2305+
"ap-southeast-1": "763104351884",
2306+
"ap-southeast-2": "763104351884",
2307+
"ap-southeast-3": "907027046896",
2308+
"ap-southeast-4": "457447274322",
2309+
"ca-central-1": "763104351884",
2310+
"ca-west-1": "204538143572",
2311+
"cn-north-1": "727897471807",
2312+
"cn-northwest-1": "727897471807",
2313+
"eu-central-1": "763104351884",
2314+
"eu-central-2": "380420809688",
2315+
"eu-north-1": "763104351884",
2316+
"eu-south-1": "692866216735",
2317+
"eu-south-2": "503227376785",
2318+
"eu-west-1": "763104351884",
2319+
"eu-west-2": "763104351884",
2320+
"eu-west-3": "763104351884",
2321+
"il-central-1": "780543022126",
2322+
"me-central-1": "914824155844",
2323+
"me-south-1": "217643126080",
2324+
"sa-east-1": "763104351884",
2325+
"us-east-1": "763104351884",
2326+
"us-east-2": "763104351884",
2327+
"us-gov-east-1": "446045086412",
2328+
"us-gov-west-1": "442386744353",
2329+
"us-west-1": "763104351884",
2330+
"us-west-2": "763104351884"
2331+
},
2332+
"repository": "pytorch-training"
22912333
}
22922334
}
22932335
}

src/sagemaker/jumpstart/factory/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ def get_deploy_kwargs(
555555
resources: Optional[ResourceRequirements] = None,
556556
managed_instance_scaling: Optional[str] = None,
557557
endpoint_type: Optional[EndpointType] = None,
558+
routing_config: Optional[Dict[str, Any]] = None,
558559
) -> JumpStartModelDeployKwargs:
559560
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""
560561

@@ -586,6 +587,7 @@ def get_deploy_kwargs(
586587
accept_eula=accept_eula,
587588
endpoint_logging=endpoint_logging,
588589
resources=resources,
590+
routing_config=routing_config,
589591
)
590592

591593
deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs)

src/sagemaker/jumpstart/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import absolute_import
1616

17-
from typing import Dict, List, Optional, Union
17+
from typing import Dict, List, Optional, Union, Any
1818
from botocore.exceptions import ClientError
1919

2020
from sagemaker import payloads
@@ -496,6 +496,7 @@ def deploy(
496496
resources: Optional[ResourceRequirements] = None,
497497
managed_instance_scaling: Optional[str] = None,
498498
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
499+
routing_config: Optional[Dict[str, Any]] = None,
499500
) -> PredictorBase:
500501
"""Creates endpoint by calling base ``Model`` class `deploy` method.
501502
@@ -590,6 +591,8 @@ def deploy(
590591
endpoint.
591592
endpoint_type (EndpointType): The type of endpoint used to deploy models.
592593
(Default: EndpointType.MODEL_BASED).
594+
routing_config (Optional[Dict]): Settings the control how the endpoint routes
595+
incoming traffic to the instances that the endpoint hosts.
593596
594597
Raises:
595598
MarketplaceModelSubscriptionError: If the caller is not subscribed to the model.
@@ -625,6 +628,7 @@ def deploy(
625628
managed_instance_scaling=managed_instance_scaling,
626629
endpoint_type=endpoint_type,
627630
model_type=self.model_type,
631+
routing_config=routing_config,
628632
)
629633
if (
630634
self.model_type == JumpStartModelType.PROPRIETARY

src/sagemaker/jumpstart/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,6 +1614,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
16141614
"endpoint_logging",
16151615
"resources",
16161616
"endpoint_type",
1617+
"routing_config",
16171618
]
16181619

16191620
SERIALIZATION_EXCLUSION_SET = {
@@ -1658,6 +1659,7 @@ def __init__(
16581659
endpoint_logging: Optional[bool] = None,
16591660
resources: Optional[ResourceRequirements] = None,
16601661
endpoint_type: Optional[EndpointType] = None,
1662+
routing_config: Optional[Dict[str, Any]] = None,
16611663
) -> None:
16621664
"""Instantiates JumpStartModelDeployKwargs object."""
16631665

@@ -1690,6 +1692,7 @@ def __init__(
16901692
self.endpoint_logging = endpoint_logging
16911693
self.resources = resources
16921694
self.endpoint_type = endpoint_type
1695+
self.routing_config = routing_config
16931696

16941697

16951698
class JumpStartEstimatorInitKwargs(JumpStartKwargs):

src/sagemaker/model.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import os
2121
import re
2222
import copy
23-
from typing import List, Dict, Optional, Union
23+
from typing import List, Dict, Optional, Union, Any
2424

2525
import sagemaker
2626
from sagemaker import (
@@ -66,6 +66,7 @@
6666
resolve_nested_dict_value_from_config,
6767
format_tags,
6868
Tags,
69+
_resolve_routing_config,
6970
)
7071
from sagemaker.async_inference import AsyncInferenceConfig
7172
from sagemaker.predictor_async import AsyncPredictor
@@ -1309,6 +1310,7 @@ def deploy(
13091310
resources: Optional[ResourceRequirements] = None,
13101311
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
13111312
managed_instance_scaling: Optional[str] = None,
1313+
routing_config: Optional[Dict[str, Any]] = None,
13121314
**kwargs,
13131315
):
13141316
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -1406,6 +1408,15 @@ def deploy(
14061408
Endpoint. (Default: None).
14071409
endpoint_type (Optional[EndpointType]): The type of an endpoint used to deploy models.
14081410
(Default: EndpointType.MODEL_BASED).
1411+
routing_config (Optional[Dict[str, Any]): Settings the control how the endpoint routes incoming
1412+
traffic to the instances that the endpoint hosts.
1413+
Currently, support dictionary key ``RoutingStrategy``.
1414+
1415+
.. code:: python
1416+
1417+
{
1418+
"RoutingStrategy": sagemaker.enums.RoutingStrategy.RANDOM
1419+
}
14091420
Raises:
14101421
ValueError: If arguments combination check failed in these circumstances:
14111422
- If no role is specified or
@@ -1458,6 +1469,8 @@ def deploy(
14581469
if self.role is None:
14591470
raise ValueError("Role can not be null for deploying a model")
14601471

1472+
routing_config = _resolve_routing_config(routing_config)
1473+
14611474
if (
14621475
inference_recommendation_id is not None
14631476
or self.inference_recommender_job_results is not None
@@ -1543,6 +1556,7 @@ def deploy(
15431556
model_data_download_timeout=model_data_download_timeout,
15441557
container_startup_health_check_timeout=container_startup_health_check_timeout,
15451558
managed_instance_scaling=managed_instance_scaling_config,
1559+
routing_config=routing_config,
15461560
)
15471561

15481562
self.sagemaker_session.endpoint_from_production_variants(
@@ -1625,6 +1639,7 @@ def deploy(
16251639
volume_size=volume_size,
16261640
model_data_download_timeout=model_data_download_timeout,
16271641
container_startup_health_check_timeout=container_startup_health_check_timeout,
1642+
routing_config=routing_config,
16281643
)
16291644
if endpoint_name:
16301645
self.endpoint_name = endpoint_name

src/sagemaker/serve/builder/model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
169169
in order for model builder to build the artifacts correctly (according
170170
to the model server). Possible values for this argument are
171171
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
172-
``TRITON``,``TGI``, and ``TEI``.
172+
``TRITON``, ``TGI``, and ``TEI``.
173173
model_metadata (Optional[Dict[str, Any]): Dictionary used to override model metadata.
174174
Currently, ``HF_TASK`` is overridable for HuggingFace model. HF_TASK should be set for
175175
new models without task metadata in the Hub, adding unsupported task types will throw

src/sagemaker/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
_log_sagemaker_config_single_substitution,
4545
_log_sagemaker_config_merge,
4646
)
47+
from sagemaker.enums import RoutingStrategy
4748
from sagemaker.session_settings import SessionSettings
4849
from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string
4950
from sagemaker.workflow.entities import PipelineVariable
@@ -1655,3 +1656,33 @@ def deep_override_dict(
16551656
)
16561657
flattened_dict1.update(flattened_dict2)
16571658
return unflatten_dict(flattened_dict1) if flattened_dict1 else {}
1659+
1660+
1661+
def _resolve_routing_config(routing_config: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
1662+
"""Resolve Routing Config
1663+
1664+
Args:
1665+
routing_config (Optional[Dict[str, Any]]): The routing config.
1666+
1667+
Returns:
1668+
Optional[Dict[str, Any]]: The resolved routing config.
1669+
1670+
Raises:
1671+
ValueError: If the RoutingStrategy is invalid.
1672+
"""
1673+
1674+
if routing_config:
1675+
routing_strategy = routing_config.get("RoutingStrategy", None)
1676+
if routing_strategy:
1677+
if isinstance(routing_strategy, RoutingStrategy):
1678+
return {"RoutingStrategy": routing_strategy.name}
1679+
if isinstance(routing_strategy, str) and (
1680+
routing_strategy.upper() == RoutingStrategy.RANDOM.name
1681+
or routing_strategy.upper() == RoutingStrategy.LEAST_OUTSTANDING_REQUESTS.name
1682+
):
1683+
return {"RoutingStrategy": routing_strategy.upper()}
1684+
raise ValueError(
1685+
"RoutingStrategy must be either RoutingStrategy.RANDOM "
1686+
"or RoutingStrategy.LEAST_OUTSTANDING_REQUESTS"
1687+
)
1688+
return None

tests/unit/sagemaker/model/test_deploy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem
125125
volume_size=None,
126126
model_data_download_timeout=None,
127127
container_startup_health_check_timeout=None,
128+
routing_config=None,
128129
)
129130

130131
sagemaker_session.create_model.assert_called_with(
@@ -184,6 +185,7 @@ def test_deploy_accelerator_type(
184185
volume_size=None,
185186
model_data_download_timeout=None,
186187
container_startup_health_check_timeout=None,
188+
routing_config=None,
187189
)
188190

189191
sagemaker_session.endpoint_from_production_variants.assert_called_with(
@@ -506,6 +508,7 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model,
506508
volume_size=None,
507509
model_data_download_timeout=None,
508510
container_startup_health_check_timeout=None,
511+
routing_config=None,
509512
)
510513

511514
sagemaker_session.endpoint_from_production_variants.assert_called_with(
@@ -938,6 +941,7 @@ def test_deploy_customized_volume_size_and_timeout(
938941
volume_size=volume_size_gb,
939942
model_data_download_timeout=model_data_download_timeout_sec,
940943
container_startup_health_check_timeout=startup_health_check_timeout_sec,
944+
routing_config=None,
941945
)
942946

943947
sagemaker_session.create_model.assert_called_with(
@@ -987,6 +991,7 @@ def test_deploy_with_resources(sagemaker_session, name_from_base, production_var
987991
volume_size=None,
988992
model_data_download_timeout=None,
989993
container_startup_health_check_timeout=None,
994+
routing_config=None,
990995
)
991996
sagemaker_session.endpoint_from_production_variants.assert_called_with(
992997
name=name_from_base(MODEL_NAME),

tests/unit/test_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from mock import call, patch, Mock, MagicMock, PropertyMock
3131

3232
import sagemaker
33+
from sagemaker.enums import RoutingStrategy
3334
from sagemaker.experiments._run_context import _RunContext
3435
from sagemaker.session_settings import SessionSettings
3536
from sagemaker.utils import (
@@ -50,6 +51,7 @@
5051
_is_bad_link,
5152
custom_extractall_tarfile,
5253
can_model_package_source_uri_autopopulate,
54+
_resolve_routing_config,
5355
)
5456
from tests.unit.sagemaker.workflow.helpers import CustomStep
5557
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
@@ -1866,3 +1868,30 @@ def test_deep_override_skip_keys(self):
18661868
expected_result = {"a": 1, "b": {"x": 20, "y": 3, "z": 30}, "c": [4, 5]}
18671869

18681870
self.assertEqual(deep_override_dict(dict1, dict2, skip_keys=["c", "d"]), expected_result)
1871+
1872+
1873+
@pytest.mark.parametrize(
1874+
"routing_config, expected",
1875+
[
1876+
({"RoutingStrategy": RoutingStrategy.RANDOM}, {"RoutingStrategy": "RANDOM"}),
1877+
({"RoutingStrategy": "RANDOM"}, {"RoutingStrategy": "RANDOM"}),
1878+
(
1879+
{"RoutingStrategy": RoutingStrategy.LEAST_OUTSTANDING_REQUESTS},
1880+
{"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
1881+
),
1882+
(
1883+
{"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
1884+
{"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
1885+
),
1886+
({"RoutingStrategy": None}, None),
1887+
(None, None),
1888+
],
1889+
)
1890+
def test_resolve_routing_config(routing_config, expected):
1891+
res = _resolve_routing_config(routing_config)
1892+
1893+
assert res == expected
1894+
1895+
1896+
def test_resolve_routing_config_ex():
1897+
pytest.raises(ValueError, lambda: _resolve_routing_config({"RoutingStrategy": "Invalid"}))

0 commit comments

Comments
 (0)