Skip to content

Commit d431778

Browse files
authored
fix: xgboost, sklearn network isolation for jumpstart (#3060)
1 parent 5bc8580 commit d431778

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/sagemaker/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def _script_mode_env_vars(self):
478478
dir_name = None
479479
if self.uploaded_code:
480480
script_name = self.uploaded_code.script_name
481-
if self.enable_network_isolation():
481+
if self.repacked_model_data or self.enable_network_isolation():
482482
dir_name = "/opt/ml/model/code"
483483
else:
484484
dir_name = self.uploaded_code.s3_prefix

tests/unit/sagemaker/model/test_model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,3 +665,25 @@ def test_all_framework_models_add_jumpstart_base_name(
665665

666666
sagemaker_session.create_model.reset_mock()
667667
sagemaker_session.endpoint_from_production_variants.reset_mock()
668+
669+
670+
@patch("sagemaker.utils.repack_model")
671+
def test_script_mode_model_uses_proper_sagemaker_submit_dir(repack_model, sagemaker_session):
672+
673+
source_dir = "s3://blah/blah/blah"
674+
t = Model(
675+
entry_point=ENTRY_POINT_INFERENCE,
676+
role=ROLE,
677+
sagemaker_session=sagemaker_session,
678+
source_dir=source_dir,
679+
image_uri=IMAGE_URI,
680+
model_data=MODEL_DATA,
681+
)
682+
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
683+
684+
assert (
685+
sagemaker_session.create_model.call_args_list[0][0][2]["Environment"][
686+
"SAGEMAKER_SUBMIT_DIRECTORY"
687+
]
688+
== "/opt/ml/model/code"
689+
)

0 commit comments

Comments
 (0)