Skip to content

Commit aa349fd

Browse files
authored
upate with aws master
upate with aws master
2 parents 1c9dd16 + 2d38df9 commit aa349fd

File tree

9 files changed

+116
-105
lines changed

9 files changed

+116
-105
lines changed

doc/using_pytorch.rst

Lines changed: 69 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ Note that SageMaker doesn't support argparse actions. If you want to use, for ex
9090
you need to specify `type` as `bool` in your script and provide an explicit `True` or `False` value for this hyperparameter
9191
when instantiating PyTorch Estimator.
9292

93-
For more on training environment variables, please visit `SageMaker Containers <https://github.com/aws/sagemaker-containers>`_.
93+
For more on training environment variables, see the `SageMaker Training Toolkit <https://github.com/aws/sagemaker-training-toolkit/blob/master/ENVIRONMENT_VARIABLES.md>`_.
9494

9595
Save the Model
9696
--------------
@@ -115,7 +115,7 @@ to a certain filesystem path called ``model_dir``. This value is accessible thro
115115
with open(os.path.join(args.model_dir, 'model.pth'), 'wb') as f:
116116
torch.save(model.state_dict(), f)
117117
118-
After your training job is complete, SageMaker will compress and upload the serialized model to S3, and your model data
118+
After your training job is complete, SageMaker compresses and uploads the serialized model to S3, and your model data
119119
will be available in the S3 ``output_path`` you specified when you created the PyTorch Estimator.
120120

121121
If you are using Elastic Inference, you must convert your models to the TorchScript format and use ``torch.jit.save`` to save the model.
@@ -566,12 +566,76 @@ The function should return a byte array of data serialized to content_type.
566566
The default implementation expects ``prediction`` to be a torch.Tensor and can serialize the result to JSON, CSV, or NPY.
567567
It accepts response content types of "application/json", "text/csv", and "application/x-npy".
568568

569-
Working with Existing Model Data and Training Jobs
570-
==================================================
571569

572-
Attach to existing training jobs
570+
Bring your own model
571+
====================
572+
573+
You can deploy a PyTorch model that you trained outside of SageMaker by using the ``PyTorchModel`` class.
574+
Typically, you save a PyTorch model as a file with extension ``.pt`` or ``.pth``.
575+
To do this, you need to:
576+
577+
* Write an inference script.
578+
* Create the directory structure for your model files.
579+
* Create the ``PyTorchModel`` object.
580+
581+
Write an inference script
582+
-------------------------
583+
584+
You must create an inference script that implements (at least) the ``model_fn`` function that calls the loaded model to get a prediction.
585+
586+
**Note**: If you use elastic inference with PyTorch, you can use the default ``model_fn`` implementation provided in the serving container.
587+
588+
Optionally, you can also implement ``input_fn`` and ``output_fn`` to process input and output,
589+
and ``predict_fn`` to customize how the model server gets predictions from the loaded model.
590+
For information about how to write an inference script, see `Serve a PyTorch Model <#serve-a-pytorch-model>`_.
591+
Save the inference script in the same folder where you saved your PyTorch model.
592+
Pass the filename of the inference script as the ``entry_point`` parameter when you create the ``PyTorchModel`` object.
593+
594+
Create the directory structure for your model files
595+
---------------------------------------------------
596+
597+
You have to create a directory structure and place your model files in the correct location.
598+
The ``PyTorchModel`` constructor packs the files into a ``tar.gz`` file and uploads it to S3.
599+
600+
The directory structure where you saved your PyTorch model should look something like the following:
601+
602+
**Note:** This directory struture is for PyTorch versions 1.2 and higher.
603+
For the directory structure for versions 1.1 and lower,
604+
see `For versions 1.1 and lower <#for-versions-1.1-and-lower>`_.
605+
606+
::
607+
608+
| my_model
609+
| |--model.pth
610+
|
611+
| code
612+
| |--inference.py
613+
| |--requirements.txt
614+
615+
Where ``requirments.txt`` is an optional file that specifies dependencies on third-party libraries.
616+
617+
Create a ``PyTorchModel`` object
573618
--------------------------------
574619

