Skip to content

Commit 0260228

Browse files
authored
documentation: update PyTorch BYOM topic (#1457)
1 parent 168a478 commit 0260228

File tree

1 file changed

+69
-68
lines changed

1 file changed

+69
-68
lines changed

doc/using_pytorch.rst

+69-68
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ Note that SageMaker doesn't support argparse actions. If you want to use, for ex
9090
you need to specify `type` as `bool` in your script and provide an explicit `True` or `False` value for this hyperparameter
9191
when instantiating PyTorch Estimator.
9292

93-
For more on training environment variables, please visit `SageMaker Containers <https://github.com/aws/sagemaker-containers>`_.
93+
For more on training environment variables, see the `SageMaker Training Toolkit <https://github.com/aws/sagemaker-training-toolkit/blob/master/ENVIRONMENT_VARIABLES.md>`_.
9494

9595
Save the Model
9696
--------------
@@ -115,7 +115,7 @@ to a certain filesystem path called ``model_dir``. This value is accessible thro
115115
with open(os.path.join(args.model_dir, 'model.pth'), 'wb') as f:
116116
torch.save(model.state_dict(), f)
117117
118-
After your training job is complete, SageMaker will compress and upload the serialized model to S3, and your model data
118+
After your training job is complete, SageMaker compresses and uploads the serialized model to S3, and your model data
119119
will be available in the S3 ``output_path`` you specified when you created the PyTorch Estimator.
120120

121121
If you are using Elastic Inference, you must convert your models to the TorchScript format and use ``torch.jit.save`` to save the model.
@@ -566,12 +566,76 @@ The function should return a byte array of data serialized to content_type.
566566
The default implementation expects ``prediction`` to be a torch.Tensor and can serialize the result to JSON, CSV, or NPY.
567567
It accepts response content types of "application/json", "text/csv", and "application/x-npy".
568568

569-
Working with Existing Model Data and Training Jobs
570-
==================================================
571569

572-
Attach to existing training jobs
570+
Bring your own model
571+
====================
572+
573+
You can deploy a PyTorch model that you trained outside of SageMaker by using the ``PyTorchModel`` class.
574+
Typically, you save a PyTorch model as a file with extension ``.pt`` or ``.pth``.
575+
To do this, you need to:
576+
577+
* Write an inference script.
578+
* Create the directory structure for your model files.
579+
* Create the ``PyTorchModel`` object.
580+
581+
Write an inference script
582+
-------------------------
583+
584+
You must create an inference script that implements (at least) the ``model_fn`` function that calls the loaded model to get a prediction.
585+
586+
**Note**: If you use elastic inference with PyTorch, you can use the default ``model_fn`` implementation provided in the serving container.
587+
588+
Optionally, you can also implement ``input_fn`` and ``output_fn`` to process input and output,
589+
and ``predict_fn`` to customize how the model server gets predictions from the loaded model.
590+
For information about how to write an inference script, see `Serve a PyTorch Model <#serve-a-pytorch-model>`_.
591+
Save the inference script in the same folder where you saved your PyTorch model.
592+
Pass the filename of the inference script as the ``entry_point`` parameter when you create the ``PyTorchModel`` object.
593+
594+
Create the directory structure for your model files
595+
---------------------------------------------------
596+
597+
You have to create a directory structure and place your model files in the correct location.
598+
The ``PyTorchModel`` constructor packs the files into a ``tar.gz`` file and uploads it to S3.
599+
600+
The directory structure where you saved your PyTorch model should look something like the following:
601+
602+
**Note:** This directory struture is for PyTorch versions 1.2 and higher.
603+
For the directory structure for versions 1.1 and lower,
604+
see `For versions 1.1 and lower <#for-versions-1.1-and-lower>`_.
605+
606+
::
607+
608+
| my_model
609+
| |--model.pth
610+
|
611+
| code
612+
| |--inference.py
613+
| |--requirements.txt
614+
615+
Where ``requirments.txt`` is an optional file that specifies dependencies on third-party libraries.
616+
617+
Create a ``PyTorchModel`` object
573618
--------------------------------
574619

620+
Now call the :class:`sagemaker.pytorch.model.PyTorchModel` constructor to create a model object, and then call its ``deploy()`` method to deploy your model for inference.
621+
622+
.. code:: python
623+
624+
from sagemaker import get_execution_role
625+
role = get_execution_role()
626+
627+
pytorch_model = PyTorchModel(model_data='s3://my-bucket/my-path/model.tar.gz', role=role,
628+
entry_point='inference.py')
629+
630+
predictor = pytorch_model.deploy(instance_type='ml.c4.xlarge', initial_instance_count=1)
631+
632+
633+
Now you can call the ``predict()`` method to get predictions from your deployed model.
634+
635+
***********************************************
636+
Attach an estimator to an existing training job
637+
***********************************************
638+
575639
You can attach a PyTorch Estimator to an existing training job using the
576640
``attach`` method.
577641

@@ -592,69 +656,6 @@ The ``attach`` method accepts the following arguments:
592656
- ``sagemaker_session:`` The Session used
593657
to interact with SageMaker
594658

595-
Deploy Endpoints from model data
596-
--------------------------------
597-
598-
In addition to attaching to existing training jobs, you can deploy models directly from model data in S3.
599-
The following code sample shows how to do this, using the ``PyTorchModel`` class.
600-
601-
.. code:: python
602-
603-
pytorch_model = PyTorchModel(model_data='s3://bucket/model.tar.gz', role='SageMakerRole',
604-
entry_point='transform_script.py')
605-
606-
predictor = pytorch_model.deploy(instance_type='ml.c4.xlarge', initial_instance_count=1)
607-
608-
The PyTorchModel constructor takes the following arguments:
609-
610-
- ``model_dat:`` An S3 location of a SageMaker model data
611-
.tar.gz file
612-
- ``image:`` A Docker image URI
613-
- ``role:`` An IAM role name or Arn for SageMaker to access AWS
614-
resources on your behalf.
615-
- ``predictor_cls:`` A function to
616-
call to create a predictor. If not None, ``deploy`` will return the
617-
result of invoking this function on the created endpoint name
618-
- ``env:`` Environment variables to run with
619-
``image`` when hosted in SageMaker.
620-
- ``name:`` The model name. If None, a default model name will be
621-
selected on each ``deploy.``
622-
- ``entry_point:`` Path (absolute or relative) to the Python file
623-
which should be executed as the entry point to model hosting.
624-
- ``source_dir:`` Optional. Path (absolute or relative) to a
625-
directory with any other training source code dependencies including
626-
the entry point file. Structure within this directory will be
627-
preserved when training on SageMaker.
628-
- ``enable_cloudwatch_metrics:`` Optional. If true, training
629-
and hosting containers will generate Cloudwatch metrics under the
630-
AWS/SageMakerContainer namespace.
631-
- ``container_log_level:`` Log level to use within the container.
632-
Valid values are defined in the Python logging module.
633-
- ``code_location:`` Optional. Name of the S3 bucket where your
634-
custom code will be uploaded to. If not specified, will use the
635-
SageMaker default bucket created by sagemaker.Session.
636-
- ``sagemaker_session:`` The SageMaker Session
637-
object, used for SageMaker interaction
638-
639-
Your model data must be a .tar.gz file in S3. SageMaker Training Job model data is saved to .tar.gz files in S3,
640-
however if you have local data you want to deploy, you can prepare the data yourself.
641-
642-
Assuming you have a local directory containg your model data named "my_model" you can tar and gzip compress the file and
643-
upload to S3 using the following commands:
644-
645-
::
646-
647-
tar -czf model.tar.gz my_model
648-
aws s3 cp model.tar.gz s3://my-bucket/my-path/model.tar.gz
649-
650-
This uploads the contents of my_model to a gzip compressed tar file to S3 in the bucket "my-bucket", with the key
651-
"my-path/model.tar.gz".
652-
653-
To run this command, you'll need the AWS CLI tool installed. Please refer to our `FAQ`_ for more information on
654-
installing this.
655-
656-
.. _FAQ: ../../../README.rst#faq
657-
658659
*************************
659660
PyTorch Training Examples
660661
*************************

0 commit comments

Comments
 (0)