You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: .github/PULL_REQUEST_TEMPLATE.md
+1-1
Original file line number
Diff line number
Diff line change
@@ -12,7 +12,7 @@ _Put an `x` in the boxes that apply. You can also fill these out after creating
12
12
13
13
-[ ] I have read the [CONTRIBUTING](https://github.com/aws/sagemaker-python-sdk/blob/master/CONTRIBUTING.md) doc
14
14
-[ ] I used the commit message format described in [CONTRIBUTING](https://github.com/aws/sagemaker-python-sdk/blob/master/CONTRIBUTING.md#committing-your-change)
15
-
-[ ] I have used the regional endpoint when creating S3 and/or STS clients (if appropriate)
15
+
-[ ] I have passed the region in to any/all clients that I've initialized as part of this change.
16
16
-[ ] I have updated any necessary documentation, including [READMEs](https://github.com/aws/sagemaker-python-sdk/blob/master/README.rst) and [API docs](https://github.com/aws/sagemaker-python-sdk/tree/master/doc) (if appropriate)
Copy file name to clipboardExpand all lines: doc/using_pytorch.rst
+33-3
Original file line number
Diff line number
Diff line change
@@ -118,10 +118,22 @@ to a certain filesystem path called ``model_dir``. This value is accessible thro
118
118
After your training job is complete, SageMaker will compress and upload the serialized model to S3, and your model data
119
119
will be available in the S3 ``output_path`` you specified when you created the PyTorch Estimator.
120
120
121
+
If you are using Elastic Inference, you must convert your models to the TorchScript format and use ``torch.jit.save`` to save the model.
122
+
For example:
123
+
124
+
.. code:: python
125
+
126
+
import os
127
+
import torch
128
+
129
+
# ... train `model`, then save it to `model_dir`
130
+
model_dir = os.path.join(model_dir, "model.pt")
131
+
torch.jit.save(model, model_dir)
132
+
121
133
Using third-party libraries
122
134
---------------------------
123
135
124
-
When running your training script on SageMaker, it will have access to some pre-installed third-party libraries including ``torch``, ``torchvisopm``, and ``numpy``.
136
+
When running your training script on SageMaker, it will have access to some pre-installed third-party libraries including ``torch``, ``torchvision``, and ``numpy``.
125
137
For more information on the runtime environment, including specific package versions, see `SageMaker PyTorch Docker containers <#id4>`__.
126
138
127
139
If there are other packages you want to use with your script, you can include a ``requirements.txt`` file in the same directory as your training script to install other dependencies at runtime. Both ``requirements.txt`` and your training script should be put in the same folder. You must specify this folder in ``source_dir`` argument when creating PyTorch estimator.
@@ -303,7 +315,8 @@ It loads the model parameters from a ``model.pth`` file in the SageMaker model d
303
315
304
316
However, if you are using PyTorch Elastic Inference, you do not have to provide a ``model_fn`` since the PyTorch serving
305
317
container has a default one for you. But please note that if you are utilizing the default ``model_fn``, please save
306
-
yor parameter file as ``model.pt`` instead of ``model.pth``. For more information on inference script, please refer to:
318
+
your ScriptModule as ``model.pt``. If you are implementing your own ``model_fn``, please use TorchScript and ``torch.jit.save``
319
+
to save your ScriptModule, then load it in your ``model_fn`` with ``torch.jit.load``. For more information on inference script, please refer to:
@@ -461,6 +474,23 @@ If you implement your own prediction function, you should take care to ensure th
461
474
first argument to ``output_fn``. If you use the default
462
475
``output_fn``, this should be a torch.Tensor.
463
476
477
+
The default Elastic Inference ``predict_fn`` is similar but runs the TorchScript model using ``torch.jit.optimized_execution``.
478
+
If you are implementing your own ``predict_fn``, please also use the ``torch.jit.optimized_execution``
479
+
block, for example:
480
+
481
+
.. code:: python
482
+
483
+
import torch
484
+
import numpy as np
485
+
486
+
defpredict_fn(input_data, model):
487
+
device = torch.device("cpu")
488
+
model = model.to(device)
489
+
input_data = data.to(device)
490
+
model.eval()
491
+
with torch.jit.optimized_execution(True, {"target_device": "eia:0"}):
492
+
output = model(input_data)
493
+
464
494
Process Model Output
465
495
^^^^^^^^^^^^^^^^^^^^
466
496
@@ -671,6 +701,6 @@ The following are optional arguments. When you create a ``PyTorch`` object, you
671
701
SageMaker PyTorch Docker Containers
672
702
***********************************
673
703
674
-
For information about SageMaker PyTorch containers, see `the SageMaker PyTorch containers repository <https://github.com/aws/sagemaker-pytorch-container>`_.
704
+
For information about SageMaker PyTorch containers, see `the SageMaker PyTorch container repository <https://github.com/aws/sagemaker-pytorch-container>`_ and `SageMaker PyTorch Serving container repository <https://github.com/aws/sagemaker-pytorch-serving-container>`__.
675
705
676
706
For information about SageMaker PyTorch container dependencies, see `SageMaker PyTorch Containers <https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/pytorch#sagemaker-pytorch-docker-containers>`_.
- ``hyperparameters (dict[str, ANY])`` Hyperparameters that will be used for training.
956
-
Will be made accessible as command line arguments.
957
-
- ``train_volume_size (int)`` Size in GB of the EBS volume to use for storing
958
-
input data during training. Must be large enough to the store training
959
-
data.
960
-
- ``train_max_run (int)`` Timeout in seconds for training, after which Amazon
961
-
SageMaker terminates the job regardless of its current status.
962
-
- ``output_path (str)`` S3 location where you want the training result (model
963
-
artifacts and optional output files) saved. If not specified, results
964
-
are stored to a default bucket. If the bucket with the specific name
965
-
does not exist, the estimator creates the bucket during the ``fit``
966
-
method execution.
967
-
- ``output_kms_key`` Optional KMS key ID to optionally encrypt training
968
-
output with.
969
-
- ``base_job_name`` Name to assign for the training job that the ``fit``
970
-
method launches. If not specified, the estimator generates a default
971
-
job name, based on the training image name and current timestamp.
972
-
- ``image_name`` An alternative docker image to use for training and
973
-
serving. If specified, the estimator will use this image for training and
974
-
hosting, instead of selecting the appropriate SageMaker official image based on
975
-
``framework_version`` and ``py_version``. Refer to: `SageMaker TensorFlow Docker containers <https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/tensorflow#sagemaker-tensorflow-docker-containers>`_ for details on what the official images support
976
-
and where to find the source code to build your custom image.
977
-
- ``script_mode (bool)`` Whether to use Script Mode or not. Script mode is the only available training mode in Python 3,
978
-
setting ``py_version`` to ``py3`` automatically sets ``script_mode`` to True.
979
-
- ``model_dir (str)`` Location where model data, checkpoint data, and TensorBoard checkpoints should be saved during training.
980
-
If not specified a S3 location will be generated under the training job's default bucket. And ``model_dir`` will be
981
-
passed in your training script as one of the command line arguments.
982
-
- ``distributions (dict)`` Configure your distribution strategy with this argument.
916
+
For information about the different TensorFlow-related classes in the SageMaker Python SDK, see https://sagemaker.readthedocs.io/en/stable/sagemaker.tensorflow.html.
0 commit comments