Skip to content

Commit d28b89f

Browse files
Merge branch 'inference-url' of https://github.com/HappyAmazonian/sagemaker-python-sdk into inference-url
2 parents edbf61f + 1d5a5dc commit d28b89f

File tree

10 files changed

+423
-128
lines changed

10 files changed

+423
-128
lines changed

CHANGELOG.md

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
# Changelog
22

3+
## v2.72.1 (2021-12-20)
4+
5+
### Bug Fixes and Other Changes
6+
7+
* typos and broken link
8+
* S3Input - add support for instance attributes
9+
* Prevent repack_model script from referencing nonexistent directories
10+
* Set ProcessingStep upload locations deterministically to avoid c…
11+
312
## v2.72.0 (2021-12-13)
413

514
### Features

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.72.1.dev0
1+
2.72.2.dev0

doc/frameworks/pytorch/using_pytorch.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ with the following:
8080
8181
# ... load from args.train and args.test, train a model, write model to args.model_dir.
8282
83-
Because the SageMaker imports your training script, you should put your training code in a main guard
83+
Because SageMaker imports your training script, you should put your training code in a main guard
8484
(``if __name__=='__main__':``) if you are using the same script to host your model, so that SageMaker does not
8585
inadvertently run your training code at the wrong point in execution.
8686

@@ -177,7 +177,7 @@ fit Required Arguments
177177
case, the S3 objects rooted at the ``my-training-data`` prefix will
178178
be available in the default ``train`` channel. A dict from
179179
string channel names to S3 URIs. In this case, the objects rooted at
180-
each S3 prefix will available as files in each channel directory.
180+
each S3 prefix will be available as files in each channel directory.
181181

182182
For example:
183183

@@ -391,7 +391,7 @@ If you are using PyTorch Elastic Inference 1.5.1, you should provide ``model_fn`
391391
The client-side Elastic Inference framework is CPU-only, even though inference still happens in a CUDA context on the server. Thus, the default ``model_fn`` for Elastic Inference loads the model to CPU. Tracing models may lead to tensor creation on a specific device, which may cause device-related errors when loading a model onto a different device. Providing an explicit ``map_location=torch.device('cpu')`` argument forces all tensors to CPU.
392392

393393
For more information on the default inference handler functions, please refer to:
394-
`SageMaker PyTorch Default Inference Handler <https://github.com/aws/sagemaker-pytorch-serving-container/blob/master/src/sagemaker_pytorch_serving_container/default_inference_handler.py>`_.
394+
`SageMaker PyTorch Default Inference Handler <https://github.com/aws/sagemaker-pytorch-inference-toolkit/blob/master/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py>`_.
395395

396396
Serve a PyTorch Model
397397
---------------------

src/sagemaker/dataset_definition/inputs.py

+152-84
Original file line numberDiff line numberDiff line change
@@ -26,94 +26,147 @@ class RedshiftDatasetDefinition(ApiObject):
2626
"""DatasetDefinition for Redshift.
2727
2828
With this input, SQL queries will be executed using Redshift to generate datasets to S3.
29-
30-
Parameters:
31-
cluster_id (str): The Redshift cluster Identifier.
32-
database (str): The name of the Redshift database used in Redshift query execution.
33-
db_user (str): The database user name used in Redshift query execution.
34-
query_string (str): The SQL query statements to be executed.
35-
cluster_role_arn (str): The IAM role attached to your Redshift cluster that
36-
Amazon SageMaker uses to generate datasets.
37-
output_s3_uri (str): The location in Amazon S3 where the Redshift query
38-
results are stored.
39-
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon
40-
SageMaker uses to encrypt data from a Redshift execution.
41-
output_format (str): The data storage format for Redshift query results.
42-
Valid options are "PARQUET", "CSV"
43-
output_compression (str): The compression used for Redshift query results.
44-
Valid options are "None", "GZIP", "SNAPPY", "ZSTD", "BZIP2"
4529
"""
4630