620+
Now call the :class:`sagemaker.pytorch.model.PyTorchModel` constructor to create a model object, and then call its ``deploy()`` method to deploy your model for inference.
621+
622+
.. code:: python
623+
624+
from sagemaker import get_execution_role
625+
role = get_execution_role()
626+
627+
pytorch_model = PyTorchModel(model_data='s3://my-bucket/my-path/model.tar.gz', role=role,
628+
entry_point='inference.py')
629+
630+
predictor = pytorch_model.deploy(instance_type='ml.c4.xlarge', initial_instance_count=1)
631+
632+
633+
Now you can call the ``predict()`` method to get predictions from your deployed model.
634+
635+
***********************************************
636+
Attach an estimator to an existing training job
637+
***********************************************
638+
575639
You can attach a PyTorch Estimator to an existing training job using the
576640
``attach`` method.
577641

@@ -592,69 +656,6 @@ The ``attach`` method accepts the following arguments:
592656
- ``sagemaker_session:`` The Session used
593657
to interact with SageMaker
594658

595-
Deploy Endpoints from model data
596-
--------------------------------
597-
598-
In addition to attaching to existing training jobs, you can deploy models directly from model data in S3.
599-
The following code sample shows how to do this, using the ``PyTorchModel`` class.
600-
601-
.. code:: python
602-
603-
pytorch_model = PyTorchModel(model_data='s3://bucket/model.tar.gz', role='SageMakerRole',
604-
entry_point='transform_script.py')
605-
606-
predictor = pytorch_model.deploy(instance_type='ml.c4.xlarge', initial_instance_count=1)
607-
608-
The PyTorchModel constructor takes the following arguments:
609-
610-
- ``model_dat:`` An S3 location of a SageMaker model data
611-
.tar.gz file
612-
- ``image:`` A Docker image URI
613-
- ``role:`` An IAM role name or Arn for SageMaker to access AWS
614-
resources on your behalf.
615-
- ``predictor_cls:`` A function to
616-
call to create a predictor. If not None, ``deploy`` will return the
617-
result of invoking this function on the created endpoint name
618-
- ``env:`` Environment variables to run with
619-
``image`` when hosted in SageMaker.
620-
- ``name:`` The model name. If None, a default model name will be
621-
selected on each ``deploy.``
622-
- ``entry_point:`` Path (absolute or relative) to the Python file
623-
which should be executed as the entry point to model hosting.
624-
- ``source_dir:`` Optional. Path (absolute or relative) to a
625-
directory with any other training source code dependencies including
626-
the entry point file. Structure within this directory will be
627-
preserved when training on SageMaker.
628-
- ``enable_cloudwatch_metrics:`` Optional. If true, training
629-
and hosting containers will generate Cloudwatch metrics under the
630-
AWS/SageMakerContainer namespace.
631-
- ``container_log_level:`` Log level to use within the container.
632-
Valid values are defined in the Python logging module.
633-
- ``code_location:`` Optional. Name of the S3 bucket where your
634-
custom code will be uploaded to. If not specified, will use the
635-
SageMaker default bucket created by sagemaker.Session.
636-
- ``sagemaker_session:`` The SageMaker Session
637-
object, used for SageMaker interaction
638-
639-
Your model data must be a .tar.gz file in S3. SageMaker Training Job model data is saved to .tar.gz files in S3,
640-
however if you have local data you want to deploy, you can prepare the data yourself.
641-
642-
Assuming you have a local directory containg your model data named "my_model" you can tar and gzip compress the file and
643-
upload to S3 using the following commands:
644-
645-
::
646-
647-
tar -czf model.tar.gz my_model
648-
aws s3 cp model.tar.gz s3://my-bucket/my-path/model.tar.gz
649-
650-
This uploads the contents of my_model to a gzip compressed tar file to S3 in the bucket "my-bucket", with the key
651-
"my-path/model.tar.gz".
652-
653-
To run this command, you'll need the AWS CLI tool installed. Please refer to our `FAQ`_ for more information on
654-
installing this.
655-
656-
.. _FAQ: ../../../README.rst#faq
657-
658659
*************************
659660
PyTorch Training Examples
660661
*************************

