Skip to content

Commit 9781aba

Browse files
committed
reverted changes
1 parent 13138ea commit 9781aba

File tree

1 file changed

+7
-18
lines changed

1 file changed

+7
-18
lines changed

src/sagemaker/huggingface/estimator.py

+7-18
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Placeholder docstring"""
14+
# commend to retrigger pipeeline
1415
from __future__ import absolute_import
1516

1617
import logging
@@ -153,14 +154,10 @@ def __init__(
153154
self._validate_args(image_uri=image_uri)
154155

155156
if distribution is not None:
156-
instance_type = renamed_kwargs(
157-
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
158-
)
157+
instance_type = renamed_kwargs("train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs)
159158

160159
base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
161-
base_framework_version = (
162-
tensorflow_version if tensorflow_version is not None else pytorch_version
163-
)
160+
base_framework_version = tensorflow_version if tensorflow_version is not None else pytorch_version
164161

165162
validate_smdistributed(
166163
instance_type=instance_type,
@@ -171,18 +168,14 @@ def __init__(
171168
image_uri=image_uri,
172169
)
173170

174-
warn_if_parameter_server_with_multi_gpu(
175-
training_instance_type=instance_type, distribution=distribution
176-
)
171+
warn_if_parameter_server_with_multi_gpu(training_instance_type=instance_type, distribution=distribution)
177172

178173
if "enable_sagemaker_metrics" not in kwargs:
179174
kwargs["enable_sagemaker_metrics"] = True
180175

181176
kwargs["py_version"] = self.py_version
182177

183-
super(HuggingFace, self).__init__(
184-
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
185-
)
178+
super(HuggingFace, self).__init__(entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs)
186179
self.distribution = distribution or {}
187180

188181
def _validate_args(self, image_uri):
@@ -220,9 +213,7 @@ def _validate_args(self, image_uri):
220213
def hyperparameters(self):
221214
"""Return hyperparameters used by your custom PyTorch code during model training."""
222215
hyperparameters = super(HuggingFace, self).hyperparameters()
223-
additional_hyperparameters = self._distribution_configuration(
224-
distribution=self.distribution
225-
)
216+
additional_hyperparameters = self._distribution_configuration(distribution=self.distribution)
226217
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
227218
return hyperparameters
228219

@@ -332,9 +323,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
332323

333324
if framework != cls._framework_name:
334325
raise ValueError(
335-
"Training job: {} didn't use image for requested framework".format(
336-
job_details["TrainingJobName"]
337-
)
326+
"Training job: {} didn't use image for requested framework".format(job_details["TrainingJobName"])
338327
)
339328

340329
return init_params

0 commit comments

Comments
 (0)