Skip to content

change: Add "distribution" parameter into record_set #4408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Feb 19, 2024
18 changes: 16 additions & 2 deletions src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,14 @@ def fit(
if wait:
self.latest_training_job.wait(logs=logs)

def record_set(self, train, labels=None, channel="train", encrypt=False):
def record_set(
self,
train,
labels=None,
channel="train",
encrypt=False,
distribution="ShardedByS3Key",
):
"""Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector.

For the 2D ``ndarray`` ``train``, each row is converted to a
Expand All @@ -294,6 +301,8 @@ def record_set(self, train, labels=None, channel="train", encrypt=False):
should be assigned to.
encrypt (bool): Specifies whether the objects uploaded to S3 are
encrypted on the server side using AES-256 (default: ``False``).
distribution (str): The SageMaker TrainingJob channel s3 data
distribution type (default: ``ShardedByS3Key``).

Returns:
RecordSet: A RecordSet referencing the encoded, uploading training
Expand All @@ -316,6 +325,7 @@ def record_set(self, train, labels=None, channel="train", encrypt=False):
num_records=train.shape[0],
feature_dim=train.shape[1],
channel=channel,
distribution=distribution,
)

def _get_default_mini_batch_size(self, num_records: int):
Expand Down Expand Up @@ -343,6 +353,7 @@ def __init__(
feature_dim: int,
s3_data_type: Union[str, PipelineVariable] = "ManifestFile",
channel: Union[str, PipelineVariable] = "train",
distribution: str = "ShardedByS3Key",
):
"""A collection of Amazon :class:~`Record` objects serialized and stored in S3.

Expand All @@ -358,12 +369,15 @@ def __init__(
single s3 manifest file, listing each s3 object to train on.
channel (str or PipelineVariable): The SageMaker Training Job channel this RecordSet
should be bound to
distribution (str): The SageMaker TrainingJob S3 data distribution type.
Valid values: 'ShardedByS3Key', 'FullyReplicated'.
"""
self.s3_data = s3_data
self.feature_dim = feature_dim
self.num_records = num_records
self.s3_data_type = s3_data_type
self.channel = channel
self.distribution = distribution

def __repr__(self):
"""Return an unambiguous representation of this RecordSet"""
Expand All @@ -377,7 +391,7 @@ def data_channel(self):
def records_s3_input(self):
"""Return a TrainingInput to represent the training data"""
return TrainingInput(
self.s3_data, distribution="ShardedByS3Key", s3_data_type=self.s3_data_type
self.s3_data, distribution=self.distribution, s3_data_type=self.s3_data_type
)


Expand Down