|
23 | 23 |
|
24 | 24 | from sagemaker import Processor, Model
|
25 | 25 | from sagemaker.parameter import IntegerParameter
|
| 26 | +from sagemaker.processing import ProcessingOutput |
26 | 27 | from sagemaker.transformer import Transformer
|
27 | 28 | from sagemaker.tuner import HyperparameterTuner
|
28 | 29 | from sagemaker.workflow.pipeline_context import _PipelineConfig
|
29 | 30 | from sagemaker.workflow.parameters import ParameterString, ParameterBoolean
|
| 31 | +from sagemaker.workflow.properties import PropertyFile |
30 | 32 |
|
31 |
| -from sagemaker.workflow.steps import TrainingStep |
| 33 | +from sagemaker.workflow.steps import ProcessingStep, TrainingStep |
32 | 34 | from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
|
33 | 35 | from sagemaker.workflow.pipeline_definition_config import PipelineDefinitionConfig
|
34 | 36 | from sagemaker.workflow.utilities import hash_files_or_dirs
|
35 |
| -from sagemaker.workflow.functions import Join |
| 37 | +from sagemaker.workflow.functions import Join, JsonGet |
36 | 38 |
|
37 | 39 | from sagemaker.estimator import Estimator
|
38 | 40 | from sagemaker.sklearn.estimator import SKLearn
|
@@ -871,3 +873,55 @@ def test_training_step_with_estimator_using_custom_prefixes(
|
871 | 873 | "Type": "Training",
|
872 | 874 | "Arguments": step_args,
|
873 | 875 | }
|
| 876 | + |
| 877 | + |
| 878 | +def test_training_step_with_jsonget_instance_type(pipeline_session): |
| 879 | + property_file = PropertyFile( |
| 880 | + name="my-property-file", output_name="TestOutputName", path="processing_output.json" |
| 881 | + ) |
| 882 | + processor = Processor( |
| 883 | + image_uri=IMAGE_URI, |
| 884 | + role=ROLE, |
| 885 | + instance_count=1, |
| 886 | + instance_type="c4.4xlarge", |
| 887 | + sagemaker_session=pipeline_session, |
| 888 | + ) |
| 889 | + process_arg = processor.run(outputs=[ProcessingOutput(output_name="TestOutputName")]) |
| 890 | + processing_step = ProcessingStep( |
| 891 | + name="inputProcessingStep", |
| 892 | + step_args=process_arg, |
| 893 | + property_files=[property_file], |
| 894 | + ) |
| 895 | + |
| 896 | + json_get_function = JsonGet( |
| 897 | + step_name=processing_step.name, property_file=property_file.name, json_path="mse" |
| 898 | + ) |
| 899 | + |
| 900 | + estimator = Estimator( |
| 901 | + image_uri=IMAGE_URI, |
| 902 | + role=ROLE, |
| 903 | + instance_count=1, |
| 904 | + instance_type=json_get_function, |
| 905 | + sagemaker_session=pipeline_session, |
| 906 | + ) |
| 907 | + |
| 908 | + training_step = TrainingStep( |
| 909 | + name="MyTrainingStep", |
| 910 | + step_args=estimator.fit(inputs=TrainingInput(s3_data=f"s3://{BUCKET}/train_manifest")), |
| 911 | + ) |
| 912 | + pipeline = Pipeline( |
| 913 | + name="MyPipeline", |
| 914 | + steps=[processing_step, training_step], |
| 915 | + sagemaker_session=pipeline_session, |
| 916 | + ) |
| 917 | + |
| 918 | + steps = json.loads(pipeline.definition())["Steps"] |
| 919 | + for step in steps: |
| 920 | + if step["Type"] == "Processing": |
| 921 | + continue |
| 922 | + assert step["Arguments"]["ResourceConfig"]["InstanceType"] == { |
| 923 | + "Std:JsonGet": { |
| 924 | + "Path": "mse", |
| 925 | + "PropertyFile": {"Get": "Steps.inputProcessingStep.PropertyFiles.my-property-file"}, |
| 926 | + } |
| 927 | + } |
0 commit comments