@@ -269,7 +269,14 @@ def fit(
269
269
if wait :
270
270
self .latest_training_job .wait (logs = logs )
271
271
272
- def record_set (self , train , labels = None , channel = "train" , encrypt = False ):
272
+ def record_set (
273
+ self ,
274
+ train ,
275
+ labels = None ,
276
+ channel = "train" ,
277
+ encrypt = False ,
278
+ distribution = "ShardedByS3Key" ,
279
+ ):
273
280
"""Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector.
274
281
275
282
For the 2D ``ndarray`` ``train``, each row is converted to a
@@ -294,6 +301,8 @@ def record_set(self, train, labels=None, channel="train", encrypt=False):
294
301
should be assigned to.
295
302
encrypt (bool): Specifies whether the objects uploaded to S3 are
296
303
encrypted on the server side using AES-256 (default: ``False``).
304
+ distribution (str): The SageMaker TrainingJob channel s3 data
305
+ distribution type (default: ``ShardedByS3Key``).
297
306
298
307
Returns:
299
308
RecordSet: A RecordSet referencing the encoded, uploading training
@@ -316,6 +325,7 @@ def record_set(self, train, labels=None, channel="train", encrypt=False):
316
325
num_records = train .shape [0 ],
317
326
feature_dim = train .shape [1 ],
318
327
channel = channel ,
328
+ distribution = distribution ,
319
329
)
320
330
321
331
def _get_default_mini_batch_size (self , num_records : int ):
@@ -343,6 +353,7 @@ def __init__(
343
353
feature_dim : int ,
344
354
s3_data_type : Union [str , PipelineVariable ] = "ManifestFile" ,
345
355
channel : Union [str , PipelineVariable ] = "train" ,
356
+ distribution : str = "ShardedByS3Key" ,
346
357
):
347
358
"""A collection of Amazon :class:~`Record` objects serialized and stored in S3.
348
359
@@ -358,12 +369,15 @@ def __init__(
358
369
single s3 manifest file, listing each s3 object to train on.
359
370
channel (str or PipelineVariable): The SageMaker Training Job channel this RecordSet
360
371
should be bound to
372
+ distribution (str): The SageMaker TrainingJob S3 data distribution type.
373
+ Valid values: 'ShardedByS3Key', 'FullyReplicated'.
361
374
"""
362
375
self .s3_data = s3_data
363
376
self .feature_dim = feature_dim
364
377
self .num_records = num_records
365
378
self .s3_data_type = s3_data_type
366
379
self .channel = channel
380
+ self .distribution = distribution
367
381
368
382
def __repr__ (self ):
369
383
"""Return an unambiguous representation of this RecordSet"""
@@ -377,7 +391,7 @@ def data_channel(self):
377
391
def records_s3_input (self ):
378
392
"""Return a TrainingInput to represent the training data"""
379
393
return TrainingInput (
380
- self .s3_data , distribution = "ShardedByS3Key" , s3_data_type = self .s3_data_type
394
+ self .s3_data , distribution = self . distribution , s3_data_type = self .s3_data_type
381
395
)
382
396
383
397
0 commit comments