27
27
from sagemaker .mxnet import MXNetPredictor , MXNetModel
28
28
29
29
DATA_DIR = os .path .join (os .path .dirname (__file__ ), ".." , "data" )
30
- SCRIPT_PATH = os .path .join (DATA_DIR , "dummy_script.py" )
30
+ SCRIPT_NAME = "dummy_script.py"
31
+ SCRIPT_PATH = os .path .join (DATA_DIR , SCRIPT_NAME )
31
32
SERVING_SCRIPT_FILE = "another_dummy_script.py"
32
33
MODEL_DATA = "s3://mybucket/model"
33
34
ENV = {"DUMMY_ENV_VAR" : "dummy_value" }
@@ -189,7 +190,8 @@ def test_create_model(name_from_base, sagemaker_session, mxnet_version, mxnet_py
189
190
base_job_name = "job"
190
191
191
192
mx = MXNet (
192
- entry_point = SCRIPT_PATH ,
193
+ entry_point = SCRIPT_NAME ,
194
+ source_dir = source_dir ,
193
195
framework_version = mxnet_version ,
194
196
py_version = mxnet_py_version ,
195
197
role = ROLE ,
@@ -198,7 +200,6 @@ def test_create_model(name_from_base, sagemaker_session, mxnet_version, mxnet_py
198
200
instance_type = INSTANCE_TYPE ,
199
201
container_log_level = container_log_level ,
200
202
base_job_name = base_job_name ,
201
- source_dir = source_dir ,
202
203
)
203
204
204
205
mx .fit (inputs = "s3://mybucket/train" , job_name = "new_name" )
@@ -210,7 +211,7 @@ def test_create_model(name_from_base, sagemaker_session, mxnet_version, mxnet_py
210
211
assert model .sagemaker_session == sagemaker_session
211
212
assert model .framework_version == mxnet_version
212
213
assert model .py_version == mxnet_py_version
213
- assert model .entry_point == SCRIPT_PATH
214
+ assert model .entry_point == SCRIPT_NAME
214
215
assert model .role == ROLE
215
216
assert model .name == model_name
216
217
assert model .container_log_level == container_log_level
@@ -226,7 +227,8 @@ def test_create_model_with_optional_params(sagemaker_session, mxnet_version, mxn
226
227
source_dir = "s3://mybucket/source"
227
228
enable_cloudwatch_metrics = "true"
228
229
mx = MXNet (
229
- entry_point = SCRIPT_PATH ,
230
+ entry_point = SCRIPT_NAME ,
231
+ source_dir = source_dir ,
230
232
framework_version = mxnet_version ,
231
233
py_version = mxnet_py_version ,
232
234
role = ROLE ,
@@ -235,7 +237,6 @@ def test_create_model_with_optional_params(sagemaker_session, mxnet_version, mxn
235
237
instance_type = INSTANCE_TYPE ,
236
238
container_log_level = container_log_level ,
237
239
base_job_name = "job" ,
238
- source_dir = source_dir ,
239
240
enable_cloudwatch_metrics = enable_cloudwatch_metrics ,
240
241
)
241
242
@@ -270,7 +271,8 @@ def test_create_model_with_custom_image(name_from_base, sagemaker_session):
270
271
base_job_name = "job"
271
272
272
273
mx = MXNet (
273
- entry_point = SCRIPT_PATH ,
274
+ entry_point = SCRIPT_NAME ,
275
+ source_dir = source_dir ,
274
276
framework_version = "2.0" ,
275
277
py_version = "py3" ,
276
278
role = ROLE ,
@@ -280,7 +282,6 @@ def test_create_model_with_custom_image(name_from_base, sagemaker_session):
280
282
image_uri = custom_image ,
281
283
container_log_level = container_log_level ,
282
284
base_job_name = base_job_name ,
283
- source_dir = source_dir ,
284
285
)
285
286
286
287
mx .fit (inputs = "s3://mybucket/train" , job_name = "new_name" )
@@ -291,7 +292,7 @@ def test_create_model_with_custom_image(name_from_base, sagemaker_session):
291
292
292
293
assert model .sagemaker_session == sagemaker_session
293
294
assert model .image_uri == custom_image
294
- assert model .entry_point == SCRIPT_PATH
295
+ assert model .entry_point == SCRIPT_NAME
295
296
assert model .role == ROLE
296
297
assert model .name == model_name
297
298
assert model .container_log_level == container_log_level
@@ -730,7 +731,6 @@ def test_model_py2_warning(warning, sagemaker_session):
730
731
731
732
def test_create_model_with_custom_hosting_image (sagemaker_session ):
732
733
container_log_level = '"logging.INFO"'
733
- source_dir = "s3://mybucket/source"
734
734
custom_image = "mxnet:2.0"
735
735
custom_hosting_image = "mxnet_hosting:2.0"
736
736
mx = MXNet (
@@ -744,7 +744,6 @@ def test_create_model_with_custom_hosting_image(sagemaker_session):
744
744
image_uri = custom_image ,
745
745
container_log_level = container_log_level ,
746
746
base_job_name = "job" ,
747
- source_dir = source_dir ,
748
747
)
749
748
750
749
mx .fit (inputs = "s3://mybucket/train" , job_name = "new_name" )
0 commit comments