47-
cluster_id = None
48-
database = None
49-
db_user = None
50-
query_string = None
51-
cluster_role_arn = None
52-
output_s3_uri = None
53-
kms_key_id = None
54-
output_format = None
55-
output_compression = None
31+
def __init__(
32+
self,
33+
cluster_id=None,
34+
database=None,
35+
db_user=None,
36+
query_string=None,
37+
cluster_role_arn=None,
38+
output_s3_uri=None,
39+
kms_key_id=None,
40+
output_format=None,
41+
output_compression=None,
42+
):
43+
"""Initialize RedshiftDatasetDefinition.
44+
45+
Args:
46+
cluster_id (str, default=None): The Redshift cluster Identifier.
47+
database (str, default=None):
48+
The name of the Redshift database used in Redshift query execution.
49+
db_user (str, default=None): The database user name used in Redshift query execution.
50+
query_string (str, default=None): The SQL query statements to be executed.
51+
cluster_role_arn (str, default=None): The IAM role attached to your Redshift cluster
52+
that Amazon SageMaker uses to generate datasets.
53+
output_s3_uri (str, default=None): The location in Amazon S3 where the Redshift query
54+
results are stored.
55+
kms_key_id (str, default=None): The AWS Key Management Service (AWS KMS) key that Amazon
56+
SageMaker uses to encrypt data from a Redshift execution.
57+
output_format (str, default=None): The data storage format for Redshift query results.
58+
Valid options are "PARQUET", "CSV"
59+
output_compression (str, default=None): The compression used for Redshift query results.
60+
Valid options are "None", "GZIP", "SNAPPY", "ZSTD", "BZIP2"
61+
"""
62+
super(RedshiftDatasetDefinition, self).__init__(
63+
cluster_id=cluster_id,
64+
database=database,
65+
db_user=db_user,
66+
query_string=query_string,
67+
cluster_role_arn=cluster_role_arn,
68+
output_s3_uri=output_s3_uri,
69+
kms_key_id=kms_key_id,
70+
output_format=output_format,
71+
output_compression=output_compression,
72+
)
5673

5774

5875
class AthenaDatasetDefinition(ApiObject):
5976
"""DatasetDefinition for Athena.
6077
6178
With this input, SQL queries will be executed using Athena to generate datasets to S3.
62-
63-
Parameters:
64-
catalog (str): The name of the data catalog used in Athena query execution.
65-
database (str): The name of the database used in the Athena query execution.
66-
query_string (str): The SQL query statements, to be executed.
67-
output_s3_uri (str): The location in Amazon S3 where Athena query results are stored.
68-
work_group (str): The name of the workgroup in which the Athena query is being started.
69-
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon
70-
SageMaker uses to encrypt data generated from an Athena query execution.
71-
output_format (str): The data storage format for Athena query results.
72-
Valid options are "PARQUET", "ORC", "AVRO", "JSON", "TEXTFILE"
73-
output_compression (str): The compression used for Athena query results.
74-
Valid options are "GZIP", "SNAPPY", "ZLIB"
7579
"""
7680

77-
catalog = None
78-
database = None
79-
query_string = None
80-
output_s3_uri = None
81-
work_group = None
82-
kms_key_id = None
83-
output_format = None
84-
output_compression = None
81+
def __init__(
82+
self,
83+
catalog=None,
84+
database=None,
85+
query_string=None,
86+
output_s3_uri=None,
87+
work_group=None,
88+
kms_key_id=None,
89+
output_format=None,
90+
output_compression=None,
91+
):
92+
"""Initialize AthenaDatasetDefinition.
93+
94+
Args:
95+
catalog (str, default=None): The name of the data catalog used in Athena query
96+
execution.
97+
database (str, default=None): The name of the database used in the Athena query
98+
execution.
99+
query_string (str, default=None): The SQL query statements, to be executed.
100+
output_s3_uri (str, default=None):
101+
The location in Amazon S3 where Athena query results are stored.
102+
work_group (str, default=None):
103+
The name of the workgroup in which the Athena query is being started.
104+
kms_key_id (str, default=None): The AWS Key Management Service (AWS KMS) key that Amazon
105+
SageMaker uses to encrypt data generated from an Athena query execution.
106+
output_format (str, default=None): The data storage format for Athena query results.
107+
Valid options are "PARQUET", "ORC", "AVRO", "JSON", "TEXTFILE"
108+
output_compression (str, default=None): The compression used for Athena query results.
109+
Valid options are "GZIP", "SNAPPY", "ZLIB"
110+
"""
111+
super(AthenaDatasetDefinition, self).__init__(
112+
catalog=catalog,
113+
database=database,
114+
query_string=query_string,
115+
output_s3_uri=output_s3_uri,
116+
work_group=work_group,
117+
kms_key_id=kms_key_id,
118+
output_format=output_format,
119+
output_compression=output_compression,
120+
)
85121

