Skip to content

Commit 9091b21

Browse files
authored
change: Add "distribution" parameter into record_set (aws#4408)
* Add distribution as input for RecordSet * bug fix * bug fix * bug fix * Add desc * fix line too long * fix format * reformat * reformat
1 parent 5559ba3 commit 9091b21

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

src/sagemaker/amazon/amazon_estimator.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,14 @@ def fit(
269269
if wait:
270270
self.latest_training_job.wait(logs=logs)
271271

272-
def record_set(self, train, labels=None, channel="train", encrypt=False):
272+
def record_set(
273+
self,
274+
train,
275+
labels=None,
276+
channel="train",
277+
encrypt=False,
278+
distribution="ShardedByS3Key",
279+
):
273280
"""Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector.
274281
275282
For the 2D ``ndarray`` ``train``, each row is converted to a
@@ -294,6 +301,8 @@ def record_set(self, train, labels=None, channel="train", encrypt=False):
294301
should be assigned to.
295302
encrypt (bool): Specifies whether the objects uploaded to S3 are
296303
encrypted on the server side using AES-256 (default: ``False``).
304+
distribution (str): The SageMaker TrainingJob channel s3 data
305+
distribution type (default: ``ShardedByS3Key``).
297306
298307
Returns:
299308
RecordSet: A RecordSet referencing the encoded, uploading training
@@ -316,6 +325,7 @@ def record_set(self, train, labels=None, channel="train", encrypt=False):
316325
num_records=train.shape[0],
317326
feature_dim=train.shape[1],
318327
channel=channel,
328+
distribution=distribution,
319329
)
320330

321331
def _get_default_mini_batch_size(self, num_records: int):
@@ -343,6 +353,7 @@ def __init__(
343353
feature_dim: int,
344354
s3_data_type: Union[str, PipelineVariable] = "ManifestFile",
345355
channel: Union[str, PipelineVariable] = "train",
356+
distribution: str = "ShardedByS3Key",
346357
):
347358
"""A collection of Amazon :class:~`Record` objects serialized and stored in S3.
348359
@@ -358,12 +369,15 @@ def __init__(
358369
single s3 manifest file, listing each s3 object to train on.
359370
channel (str or PipelineVariable): The SageMaker Training Job channel this RecordSet
360371
should be bound to
372+
distribution (str): The SageMaker TrainingJob S3 data distribution type.
373+
Valid values: 'ShardedByS3Key', 'FullyReplicated'.
361374
"""
362375
self.s3_data = s3_data
363376
self.feature_dim = feature_dim
364377
self.num_records = num_records
365378
self.s3_data_type = s3_data_type
366379
self.channel = channel
380+
self.distribution = distribution
367381

368382
def __repr__(self):
369383
"""Return an unambiguous representation of this RecordSet"""
@@ -377,7 +391,7 @@ def data_channel(self):
377391
def records_s3_input(self):
378392
"""Return a TrainingInput to represent the training data"""
379393
return TrainingInput(
380-
self.s3_data, distribution="ShardedByS3Key", s3_data_type=self.s3_data_type
394+
self.s3_data, distribution=self.distribution, s3_data_type=self.s3_data_type
381395
)
382396

383397

0 commit comments

Comments
 (0)