Skip to content

fix: pipeline variable kms key #4065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@
)
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -614,16 +613,21 @@ def __init__(
self.output_kms_key = resolve_value_from_config(
output_kms_key, TRAINING_JOB_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
)
use_volume_kms_config: bool = False
if instance_type is None or isinstance(instance_type, str):
instance_type_for_volume_kms = instance_type
elif isinstance(instance_type, ParameterString):
instance_type_for_volume_kms = instance_type.default_value
elif isinstance(instance_type, PipelineVariable):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update the merge checklist

use_volume_kms_config = True
instance_type_for_volume_kms = instance_type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a warning here? See previous comment #4065 (comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline, we will not include a comment to adhere to existing code style.

else:
raise ValueError(f"Bad value for instance type: '{instance_type}'")

# KMS can only be attached to supported instances
use_volume_kms_config = (
(instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms))
use_volume_kms_config
or (
instance_type_for_volume_kms and instance_supports_kms(instance_type_for_volume_kms)
)
or instance_groups is not None
and any(
[
Expand Down
58 changes: 56 additions & 2 deletions tests/unit/sagemaker/workflow/test_training_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@

from sagemaker import Processor, Model
from sagemaker.parameter import IntegerParameter
from sagemaker.processing import ProcessingOutput
from sagemaker.transformer import Transformer
from sagemaker.tuner import HyperparameterTuner
from sagemaker.workflow.pipeline_context import _PipelineConfig
from sagemaker.workflow.parameters import ParameterString, ParameterBoolean
from sagemaker.workflow.properties import PropertyFile

from sagemaker.workflow.steps import TrainingStep
from sagemaker.workflow.steps import ProcessingStep, TrainingStep
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
from sagemaker.workflow.pipeline_definition_config import PipelineDefinitionConfig
from sagemaker.workflow.utilities import hash_files_or_dirs
from sagemaker.workflow.functions import Join
from sagemaker.workflow.functions import Join, JsonGet

from sagemaker.estimator import Estimator
from sagemaker.sklearn.estimator import SKLearn
Expand Down Expand Up @@ -871,3 +873,55 @@ def test_training_step_with_estimator_using_custom_prefixes(
"Type": "Training",
"Arguments": step_args,
}


def test_training_step_with_jsonget_instance_type(pipeline_session):
property_file = PropertyFile(
name="my-property-file", output_name="TestOutputName", path="processing_output.json"
)
processor = Processor(
image_uri=IMAGE_URI,
role=ROLE,
instance_count=1,
instance_type="c4.4xlarge",
sagemaker_session=pipeline_session,
)
process_arg = processor.run(outputs=[ProcessingOutput(output_name="TestOutputName")])
processing_step = ProcessingStep(
name="inputProcessingStep",
step_args=process_arg,
property_files=[property_file],
)

json_get_function = JsonGet(
step_name=processing_step.name, property_file=property_file.name, json_path="mse"
)

estimator = Estimator(
image_uri=IMAGE_URI,
role=ROLE,
instance_count=1,
instance_type=json_get_function,
sagemaker_session=pipeline_session,
)

training_step = TrainingStep(
name="MyTrainingStep",
step_args=estimator.fit(inputs=TrainingInput(s3_data=f"s3://{BUCKET}/train_manifest")),
)
pipeline = Pipeline(
name="MyPipeline",
steps=[processing_step, training_step],
sagemaker_session=pipeline_session,
)

steps = json.loads(pipeline.definition())["Steps"]
for step in steps:
if step["Type"] == "Processing":
continue
assert step["Arguments"]["ResourceConfig"]["InstanceType"] == {
"Std:JsonGet": {
"Path": "mse",
"PropertyFile": {"Get": "Steps.inputProcessingStep.PropertyFiles.my-property-file"},
}
}