Skip to content

Commit 57f2000

Browse files
fix: S3Input - add support for instance attributes
1 parent 7268e82 commit 57f2000

File tree

2 files changed

+141
-29
lines changed

2 files changed

+141
-29
lines changed

src/sagemaker/dataset_definition/inputs.py

+122-29
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,20 @@ class RedshiftDatasetDefinition(ApiObject):
2828
With this input, SQL queries will be executed using Redshift to generate datasets to S3.
2929
3030
Parameters:
31-
cluster_id (str): The Redshift cluster Identifier.
32-
database (str): The name of the Redshift database used in Redshift query execution.
33-
db_user (str): The database user name used in Redshift query execution.
34-
query_string (str): The SQL query statements to be executed.
35-
cluster_role_arn (str): The IAM role attached to your Redshift cluster that
31+
cluster_id (str, default=None): The Redshift cluster Identifier.
32+
database (str, default=None):
33+
The name of the Redshift database used in Redshift query execution.
34+
db_user (str, default=None): The database user name used in Redshift query execution.
35+
query_string (str, default=None): The SQL query statements to be executed.
36+
cluster_role_arn (str, default=None): The IAM role attached to your Redshift cluster that
3637
Amazon SageMaker uses to generate datasets.
37-
output_s3_uri (str): The location in Amazon S3 where the Redshift query
38+
output_s3_uri (str, default=None): The location in Amazon S3 where the Redshift query
3839
results are stored.
39-
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon
40+
kms_key_id (str, default=None): The AWS Key Management Service (AWS KMS) key that Amazon
4041
SageMaker uses to encrypt data from a Redshift execution.
41-
output_format (str): The data storage format for Redshift query results.
42+
output_format (str, default=None): The data storage format for Redshift query results.
4243
Valid options are "PARQUET", "CSV"
43-
output_compression (str): The compression used for Redshift query results.
44+
output_compression (str, default=None): The compression used for Redshift query results.
4445
Valid options are "None", "GZIP", "SNAPPY", "ZSTD", "BZIP2"
4546
"""
4647

@@ -54,23 +55,50 @@ class RedshiftDatasetDefinition(ApiObject):
5455
output_format = None
5556
output_compression = None
5657

58+
def __init__(
59+
self,
60+
cluster_id=None,
61+
database=None,
62+
db_user=None,
63+
query_string=None,
64+
cluster_role_arn=None,
65+
output_s3_uri=None,
66+
kms_key_id=None,
67+
output_format=None,
68+
output_compression=None,
69+
):
70+
"""Initialize RedshiftDatasetDefinition."""
71+
super(RedshiftDatasetDefinition, self).__init__(
72+
cluster_id=cluster_id,
73+
database=database,
74+
db_user=db_user,
75+
query_string=query_string,
76+
cluster_role_arn=cluster_role_arn,
77+
output_s3_uri=output_s3_uri,
78+
kms_key_id=kms_key_id,
79+
output_format=output_format,
80+
output_compression=output_compression,
81+
)
82+
5783

5884
class AthenaDatasetDefinition(ApiObject):
5985
"""DatasetDefinition for Athena.
6086
6187
With this input, SQL queries will be executed using Athena to generate datasets to S3.
6288
6389
Parameters:
64-
catalog (str): The name of the data catalog used in Athena query execution.
65-
database (str): The name of the database used in the Athena query execution.
66-
query_string (str): The SQL query statements, to be executed.
67-
output_s3_uri (str): The location in Amazon S3 where Athena query results are stored.
68-
work_group (str): The name of the workgroup in which the Athena query is being started.
69-
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon
90+
catalog (str, default=None): The name of the data catalog used in Athena query execution.
91+
database (str, default=None): The name of the database used in the Athena query execution.
92+
query_string (str, default=None): The SQL query statements, to be executed.
93+
output_s3_uri (str, default=None):
94+
The location in Amazon S3 where Athena query results are stored.
95+
work_group (str, default=None):
96+
The name of the workgroup in which the Athena query is being started.
97+
kms_key_id (str, default=None): The AWS Key Management Service (AWS KMS) key that Amazon
7098
SageMaker uses to encrypt data generated from an Athena query execution.
71-
output_format (str): The data storage format for Athena query results.
99+
output_format (str, default=None): The data storage format for Athena query results.
72100
Valid options are "PARQUET", "ORC", "AVRO", "JSON", "TEXTFILE"
73-
output_compression (str): The compression used for Athena query results.
101+
output_compression (str, default=None): The compression used for Athena query results.
74102
Valid options are "GZIP", "SNAPPY", "ZLIB"
75103
"""
76104

@@ -83,24 +111,51 @@ class AthenaDatasetDefinition(ApiObject):
83111
output_format = None
84112
output_compression = None
85113

114+
def __init__(
115+
self,
116+
catalog=None,
117+
database=None,
118+
query_string=None,
119+
output_s3_uri=None,
120+
work_group=None,
121+
kms_key_id=None,
122+
output_format=None,
123+
output_compression=None,
124+
):
125+
"""Initialize AthenaDatasetDefinition."""
126+
super(AthenaDatasetDefinition, self).__init__(
127+
catalog=catalog,
128+
database=database,
129+
query_string=query_string,
130+
output_s3_uri=output_s3_uri,
131+
work_group=work_group,
132+
kms_key_id=kms_key_id,
133+
output_format=output_format,
134+
output_compression=output_compression,
135+
)
136+
86137

87138
class DatasetDefinition(ApiObject):
88139
"""DatasetDefinition input.
89140
90141
Parameters:
91-
data_distribution_type (str): Whether the generated dataset is FullyReplicated or
92-
ShardedByS3Key (default).
93-
input_mode (str): Whether to use File or Pipe input mode. In File (default) mode, Amazon
142+
data_distribution_type (str, default="ShardedByS3Key"):
143+
Whether the generated dataset is FullyReplicated or ShardedByS3Key (default).
144+
input_mode (str, default="File"):
145+
Whether to use File or Pipe input mode. In File (default) mode, Amazon
94146
SageMaker copies the data from the input source onto the local Amazon Elastic Block
95147
Store (Amazon EBS) volumes before starting your training algorithm. This is the most
96148
commonly used input mode. In Pipe mode, Amazon SageMaker streams input data from the
97149
source directly to your algorithm without using the EBS volume.
98-
local_path (str): The local path where you want Amazon SageMaker to download the Dataset
150+
local_path (str, default=None):
151+
The local path where you want Amazon SageMaker to download the Dataset
99152
Definition inputs to run a processing job. LocalPath is an absolute path to the input
100153
data. This is a required parameter when `AppManaged` is False (default).
101-
redshift_dataset_definition (:class:`~sagemaker.dataset_definition.inputs.RedshiftDatasetDefinition`):
154+
redshift_dataset_definition
155+
(:class:`~sagemaker.dataset_definition.inputs.RedshiftDatasetDefinition`,default=None):
102156
Configuration for Redshift Dataset Definition input.
103-
athena_dataset_definition (:class:`~sagemaker.dataset_definition.inputs.AthenaDatasetDefinition`):
157+
athena_dataset_definition
158+
(:class:`~sagemaker.dataset_definition.inputs.AthenaDatasetDefinition`, default=None):
104159
Configuration for Athena Dataset Definition input.
105160
"""
106161

@@ -115,6 +170,23 @@ class DatasetDefinition(ApiObject):
115170
redshift_dataset_definition = None
116171
athena_dataset_definition = None
117172

173+
def __init__(
174+
self,
175+
data_distribution_type="ShardedByS3Key",
176+
input_mode="File",
177+
local_path=None,
178+
redshift_dataset_definition=None,
179+
athena_dataset_definition=None,
180+
):
181+
"""Initialize DatasetDefinition."""
182+
super(DatasetDefinition, self).__init__(
183+
data_distribution_type=data_distribution_type,
184+
input_mode=input_mode,
185+
local_path=local_path,
186+
redshift_dataset_definition=redshift_dataset_definition,
187+
athena_dataset_definition=athena_dataset_definition,
188+
)
189+
118190

119191
class S3Input(ApiObject):
120192
"""Metadata of data objects stored in S3.
@@ -126,13 +198,15 @@ class S3Input(ApiObject):
126198
Use ManifestFile if strong consistency is required.
127199
128200
Parameters:
129-
s3_uri (str): the path to a specific S3 object or a S3 prefix
130-
local_path (str): the path to a local directory. If not provided, skips data download
201+
s3_uri (str, default=None): the path to a specific S3 object or a S3 prefix
202+
local_path (str, default=None):
203+
the path to a local directory. If not provided, skips data download
131204
by SageMaker platform.
132-
s3_data_type (str): Valid options are "ManifestFile" or "S3Prefix".
133-
s3_input_mode (str): Valid options are "Pipe" or "File".
134-
s3_data_distribution_type (str): Valid options are "FullyReplicated" or "ShardedByS3Key".
135-
s3_compression_type (str): Valid options are "None" or "Gzip".
205+
s3_data_type (str, default="S3Prefix"): Valid options are "ManifestFile" or "S3Prefix".
206+
s3_input_mode (str, default="File"): Valid options are "Pipe" or "File".
207+
s3_data_distribution_type (str, default="FullyReplicated"):
208+
Valid options are "FullyReplicated" or "ShardedByS3Key".
209+
s3_compression_type (str, default=None): Valid options are "None" or "Gzip".
136210
"""
137211

138212
s3_uri = None
@@ -141,3 +215,22 @@ class S3Input(ApiObject):
141215
s3_input_mode = "File"
142216
s3_data_distribution_type = "FullyReplicated"
143217
s3_compression_type = None
218+
219+
def __init__(
220+
self,
221+
s3_uri=None,
222+
local_path=None,
223+
s3_data_type="S3Prefix",
224+
s3_input_mode="File",
225+
s3_data_distribution_type="FullyReplicated",
226+
s3_compression_type=None,
227+
):
228+
"""Initialize S3Input."""
229+
super(S3Input, self).__init__(
230+
s3_uri=s3_uri,
231+
local_path=local_path,
232+
s3_data_type=s3_data_type,
233+
s3_input_mode=s3_input_mode,
234+
s3_data_distribution_type=s3_data_distribution_type,
235+
s3_compression_type=s3_compression_type,
236+
)

tests/integ/test_processing.py

+19
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,14 @@ def _get_processing_inputs_with_all_parameters(bucket):
747747
destination="/opt/ml/processing/input/data/",
748748
input_name="my_dataset",
749749
),
750+
ProcessingInput(
751+
input_name="s3_input_wo_defaults",
752+
s3_input=S3Input(
753+
s3_uri=f"s3://{bucket}",
754+
local_path="/opt/ml/processing/input/s3_input_wo_defaults",
755+
s3_data_type="S3Prefix",
756+
),
757+
),
750758
ProcessingInput(
751759
input_name="s3_input",
752760
s3_input=S3Input(
@@ -822,6 +830,17 @@ def _get_processing_job_inputs_and_outputs(bucket, output_kms_key):
822830
"S3CompressionType": "None",
823831
},
824832
},
833+
{
834+
"InputName": "s3_input_wo_defaults",
835+
"AppManaged": False,
836+
"S3Input": {
837+
"S3Uri": f"s3://{bucket}",
838+
"LocalPath": "/opt/ml/processing/input/s3_input_wo_defaults",
839+
"S3DataType": "S3Prefix",
840+
"S3InputMode": "File",
841+
"S3DataDistributionType": "FullyReplicated",
842+
},
843+
},
825844
{
826845
"InputName": "s3_input",
827846
"AppManaged": False,

0 commit comments

Comments
 (0)