Skip to content

Commit b246739

Browse files
authored
Merge branch 'master' into fix/js-cache-s3-client
2 parents b2d281c + 63f39e1 commit b246739

File tree

2 files changed

+64
-6
lines changed

2 files changed

+64
-6
lines changed

src/sagemaker/estimator.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@
9999
)
100100
from sagemaker.workflow import is_pipeline_variable
101101
from sagemaker.workflow.entities import PipelineVariable
102-
from sagemaker.workflow.parameters import ParameterString
103102
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
104103

105104
logger = logging.getLogger(__name__)
@@ -614,16 +613,21 @@ def __init__(
614613
self.output_kms_key = resolve_value_from_config(
615614
output_kms_key, TRAINING_JOB_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
616615
)
616+
use_volume_kms_config: bool = False
617617
if instance_type is None or isinstance(instance_type, str):
618618
instance_type_for_volume_kms = instance_type
619-
elif isinstance(instance_type, ParameterString):
620-
instance_type_for_volume_kms = instance_type.default_value
619+
elif isinstance(instance_type, PipelineVariable):
620+
use_volume_kms_config = True
621+
instance_type_for_volume_kms = instance_type
621622
else:
622623
raise ValueError(f"Bad value for instance type: '{instance_type}'")
623624

624625
# KMS can only be attached to supported instances
625626
use_volume_kms_config = (
626-
(instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms))
627+
use_volume_kms_config
628+
or (
629+
instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms)
630+
)
627631
or instance_groups is not None
628632
and any(
629633
[

tests/unit/sagemaker/workflow/test_training_step.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,18 @@
2323

2424
from sagemaker import Processor, Model
2525
from sagemaker.parameter import IntegerParameter
26+
from sagemaker.processing import ProcessingOutput
2627
from sagemaker.transformer import Transformer
2728
from sagemaker.tuner import HyperparameterTuner
2829
from sagemaker.workflow.pipeline_context import _PipelineConfig
2930
from sagemaker.workflow.parameters import ParameterString, ParameterBoolean
31+
from sagemaker.workflow.properties import PropertyFile
3032

31-
from sagemaker.workflow.steps import TrainingStep
33+
from sagemaker.workflow.steps import ProcessingStep, TrainingStep
3234
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
3335
from sagemaker.workflow.pipeline_definition_config import PipelineDefinitionConfig
3436
from sagemaker.workflow.utilities import hash_files_or_dirs
35-
from sagemaker.workflow.functions import Join
37+
from sagemaker.workflow.functions import Join, JsonGet
3638

3739
from sagemaker.estimator import Estimator
3840
from sagemaker.sklearn.estimator import SKLearn
@@ -871,3 +873,55 @@ def test_training_step_with_estimator_using_custom_prefixes(
871873
"Type": "Training",
872874
"Arguments": step_args,
873875
}
876+
877+
878+
def test_training_step_with_jsonget_instance_type(pipeline_session):
879+
property_file = PropertyFile(
880+
name="my-property-file", output_name="TestOutputName", path="processing_output.json"
881+
)
882+
processor = Processor(
883+
image_uri=IMAGE_URI,
884+
role=ROLE,
885+
instance_count=1,
886+
instance_type="c4.4xlarge",
887+
sagemaker_session=pipeline_session,
888+
)
889+
process_arg = processor.run(outputs=[ProcessingOutput(output_name="TestOutputName")])
890+
processing_step = ProcessingStep(
891+
name="inputProcessingStep",
892+
step_args=process_arg,
893+
property_files=[property_file],
894+
)
895+
896+
json_get_function = JsonGet(
897+
step_name=processing_step.name, property_file=property_file.name, json_path="mse"
898+
)
899+
900+
estimator = Estimator(
901+
image_uri=IMAGE_URI,
902+
role=ROLE,
903+
instance_count=1,
904+
instance_type=json_get_function,
905+
sagemaker_session=pipeline_session,
906+
)
907+
908+
training_step = TrainingStep(
909+
name="MyTrainingStep",
910+
step_args=estimator.fit(inputs=TrainingInput(s3_data=f"s3://{BUCKET}/train_manifest")),
911+
)
912+
pipeline = Pipeline(
913+
name="MyPipeline",
914+
steps=[processing_step, training_step],
915+
sagemaker_session=pipeline_session,
916+
)
917+
918+
steps = json.loads(pipeline.definition())["Steps"]
919+
for step in steps:
920+
if step["Type"] == "Processing":
921+
continue
922+
assert step["Arguments"]["ResourceConfig"]["InstanceType"] == {
923+
"Std:JsonGet": {
924+
"Path": "mse",
925+
"PropertyFile": {"Get": "Steps.inputProcessingStep.PropertyFiles.my-property-file"},
926+
}
927+
}

0 commit comments

Comments
 (0)