|
32 | 32 | from sagemaker.tensorflow import TensorFlowModel
|
33 | 33 | from sagemaker.transformer import Transformer
|
34 | 34 | from sagemaker.tuner import HyperparameterTuner
|
| 35 | +from sagemaker.workflow._utils import REPACK_SCRIPT_LAUNCHER |
35 | 36 | from sagemaker.workflow.condition_step import ConditionStep
|
36 | 37 | from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo
|
37 | 38 | from sagemaker.workflow.model_step import (
|
@@ -188,7 +189,9 @@ def test_register_model_with_runtime_repack(pipeline_session, model_data_param,
|
188 | 189 | }
|
189 | 190 | assert arguments["HyperParameters"]["inference_script"] == '"dummy_script.py"'
|
190 | 191 | assert arguments["HyperParameters"]["model_archive"] == {"Get": "Parameters.ModelData"}
|
191 |
| - assert arguments["HyperParameters"]["sagemaker_program"] == '"_repack_model.py"' |
| 192 | + assert ( |
| 193 | + arguments["HyperParameters"]["sagemaker_program"] == f'"{REPACK_SCRIPT_LAUNCHER}"' |
| 194 | + ) |
192 | 195 | assert "s3://" in arguments["HyperParameters"]["sagemaker_submit_directory"]
|
193 | 196 | assert arguments["HyperParameters"]["dependencies"] == "null"
|
194 | 197 | assert step["RetryPolicies"] == [
|
@@ -269,7 +272,9 @@ def test_create_model_with_runtime_repack(pipeline_session, model_data_param, mo
|
269 | 272 | }
|
270 | 273 | assert arguments["HyperParameters"]["inference_script"] == '"dummy_script.py"'
|
271 | 274 | assert arguments["HyperParameters"]["model_archive"] == {"Get": "Parameters.ModelData"}
|
272 |
| - assert arguments["HyperParameters"]["sagemaker_program"] == '"_repack_model.py"' |
| 275 | + assert ( |
| 276 | + arguments["HyperParameters"]["sagemaker_program"] == f'"{REPACK_SCRIPT_LAUNCHER}"' |
| 277 | + ) |
273 | 278 | assert "s3://" in arguments["HyperParameters"]["sagemaker_submit_directory"]
|
274 | 279 | assert arguments["HyperParameters"]["dependencies"] == "null"
|
275 | 280 | assert "repack a model with customer scripts" in step["Description"]
|
@@ -360,7 +365,9 @@ def test_create_pipeline_model_with_runtime_repack(pipeline_session, model_data_
|
360 | 365 | }
|
361 | 366 | assert arguments["HyperParameters"]["inference_script"] == '"dummy_script.py"'
|
362 | 367 | assert arguments["HyperParameters"]["model_archive"] == {"Get": "Parameters.ModelData"}
|
363 |
| - assert arguments["HyperParameters"]["sagemaker_program"] == '"_repack_model.py"' |
| 368 | + assert ( |
| 369 | + arguments["HyperParameters"]["sagemaker_program"] == f'"{REPACK_SCRIPT_LAUNCHER}"' |
| 370 | + ) |
364 | 371 | assert "s3://" in arguments["HyperParameters"]["sagemaker_submit_directory"]
|
365 | 372 | assert arguments["HyperParameters"]["dependencies"] == "null"
|
366 | 373 | assert step["RetryPolicies"] == [
|
@@ -460,7 +467,9 @@ def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_dat
|
460 | 467 | }
|
461 | 468 | assert arguments["HyperParameters"]["inference_script"] == '"dummy_script.py"'
|
462 | 469 | assert arguments["HyperParameters"]["model_archive"] == {"Get": "Parameters.ModelData"}
|
463 |
| - assert arguments["HyperParameters"]["sagemaker_program"] == '"_repack_model.py"' |
| 470 | + assert ( |
| 471 | + arguments["HyperParameters"]["sagemaker_program"] == f'"{REPACK_SCRIPT_LAUNCHER}"' |
| 472 | + ) |
464 | 473 | assert "s3://" in arguments["HyperParameters"]["sagemaker_submit_directory"]
|
465 | 474 | assert arguments["HyperParameters"]["dependencies"] == "null"
|
466 | 475 | elif step["Type"] == "RegisterModel":
|
@@ -641,7 +650,9 @@ def test_conditional_model_create_and_regis(
|
641 | 650 | }
|
642 | 651 | assert arguments["HyperParameters"]["inference_script"] == '"dummy_script.py"'
|
643 | 652 | assert arguments["HyperParameters"]["model_archive"] == {"Get": "Parameters.ModelData"}
|
644 |
| - assert arguments["HyperParameters"]["sagemaker_program"] == '"_repack_model.py"' |
| 653 | + assert ( |
| 654 | + arguments["HyperParameters"]["sagemaker_program"] == f'"{REPACK_SCRIPT_LAUNCHER}"' |
| 655 | + ) |
645 | 656 | assert "s3://" in arguments["HyperParameters"]["sagemaker_submit_directory"]
|
646 | 657 | assert arguments["HyperParameters"]["dependencies"] == "null"
|
647 | 658 | elif step["Type"] == "RegisterModel":
|
|
0 commit comments