Skip to content

doc: add link to framework-related parent classes to clarify **kwargs #1180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,20 @@ def __init__(
image_name (str): If specified, the estimator will use this image
for training and hosting, instead of selecting the appropriate
SageMaker official image based on framework_version and
py_version. It can be an ECR url or dockerhub image and tag. ..
admonition:: Examples
py_version. It can be an ECR url or dockerhub image and tag.

Examples
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
* ``custom-image:latest``

123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
custom-image:latest.
**kwargs: Additional kwargs passed to the
:class:`~sagemaker.estimator.Framework` constructor.

.. tip::

You can find additional parameters for initializing this class at
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
if framework_version is None:
logger.warning(empty_framework_version_warning(CHAINER_VERSION, self.LATEST_VERSION))
Expand Down
12 changes: 9 additions & 3 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand Down Expand Up @@ -93,8 +93,14 @@ def __init__(
model_server_workers (int): Optional. The number of worker processes
used by the inference server. If None, server will use one
worker per vCPU.
**kwargs: Keyword arguments passed to the ``FrameworkModel``
initializer.
**kwargs: Keyword arguments passed to the
:class:`~sagemaker.model.FrameworkModel` initializer.

.. tip::

You can find additional parameters for initializing this class at
:class:`~sagemaker.model.FrameworkModel` and
:class:`~sagemaker.model.Model`.
"""
super(ChainerModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
Expand Down
22 changes: 17 additions & 5 deletions src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,22 @@ def __init__(
image_name (str): If specified, the estimator will use this image for training and
hosting, instead of selecting the appropriate SageMaker official image based on
framework_version and py_version. It can be an ECR url or dockerhub image and tag.

Examples:
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
custom-image:latest.
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
* ``custom-image:latest``

distributions (dict): A dictionary with information on how to run distributed
training (default: None). To have parameter servers launched for training,
set this value to be ``{'parameter_server': {'enabled': True}}``.
**kwargs: Additional kwargs passed to the
:class:`~sagemaker.estimator.Framework` constructor.

.. tip::

You can find additional parameters for initializing this class at
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
if framework_version is None:
logger.warning(empty_framework_version_warning(MXNET_VERSION, self.LATEST_VERSION))
Expand Down Expand Up @@ -161,8 +168,10 @@ def create_model(
role from the Estimator will be used.
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
the model. Default: use subnets and security groups from this Estimator.

* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.

entry_point (str): Path (absolute or relative) to the local Python source file which
should be executed as the entry point to training. If not specified, the training
entry point is used.
Expand All @@ -175,10 +184,13 @@ def create_model(
image_name (str): If specified, the estimator will use this image for hosting, instead
of selecting the appropriate SageMaker official image based on framework_version
and py_version. It can be an ECR url or dockerhub image and tag.

Examples:
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
custom-image:latest.
**kwargs: Additional kwargs passed to the MXNetModel constructor.
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
* ``custom-image:latest``

**kwargs: Additional kwargs passed to the :class:`~sagemaker.mxnet.model.MXNetModel`
constructor.

Returns:
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
Expand Down
8 changes: 7 additions & 1 deletion src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand Down Expand Up @@ -96,6 +96,12 @@ def __init__(
worker per vCPU.
**kwargs: Keyword arguments passed to the ``FrameworkModel``
initializer.

.. tip::

You can find additional parameters for initializing this class at
:class:`~sagemaker.model.FrameworkModel` and
:class:`~sagemaker.model.Model`.
"""
super(MXNetModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
Expand Down
15 changes: 12 additions & 3 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,19 @@ def __init__(
for training and hosting, instead of selecting the appropriate
SageMaker official image based on framework_version and
py_version. It can be an ECR url or dockerhub image and tag.

Examples:
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
custom-image:latest.
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
* ``custom-image:latest``

**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
constructor.

.. tip::

You can find additional parameters for initializing this class at
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
if framework_version is None:
logger.warning(empty_framework_version_warning(PYTORCH_VERSION, PYTORCH_VERSION))
Expand Down Expand Up @@ -146,7 +154,8 @@ def create_model(
dependencies (list[str]): A list of paths to directories (absolute or relative) with
any additional libraries that will be exported to the container.
If not specified, the dependencies from training are used.
**kwargs: Additional kwargs passed to the PyTorchModel constructor.
**kwargs: Additional kwargs passed to the :class:`~sagemaker.pytorch.model.PyTorchModel`
constructor.

Returns:
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel``
Expand Down
8 changes: 7 additions & 1 deletion src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand Down Expand Up @@ -97,6 +97,12 @@ def __init__(
worker per vCPU.
**kwargs: Keyword arguments passed to the ``FrameworkModel``
initializer.

.. tip::

You can find additional parameters for initializing this class at
:class:`~sagemaker.model.FrameworkModel` and
:class:`~sagemaker.model.Model`.
"""
super(PyTorchModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
Expand Down
22 changes: 16 additions & 6 deletions src/sagemaker/rl/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ def __init__(
don't use an Amazon algorithm.
**kwargs: Additional kwargs passed to the
:class:`~sagemaker.estimator.Framework` constructor.

.. tip::

You can find additional parameters for initializing this class at
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
self._validate_images_args(toolkit, toolkit_version, framework, image_name)

Expand Down Expand Up @@ -174,8 +180,10 @@ def create_model(
role from the Estimator will be used.
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
the model. Default: use subnets and security groups from this Estimator.

* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.

entry_point (str): Path (absolute or relative) to the Python source
file which should be executed as the entry point for MXNet
hosting. This should be compatible with Python 3.5 (default:
Expand All @@ -190,18 +198,20 @@ def create_model(
folders will be copied to SageMaker in the same folder where the
entry_point is copied. If the ```source_dir``` points to S3,
code will be uploaded and the S3 location will be used instead.
**kwargs: Additional kwargs passed to the FrameworkModel constructor.
**kwargs: Additional kwargs passed to the :class:`~sagemaker.model.FrameworkModel`
constructor.

Returns:
sagemaker.model.FrameworkModel: Depending on input parameters returns
one of the following:

* sagemaker.model.FrameworkModel - in case image_name was specified
* :class:`~sagemaker.model.FrameworkModel` - if ``image_name`` was specified
on the estimator;
* sagemaker.mxnet.MXNetModel - if image_name wasn't specified and
MXNet was used as RL backend;
* sagemaker.tensorflow.serving.Model - if image_name wasn't specified and
TensorFlow was used as RL backend.
* :class:`~sagemaker.mxnet.MXNetModel` - if ``image_name`` wasn't specified and
MXNet was used as the RL backend;
* :class:`~sagemaker.tensorflow.serving.Model` - if ``image_name`` wasn't specified
and TensorFlow was used as the RL backend.

Raises:
ValueError: If image_name was not specified and framework enum is not valid.
"""
Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/sklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def __init__(
custom-image:latest.
**kwargs: Additional kwargs passed to the
:class:`~sagemaker.estimator.Framework` constructor.

.. tip::

You can find additional parameters for initializing this class at
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
# SciKit-Learn does not support distributed training or training on GPU instance types.
# Fail fast.
Expand Down
8 changes: 7 additions & 1 deletion src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand Down Expand Up @@ -96,6 +96,12 @@ def __init__(
worker per vCPU.
**kwargs: Keyword arguments passed to the ``FrameworkModel``
initializer.

.. tip::

You can find additional parameters for initializing this class at
:class:`~sagemaker.model.FrameworkModel` and
:class:`~sagemaker.model.Model`.
"""
super(SKLearnModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
Expand Down
10 changes: 8 additions & 2 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,12 @@ def __init__(
}

**kwargs: Additional kwargs passed to the Framework constructor.

.. tip::

You can find additional parameters for initializing this class at
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
if framework_version is None:
logger.warning(fw.empty_framework_version_warning(TF_VERSION, self.LATEST_VERSION))
Expand Down Expand Up @@ -551,8 +557,8 @@ def create_model(
If not specified and ``endpoint_type`` is 'tensorflow-serving', ``dependencies`` is
set to ``None``.
If ``endpoint_type`` is also ``None``, then the dependencies from training are used.
**kwargs: Additional kwargs passed to ``sagemaker.tensorflow.serving.Model`` constructor
and ``sagemaker.tensorflow.model.TensorFlowModel`` constructor.
**kwargs: Additional kwargs passed to :class:`~sagemaker.tensorflow.serving.Model`
and :class:`~sagemaker.tensorflow.model.TensorFlowModel` constructors.

Returns:
sagemaker.tensorflow.model.TensorFlowModel or sagemaker.tensorflow.serving.Model: A
Expand Down
8 changes: 7 additions & 1 deletion src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand Down Expand Up @@ -94,6 +94,12 @@ def __init__(
worker per vCPU.
**kwargs: Keyword arguments passed to the ``FrameworkModel``
initializer.

.. tip::

You can find additional parameters for initializing this class at
:class:`~sagemaker.model.FrameworkModel` and
:class:`~sagemaker.model.Model`.
"""
super(TensorFlowModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
Expand Down
8 changes: 7 additions & 1 deletion src/sagemaker/tensorflow/serving.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand Down Expand Up @@ -165,6 +165,12 @@ def __init__(
SageMaker ``Session``. If specified, ``deploy()`` returns the
result of invoking this function on the created endpoint name.
**kwargs: Keyword arguments passed to the ``Model`` initializer.

.. tip::

You can find additional parameters for initializing this class at
:class:`~sagemaker.model.FrameworkModel` and
:class:`~sagemaker.model.Model`.
"""
super(Model, self).__init__(
model_data=model_data,
Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/xgboost/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ def __init__(
custom-image:latest.
**kwargs: Additional kwargs passed to the
:class:`~sagemaker.estimator.Framework` constructor.

.. tip::

You can find additional parameters for initializing this class at
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
super(XGBoost, self).__init__(
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
Expand Down
8 changes: 7 additions & 1 deletion src/sagemaker/xgboost/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand Down Expand Up @@ -87,6 +87,12 @@ def __init__(
model_server_workers (int): Optional. The number of worker processes used by the
inference server. If None, server will use one worker per vCPU.
**kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer.

.. tip::

You can find additional parameters for initializing this class at
:class:`~sagemaker.model.FrameworkModel` and
:class:`~sagemaker.model.Model`.
"""
super(XGBoostModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
Expand Down