diff --git a/src/sagemaker/workflow/conditions.py b/src/sagemaker/workflow/conditions.py index 40d38e7339..3b03fdf5fc 100644 --- a/src/sagemaker/workflow/conditions.py +++ b/src/sagemaker/workflow/conditions.py @@ -20,13 +20,15 @@ import abc from enum import Enum -from typing import List, Union +from typing import Dict, List, Union import attr +from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import ( DefaultEnumMeta, Entity, + Expression, PrimitiveType, RequestType, ) @@ -289,3 +291,18 @@ def _referenced_steps(self) -> List[str]: for condition in self.conditions: steps.extend(condition._referenced_steps) return steps + + +def primitive_or_expr( + value: Union[ExecutionVariable, Expression, PrimitiveType, Parameter, Properties] +) -> Union[Dict[str, str], PrimitiveType]: + """Provide the expression of the value or return value if it is a primitive. + + Args: + value (Union[ConditionValueType, PrimitiveType]): The value to evaluate. + Returns: + Either the expression of the value or the primitive value. + """ + if is_pipeline_variable(value): + return value.expr + return value