diff --git a/src/sagemaker/workflow/callback_step.py b/src/sagemaker/workflow/callback_step.py index 2c0f0f6355..3517b62c1f 100644 --- a/src/sagemaker/workflow/callback_step.py +++ b/src/sagemaker/workflow/callback_step.py @@ -49,7 +49,7 @@ class CallbackOutput: """ output_name: str = attr.ib(default=None) - output_type: CallbackOutputTypeEnum = attr.ib(default=CallbackOutputTypeEnum.String.value) + output_type: CallbackOutputTypeEnum = attr.ib(default=CallbackOutputTypeEnum.String) def to_request(self) -> RequestType: """Get the request structure for workflow service calls.""" diff --git a/tests/unit/sagemaker/workflow/test_callback_step.py b/tests/unit/sagemaker/workflow/test_callback_step.py index e8161ac16b..d02b1cf7bf 100644 --- a/tests/unit/sagemaker/workflow/test_callback_step.py +++ b/tests/unit/sagemaker/workflow/test_callback_step.py @@ -53,6 +53,29 @@ def test_callback_step(): } +def test_callback_step_default_values(): + param = ParameterInteger(name="MyInt") + outputParam1 = CallbackOutput(output_name="output1") + cb_step = CallbackStep( + name="MyCallbackStep", + depends_on=["TestStep"], + sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue", + inputs={"arg1": "foo", "arg2": 5, "arg3": param}, + outputs=[outputParam1], + ) + cb_step.add_depends_on(["SecondTestStep"]) + assert cb_step.to_request() == { + "Name": "MyCallbackStep", + "Type": "Callback", + "DependsOn": ["TestStep", "SecondTestStep"], + "SqsQueueUrl": "https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue", + "OutputParameters": [ + {"OutputName": "output1", "OutputType": "String"}, + ], + "Arguments": {"arg1": "foo", "arg2": 5, "arg3": param}, + } + + def test_callback_step_output_expr(): param = ParameterInteger(name="MyInt") outputParam1 = CallbackOutput(output_name="output1", output_type=CallbackOutputTypeEnum.String)