Skip to content

Commit 32aaa73

Browse files
author
Edward J Kim
authored
doc: add xgboost documentation for inference (#1659)
1 parent 3624364 commit 32aaa73

File tree

1 file changed

+216
-13
lines changed

1 file changed

+216
-13
lines changed

doc/frameworks/xgboost/using_xgboost.rst

+216-13
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ and a dictionary of the hyperparameters to pass to the training script.
161161
role=role,
162162
train_instance_count=1,
163163
train_instance_type="ml.m5.2xlarge",
164-
framework_version="0.90-1",
164+
framework_version="1.0-1",
165165
)
166166
167167
@@ -179,24 +179,227 @@ After you create an estimator, call the ``fit`` method to run the training job.
179179
Deploy Open Source XGBoost Models
180180
=================================
181181

182-
After the training job finishes, call the ``deploy`` method of the estimator to create a predictor that you can use to get inferences from your trained model.
182+
After you fit an XGBoost Estimator, you can host the newly created model in SageMaker.
183+
184+
After you call ``fit``, you can call ``deploy`` on an ``XGBoost`` estimator to create a SageMaker endpoint.
185+
The endpoint runs a SageMaker-provided XGBoost model server and hosts the model produced by your training script,
186+
which was run when you called ``fit``. This was the model you saved to ``model_dir``.
187+
188+
``deploy`` returns a ``Predictor`` object, which you can use to do inference on the Endpoint hosting your XGBoost model.
189+
Each ``Predictor`` provides a ``predict`` method which can do inference with numpy arrays, Python lists, or strings.
190+
After inference arrays or lists are serialized and sent to the XGBoost model server, ``predict`` returns the result of
191+
inference against your model.
183192

