@@ -317,6 +317,9 @@ def test_repack_model_without_source_dir(tmpdir):
317
317
script_path = os .path .join (source_dir , 'inference.py' )
318
318
write_file (script_path , 'inference script' )
319
319
320
+ script_path = os .path .join (source_dir , 'this-file-should-not-be-included.py' )
321
+ write_file (script_path , 'This file should not be included' )
322
+
320
323
contents = [model_path ]
321
324
322
325
sagemaker_session = MagicMock ()
@@ -334,6 +337,44 @@ def test_repack_model_without_source_dir(tmpdir):
334
337
assert re .match (r'^s3://fake/model-\d+-\d+.tar.gz$' , new_model_uri )
335
338
336
339
340
+ def test_repack_model_with_entry_point_without_path_without_source_dir (tmpdir ):
341
+
342
+ tmp = str (tmpdir )
343
+
344
+ model_path = os .path .join (tmp , 'model' )
345
+ write_file (model_path , 'model data' )
346
+
347
+ source_dir = os .path .join (tmp , 'source-dir' )
348
+ os .mkdir (source_dir )
349
+ script_path = os .path .join (source_dir , 'inference.py' )
350
+ write_file (script_path , 'inference script' )
351
+
352
+ script_path = os .path .join (source_dir , 'this-file-should-not-be-included.py' )
353
+ write_file (script_path , 'This file should not be included' )
354
+
355
+ contents = [model_path ]
356
+
357
+ sagemaker_session = MagicMock ()
358
+ mock_s3_model_tar (contents , sagemaker_session , tmp )
359
+ fake_upload_path = mock_s3_upload (sagemaker_session , tmp )
360
+
361
+ model_uri = 's3://fake/location'
362
+
363
+ cwd = os .getcwd ()
364
+ try :
365
+ os .chdir (source_dir )
366
+
367
+ new_model_uri = sagemaker .utils .repack_model ('inference.py' ,
368
+ None ,
369
+ model_uri ,
370
+ sagemaker_session )
371
+ finally :
372
+ os .chdir (cwd )
373
+
374
+ assert list_tar_files (fake_upload_path , tmpdir ) == {'/code/inference.py' , '/model' }
375
+ assert re .match (r'^s3://fake/model-\d+-\d+.tar.gz$' , new_model_uri )
376
+
377
+
337
378
def test_repack_model_from_s3_saved_model_to_s3 (tmpdir ):
338
379
339
380
tmp = str (tmpdir )
@@ -346,6 +387,9 @@ def test_repack_model_from_s3_saved_model_to_s3(tmpdir):
346
387
script_path = os .path .join (source_dir , 'inference.py' )
347
388
write_file (script_path , 'inference script' )
348
389
390
+ script_path = os .path .join (source_dir , 'this-file-should-be-included.py' )
391
+ write_file (script_path , 'This file should be included' )
392
+
349
393
contents = [model_path ]
350
394
351
395
sagemaker_session = MagicMock ()
@@ -359,7 +403,9 @@ def test_repack_model_from_s3_saved_model_to_s3(tmpdir):
359
403
model_uri ,
360
404
sagemaker_session )
361
405
362
- assert list_tar_files (fake_upload_path , tmpdir ) == {'/code/inference.py' , '/model' }
406
+ assert list_tar_files (fake_upload_path , tmpdir ) == {'/code/this-file-should-be-included.py' ,
407
+ '/code/inference.py' ,
408
+ '/model' }
363
409
assert re .match (r'^s3://fake/model-\d+-\d+.tar.gz$' , new_model_uri )
364
410
365
411
0 commit comments