Skip to content

fix: Fix Pipeline variables related customer issues #2959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1879,7 +1879,9 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
if estimator.use_spot_instances:
if local_mode:
raise ValueError("Spot training is not supported in local mode.")
train_args["use_spot_instances"] = True
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
# which is parsed during the Pipeline execution runtime
train_args["use_spot_instances"] = estimator.use_spot_instances

if estimator.checkpoint_s3_uri:
if local_mode:
Expand Down
3 changes: 2 additions & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from sagemaker.utils import unique_name_from_base
from sagemaker.async_inference import AsyncInferenceConfig
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.workflow.entities import PipelineVariable

LOGGER = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -443,7 +444,7 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
)

if repack and self.model_data is not None and self.entry_point is not None:
if isinstance(self.model_data, sagemaker.workflow.properties.Properties):
if isinstance(self.model_data, PipelineVariable):
# model is not yet there, defer repacking to later during pipeline execution
return

Expand Down
8 changes: 7 additions & 1 deletion src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from sagemaker.session import Session
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.entities import Expression
from sagemaker.workflow.entities import Expression, PipelineVariable
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
from sagemaker.apiutils._base_types import ApiObject
from sagemaker.s3 import S3Uploader
Expand Down Expand Up @@ -233,6 +233,12 @@ def _normalize_args(
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
"""
if code and isinstance(code, PipelineVariable):
raise ValueError(
"code argument has to be a valid S3 URI or local file path "
+ "rather than a pipeline variable"
)

self._current_job_name = self._generate_current_job_name(job_name=job_name)

inputs_with_code = self._include_code_in_inputs(inputs, code, kms_key)
Expand Down
12 changes: 9 additions & 3 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,8 @@ def _get_train_request( # noqa: C901
train_request["EnableInterContainerTrafficEncryption"] = encrypt_inter_container_traffic

if use_spot_instances:
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
# which is parsed during the Pipeline execution runtime
train_request["EnableManagedSpotTraining"] = use_spot_instances

if checkpoint_s3_uri:
Expand Down Expand Up @@ -2338,13 +2340,17 @@ def _map_training_config(
training_job_definition["VpcConfig"] = vpc_config

if enable_network_isolation:
training_job_definition["EnableNetworkIsolation"] = True
training_job_definition["EnableNetworkIsolation"] = enable_network_isolation

if encrypt_inter_container_traffic:
training_job_definition["EnableInterContainerTrafficEncryption"] = True
training_job_definition[
"EnableInterContainerTrafficEncryption"
] = encrypt_inter_container_traffic

if use_spot_instances:
training_job_definition["EnableManagedSpotTraining"] = True
# use_spot_instances may be a Pipeline ParameterBoolean object
# which is parsed during the Pipeline execution runtime
training_job_definition["EnableManagedSpotTraining"] = use_spot_instances

if checkpoint_s3_uri:
checkpoint_config = {"S3Uri": checkpoint_s3_uri}
Expand Down
5 changes: 4 additions & 1 deletion src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sagemaker.deprecations import removed_kwargs
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.workflow.entities import PipelineVariable


class TensorFlowPredictor(Predictor):
Expand Down Expand Up @@ -330,7 +331,9 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
image_uri = self._get_image_uri(instance_type, accelerator_type)
env = self._get_container_env()

if self.entry_point:
# If self.model_data is pipeline variable, model is not yet there.
# So defer repacking to later during pipeline execution
if self.entry_point and not isinstance(self.model_data, PipelineVariable):
key_prefix = sagemaker.fw_utils.model_code_key_prefix(
self.key_prefix, self.name, image_uri
)
Expand Down
15 changes: 0 additions & 15 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@
ParameterRange,
)
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.parameters import Parameter as PipelineParameter
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
from sagemaker.workflow.functions import Join as PipelineJoin

from sagemaker.session import Session
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base
Expand All @@ -64,18 +61,6 @@
logger = logging.getLogger(__name__)


def is_pipeline_parameters(value):
"""Determine if a value is a pipeline parameter or function representation

Args:
value (float or int): The value to be verified.

Returns:
bool: True if it is, False otherwise.
"""
return isinstance(value, (PipelineParameter, PipelineJsonGet, PipelineJoin))


class WarmStartTypes(Enum):
"""Warm Start Configuration type.

Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/workflow/_repack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
Args:
inference_script (str): The path to the custom entry point.
model_archive (str): The name of the model TAR archive.
model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive.
dependencies (str): A space-delimited string of paths to custom dependencies.
source_dir (str): The path to a custom source directory.
"""

# the data directory contains a model archive generated by a previous training job
data_directory = "/opt/ml/input/data/training"
model_path = os.path.join(data_directory, model_archive)
model_path = os.path.join(data_directory, model_archive.split("/")[-1])

# create a temporary directory
with tempfile.TemporaryDirectory() as tmp:
Expand Down
10 changes: 2 additions & 8 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,6 @@ def __init__(
self._model_data = model_data
self.sagemaker_session = sagemaker_session
self.role = role
if isinstance(model_data, Properties):
self._model_prefix = model_data
self._model_archive = "model.tar.gz"
else:
self._model_prefix = "/".join(self._model_data.split("/")[:-1])
self._model_archive = self._model_data.split("/")[-1]
self._entry_point = entry_point
self._entry_point_basename = os.path.basename(self._entry_point)
self._source_dir = source_dir
Expand All @@ -164,7 +158,7 @@ def __init__(
role=self.role,
hyperparameters={
"inference_script": self._entry_point_basename,
"model_archive": self._model_archive,
"model_archive": self._model_data,
"dependencies": dependencies_hyperparameter,
"source_dir": self._source_dir,
},
Expand All @@ -173,7 +167,7 @@ def __init__(
**kwargs,
)
repacker.disable_profiler = True
inputs = TrainingInput(self._model_prefix)
inputs = TrainingInput(self._model_data)

# super!
super(_RepackModelStep, self).__init__(
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/workflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
train_config["VpcConfig"] = job_config["vpc_config"]

if estimator.use_spot_instances:
train_config["EnableManagedSpotTraining"] = True
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
# which is parsed during the Pipeline execution runtime
train_config["EnableManagedSpotTraining"] = estimator.use_spot_instances

if estimator.hyperparameters() is not None:
hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}
Expand Down
10 changes: 4 additions & 6 deletions src/sagemaker/workflow/clarify_check_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
from sagemaker.model_monitor.model_monitoring import _MODEL_MONITOR_S3_PATH
from sagemaker.processing import ProcessingInput, ProcessingOutput, ProcessingJob
from sagemaker.utils import name_from_base
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, ExecutionVariable, Parameter
from sagemaker.workflow.entities import RequestType, Expression
from sagemaker.workflow import PipelineNonPrimitiveInputTypes
from sagemaker.workflow.entities import RequestType, PipelineVariable
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
from sagemaker.workflow.check_job_config import CheckJobConfig
Expand Down Expand Up @@ -194,17 +194,15 @@ def __init__(
)

if isinstance(
clarify_check_config.data_config.s3_analysis_config_output_path,
(ExecutionVariable, Expression, Parameter, Properties),
clarify_check_config.data_config.s3_analysis_config_output_path, PipelineVariable
):
raise RuntimeError(
"s3_analysis_config_output_path cannot be of type "
+ "ExecutionVariable/Expression/Parameter/Properties"
)

if not clarify_check_config.data_config.s3_analysis_config_output_path and isinstance(
clarify_check_config.data_config.s3_output_path,
(ExecutionVariable, Expression, Parameter, Properties),
clarify_check_config.data_config.s3_output_path, PipelineVariable
):
raise RuntimeError(
"`s3_output_path` cannot be of type ExecutionVariable/Expression/Parameter"
Expand Down
3 changes: 2 additions & 1 deletion src/sagemaker/workflow/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Expression,
PrimitiveType,
RequestType,
PipelineVariable,
)
from sagemaker.workflow.execution_variables import ExecutionVariable
from sagemaker.workflow.parameters import Parameter
Expand Down Expand Up @@ -261,6 +262,6 @@ def primitive_or_expr(
Returns:
Either the expression of the value or the primitive value.
"""
if isinstance(value, (ExecutionVariable, Expression, Parameter, Properties)):
if isinstance(value, PipelineVariable):
return value.expr
return value
6 changes: 3 additions & 3 deletions src/sagemaker/workflow/quality_check_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from sagemaker import s3
from sagemaker.model_monitor import ModelMonitor
from sagemaker.processing import ProcessingOutput, ProcessingJob, Processor, ProcessingInput
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, ExecutionVariable, Parameter
from sagemaker.workflow import PipelineNonPrimitiveInputTypes

from sagemaker.workflow.entities import RequestType, Expression
from sagemaker.workflow.entities import RequestType, PipelineVariable
from sagemaker.workflow.properties import (
Properties,
)
Expand Down Expand Up @@ -279,7 +279,7 @@ def _generate_baseline_job_inputs(self):
_CONTAINER_BASE_PATH, _CONTAINER_INPUT_PATH, _BASELINE_DATASET_INPUT_NAME
)
)
if isinstance(baseline_dataset, (ExecutionVariable, Expression, Parameter, Properties)):
if isinstance(baseline_dataset, PipelineVariable):
baseline_dataset_input = ProcessingInput(
source=self.quality_check_config.baseline_dataset,
destination=baseline_dataset_des,
Expand Down
99 changes: 99 additions & 0 deletions tests/integ/sagemaker/workflow/test_model_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from botocore.exceptions import WaiterError

import tests
from sagemaker.tensorflow import TensorFlow, TensorFlowModel
from tests.integ.retry import retries
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker import (
Expand Down Expand Up @@ -745,3 +746,101 @@ def test_model_registration_with_model_repack(
pipeline.delete()
except Exception:
pass


def test_model_registration_with_tensorflow_model_with_pipeline_model(
sagemaker_session, role, tf_full_version, tf_full_py_version, pipeline_name, region_name
):
base_dir = os.path.join(DATA_DIR, "tensorflow_mnist")
entry_point = os.path.join(base_dir, "mnist_v2.py")
input_path = sagemaker_session.upload_data(
path=os.path.join(base_dir, "data"),
key_prefix="integ-test-data/tf-scriptmode/mnist/training",
)
inputs = TrainingInput(s3_data=input_path)

instance_count = ParameterInteger(name="InstanceCount", default_value=1)
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")

tensorflow_estimator = TensorFlow(
entry_point=entry_point,
role=role,
instance_count=instance_count,
instance_type=instance_type,
framework_version=tf_full_version,
py_version=tf_full_py_version,
sagemaker_session=sagemaker_session,
)
step_train = TrainingStep(
name="MyTrain",
estimator=tensorflow_estimator,
inputs=inputs,
)

model = TensorFlowModel(
entry_point=entry_point,
framework_version="2.4",
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
role=role,
sagemaker_session=sagemaker_session,
)

pipeline_model = PipelineModel(
name="MyModelPipeline", models=[model], role=role, sagemaker_session=sagemaker_session
)

step_register_model = RegisterModel(
name="MyRegisterModel",
model=pipeline_model,
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
content_types=["application/json"],
response_types=["application/json"],
inference_instances=["ml.t2.medium", "ml.m5.large"],
transform_instances=["ml.m5.large"],
model_package_group_name=f"{pipeline_name}TestModelPackageGroup",
)

pipeline = Pipeline(
name=pipeline_name,
parameters=[
instance_count,
instance_type,
],
steps=[step_train, step_register_model],
sagemaker_session=sagemaker_session,
)

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]

assert re.match(
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)

for _ in retries(
max_retry_count=5,
exception_message_prefix="Waiting for a successful execution of pipeline",
seconds_to_sleep=10,
):
execution = pipeline.start(parameters={})
assert re.match(
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
execution.arn,
)
try:
execution.wait(delay=30, max_attempts=60)
except WaiterError:
pass
execution_steps = execution.list_steps()

assert len(execution_steps) == 3
for step in execution_steps:
assert step["StepStatus"] == "Succeeded"
break
finally:
try:
pipeline.delete()
except Exception:
pass
Loading