Skip to content

Commit 31528ac

Browse files
committed
move test to model dir, update test to test properties based model
1 parent 48ca435 commit 31528ac

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

tests/unit/sagemaker/model/test_framework_model.py

-17
Original file line numberDiff line numberDiff line change
@@ -112,23 +112,6 @@ def test_prepare_container_def(time, sagemaker_session):
112112
"ModelDataUrl": MODEL_DATA,
113113
}
114114

115-
@patch("shutil.rmtree", MagicMock())
116-
@patch("tarfile.open", MagicMock())
117-
@patch("os.listdir", MagicMock(return_value=["blah.py"]))
118-
@patch("time.strftime", return_value=TIMESTAMP)
119-
def test_prepare_container_def_s3_src(time, sagemaker_session):
120-
model = DummyFrameworkModel(sagemaker_session, source_dir=S3_SOURCE_DIR)
121-
assert model.prepare_container_def(INSTANCE_TYPE) == {
122-
"Environment": {
123-
"SAGEMAKER_PROGRAM": ENTRY_POINT,
124-
"SAGEMAKER_SUBMIT_DIRECTORY": "s3://somebucket/sourcedir.tar.gz",
125-
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
126-
"SAGEMAKER_REGION": REGION,
127-
},
128-
"Image": MODEL_IMAGE,
129-
"ModelDataUrl": MODEL_DATA,
130-
}
131-
S3_SOURCE_DIR
132115

133116
@patch("shutil.rmtree", MagicMock())
134117
@patch("tarfile.open", MagicMock())

tests/unit/sagemaker/model/test_model.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from sagemaker.sklearn.model import SKLearnModel
2727
from sagemaker.tensorflow.model import TensorFlowModel
2828
from sagemaker.xgboost.model import XGBoostModel
29+
from sagemaker.workflow.properties import Properties
30+
2931

3032
MODEL_DATA = "s3://bucket/model.tar.gz"
3133
MODEL_IMAGE = "mi"
@@ -42,7 +44,6 @@
4244
BRANCH = "test-branch-git-config"
4345
COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73"
4446
ENTRY_POINT_INFERENCE = "inference.py"
45-
4647
SCRIPT_URI = "s3://codebucket/someprefix/sourcedir.tar.gz"
4748
IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
4849

@@ -71,6 +72,23 @@ def sagemaker_session():
7172
return sms
7273

7374

75+
@patch("shutil.rmtree", MagicMock())
76+
@patch("tarfile.open", MagicMock())
77+
@patch("os.listdir", MagicMock(return_value=[ENTRY_POINT_INFERENCE]))
78+
def test_prepare_container_def_with_model_src_s3_returns_correct_url(sagemaker_session):
79+
model = Model(
80+
entry_point=ENTRY_POINT_INFERENCE,
81+
role=ROLE,
82+
sagemaker_session=sagemaker_session,
83+
source_dir=SCRIPT_URI,
84+
image_uri=MODEL_IMAGE,
85+
model_data=Properties("Steps.MyStep"),
86+
)
87+
container_def = model.prepare_container_def(INSTANCE_TYPE, "ml.eia.medium")
88+
89+
assert container_def["Environment"]["SAGEMAKER_SUBMIT_DIRECTORY"] == SCRIPT_URI
90+
91+
7492
def test_prepare_container_def_with_model_data():
7593
model = Model(MODEL_IMAGE)
7694
container_def = model.prepare_container_def(INSTANCE_TYPE, "ml.eia.medium")

0 commit comments

Comments
 (0)