86122

87123
class DatasetDefinition(ApiObject):
88-
"""DatasetDefinition input.
89-
90-
Parameters:
91-
data_distribution_type (str): Whether the generated dataset is FullyReplicated or
92-
ShardedByS3Key (default).
93-
input_mode (str): Whether to use File or Pipe input mode. In File (default) mode, Amazon
94-
SageMaker copies the data from the input source onto the local Amazon Elastic Block
95-
Store (Amazon EBS) volumes before starting your training algorithm. This is the most
96-
commonly used input mode. In Pipe mode, Amazon SageMaker streams input data from the
97-
source directly to your algorithm without using the EBS volume.
98-
local_path (str): The local path where you want Amazon SageMaker to download the Dataset
99-
Definition inputs to run a processing job. LocalPath is an absolute path to the input
100-
data. This is a required parameter when `AppManaged` is False (default).
101-
redshift_dataset_definition (:class:`~sagemaker.dataset_definition.inputs.RedshiftDatasetDefinition`):
102-
Configuration for Redshift Dataset Definition input.
103-
athena_dataset_definition (:class:`~sagemaker.dataset_definition.inputs.AthenaDatasetDefinition`):
104-
Configuration for Athena Dataset Definition input.
105-
"""
124+
"""DatasetDefinition input."""
106125

107126
_custom_boto_types = {
108127
"redshift_dataset_definition": (RedshiftDatasetDefinition, True),
109128
"athena_dataset_definition": (AthenaDatasetDefinition, True),
110129
}
111130

112-
data_distribution_type = "ShardedByS3Key"
113-
input_mode = "File"
114-
local_path = None
115-
redshift_dataset_definition = None
116-
athena_dataset_definition = None
131+
def __init__(
132+
self,
133+
data_distribution_type="ShardedByS3Key",
134+
input_mode="File",
135+
local_path=None,
136+
redshift_dataset_definition=None,
137+
athena_dataset_definition=None,
138+
):
139+
"""Initialize DatasetDefinition.
140+
141+
Parameters:
142+
data_distribution_type (str, default="ShardedByS3Key"):
143+
Whether the generated dataset is FullyReplicated or ShardedByS3Key (default).
144+
input_mode (str, default="File"):
145+
Whether to use File or Pipe input mode. In File (default) mode, Amazon
146+
SageMaker copies the data from the input source onto the local Amazon Elastic Block
147+
Store (Amazon EBS) volumes before starting your training algorithm. This is the most
148+
commonly used input mode. In Pipe mode, Amazon SageMaker streams input data from the
149+
source directly to your algorithm without using the EBS volume.
150+
local_path (str, default=None):
151+
The local path where you want Amazon SageMaker to download the Dataset
152+
Definition inputs to run a processing job. LocalPath is an absolute path to the
153+
input data. This is a required parameter when `AppManaged` is False (default).
154+
redshift_dataset_definition
155+
(:class:`~sagemaker.dataset_definition.inputs.RedshiftDatasetDefinition`,
156+
default=None):
157+
Configuration for Redshift Dataset Definition input.
158+
athena_dataset_definition
159+
(:class:`~sagemaker.dataset_definition.inputs.AthenaDatasetDefinition`,
160+
default=None):
161+
Configuration for Athena Dataset Definition input.
162+
"""
163+
super(DatasetDefinition, self).__init__(
164+
data_distribution_type=data_distribution_type,
165+
input_mode=input_mode,
166+
local_path=local_path,
167+
redshift_dataset_definition=redshift_dataset_definition,
168+
athena_dataset_definition=athena_dataset_definition,
169+
)
117170

118171

