Skip to content

Commit 20df3d7

Browse files
staubhpPayton Staub
and
Payton Staub
authored
fix: Remove sagemaker_job_name from hyperparameters in TrainingStep (#2950)
Co-authored-by: Payton Staub <[email protected]>
1 parent 28fd737 commit 20df3d7

File tree

3 files changed

+8
-17
lines changed

3 files changed

+8
-17
lines changed

src/sagemaker/workflow/steps.py

+2
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ def arguments(self) -> RequestType:
301301
)
302302
request_dict = self.estimator.sagemaker_session._get_train_request(**train_args)
303303
request_dict.pop("TrainingJobName")
304+
if "HyperParameters" in request_dict:
305+
request_dict["HyperParameters"].pop("sagemaker_job_name", None)
304306

305307
return request_dict
306308

tests/unit/sagemaker/workflow/test_step_collections.py

+6-16
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,8 @@ def test_register_model_with_model_repack_with_estimator(
457457
assert len(request_dict["DependsOn"]) == 1
458458
assert request_dict["DependsOn"][0] == "TestStep"
459459
arguments = request_dict["Arguments"]
460-
repacker_job_name = arguments["HyperParameters"]["sagemaker_job_name"]
460+
assert BUCKET in arguments["HyperParameters"]["sagemaker_submit_directory"]
461+
arguments["HyperParameters"].pop("sagemaker_submit_directory")
461462
assert ordered(arguments) == ordered(
462463
{
463464
"AlgorithmSpecification": {
@@ -472,12 +473,8 @@ def test_register_model_with_model_repack_with_estimator(
472473
"inference_script": '"dummy_script.py"',
473474
"dependencies": f'"{dummy_requirements}"',
474475
"model_archive": '"model.tar.gz"',
475-
"sagemaker_submit_directory": '"s3://{}/{}/source/sourcedir.tar.gz"'.format(
476-
BUCKET, repacker_job_name.replace('"', "")
477-
),
478476
"sagemaker_program": '"_repack_model.py"',
479477
"sagemaker_container_log_level": "20",
480-
"sagemaker_job_name": repacker_job_name,
481478
"sagemaker_region": f'"{REGION}"',
482479
"source_dir": "null",
483480
},
@@ -585,7 +582,8 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift
585582
assert len(request_dict["DependsOn"]) == 1
586583
assert request_dict["DependsOn"][0] == "TestStep"
587584
arguments = request_dict["Arguments"]
588-
repacker_job_name = arguments["HyperParameters"]["sagemaker_job_name"]
585+
assert BUCKET in arguments["HyperParameters"]["sagemaker_submit_directory"]
586+
arguments["HyperParameters"].pop("sagemaker_submit_directory")
589587
assert ordered(arguments) == ordered(
590588
{
591589
"AlgorithmSpecification": {
@@ -599,12 +597,8 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift
599597
"HyperParameters": {
600598
"inference_script": '"dummy_script.py"',
601599
"model_archive": '"model.tar.gz"',
602-
"sagemaker_submit_directory": '"s3://{}/{}/source/sourcedir.tar.gz"'.format(
603-
BUCKET, repacker_job_name.replace('"', "")
604-
),
605600
"sagemaker_program": '"_repack_model.py"',
606601
"sagemaker_container_log_level": "20",
607-
"sagemaker_job_name": repacker_job_name,
608602
"sagemaker_region": f'"{REGION}"',
609603
"dependencies": "null",
610604
"source_dir": "null",
@@ -717,7 +711,8 @@ def test_register_model_with_model_repack_with_pipeline_model(
717711
assert len(request_dict["DependsOn"]) == 1
718712
assert request_dict["DependsOn"][0] == "TestStep"
719713
arguments = request_dict["Arguments"]
720-
repacker_job_name = arguments["HyperParameters"]["sagemaker_job_name"]
714+
assert BUCKET in arguments["HyperParameters"]["sagemaker_submit_directory"]
715+
arguments["HyperParameters"].pop("sagemaker_submit_directory")
721716
assert ordered(arguments) == ordered(
722717
{
723718
"AlgorithmSpecification": {
@@ -732,12 +727,8 @@ def test_register_model_with_model_repack_with_pipeline_model(
732727
"dependencies": "null",
733728
"inference_script": '"dummy_script.py"',
734729
"model_archive": '"model.tar.gz"',
735-
"sagemaker_submit_directory": '"s3://{}/{}/source/sourcedir.tar.gz"'.format(
736-
BUCKET, repacker_job_name.replace('"', "")
737-
),
738730
"sagemaker_program": '"_repack_model.py"',
739731
"sagemaker_container_log_level": "20",
740-
"sagemaker_job_name": repacker_job_name,
741732
"sagemaker_region": f'"{REGION}"',
742733
"source_dir": "null",
743734
},
@@ -917,7 +908,6 @@ def test_estimator_transformer_with_model_repack_with_estimator(estimator):
917908
arguments = request_dict["Arguments"]
918909
# pop out the dynamic generated fields
919910
arguments["HyperParameters"].pop("sagemaker_submit_directory")
920-
arguments["HyperParameters"].pop("sagemaker_job_name")
921911
assert arguments == {
922912
"AlgorithmSpecification": {
923913
"TrainingInputMode": "File",

tests/unit/sagemaker/workflow/test_steps.py

-1
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,6 @@ def test_training_step_tensorflow(sagemaker_session):
399399
name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config
400400
)
401401
step_request = step.to_request()
402-
step_request["Arguments"]["HyperParameters"].pop("sagemaker_job_name", None)
403402
step_request["Arguments"]["HyperParameters"].pop("sagemaker_program", None)
404403
step_request["Arguments"].pop("ProfilerRuleConfigurations", None)
405404
assert step_request == {

0 commit comments

Comments
 (0)