|
20 | 20 | import pytest
|
21 | 21 |
|
22 | 22 | from sagemaker.drift_check_baselines import DriftCheckBaselines
|
| 23 | +from sagemaker.workflow.condition_step import ConditionStep |
| 24 | +from sagemaker.workflow.conditions import ConditionEquals |
23 | 25 | from sagemaker.workflow.model_step import (
|
24 | 26 | ModelStep,
|
25 | 27 | _CREATE_MODEL_NAME_BASE,
|
@@ -360,6 +362,154 @@ def test_step_collection_is_depended_on(pipeline_session, sagemaker_session):
|
360 | 362 | )
|
361 | 363 |
|
362 | 364 |
|
| 365 | +def test_step_collection_in_condition_branch_is_depended_on(pipeline_session, sagemaker_session): |
| 366 | + custom_step1 = CustomStep(name="MyStep1") |
| 367 | + |
| 368 | + # Define a step collection which will be inserted into the ConditionStep |
| 369 | + model_name = "MyModel" |
| 370 | + model = Model( |
| 371 | + name=model_name, |
| 372 | + image_uri=IMAGE_URI, |
| 373 | + model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"), |
| 374 | + sagemaker_session=pipeline_session, |
| 375 | + entry_point=f"{DATA_DIR}/dummy_script.py", |
| 376 | + source_dir=f"{DATA_DIR}", |
| 377 | + role=ROLE, |
| 378 | + ) |
| 379 | + step_args = model.create( |
| 380 | + instance_type="c4.4xlarge", |
| 381 | + accelerator_type="ml.eia1.medium", |
| 382 | + ) |
| 383 | + model_step_name = "MyModelStep" |
| 384 | + model_step = ModelStep( |
| 385 | + name=model_step_name, |
| 386 | + step_args=step_args, |
| 387 | + ) |
| 388 | + |
| 389 | + # Define another step collection which will be inserted into the ConditionStep |
| 390 | + # This StepCollection object depends on a StepCollection object in the ConditionStep |
| 391 | + # And a normal step outside ConditionStep |
| 392 | + model.sagemaker_session = sagemaker_session |
| 393 | + register_model_name = "RegisterModelStep" |
| 394 | + register_model = RegisterModel( |
| 395 | + name=register_model_name, |
| 396 | + model=model, |
| 397 | + model_data="s3://", |
| 398 | + content_types=["content_type"], |
| 399 | + response_types=["response_type"], |
| 400 | + inference_instances=["inference_instance"], |
| 401 | + transform_instances=["transform_instance"], |
| 402 | + model_package_group_name="mpg", |
| 403 | + depends_on=["MyStep1", model_step], |
| 404 | + ) |
| 405 | + |
| 406 | + # StepCollection objects are depended on by a normal step in the ConditionStep |
| 407 | + custom_step2 = CustomStep( |
| 408 | + name="MyStep2", depends_on=["MyStep1", model_step, register_model_name] |
| 409 | + ) |
| 410 | + # StepCollection objects are depended on by a normal step outside the ConditionStep |
| 411 | + custom_step3 = CustomStep( |
| 412 | + name="MyStep3", depends_on=[custom_step1, model_step_name, register_model] |
| 413 | + ) |
| 414 | + |
| 415 | + cond_step = ConditionStep( |
| 416 | + name="CondStep", |
| 417 | + conditions=[ConditionEquals(left=2, right=1)], |
| 418 | + if_steps=[], |
| 419 | + else_steps=[model_step, register_model, custom_step2], |
| 420 | + ) |
| 421 | + |
| 422 | + pipeline = Pipeline( |
| 423 | + name="MyPipeline", |
| 424 | + steps=[cond_step, custom_step1, custom_step3], |
| 425 | + ) |
| 426 | + step_list = json.loads(pipeline.definition())["Steps"] |
| 427 | + assert len(step_list) == 3 |
| 428 | + for step in step_list: |
| 429 | + if step["Name"] == "MyStep1": |
| 430 | + assert "DependsOn" not in step |
| 431 | + elif step["Name"] == "CondStep": |
| 432 | + assert not step["Arguments"]["IfSteps"] |
| 433 | + for sub_step in step["Arguments"]["ElseSteps"]: |
| 434 | + if sub_step["Name"] == f"{model_name}-RepackModel": |
| 435 | + assert set(sub_step["DependsOn"]) == { |
| 436 | + "MyStep1", |
| 437 | + f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}", |
| 438 | + f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}", |
| 439 | + } |
| 440 | + if sub_step["Name"] == "MyStep2": |
| 441 | + assert set(sub_step["DependsOn"]) == { |
| 442 | + "MyStep1", |
| 443 | + f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}", |
| 444 | + f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}", |
| 445 | + f"{model_name}-RepackModel", |
| 446 | + f"{register_model_name}-RegisterModel", |
| 447 | + } |
| 448 | + else: |
| 449 | + assert set(step["DependsOn"]) == { |
| 450 | + "MyStep1", |
| 451 | + f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}", |
| 452 | + f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}", |
| 453 | + f"{model_name}-RepackModel", |
| 454 | + f"{register_model_name}-RegisterModel", |
| 455 | + } |
| 456 | + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list |
| 457 | + assert ordered(adjacency_list) == ordered( |
| 458 | + { |
| 459 | + "CondStep": ["MyModel-RepackModel", "MyModelStep-RepackModel-MyModel", "MyStep2"], |
| 460 | + "MyStep1": ["MyStep2", "MyStep3", "MyModel-RepackModel"], |
| 461 | + "MyStep2": [], |
| 462 | + "MyStep3": [], |
| 463 | + "MyModelStep-RepackModel-MyModel": ["MyModelStep-CreateModel"], |
| 464 | + "MyModelStep-CreateModel": ["MyStep2", "MyStep3", "MyModel-RepackModel"], |
| 465 | + "MyModel-RepackModel": [], |
| 466 | + "RegisterModelStep-RegisterModel": ["MyStep2", "MyStep3"], |
| 467 | + } |
| 468 | + ) |
| 469 | + |
| 470 | + |
| 471 | +def test_condition_step_depends_on_step_collection(): |
| 472 | + step1 = CustomStep(name="MyStep1") |
| 473 | + step2 = CustomStep(name="MyStep2", input_data=step1.properties) |
| 474 | + step_collection = StepCollection(name="MyStepCollection", steps=[step1, step2]) |
| 475 | + cond_step = ConditionStep( |
| 476 | + name="MyConditionStep", |
| 477 | + depends_on=[step_collection], |
| 478 | + conditions=[ConditionEquals(left=2, right=1)], |
| 479 | + if_steps=[], |
| 480 | + else_steps=[], |
| 481 | + ) |
| 482 | + pipeline = Pipeline( |
| 483 | + name="MyPipeline", |
| 484 | + steps=[step_collection, cond_step], |
| 485 | + ) |
| 486 | + step_list = json.loads(pipeline.definition())["Steps"] |
| 487 | + assert len(step_list) == 3 |
| 488 | + for step in step_list: |
| 489 | + if step["Name"] != "MyConditionStep": |
| 490 | + continue |
| 491 | + assert step == { |
| 492 | + "Name": "MyConditionStep", |
| 493 | + "Type": "Condition", |
| 494 | + "DependsOn": ["MyStep1", "MyStep2"], |
| 495 | + "Arguments": { |
| 496 | + "Conditions": [ |
| 497 | + { |
| 498 | + "Type": "Equals", |
| 499 | + "LeftValue": 2, |
| 500 | + "RightValue": 1, |
| 501 | + }, |
| 502 | + ], |
| 503 | + "IfSteps": [], |
| 504 | + "ElseSteps": [], |
| 505 | + }, |
| 506 | + } |
| 507 | + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list |
| 508 | + assert ordered(adjacency_list) == ordered( |
| 509 | + [("MyConditionStep", []), ("MyStep1", ["MyStep2"]), ("MyStep2", ["MyConditionStep"])] |
| 510 | + ) |
| 511 | + |
| 512 | + |
363 | 513 | def test_register_model(estimator, model_metrics, drift_check_baselines):
|
364 | 514 | model_data = f"s3://{BUCKET}/model.tar.gz"
|
365 | 515 | register_model = RegisterModel(
|
|
0 commit comments