Skip to content

Commit 0bae071

Browse files
mufaddal-rohawalashreyapandit
authored andcommitted
fix: S3Input - add support for instance attributes (#2754)
1 parent 0f72907 commit 0bae071

File tree

2 files changed

+171
-84
lines changed

2 files changed

+171
-84
lines changed

src/sagemaker/dataset_definition/inputs.py

+152-84
Original file line numberDiff line numberDiff line change
@@ -26,94 +26,147 @@ class RedshiftDatasetDefinition(ApiObject):
2626
"""DatasetDefinition for Redshift.
2727
2828
With this input, SQL queries will be executed using Redshift to generate datasets to S3.
29-
30-
Parameters:
31-
cluster_id (str): The Redshift cluster Identifier.
32-
database (str): The name of the Redshift database used in Redshift query execution.
33-
db_user (str): The database user name used in Redshift query execution.
34-
query_string (str): The SQL query statements to be executed.
35-
cluster_role_arn (str): The IAM role attached to your Redshift cluster that
36-
Amazon SageMaker uses to generate datasets.
37-
output_s3_uri (str): The location in Amazon S3 where the Redshift query
38-
results are stored.
39-
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon
40-
SageMaker uses to encrypt data from a Redshift execution.
41-
output_format (str): The data storage format for Redshift query results.
42-
Valid options are "PARQUET", "CSV"
43-
output_compression (str): The compression used for Redshift query results.
44-
Valid options are "None", "GZIP", "SNAPPY", "ZSTD", "BZIP2"
4529
"""
4630

47-
cluster_id = None
48-
database = None
49-
db_user = None
50-
query_string = None
51-
cluster_role_arn = None
52-
output_s3_uri = None
53-
kms_key_id = None
54-
output_format = None
55-
output_compression = None
31+
def __init__(
32+
self,
33+
cluster_id=None,
34+
database=None,
35+
db_user=None,
36+
query_string=None,
37+
cluster_role_arn=None,
38+
output_s3_uri=None,
39+
kms_key_id=None,
40+
output_format=None,
41+
output_compression=None,
42+
):
43+
"""Initialize RedshiftDatasetDefinition.
44+
45+
Args:
46+
cluster_id (str, default=None): The Redshift cluster Identifier.
47+
database (str, default=None):
48+
The name of the Redshift database used in Redshift query execution.
49+
db_user (str, default=None): The database user name used in Redshift query execution.
50+
query_string (str, default=None): The SQL query statements to be executed.
51+
cluster_role_arn (str, default=None): The IAM role attached to your Redshift cluster
52+
that Amazon SageMaker uses to generate datasets.
53+
output_s3_uri (str, default=None): The location in Amazon S3 where the Redshift query
54+
results are stored.
55+
kms_key_id (str, default=None): The AWS Key Management Service (AWS KMS) key that Amazon
56+
SageMaker uses to encrypt data from a Redshift execution.
57+
output_format (str, default=None): The data storage format for Redshift query results.
58+
Valid options are "PARQUET", "CSV"
59+
output_compression (str, default=None): The compression used for Redshift query results.
60+
Valid options are "None", "GZIP", "SNAPPY", "ZSTD", "BZIP2"
61+
"""
62+
super(RedshiftDatasetDefinition, self).__init__(
63+
cluster_id=cluster_id,
64+
database=database,
65+
db_user=db_user,
66+
query_string=query_string,
67+
cluster_role_arn=cluster_role_arn,
68+
output_s3_uri=output_s3_uri,
69+
kms_key_id=kms_key_id,
70+
output_format=output_format,
71+
output_compression=output_compression,
72+
)
5673

5774

5875
class AthenaDatasetDefinition(ApiObject):
5976
"""DatasetDefinition for Athena.
6077
6178
With this input, SQL queries will be executed using Athena to generate datasets to S3.
62-
63-
Parameters:
64-
catalog (str): The name of the data catalog used in Athena query execution.
65-
database (str): The name of the database used in the Athena query execution.
66-
query_string (str): The SQL query statements, to be executed.
67-
output_s3_uri (str): The location in Amazon S3 where Athena query results are stored.
68-
work_group (str): The name of the workgroup in which the Athena query is being started.
69-
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon
70-
SageMaker uses to encrypt data generated from an Athena query execution.
71-
output_format (str): The data storage format for Athena query results.
72-
Valid options are "PARQUET", "ORC", "AVRO", "JSON", "TEXTFILE"
73-
output_compression (str): The compression used for Athena query results.
74-
Valid options are "GZIP", "SNAPPY", "ZLIB"
7579
"""
7680

77-
catalog = None
78-
database = None
79-
query_string = None
80-
output_s3_uri = None
81-
work_group = None
82-
kms_key_id = None
83-
output_format = None
84-
output_compression = None
81+
def __init__(
82+
self,
83+
catalog=None,
84+
database=None,
85+
query_string=None,
86+
output_s3_uri=None,
87+
work_group=None,
88+
kms_key_id=None,
89+
output_format=None,
90+
output_compression=None,
91+
):
92+
"""Initialize AthenaDatasetDefinition.
93+
94+
Args:
95+
catalog (str, default=None): The name of the data catalog used in Athena query
96+
execution.
97+
database (str, default=None): The name of the database used in the Athena query
98+
execution.
99+
query_string (str, default=None): The SQL query statements, to be executed.
100+
output_s3_uri (str, default=None):
101+
The location in Amazon S3 where Athena query results are stored.
102+
work_group (str, default=None):
103+
The name of the workgroup in which the Athena query is being started.
104+
kms_key_id (str, default=None): The AWS Key Management Service (AWS KMS) key that Amazon
105+
SageMaker uses to encrypt data generated from an Athena query execution.
106+
output_format (str, default=None): The data storage format for Athena query results.
107+
Valid options are "PARQUET", "ORC", "AVRO", "JSON", "TEXTFILE"
108+
output_compression (str, default=None): The compression used for Athena query results.
109+
Valid options are "GZIP", "SNAPPY", "ZLIB"
110+
"""
111+
super(AthenaDatasetDefinition, self).__init__(
112+
catalog=catalog,
113+
database=database,
114+
query_string=query_string,
115+
output_s3_uri=output_s3_uri,
116+
work_group=work_group,
117+
kms_key_id=kms_key_id,
118+
output_format=output_format,
119+
output_compression=output_compression,
120+
)
85121

86122

87123
class DatasetDefinition(ApiObject):
88-
"""DatasetDefinition input.
89-
90-
Parameters:
91-
data_distribution_type (str): Whether the generated dataset is FullyReplicated or
92-
ShardedByS3Key (default).
93-
input_mode (str): Whether to use File or Pipe input mode. In File (default) mode, Amazon
94-
SageMaker copies the data from the input source onto the local Amazon Elastic Block
95-
Store (Amazon EBS) volumes before starting your training algorithm. This is the most
96-
commonly used input mode. In Pipe mode, Amazon SageMaker streams input data from the
97-
source directly to your algorithm without using the EBS volume.
98-
local_path (str): The local path where you want Amazon SageMaker to download the Dataset
99-
Definition inputs to run a processing job. LocalPath is an absolute path to the input
100-
data. This is a required parameter when `AppManaged` is False (default).
101-
redshift_dataset_definition (:class:`~sagemaker.dataset_definition.inputs.RedshiftDatasetDefinition`):
102-
Configuration for Redshift Dataset Definition input.
103-
athena_dataset_definition (:class:`~sagemaker.dataset_definition.inputs.AthenaDatasetDefinition`):
104-
Configuration for Athena Dataset Definition input.
105-
"""
124+
"""DatasetDefinition input."""
106125

107126
_custom_boto_types = {
108127
"redshift_dataset_definition": (RedshiftDatasetDefinition, True),
109128
"athena_dataset_definition": (AthenaDatasetDefinition, True),
110129
}
111130

112-
data_distribution_type = "ShardedByS3Key"
113-
input_mode = "File"
114-
local_path = None
115-
redshift_dataset_definition = None
116-
athena_dataset_definition = None
131+
def __init__(
132+
self,
133+
data_distribution_type="ShardedByS3Key",
134+
input_mode="File",
135+
local_path=None,
136+
redshift_dataset_definition=None,
137+
athena_dataset_definition=None,
138+
):
139+
"""Initialize DatasetDefinition.
140+
141+
Parameters:
142+
data_distribution_type (str, default="ShardedByS3Key"):
143+
Whether the generated dataset is FullyReplicated or ShardedByS3Key (default).
144+
input_mode (str, default="File"):
145+
Whether to use File or Pipe input mode. In File (default) mode, Amazon
146+
SageMaker copies the data from the input source onto the local Amazon Elastic Block
147+
Store (Amazon EBS) volumes before starting your training algorithm. This is the most
148+
commonly used input mode. In Pipe mode, Amazon SageMaker streams input data from the
149+
source directly to your algorithm without using the EBS volume.
150+
local_path (str, default=None):
151+
The local path where you want Amazon SageMaker to download the Dataset
152+
Definition inputs to run a processing job. LocalPath is an absolute path to the
153+
input data. This is a required parameter when `AppManaged` is False (default).
154+
redshift_dataset_definition
155+
(:class:`~sagemaker.dataset_definition.inputs.RedshiftDatasetDefinition`,
156+
default=None):
157+
Configuration for Redshift Dataset Definition input.
158+
athena_dataset_definition
159+
(:class:`~sagemaker.dataset_definition.inputs.AthenaDatasetDefinition`,
160+
default=None):
161+
Configuration for Athena Dataset Definition input.
162+
"""
163+
super(DatasetDefinition, self).__init__(
164+
data_distribution_type=data_distribution_type,
165+
input_mode=input_mode,
166+
local_path=local_path,
167+
redshift_dataset_definition=redshift_dataset_definition,
168+
athena_dataset_definition=athena_dataset_definition,
169+
)
117170

118171

119172
class S3Input(ApiObject):
@@ -124,20 +177,35 @@ class S3Input(ApiObject):
124177
Note: Strong consistency is not guaranteed if S3Prefix is provided here.
125178
S3 list operations are not strongly consistent.
126179
Use ManifestFile if strong consistency is required.
127-
128-
Parameters:
129-
s3_uri (str): the path to a specific S3 object or a S3 prefix
130-
local_path (str): the path to a local directory. If not provided, skips data download
131-
by SageMaker platform.
132-
s3_data_type (str): Valid options are "ManifestFile" or "S3Prefix".
133-
s3_input_mode (str): Valid options are "Pipe" or "File".
134-
s3_data_distribution_type (str): Valid options are "FullyReplicated" or "ShardedByS3Key".
135-
s3_compression_type (str): Valid options are "None" or "Gzip".
136180
"""
137181

138-
s3_uri = None
139-
local_path = None
140-
s3_data_type = "S3Prefix"
141-
s3_input_mode = "File"
142-
s3_data_distribution_type = "FullyReplicated"
143-
s3_compression_type = None
182+
def __init__(
183+
self,
184+
s3_uri=None,
185+
local_path=None,
186+
s3_data_type="S3Prefix",
187+
s3_input_mode="File",
188+
s3_data_distribution_type="FullyReplicated",
189+
s3_compression_type=None,
190+
):
191+
"""Initialize S3Input.
192+
193+
Parameters:
194+
s3_uri (str, default=None): the path to a specific S3 object or a S3 prefix
195+
local_path (str, default=None):
196+
the path to a local directory. If not provided, skips data download
197+
by SageMaker platform.
198+
s3_data_type (str, default="S3Prefix"): Valid options are "ManifestFile" or "S3Prefix".
199+
s3_input_mode (str, default="File"): Valid options are "Pipe" or "File".
200+
s3_data_distribution_type (str, default="FullyReplicated"):
201+
Valid options are "FullyReplicated" or "ShardedByS3Key".
202+
s3_compression_type (str, default=None): Valid options are "None" or "Gzip".
203+
"""
204+
super(S3Input, self).__init__(
205+
s3_uri=s3_uri,
206+
local_path=local_path,
207+
s3_data_type=s3_data_type,
208+
s3_input_mode=s3_input_mode,
209+
s3_data_distribution_type=s3_data_distribution_type,
210+
s3_compression_type=s3_compression_type,
211+
)

tests/integ/test_processing.py

+19
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,14 @@ def _get_processing_inputs_with_all_parameters(bucket):
747747
destination="/opt/ml/processing/input/data/",
748748
input_name="my_dataset",
749749
),
750+
ProcessingInput(
751+
input_name="s3_input_wo_defaults",
752+
s3_input=S3Input(
753+
s3_uri=f"s3://{bucket}",
754+
local_path="/opt/ml/processing/input/s3_input_wo_defaults",
755+
s3_data_type="S3Prefix",
756+
),
757+
),
750758
ProcessingInput(
751759
input_name="s3_input",
752760
s3_input=S3Input(
@@ -822,6 +830,17 @@ def _get_processing_job_inputs_and_outputs(bucket, output_kms_key):
822830
"S3CompressionType": "None",
823831
},
824832
},
833+
{
834+
"InputName": "s3_input_wo_defaults",
835+
"AppManaged": False,
836+
"S3Input": {
837+
"S3Uri": f"s3://{bucket}",
838+
"LocalPath": "/opt/ml/processing/input/s3_input_wo_defaults",
839+
"S3DataType": "S3Prefix",
840+
"S3InputMode": "File",
841+
"S3DataDistributionType": "FullyReplicated",
842+
},
843+
},
825844
{
826845
"InputName": "s3_input",
827846
"AppManaged": False,

0 commit comments

Comments
 (0)