119172
class S3Input(ApiObject):
@@ -124,20 +177,35 @@ class S3Input(ApiObject):
124177
Note: Strong consistency is not guaranteed if S3Prefix is provided here.
125178
S3 list operations are not strongly consistent.
126179
Use ManifestFile if strong consistency is required.
127-
128-
Parameters:
129-
s3_uri (str): the path to a specific S3 object or a S3 prefix
130-
local_path (str): the path to a local directory. If not provided, skips data download
131-
by SageMaker platform.
132-
s3_data_type (str): Valid options are "ManifestFile" or "S3Prefix".
133-
s3_input_mode (str): Valid options are "Pipe" or "File".
134-
s3_data_distribution_type (str): Valid options are "FullyReplicated" or "ShardedByS3Key".
135-
s3_compression_type (str): Valid options are "None" or "Gzip".
136180
"""
137181

138-
s3_uri = None
139-
local_path = None
140-
s3_data_type = "S3Prefix"
141-
s3_input_mode = "File"
142-
s3_data_distribution_type = "FullyReplicated"
143-
s3_compression_type = None
182+
def __init__(
183+
self,
184+
s3_uri=None,
185+
local_path=None,
186+
s3_data_type="S3Prefix",
187+
s3_input_mode="File",
188+
s3_data_distribution_type="FullyReplicated",
189+
s3_compression_type=None,
190+
):
191+
"""Initialize S3Input.
192+
193+
Parameters:
194+
s3_uri (str, default=None): the path to a specific S3 object or a S3 prefix
195+
local_path (str, default=None):
196+
the path to a local directory. If not provided, skips data download
197+
by SageMaker platform.
198+
s3_data_type (str, default="S3Prefix"): Valid options are "ManifestFile" or "S3Prefix".
199+
s3_input_mode (str, default="File"): Valid options are "Pipe" or "File".
200+
s3_data_distribution_type (str, default="FullyReplicated"):
201+
Valid options are "FullyReplicated" or "ShardedByS3Key".
202+
s3_compression_type (str, default=None): Valid options are "None" or "Gzip".
203+
"""
204+
super(S3Input, self).__init__(
205+
s3_uri=s3_uri,
206+
local_path=local_path,
207+
s3_data_type=s3_data_type,
208+
s3_input_mode=s3_input_mode,
209+
s3_data_distribution_type=s3_data_distribution_type,
210+
s3_compression_type=s3_compression_type,
211+
)

src/sagemaker/workflow/_repack_model.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,15 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
6262
with tarfile.open(name=local_path, mode="r:gz") as tf:
6363
tf.extractall(path=src_dir)
6464

65-
# copy the custom inference script to code/
66-
entry_point = os.path.join("/opt/ml/code", inference_script)
67-
shutil.copy2(entry_point, os.path.join(src_dir, "code", inference_script))
68-
69-
# copy source_dir to code/
7065
if source_dir:
66+
# copy /opt/ml/code to code/
7167
if os.path.exists(code_dir):
7268
shutil.rmtree(code_dir)
73-
shutil.copytree(source_dir, code_dir)
69+
shutil.copytree("/opt/ml/code", code_dir)
70+
else:
71+
# copy the custom inference script to code/
72+
entry_point = os.path.join("/opt/ml/code", inference_script)
73+
shutil.copy2(entry_point, os.path.join(code_dir, inference_script))
7474

7575
# copy any dependencies to code/lib/
7676
if dependencies:
@@ -79,13 +79,16 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
7979
lib_dir = os.path.join(code_dir, "lib")
8080
if not os.path.exists(lib_dir):
8181
os.mkdir(lib_dir)
82-
if os.path.isdir(actual_dependency_path):
83-
shutil.copytree(
84-
actual_dependency_path,
85-
os.path.join(lib_dir, os.path.basename(actual_dependency_path)),
86-
)
87-
else:
82+
if os.path.isfile(actual_dependency_path):
8883
shutil.copy2(actual_dependency_path, lib_dir)
84+
else:
85+
if os.path.exists(lib_dir):
86+
shutil.rmtree(lib_dir)
87+
# a directory is in the dependencies. we have to copy
88+
# all of /opt/ml/code into the lib dir because the original directory
89+
# was flattened by the SDK training job upload..
90+
shutil.copytree("/opt/ml/code", lib_dir)
91+
break
8992

9093
# copy the "src" dir, which includes the previous training job's model and the
9194
# custom inference script, to the output of this training job

0 commit comments

Comments
 (0)