184193
.. code::
185194
186-
predictor = xgb_script_mode_estimator.deploy(initial_instance_count=1, instance_type="ml.m5.xlarge")
187-
test_data = xgboost.DMatrix('/path/to/data')
188-
predictor.predict(test_data)
195+
predictor = estimator.deploy(
196+
initial_instance_count=1,
197+
instance_type="ml.m5.xlarge"
198+
)
199+
predictor.serializer = str
200+
predictor.content_type = "text/libsvm"
189201
190-
Customize inference
191-
-------------------
202+
with open("abalone") as f:
203+
payload = f.read()
204+
205+
predictor.predict(payload)
206+
207+
SageMaker XGBoost Model Server
208+
-----------------------------------
209+
210+
You can configure two components of the SageMaker XGBoost model server: Model loading and model serving.
211+
Model loading is the process of deserializing your saved model back into an XGBoost model.
212+
Model serving is the process of translating endpoint requests to inference calls on the loaded model.
213+
214+
You configure the XGBoost model server by defining functions in the Python source file you passed to the XGBoost constructor.
215+
216+
Load a Model
217+
^^^^^^^^^^^^
218+
219+
Before a model can be served, it must be loaded. The SageMaker XGBoost model server loads your model by invoking a
220+
``model_fn`` function that you must provide in your script. The ``model_fn`` should have the following signature:
221+
222+
.. code:: python
223+
224+
def model_fn(model_dir)
225+
226+
SageMaker will inject the directory where your model files and sub-directories, saved by ``save``, have been mounted.
227+
Your model function should return a ``xgboost.Booster`` object that can be used for model serving.
228+
229+
The following code-snippet shows an example ``model_fn`` implementation.
230+
It loads and returns a pickled XGBoost model from a ``xgboost-model`` file in the SageMaker model directory ``model_dir``.
231+
232+
.. code:: python
233+
234+
import pickle as pkl
235+
236+
def model_fn(model_dir):
237+
with open(os.path.join(model_dir, "xgboost-model"), "rb") as f:
238+
booster = pkl.load(f)
239+
return booster
240+
241+
Serve a Model
242+
^^^^^^^^^^^^^
243+
244+
After the SageMaker model server has loaded your model by calling ``model_fn``, SageMaker will serve your model.
245+
The SageMaker Scikit-learn model server breaks request handling into three steps:
246+
247+
- input processing,
248+
- prediction, and
249+
- output processing.
250+
251+
In a similar way to model loading, you can customize the inference behavior by defining functions in your inference
252+
script, which can be either in the same file as your training script or in a separate file,
253+
254+
Each step involves invoking a python function, with information about the request and the return-value from the previous
255+
function in the chain.
256+
Inside the SageMaker XGBoost model server, the process looks like:
257+
258+
.. code:: python
259+
260+
# Deserialize the Invoke request body into an object we can perform prediction on
261+
input_object = input_fn(request_body, request_content_type)
262+
263+
# Perform prediction on the deserialized object, with the loaded model
264+
prediction = predict_fn(input_object, model)
265+
266+
# Serialize the prediction result into the desired response content type
267+
output = output_fn(prediction, response_content_type)
268+
269+
The above code-sample shows the three function definitions:
270+
271+
- ``input_fn``: Takes request data and deserializes the data into an object for prediction.
272+
- ``predict_fn``: Takes the deserialized request object and performs inference against the loaded model.
273+
- ``output_fn``: Takes the result of prediction and serializes this according to the response content type.
274+
275+
These functions are optional.
276+
The SageMaker XGBoost model server provides default implementations of these functions.
277+
You can provide your own implementations for these functions in your hosting script.
278+
If you omit any definition then the SageMaker XGBoost model server will use its default implementation for that
279+
function.
280+
281+
In the following sections we describe the default implementations of ``input_fn``, ``predict_fn``, and ``output_fn``.
282+
We describe the input arguments and expected return types of each, so you can define your own implementations.
283+
284+
Process Input
285+
"""""""""""""
286+
287+
When a request is made against an endpoint running a SageMaker XGBoost model server, the model server receives two
288+
pieces of information:
289+
290+
- The request Content-Type, for example "application/x-npy" or "text/libsvm"
291+
- The request data body, a byte array
292+
293+
The SageMaker XGBoost model server will invoke an ``input_fn`` function in your inference script, passing in this
294+
information. If you define an ``input_fn`` function definition, it should return an object that can be passed
295+
to ``predict_fn`` and have the following signature:
296+
297+
.. code:: python
298+
299+
def input_fn(request_body, request_content_type)
300+
301+
where ``request_body`` is a byte buffer and ``request_content_type`` is a Python string.
302+
303+
The SageMaker XGBoost model server provides a default implementation of ``input_fn``.
304+
This function deserializes CSV, LIBSVM, or protobuf recordIO into a ``xgboost.DMatrix``.
305+
306+
Default csv deserialization requires ``request_body`` contain one or more lines of CSV numerical data.
307+
The data is first loaded into a two-dimensional array, where each line break defines the boundaries of the first
308+
dimension, and then it is converted to an `xgboost.Dmatrix`. It assumes that CSV input does not have the
309+
label column.
310+
311+
Default LIBSVM deserialization requires ``request_body`` to follow the `LIBSVM <https://www.csie.ntu.edu.tw/~cjlin/libsvm/>`_ format.
312+
313+
The example below shows a custom ``input_fn`` for preparing pickled NumPy arrays.
314+
315+
.. code:: python
316+
317+
from io import BytesIO
318+
import numpy as np
319+
import xgboost as xgb
320+
321+
def input_fn(request_body, request_content_type):
322+
"""An input_fn that loads a numpy array"""
323+
if request_content_type == "application/npy":
324+
array = np.load(BytesIO(request_body))
325+
return xgb.DMatrix(array)
326+
else:
327+
# Handle other content-types here or raise an Exception
328+
# if the content type is not supported.
329+
pass
330+
331+
Get Predictions
332+
"""""""""""""""
333+
334+
After the inference request has been deserialized by ``input_fn``, the SageMaker XGBoost model server invokes
335+
``predict_fn`` on the return value of ``input_fn``.
336+
337+
As with ``input_fn``, you can define your own ``predict_fn`` or use the SageMaker XGBoost model server default.
338+
339+
The ``predict_fn`` function has the following signature:
340+
341+
.. code:: python
342+
343+
def predict_fn(input_object, model)
344+
345+
Where ``input_object`` is the object returned from ``input_fn`` and ``model`` is the model loaded by ``model_fn``.
346+
347+
The default implementation of ``predict_fn`` invokes the loaded model's ``predict`` function on ``input_object``,
348+
and returns the resulting value. The return-type should be a NumPy array to be compatible with the default
349+
``output_fn``.
350+
351+
The example below shows an overriden ``predict_fn`` that returns a two-dimensional NumPy array where
352+
the first columns are predictions and the remaining columns are the feature contributions
353+
(`SHAP values <https://github.com/slundberg/shap>`_) for that prediction.
354+
When ``pred_contribs`` is ``True`` in ``xgboost.Booster.predict()``, the output will be a matrix of size
355+
(nsample, nfeats + 1) with each record indicating the feature contributions for that prediction.
356+
Note the final column is the bias term.
357+
358+
.. code:: python
359+
360+
import numpy as np
361+
362+
def predict_fn(input_data, model):
363+
prediction = model.predict(input_data)
364+
feature_contribs = model.predict(input_data, pred_contribs=True)
365+
output = np.hstack((prediction[:, np.newaxis], feature_contribs))
366+
return output
367+
368+
If you implement your own prediction function, you should take care to ensure that:
369+
370+
- The first argument is expected to be the return value from input_fn.
371+
- The second argument is the loaded model.
372+
- The return value should be of the correct type to be passed as the first argument to ``output_fn``.
373+
If you use the default ``output_fn``, this should be a NumPy array.
374+
375+
Process Output
376+
""""""""""""""
377+
378+
After invoking ``predict_fn``, the model server invokes ``output_fn``, passing in the return value from
379+
``predict_fn`` and the requested response content-type.
380+
381+
The ``output_fn`` has the following signature:
382+
383+
.. code:: python
384+
385+
def output_fn(prediction, content_type)
386+
387+
``prediction`` is the result of invoking ``predict_fn`` and ``content_type`` is the requested response content-type.
388+
The function should return a byte array of data serialized to ``content_type``.
389+
390+
The default implementation expects ``prediction`` to be a NumPy array and can serialize the result to JSON, CSV, or NPY.
391+
It accepts response content types of "application/json", "text/csv", and "application/x-npy".
192392

193-
In your inference script, which can be either in the same file as your training script or in a separate file,
194-
you can customize the inference behavior by implementing the following functions:
195-
* ``input_fn`` - how input data is handled
196-
* ``predict_fn`` - how the model is invoked
197-
* ``output_fn`` - How the response data is handled
393+
Host Multiple Models with Multi-Model Endpoints
394+
-----------------------------------------------
198395

199-
These functions are optional. If you want to use the default implementations, do not implement them in your training script.
396+
To create an endpoint that can host multiple models, use multi-model endpoints.
397+
Multi-model endpoints are supported in SageMaker XGBoost versions ``0.90-2``, ``1.0-1``, and later.
398+
For information about using multiple XGBoost models with multi-model endpoints, see
399+
`Host Multiple Models with Multi-Model Endpoints <https://docs.aws.amazon.com/sagemaker/latest/dg/multi-model-endpoints.html>`_
400+
in the AWS documentation.
401+
For a sample notebook that uses Amazon SageMaker to deploy multiple XGBoost models to an endpoint, see the
402+
`Multi-Model Endpoint XGBoost Sample Notebook <https://github.com/awslabs/amazon-sagemaker-examples/blob/master/advanced_functionality/multi_model_xgboost_home_value/xgboost_multi_model_endpoint_home_value.ipynb>`_.
200403

201404

202405
*************************

0 commit comments

Comments
 (0)