src/sagemaker/estimator.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,12 +1478,14 @@ def __init__(
14781478
>>> |----- test.py
14791479
14801480
You can assign entry_point='src/train.py'.
1481-
source_dir (str): Path (absolute, relative, or an S3 URI) to a directory with
1482-
any other training source code dependencies aside from the entry
1483-
point file (default: None). Structure within this directory are
1484-
preserved when training on Amazon SageMaker. If 'git_config' is
1485-
provided, 'source_dir' should be a relative location to a
1486-
directory in the Git repo. .. admonition:: Example
1481+
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
1482+
with any other training source code dependencies aside from the entry
1483+
point file (default: None). If ``source_dir`` is an S3 URI, it must
1484+
point to a tar.gz file. Structure within this directory are preserved
1485+
when training on Amazon SageMaker. If 'git_config' is provided,
1486+
'source_dir' should be a relative location to a directory in the Git
1487+
repo.
1488+
.. admonition:: Example
14871489
14881490
With the following GitHub repo directory structure:
14891491

src/sagemaker/model.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -659,13 +659,14 @@ def __init__(
659659
>>> |----- test.py
660660
661661
You can assign entry_point='src/inference.py'.
662-
source_dir (str): Path (absolute or relative) to a directory with
663-
any other training source code dependencies aside from the entry
664-
point file (default: None). Structure within this directory will
665-
be preserved when training on SageMaker. If 'git_config' is
666-
provided, 'source_dir' should be a relative location to a
667-
directory in the Git repo. If the directory points to S3, no
668-
code will be uploaded and the S3 location will be used instead.
662+
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
663+
with any other training source code dependencies aside from the entry
664+
point file (default: None). If ``source_dir`` is an S3 URI, it must
665+
point to a tar.gz file. Structure within this directory are preserved
666+
when training on Amazon SageMaker. If 'git_config' is provided,
667+
'source_dir' should be a relative location to a directory in the Git repo.
668+
If the directory points to S3, no code will be uploaded and the S3 location
669+
will be used instead.
669670
.. admonition:: Example
670671
671672
With the following GitHub repo directory structure:

src/sagemaker/mxnet/estimator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,11 @@ def __init__(
7171
entry_point (str): Path (absolute or relative) to the Python source
7272
file which should be executed as the entry point to training.
7373
This should be compatible with either Python 2.7 or Python 3.5.
74-
source_dir (str): Path (absolute or relative) to a directory with
75-
any other training source code dependencies aside from the entry
76-
point file (default: None). Structure within this directory are
77-
preserved when training on Amazon SageMaker.
74+
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
75+
with any other training source code dependencies aside from the entry
76+
point file (default: None). If ``source_dir`` is an S3 URI, it must
77+
point to a tar.gz file. Structure within this directory are preserved
78+
when training on Amazon SageMaker.
7879
hyperparameters (dict): Hyperparameters that will be used for
7980
training (default: None). The hyperparameters are made
8081
accessible as a dict[str, str] to the training code on

src/sagemaker/pytorch/estimator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,11 @@ def __init__(
6868
entry_point (str): Path (absolute or relative) to the Python source
6969
file which should be executed as the entry point to training.
7070
This should be compatible with either Python 2.7 or Python 3.5.
71-
source_dir (str): Path (absolute or relative) to a directory with
72-
any other training source code dependencies aside from the entry
73-
point file (default: None). Structure within this directory are
74-
preserved when training on Amazon SageMaker.
71+
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
72+
with any other training source code dependencies aside from the entry
73+
point file (default: None). If ``source_dir`` is an S3 URI, it must
74+
point to a tar.gz file. Structure within this directory are preserved
75+
when training on Amazon SageMaker.
7576
hyperparameters (dict): Hyperparameters that will be used for
7677
training (default: None). The hyperparameters are made
7778
accessible as a dict[str, str] to the training code on

src/sagemaker/rl/estimator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,11 @@ def __init__(
109109
framework (sagemaker.rl.RLFramework): Framework (MXNet or
110110
TensorFlow) you want to be used as a toolkit backed for
111111
reinforcement learning training.
112-
source_dir (str): Path (absolute or relative) to a directory with
113-
any other training source code dependencies aside from the entry
114-
point file (default: None). Structure within this directory is
115-
preserved when training on Amazon SageMaker.
112+
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
113+
with any other training source code dependencies aside from the entry
114+
point file (default: None). If ``source_dir`` is an S3 URI, it must
115+
point to a tar.gz file. Structure within this directory are preserved
116+
when training on Amazon SageMaker.
116117
hyperparameters (dict): Hyperparameters that will be used for
117118
training (default: None). The hyperparameters are made
118119
accessible as a dict[str, str] to the training code on

src/sagemaker/sklearn/estimator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,11 @@ def __init__(
6969
framework_version (str): Scikit-learn version you want to use for
7070
executing your model training code. List of supported versions
7171
https://github.com/aws/sagemaker-python-sdk#sklearn-sagemaker-estimators
72-
source_dir (str): Path (absolute or relative) to a directory with
73-
any other training source code dependencies aside from the entry
74-
point file (default: None). Structure within this directory are
75-
preserved when training on Amazon SageMaker.
72+
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
73+
with any other training source code dependencies aside from the entry
74+
point file (default: None). If ``source_dir`` is an S3 URI, it must
75+
point to a tar.gz file. Structure within this directory are preserved
76+
when training on Amazon SageMaker.
7677
hyperparameters (dict): Hyperparameters that will be used for
7778
training (default: None). The hyperparameters are made
7879
accessible as a dict[str, str] to the training code on

src/sagemaker/tensorflow/estimator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -567,11 +567,12 @@ def create_model(
567567
should be executed as the entry point to training. If not specified and
568568
``endpoint_type`` is 'tensorflow-serving', no entry point is used. If
569569
``endpoint_type`` is also ``None``, then the training entry point is used.
570-
source_dir (str): Path (absolute or relative) to a directory with any other serving
571-
source code dependencies aside from the entry point file. If not specified and
572-
``endpoint_type`` is 'tensorflow-serving', no source_dir is used. If
573-
``endpoint_type`` is also ``None``, then the model source directory from training
574-
is used.
570+
source_dir (str): Path (absolute or relative or an S3 URI ) to a directory with any
571+
other serving source code dependencies aside from the entry point file. If
572+
``source_dir`` is an S3 URI, it must point to a tar.gz file. If not specified
573+
and ``endpoint_type`` is 'tensorflow-serving', no source_dir is used. If
574+
``endpoint_type`` is also ``None``, then the model source directory from
575+
training is used.
575576
dependencies (list[str]): A list of paths to directories (absolute or relative) with
576577
any additional libraries that will be exported to the container.
577578
If not specified and ``endpoint_type`` is 'tensorflow-serving', ``dependencies`` is

src/sagemaker/xgboost/estimator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,11 @@ def __init__(
7575
framework_version (str): XGBoost version you want to use for executing your model
7676
training code. List of supported versions
7777
https://github.com/aws/sagemaker-python-sdk#xgboost-sagemaker-estimators
78-
source_dir (str): Path (absolute or relative) to a directory with any other training
79-
source code dependencies aside from the entry point file (default: None).
80-
Structure within this directory are preserved when training on Amazon SageMaker.
78+
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
79+
with any other training source code dependencies aside from the entry
80+
point file (default: None). If ``source_dir`` is an S3 URI, it must
81+
point to a tar.gz file. Structure within this directory are preserved
82+
when training on Amazon SageMaker.
8183
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
8284
The hyperparameters are made accessible as a dict[str, str] to the training code
8385
on SageMaker. For convenience, this accepts other types for keys and values, but

0 commit comments

Comments
 (0)