Skip to content

Commit ba0e1da

Browse files
authored
Merge branch 'master' into pipeline-experiment-config
2 parents 331d24f + 2a61c41 commit ba0e1da

File tree

8 files changed

+483
-1
lines changed

8 files changed

+483
-1
lines changed

src/sagemaker/workflow/_repack_model.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@
1919
import tarfile
2020
import tempfile
2121

22+
# Repack Model
23+
# The following script is run via a training job which takes an existing model and a custom
24+
# entry point script as arguments. The script creates a new model archive with the custom
25+
# entry point in the "code" directory along with the existing model. Subsequently, when the model
26+
# is unpacked for inference, the custom entry point will be used.
27+
# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html
28+
2229
# distutils.dir_util.copy_tree works way better than the half-baked
2330
# shutil.copytree which bombs on previously existing target dirs...
2431
# alas ... https://bugs.python.org/issue10948
@@ -33,17 +40,28 @@
3340
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
3441
args = parser.parse_args()
3542

43+
# the data directory contains a model archive generated by a previous training job
3644
data_directory = "/opt/ml/input/data/training"
3745
model_path = os.path.join(data_directory, args.model_archive)
3846

47+
# create a temporary directory
3948
with tempfile.TemporaryDirectory() as tmp:
4049
local_path = os.path.join(tmp, "local.tar.gz")
50+
# copy the previous training job's model archive to the temporary directory
4151
shutil.copy2(model_path, local_path)
4252
src_dir = os.path.join(tmp, "src")
53+
# create the "code" directory which will contain the inference script
54+
os.makedirs(os.path.join(src_dir, "code"))
55+
# extract the contents of the previous training job's model archive to the "src"
56+
# directory of this training job
4357
with tarfile.open(name=local_path, mode="r:gz") as tf:
4458
tf.extractall(path=src_dir)
4559

60+
# generate a path to the custom inference script
4661
entry_point = os.path.join("/opt/ml/code", args.inference_script)
47-
shutil.copy2(entry_point, os.path.join(src_dir, args.inference_script))
62+
# copy the custom inference script to the "src" dir
63+
shutil.copy2(entry_point, os.path.join(src_dir, "code", args.inference_script))
4864

65+
# copy the "src" dir, which includes the previous training job's model and the
66+
# custom inference script, to the output of this training job
4967
copy_tree(src_dir, "/opt/ml/model")

src/sagemaker/wrangler/__init__.py

Whitespace-only changes.

