Skip to content

feature: Allow custom output for RepackModelStep #2804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
dde8d00
fix: Set ProcessingStep upload locations deterministically to avoid c…
staubhp Dec 8, 2021
0f72907
fix: Prevent repack_model script from referencing nonexistent directo…
staubhp Dec 9, 2021
0bae071
fix: S3Input - add support for instance attributes (#2754)
mufaddal-rohawala Dec 15, 2021
17fe93e
fix: typos and broken link (#2765)
mohamed-ali Dec 16, 2021
f0efd27
feature: Add output path parameter for _RepackModelStep
tuliocasagrande Dec 17, 2021
7a1f4f8
fix: Fix role parameter for _RepackModelStep
tuliocasagrande Dec 17, 2021
ee6afcf
fix: Remove entry_point before calling Model on EstimatorTransformer
tuliocasagrande Jan 4, 2022
faf4ad5
feature: Add tests for RegisterModel with repack output
tuliocasagrande Jan 4, 2022
8210375
fix: fixes unnecessary session call while generating pipeline definit…
xchen909 Jan 10, 2022
972a6d2
feature: Add models_v2 under lineage context (#2800)
yzhu0 Jan 10, 2022
7206b9e
feature: enable python 3.9 (#2802)
mufaddal-rohawala Jan 10, 2022
127c964
change: Update CHANGELOG.md (#2842)
shreyapandit Jan 11, 2022
554d735
fix: update pricing link (#2805)
ahsan-z-khan Jan 11, 2022
88e4d68
doc: Document the available ExecutionVariables (#2807)
tuliocasagrande Jan 12, 2022
b3c19d8
fix: Remove duplicate vertex/edge in query lineage (#2784)
yzhu0 Jan 12, 2022
fd7a335
feature: Support model pipelines in CreateModelStep (#2845)
staubhp Jan 12, 2022
ccfcbe7
feature: support JsonGet/Join parameterization in tuning step Hyperpa…
jerrypeng7773 Jan 13, 2022
71c5617
doc: Enhance smddp 1.2.2 doc (#2852)
mchoi8739 Jan 13, 2022
975e031
feature: support checkpoint to be passed from estimator (#2849)
marckarp Jan 13, 2022
b377b52
fix: allow kms_key to be passed for processing step (#2779)
jayatalr Jan 13, 2022
9d259b3
feature: Adds support for Serverless inference (#2831)
bhaoz Jan 14, 2022
b82fb8a
feature: Add support for SageMaker lineage queries in action (#2853)
yzhu0 Jan 14, 2022
0489b59
Merge branch 'dev' into repack_output
shreyapandit Jan 14, 2022
ed9131b
Merge remote-tracking branch 'upstream/dev' into repack_output
May 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions doc/frameworks/pytorch/using_pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ with the following:

# ... load from args.train and args.test, train a model, write model to args.model_dir.

Because the SageMaker imports your training script, you should put your training code in a main guard
Because SageMaker imports your training script, you should put your training code in a main guard
(``if __name__=='__main__':``) if you are using the same script to host your model, so that SageMaker does not
inadvertently run your training code at the wrong point in execution.

Expand Down Expand Up @@ -177,7 +177,7 @@ fit Required Arguments
case, the S3 objects rooted at the ``my-training-data`` prefix will
be available in the default ``train`` channel. A dict from
string channel names to S3 URIs. In this case, the objects rooted at
each S3 prefix will available as files in each channel directory.
each S3 prefix will be available as files in each channel directory.

For example:

Expand Down Expand Up @@ -391,7 +391,7 @@ If you are using PyTorch Elastic Inference 1.5.1, you should provide ``model_fn`
The client-side Elastic Inference framework is CPU-only, even though inference still happens in a CUDA context on the server. Thus, the default ``model_fn`` for Elastic Inference loads the model to CPU. Tracing models may lead to tensor creation on a specific device, which may cause device-related errors when loading a model onto a different device. Providing an explicit ``map_location=torch.device('cpu')`` argument forces all tensors to CPU.

For more information on the default inference handler functions, please refer to:
`SageMaker PyTorch Default Inference Handler <https://github.com/aws/sagemaker-pytorch-serving-container/blob/master/src/sagemaker_pytorch_serving_container/default_inference_handler.py>`_.
`SageMaker PyTorch Default Inference Handler <https://github.com/aws/sagemaker-pytorch-inference-toolkit/blob/master/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py>`_.

Serve a PyTorch Model
---------------------
Expand Down
236 changes: 152 additions & 84 deletions src/sagemaker/dataset_definition/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,94 +26,147 @@ class RedshiftDatasetDefinition(ApiObject):
"""DatasetDefinition for Redshift.

With this input, SQL queries will be executed using Redshift to generate datasets to S3.

Parameters:
cluster_id (str): The Redshift cluster Identifier.
database (str): The name of the Redshift database used in Redshift query execution.
db_user (str): The database user name used in Redshift query execution.
query_string (str): The SQL query statements to be executed.
cluster_role_arn (str): The IAM role attached to your Redshift cluster that
Amazon SageMaker uses to generate datasets.
output_s3_uri (str): The location in Amazon S3 where the Redshift query
results are stored.
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon
SageMaker uses to encrypt data from a Redshift execution.
output_format (str): The data storage format for Redshift query results.
Valid options are "PARQUET", "CSV"
output_compression (str): The compression used for Redshift query results.
Valid options are "None", "GZIP", "SNAPPY", "ZSTD", "BZIP2"
"""

cluster_id = None
database = None
db_user = None
query_string = None
cluster_role_arn = None
output_s3_uri = None
kms_key_id = None
output_format = None
output_compression = None
def __init__(
self,
cluster_id=None,
database=None,
db_user=None,
query_string=None,
cluster_role_arn=None,
output_s3_uri=None,
kms_key_id=None,
output_format=None,
output_compression=None,
):
"""Initialize RedshiftDatasetDefinition.

Args:
cluster_id (str, default=None): The Redshift cluster Identifier.
database (str, default=None):
The name of the Redshift database used in Redshift query execution.
db_user (str, default=None): The database user name used in Redshift query execution.
query_string (str, default=None): The SQL query statements to be executed.
cluster_role_arn (str, default=None): The IAM role attached to your Redshift cluster
that Amazon SageMaker uses to generate datasets.
output_s3_uri (str, default=None): The location in Amazon S3 where the Redshift query
results are stored.
kms_key_id (str, default=None): The AWS Key Management Service (AWS KMS) key that Amazon
SageMaker uses to encrypt data from a Redshift execution.
output_format (str, default=None): The data storage format for Redshift query results.
Valid options are "PARQUET", "CSV"
output_compression (str, default=None): The compression used for Redshift query results.
Valid options are "None", "GZIP", "SNAPPY", "ZSTD", "BZIP2"
"""
super(RedshiftDatasetDefinition, self).__init__(
cluster_id=cluster_id,
database=database,
db_user=db_user,
query_string=query_string,
cluster_role_arn=cluster_role_arn,
output_s3_uri=output_s3_uri,
kms_key_id=kms_key_id,
output_format=output_format,
output_compression=output_compression,
)


class AthenaDatasetDefinition(ApiObject):
"""DatasetDefinition for Athena.

With this input, SQL queries will be executed using Athena to generate datasets to S3.

Parameters:
catalog (str): The name of the data catalog used in Athena query execution.
database (str): The name of the database used in the Athena query execution.
query_string (str): The SQL query statements, to be executed.
output_s3_uri (str): The location in Amazon S3 where Athena query results are stored.
work_group (str): The name of the workgroup in which the Athena query is being started.
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon
SageMaker uses to encrypt data generated from an Athena query execution.
output_format (str): The data storage format for Athena query results.
Valid options are "PARQUET", "ORC", "AVRO", "JSON", "TEXTFILE"
output_compression (str): The compression used for Athena query results.
Valid options are "GZIP", "SNAPPY", "ZLIB"
"""

catalog = None
database = None
query_string = None
output_s3_uri = None
work_group = None
kms_key_id = None
output_format = None
output_compression = None
def __init__(
self,
catalog=None,
database=None,
query_string=None,
output_s3_uri=None,
work_group=None,
kms_key_id=None,
output_format=None,
output_compression=None,
):
"""Initialize AthenaDatasetDefinition.

Args:
catalog (str, default=None): The name of the data catalog used in Athena query
execution.
database (str, default=None): The name of the database used in the Athena query
execution.
query_string (str, default=None): The SQL query statements, to be executed.
output_s3_uri (str, default=None):
The location in Amazon S3 where Athena query results are stored.
work_group (str, default=None):
The name of the workgroup in which the Athena query is being started.
kms_key_id (str, default=None): The AWS Key Management Service (AWS KMS) key that Amazon
SageMaker uses to encrypt data generated from an Athena query execution.
output_format (str, default=None): The data storage format for Athena query results.
Valid options are "PARQUET", "ORC", "AVRO", "JSON", "TEXTFILE"
output_compression (str, default=None): The compression used for Athena query results.
Valid options are "GZIP", "SNAPPY", "ZLIB"
"""
super(AthenaDatasetDefinition, self).__init__(
catalog=catalog,
database=database,
query_string=query_string,
output_s3_uri=output_s3_uri,
work_group=work_group,
kms_key_id=kms_key_id,
output_format=output_format,
output_compression=output_compression,
)


class DatasetDefinition(ApiObject):
"""DatasetDefinition input.

Parameters:
data_distribution_type (str): Whether the generated dataset is FullyReplicated or
ShardedByS3Key (default).
input_mode (str): Whether to use File or Pipe input mode. In File (default) mode, Amazon
SageMaker copies the data from the input source onto the local Amazon Elastic Block
Store (Amazon EBS) volumes before starting your training algorithm. This is the most
commonly used input mode. In Pipe mode, Amazon SageMaker streams input data from the
source directly to your algorithm without using the EBS volume.
local_path (str): The local path where you want Amazon SageMaker to download the Dataset
Definition inputs to run a processing job. LocalPath is an absolute path to the input
data. This is a required parameter when `AppManaged` is False (default).
redshift_dataset_definition (:class:`~sagemaker.dataset_definition.inputs.RedshiftDatasetDefinition`):
Configuration for Redshift Dataset Definition input.
athena_dataset_definition (:class:`~sagemaker.dataset_definition.inputs.AthenaDatasetDefinition`):
Configuration for Athena Dataset Definition input.
"""
"""DatasetDefinition input."""

_custom_boto_types = {
"redshift_dataset_definition": (RedshiftDatasetDefinition, True),
"athena_dataset_definition": (AthenaDatasetDefinition, True),
}

data_distribution_type = "ShardedByS3Key"
input_mode = "File"
local_path = None
redshift_dataset_definition = None
athena_dataset_definition = None
def __init__(
self,
data_distribution_type="ShardedByS3Key",
input_mode="File",
local_path=None,
redshift_dataset_definition=None,
athena_dataset_definition=None,
):
"""Initialize DatasetDefinition.

Parameters:
data_distribution_type (str, default="ShardedByS3Key"):
Whether the generated dataset is FullyReplicated or ShardedByS3Key (default).
input_mode (str, default="File"):
Whether to use File or Pipe input mode. In File (default) mode, Amazon
SageMaker copies the data from the input source onto the local Amazon Elastic Block
Store (Amazon EBS) volumes before starting your training algorithm. This is the most
commonly used input mode. In Pipe mode, Amazon SageMaker streams input data from the
source directly to your algorithm without using the EBS volume.
local_path (str, default=None):
The local path where you want Amazon SageMaker to download the Dataset
Definition inputs to run a processing job. LocalPath is an absolute path to the
input data. This is a required parameter when `AppManaged` is False (default).
redshift_dataset_definition
(:class:`~sagemaker.dataset_definition.inputs.RedshiftDatasetDefinition`,
default=None):
Configuration for Redshift Dataset Definition input.
athena_dataset_definition
(:class:`~sagemaker.dataset_definition.inputs.AthenaDatasetDefinition`,
default=None):
Configuration for Athena Dataset Definition input.
"""
super(DatasetDefinition, self).__init__(
data_distribution_type=data_distribution_type,
input_mode=input_mode,
local_path=local_path,
redshift_dataset_definition=redshift_dataset_definition,
athena_dataset_definition=athena_dataset_definition,
)


class S3Input(ApiObject):
Expand All @@ -124,20 +177,35 @@ class S3Input(ApiObject):
Note: Strong consistency is not guaranteed if S3Prefix is provided here.
S3 list operations are not strongly consistent.
Use ManifestFile if strong consistency is required.

Parameters:
s3_uri (str): the path to a specific S3 object or a S3 prefix
local_path (str): the path to a local directory. If not provided, skips data download
by SageMaker platform.
s3_data_type (str): Valid options are "ManifestFile" or "S3Prefix".
s3_input_mode (str): Valid options are "Pipe" or "File".
s3_data_distribution_type (str): Valid options are "FullyReplicated" or "ShardedByS3Key".
s3_compression_type (str): Valid options are "None" or "Gzip".
"""

s3_uri = None
local_path = None
s3_data_type = "S3Prefix"
s3_input_mode = "File"
s3_data_distribution_type = "FullyReplicated"
s3_compression_type = None
def __init__(
self,
s3_uri=None,
local_path=None,
s3_data_type="S3Prefix",
s3_input_mode="File",
s3_data_distribution_type="FullyReplicated",
s3_compression_type=None,
):
"""Initialize S3Input.

Parameters:
s3_uri (str, default=None): the path to a specific S3 object or a S3 prefix
local_path (str, default=None):
the path to a local directory. If not provided, skips data download
by SageMaker platform.
s3_data_type (str, default="S3Prefix"): Valid options are "ManifestFile" or "S3Prefix".
s3_input_mode (str, default="File"): Valid options are "Pipe" or "File".
s3_data_distribution_type (str, default="FullyReplicated"):
Valid options are "FullyReplicated" or "ShardedByS3Key".
s3_compression_type (str, default=None): Valid options are "None" or "Gzip".
"""
super(S3Input, self).__init__(
s3_uri=s3_uri,
local_path=local_path,
s3_data_type=s3_data_type,
s3_input_mode=s3_input_mode,
s3_data_distribution_type=s3_data_distribution_type,
s3_compression_type=s3_compression_type,
)
27 changes: 15 additions & 12 deletions src/sagemaker/workflow/_repack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
with tarfile.open(name=local_path, mode="r:gz") as tf:
tf.extractall(path=src_dir)

# copy the custom inference script to code/
entry_point = os.path.join("/opt/ml/code", inference_script)
shutil.copy2(entry_point, os.path.join(src_dir, "code", inference_script))

# copy source_dir to code/
if source_dir:
# copy /opt/ml/code to code/
if os.path.exists(code_dir):
shutil.rmtree(code_dir)
shutil.copytree(source_dir, code_dir)
shutil.copytree("/opt/ml/code", code_dir)
else:
# copy the custom inference script to code/
entry_point = os.path.join("/opt/ml/code", inference_script)
shutil.copy2(entry_point, os.path.join(code_dir, inference_script))

# copy any dependencies to code/lib/
if dependencies:
Expand All @@ -79,13 +79,16 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
lib_dir = os.path.join(code_dir, "lib")
if not os.path.exists(lib_dir):
os.mkdir(lib_dir)
if os.path.isdir(actual_dependency_path):
shutil.copytree(
actual_dependency_path,
os.path.join(lib_dir, os.path.basename(actual_dependency_path)),
)
else:
if os.path.isfile(actual_dependency_path):
shutil.copy2(actual_dependency_path, lib_dir)
else:
if os.path.exists(lib_dir):
shutil.rmtree(lib_dir)
# a directory is in the dependencies. we have to copy
# all of /opt/ml/code into the lib dir because the original directory
# was flattened by the SDK training job upload..
shutil.copytree("/opt/ml/code", lib_dir)
break

# copy the "src" dir, which includes the previous training job's model and the
# custom inference script, to the output of this training job
Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
display_name: str = None,
description: str = None,
source_dir: str = None,
repack_output_path=None,
dependencies: List = None,
depends_on: Union[List[str], List[Step]] = None,
retry_policies: List[RetryPolicy] = None,
Expand Down Expand Up @@ -101,6 +102,9 @@ def __init__(
or model hosting source code dependencies aside from the entry point
file in the Git repo (default: None). Structure within this
directory are preserved when training on Amazon SageMaker.
repack_output_path (str): The S3 prefix URI where the repacked model will be
uploaded (default: None) - don't include a trailing slash.
If not specified, the default location is s3://default-bucket/job-name.
dependencies (list[str]): A list of paths to directories (absolute
or relative) with any additional libraries that will be exported
to the container (default: []). The library folders will be
Expand Down Expand Up @@ -170,6 +174,8 @@ def __init__(
},
subnets=subnets,
security_group_ids=security_group_ids,
output_path=repack_output_path,
code_location=repack_output_path,
**kwargs,
)
repacker.disable_profiler = True
Expand Down
Loading