Skip to content

Commit c01dde7

Browse files
committed
Add environment variable with VMARGS
1 parent 3774c1a commit c01dde7

File tree

4 files changed

+13
-7
lines changed

4 files changed

+13
-7
lines changed

src/sagemaker_inference/environment.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
DEFAULT_MODEL_SERVER_TIMEOUT = "60"
2626
DEFAULT_STARTUP_TIMEOUT = "600" # 10 minutes
2727
DEFAULT_HTTP_PORT = "8080"
28+
DEFAULT_VMARGS = "-XX:-UseContainerSupport"
2829

2930
SAGEMAKER_BASE_PATH = os.path.join("/opt", "ml") # type: str
3031

@@ -70,15 +71,12 @@ def __init__(self):
7071
os.environ.get(parameters.MODEL_SERVER_TIMEOUT_ENV, DEFAULT_MODEL_SERVER_TIMEOUT)
7172
)
7273
self._model_server_workers = os.environ.get(parameters.MODEL_SERVER_WORKERS_ENV)
73-
self._startup_timeout = int(
74-
os.environ.get(parameters.STARTUP_TIMEOUT_ENV, DEFAULT_STARTUP_TIMEOUT)
75-
)
76-
self._default_accept = os.environ.get(
77-
parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV, content_types.JSON
78-
)
74+
self._startup_timeout = int(os.environ.get(parameters.STARTUP_TIMEOUT_ENV, DEFAULT_STARTUP_TIMEOUT))
75+
self._default_accept = os.environ.get(parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV, content_types.JSON)
7976
self._inference_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT)
8077
self._management_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT)
8178
self._safe_port_range = os.environ.get(parameters.SAFE_PORT_RANGE_ENV)
79+
self._vmargs = os.environ.get(parameters.VMARGS, DEFAULT_VMARGS)
8280

8381
@staticmethod
8482
def _parse_module_name(program_param):
@@ -140,3 +138,8 @@ def safe_port_range(self): # type: () -> str
140138
specified by SageMaker for handling pings and invocations.
141139
"""
142140
return self._safe_port_range
141+
142+
@property
143+
def vmargs(self): # type: () -> str
144+
"""str: vmargs can be provided for the JVM, to be overriden"""
145+
return self._vmargs

src/sagemaker_inference/model_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def _generate_mms_config_properties(env, handler_service=None):
159159
"default_workers_per_model": env.model_server_workers,
160160
"inference_address": "http://0.0.0.0:{}".format(env.inference_http_port),
161161
"management_address": "http://0.0.0.0:{}".format(env.management_http_port),
162-
"vmargs": "-XX:-UseContainerSupport",
162+
"vmargs": env.vmargs,
163163
}
164164
# If provided, add handler service to user config
165165
if handler_service:

src/sagemaker_inference/parameters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@
2424
BIND_TO_PORT_ENV = "SAGEMAKER_BIND_TO_PORT" # type: str
2525
SAFE_PORT_RANGE_ENV = "SAGEMAKER_SAFE_PORT_RANGE" # type: str
2626
MULTI_MODEL_ENV = "SAGEMAKER_MULTI_MODEL" # type: str
27+
VMARGS = "VMARGS" # type: str

test/unit/test_environment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV: "text/html",
2929
parameters.BIND_TO_PORT_ENV: "1738",
3030
parameters.SAFE_PORT_RANGE_ENV: "1111-2222",
31+
parameters.VMARGS: "-XX:-UseContainerSupport",
3132
},
3233
clear=True,
3334
)
@@ -45,6 +46,7 @@ def test_env():
4546
assert env.inference_http_port == "1738"
4647
assert env.management_http_port == "1738"
4748
assert env.safe_port_range == "1111-2222"
49+
assert "-XX:-UseContainerSupport" in env.vmargs
4850

4951

5052
@pytest.mark.parametrize("sagemaker_program", ["program.py", "program"])

0 commit comments

Comments
 (0)