src/sagemaker/wrangler/processing.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# #
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
# #
7+
# http://aws.amazon.com/apache2.0/
8+
# #
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The process definitions for data wrangler."""
14+
15+
from __future__ import absolute_import
16+
17+
from typing import Dict, List
18+
19+
from sagemaker.network import NetworkConfig
20+
from sagemaker.processing import (
21+
ProcessingInput,
22+
Processor,
23+
)
24+
from sagemaker import image_uris
25+
from sagemaker.session import Session
26+
27+
28+
class DataWranglerProcessor(Processor):
29+
"""Handles Amazon SageMaker DataWrangler tasks"""
30+
31+
def __init__(
32+
self,
33+
role: str,
34+
data_wrangler_flow_source: str,
35+
instance_count: int,
36+
instance_type: str,
37+
volume_size_in_gb: int = 30,
38+
volume_kms_key: str = None,
39+
output_kms_key: str = None,
40+
max_runtime_in_seconds: int = None,
41+
base_job_name: str = None,
42+
sagemaker_session: Session = None,
43+
env: Dict[str, str] = None,
44+
tags: List[dict] = None,
45+
network_config: NetworkConfig = None,
46+
):
47+
"""Initializes a ``Processor`` instance.
48+
49+
The ``Processor`` handles Amazon SageMaker Processing tasks.
50+
51+
Args:
52+
role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing
53+
uses this role to access AWS resources, such as
54+
data stored in Amazon S3.
55+
data_wrangler_flow_source (str): The source of the DaraWrangler flow which will be
56+
used for the DataWrangler job. If a local path is provided, it will automatically
57+
be uploaded to S3 under:
58+
"s3://<default-bucket-name>/<job-name>/input/<input-name>".
59+
instance_count (int): The number of instances to run
60+
a processing job with.
61+
instance_type (str): The type of EC2 instance to use for
62+
processing, for example, 'ml.c4.xlarge'.
63+
volume_size_in_gb (int): Size in GB of the EBS volume
64+
to use for storing data during processing (default: 30).
65+
volume_kms_key (str): A KMS key for the processing
66+
volume (default: None).
67+
output_kms_key (str): The KMS key ID for processing job outputs (default: None).
68+
max_runtime_in_seconds (int): Timeout in seconds (default: None).
69+
After this amount of time, Amazon SageMaker terminates the job,
70+
regardless of its current status. If `max_runtime_in_seconds` is not
71+
specified, the default value is 24 hours.
72+
base_job_name (str): Prefix for processing job name. If not specified,
73+
the processor generates a default job name, based on the
74+
processing image name and current timestamp.
75+
sagemaker_session (:class:`~sagemaker.session.Session`):
76+
Session object which manages interactions with Amazon SageMaker and
77+
any other AWS services needed. If not specified, the processor creates
78+
one using the default AWS configuration chain.
79+
env (dict[str, str]): Environment variables to be passed to
80+
the processing jobs (default: None).
81+
tags (list[dict]): List of tags to be passed to the processing job
82+
(default: None). For more, see
83+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
84+
network_config (:class:`~sagemaker.network.NetworkConfig`):
85+
A :class:`~sagemaker.network.NetworkConfig`
86+
object that configures network isolation, encryption of
87+
inter-container traffic, security group IDs, and subnets.
88+
"""
89+
self.data_wrangler_flow_source = data_wrangler_flow_source
90+
self.sagemaker_session = sagemaker_session or Session()
91+
image_uri = image_uris.retrieve(
92+
"data-wrangler", region=self.sagemaker_session.boto_region_name
93+
)
94+
super().__init__(
95+
role,
96+
image_uri,
97+
instance_count,
98+
instance_type,
99+
volume_size_in_gb=volume_size_in_gb,
100+
volume_kms_key=volume_kms_key,
101+
output_kms_key=output_kms_key,
102+
max_runtime_in_seconds=max_runtime_in_seconds,
103+
base_job_name=base_job_name,
104+
sagemaker_session=sagemaker_session,
105+
env=env,
106+
tags=tags,
107+
network_config=network_config,
108+
)
109+
110+
def _normalize_args(
111+
self,
112+
job_name=None,
113+
arguments=None,
114+
inputs=None,
115+
outputs=None,
116+
code=None,
117+
kms_key=None,
118+
):
119+
"""Normalizes the arguments so that they can be passed to the job run
120+
121+
Args:
122+
job_name (str): Name of the processing job to be created. If not specified, one
123+
is generated, using the base name given to the constructor, if applicable
124+
(default: None).
125+
arguments (list[str]): A list of string arguments to be passed to a
126+
processing job (default: None).
127+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
128+
the processing job. These must be provided as
129+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
130+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
131+
the processing job. These can be specified as either path strings or
132+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
133+
code (str): This can be an S3 URI or a local path to a file with the framework
134+
script to run (default: None). A no op in the base class.
135+
kms_key (str): The ARN of the KMS key that is used to encrypt the
136+
user code file (default: None).
137+
"""
138+
inputs = inputs or []
139+
found = any(element.input_name == "flow" for element in inputs)
140+
if not found:
141+
inputs.append(self._get_recipe_input())
142+
return super()._normalize_args(job_name, arguments, inputs, outputs, code, kms_key)
143+
144+
def _get_recipe_input(self):
145+
"""Creates a ProcessingInput with Data Wrangler recipe uri and appends it to inputs"""
146+
return ProcessingInput(
147+
source=self.data_wrangler_flow_source,
148+
destination="/opt/ml/processing/flow",
149+
input_name="flow",
150+
s3_data_type="S3Prefix",
151+
s3_input_mode="File",
152+
s3_data_distribution_type="FullyReplicated",
153+
)

tests/data/workflow/dummy_data.csv

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Class,Age,Sex,SurvivalStatus
2+
1st,"Quantity[29., ""Years""]",female,survived
3+
1st,"Quantity[0.9167, ""Years""]",male,survived
4+
2nd,"Quantity[30., ""Years""]",male,died
5+
2nd,"Quantity[28., ""Years""]",female,survived
6+
3rd,"Quantity[16., ""Years""]",male,died
7+
3rd,"Quantity[35., ""Years""]",female,survived

tests/data/workflow/dummy_recipe.flow

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
{
2+
"metadata": {
3+
"version": 1
4+
},
5+
"nodes": [
6+
{
7+
"node_id": "3f74973c-fd1e-4845-89f8-0dd400031be9",
8+
"type": "SOURCE",
9+
"operator": "sagemaker.s3_source_0.1",
10+
"parameters": {
11+
"dataset_definition": {
12+
"__typename": "S3CreateDatasetDefinitionOutput",
13+
"datasetSourceType": "S3",
14+
"name": "dummy_data.csv",
15+
"description": null,
16+
"s3ExecutionContext": {
17+
"__typename": "S3ExecutionContext",
18+
"s3Uri": "s3://bucket/dummy_data.csv",
19+
"s3ContentType": "csv",
20+
"s3HasHeader": true
21+
}
22+
}
23+
},
24+
"inputs": [],
25+
"outputs": [
26+
{
27+
"name": "default",
28+
"sampling": {
29+
"sampling_method": "sample_by_limit",
30+
"limit_rows": 50000
31+
}
32+
}
33+
]
34+
},
35+
{
36+
"node_id": "67c18cb1-0192-445a-86f4-31e4c3553c60",
37+
"type": "TRANSFORM",
38+
"operator": "sagemaker.spark.infer_and_cast_type_0.1",
39+
"parameters": {},
40+
"trained_parameters": {
41+
"schema": {
42+
"Class": "string",
43+
"Age": "string",
44+
"Sex": "string",
45+
"SurvivalStatus": "string"
46+
}
47+
},
48+
"inputs": [
49+
{
50+
"name": "default",
51+
"node_id": "3f74973c-fd1e-4845-89f8-0dd400031be9",
52+
"output_name": "default"
53+
}
54+
],
55+
"outputs": [
56+
{
57+
"name": "default"
58+
}
59+
]
60+
}
61+
]
62+
}

tests/integ/test_workflow.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
rule_configs,
2929
)
3030
from datetime import datetime
31+
from sagemaker import image_uris
3132
from sagemaker.inputs import CreateModelInput, TrainingInput
3233
from sagemaker.model import Model
3334
from sagemaker.processing import ProcessingInput, ProcessingOutput
@@ -39,6 +40,7 @@
3940
from sagemaker.spark.processing import PySparkProcessor, SparkJarProcessor
4041
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo
4142
from sagemaker.workflow.condition_step import ConditionStep
43+
from sagemaker.wrangler.processing import DataWranglerProcessor
4244
from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition
4345
from sagemaker.workflow.execution_variables import ExecutionVariables
4446
from sagemaker.workflow.functions import Join
@@ -1076,3 +1078,113 @@ def test_two_processing_job_depends_on(
10761078
pipeline.delete()
10771079
except Exception:
10781080
pass
1081+
1082+
1083+
def test_one_step_data_wrangler_processing_pipeline(
1084+
sagemaker_session,
1085+
role,
1086+
pipeline_name,
1087+
region_name,
1088+
):
1089+
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
1090+
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.4xlarge")
1091+
1092+
recipe_file_path = os.path.join(DATA_DIR, "workflow", "dummy_recipe.flow")
1093+
input_file_path = os.path.join(DATA_DIR, "workflow", "dummy_data.csv")
1094+
1095+
output_name = "3f74973c-fd1e-4845-89f8-0dd400031be9.default"
1096+
output_content_type = "CSV"
1097+
output_config = {output_name: {"content_type": output_content_type}}
1098+
job_argument = [f"--output-config '{json.dumps(output_config)}'"]
1099+
1100+
inputs = [
1101+
ProcessingInput(
1102+
input_name="dummy_data.csv",
1103+
source=input_file_path,
1104+
destination="/opt/ml/processing/dummy_data.csv",
1105+
)
1106+
]
1107+
1108+
output_s3_uri = f"s3://{sagemaker_session.default_bucket()}/output"
1109+
outputs = [
1110+
ProcessingOutput(
1111+
output_name=output_name,
1112+
source="/opt/ml/processing/output",
1113+
destination=output_s3_uri,
1114+
s3_upload_mode="EndOfJob",
1115+
)
1116+
]
1117+
1118+
data_wrangler_processor = DataWranglerProcessor(
1119+
role=role,
1120+
data_wrangler_flow_source=recipe_file_path,
1121+
instance_count=instance_count,
1122+
instance_type=instance_type,
1123+
sagemaker_session=sagemaker_session,
1124+
max_runtime_in_seconds=86400,
1125+
)
1126+
1127+
data_wrangler_step = ProcessingStep(
1128+
name="data-wrangler-step",
1129+
processor=data_wrangler_processor,
1130+
inputs=inputs,
1131+
outputs=outputs,
1132+
job_arguments=job_argument,
1133+
)
1134+
1135+
pipeline = Pipeline(
1136+
name=pipeline_name,
1137+
parameters=[instance_count, instance_type],
1138+
steps=[data_wrangler_step],
1139+
sagemaker_session=sagemaker_session,
1140+
)
1141+
1142+
definition = json.loads(pipeline.definition())
1143+
expected_image_uri = image_uris.retrieve(
1144+
"data-wrangler", region=sagemaker_session.boto_region_name
1145+
)
1146+
assert len(definition["Steps"]) == 1
1147+
assert definition["Steps"][0]["Arguments"]["AppSpecification"]["ImageUri"] is not None
1148+
assert definition["Steps"][0]["Arguments"]["AppSpecification"]["ImageUri"] == expected_image_uri
1149+
1150+
assert definition["Steps"][0]["Arguments"]["ProcessingInputs"] is not None
1151+
processing_inputs = definition["Steps"][0]["Arguments"]["ProcessingInputs"]
1152+
assert len(processing_inputs) == 2
1153+
for processing_input in processing_inputs:
1154+
if processing_input["InputName"] == "flow":
1155+
assert processing_input["S3Input"]["S3Uri"].endswith(".flow")
1156+
assert processing_input["S3Input"]["LocalPath"] == "/opt/ml/processing/flow"
1157+
elif processing_input["InputName"] == "dummy_data.csv":
1158+
assert processing_input["S3Input"]["S3Uri"].endswith(".csv")
1159+
assert processing_input["S3Input"]["LocalPath"] == "/opt/ml/processing/dummy_data.csv"
1160+
else:
1161+
raise AssertionError("Unknown input name")
1162+
assert definition["Steps"][0]["Arguments"]["ProcessingOutputConfig"] is not None
1163+
processing_outputs = definition["Steps"][0]["Arguments"]["ProcessingOutputConfig"]["Outputs"]
1164+
assert len(processing_outputs) == 1
1165+
assert processing_outputs[0]["OutputName"] == output_name
1166+
assert processing_outputs[0]["S3Output"] is not None
1167+
assert processing_outputs[0]["S3Output"]["LocalPath"] == "/opt/ml/processing/output"
1168+
assert processing_outputs[0]["S3Output"]["S3Uri"] == output_s3_uri
1169+
1170+
try:
1171+
response = pipeline.create(role)
1172+
create_arn = response["PipelineArn"]
1173+
1174+
execution = pipeline.start()
1175+
response = execution.describe()
1176+
assert response["PipelineArn"] == create_arn
1177+
1178+
try:
1179+
execution.wait(delay=60, max_attempts=10)
1180+
except WaiterError:
1181+
pass
1182+
1183+
execution_steps = execution.list_steps()
1184+
assert len(execution_steps) == 1
1185+
assert execution_steps[0]["StepName"] == "data-wrangler-step"
1186+
finally:
1187+
try:
1188+
pipeline.delete()
1189+
except Exception:
1190+
pass

tests/unit/sagemaker/wrangler/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)