Skip to content

Commit 0d85338

Browse files
committed
breaking: updates based on PR feedback
1 parent ddca516 commit 0d85338

File tree

3 files changed

+7
-10
lines changed

3 files changed

+7
-10
lines changed

src/sagemaker/fw_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -690,12 +690,12 @@ def validate_version_or_image_args(framework_version, py_version, image_name):
690690
691691
Args:
692692
framework_version (str): The version of the framework.
693-
py_version (str): The version of python.
693+
py_version (str): The version of Python.
694694
image_name (str): The URI of the image.
695695
696696
Raises:
697697
ValueError: if `image_name` is None and either `framework_version` or `py_version` is
698-
None.
698+
None.
699699
"""
700700
if (framework_version is None or py_version is None) and image_name is None:
701701
raise ValueError(

src/sagemaker/mxnet/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,15 @@ def __init__(
8888
must point to a file located at the root of ``source_dir``.
8989
framework_version (str): MXNet version you want to use for executing
9090
your model training code. Defaults to ``None``. Required unless
91-
``image_name`` is provided.
91+
``image`` is provided.
9292
py_version (str): Python version you want to use for executing your
9393
model training code. Defaults to ``None``. Required unless
94-
``image_name`` is provided.
94+
``image`` is provided.
9595
image (str): A Docker image URI (default: None). If not specified, a
9696
default image for MXNet will be used.
9797
9898
If ``framework_version`` or ``py_version`` are ``None``, then
99-
``image_name`` is required. If also ``None``, then a ``ValueError``
99+
``image`` is required. If also ``None``, then a ``ValueError``
100100
will be raised.
101101
predictor_cls (callable[str, sagemaker.session.Session]): A function
102102
to call to create a predictor with an endpoint name and

tests/unit/test_fw_utils.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1274,13 +1274,10 @@ def test_warn_if_parameter_server_with_multi_gpu(caplog):
12741274
assert fw_utils.PARAMETER_SERVER_MULTI_GPU_WARNING in caplog.text
12751275

12761276

1277-
def test_validate_version_or_image_args():
1277+
def test_validate_version_or_image_args_not_raises():
12781278
good_args = [("1.0", "py3", None), (None, "py3", "my:uri"), ("1.0", None, "my:uri")]
12791279
for framework_version, py_version, image_name in good_args:
1280-
assert (
1281-
fw_utils.validate_version_or_image_args(framework_version, py_version, image_name)
1282-
is None
1283-
)
1280+
fw_utils.validate_version_or_image_args(framework_version, py_version, image_name)
12841281

12851282

12861283
def test_validate_version_or_image_args_raises():

0 commit comments

Comments
 (0)