Skip to content

Commit f53f3c6

Browse files
authored
Merge branch 'master' into master
2 parents ee202a8 + a1b6c64 commit f53f3c6

File tree

20 files changed

+1278
-80
lines changed

20 files changed

+1278
-80
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
# Changelog
22

3+
## v2.48.0 (2021-07-07)
4+
5+
### Features
6+
7+
* HuggingFace Inference
8+
9+
### Bug Fixes and Other Changes
10+
11+
* add support for SageMaker workflow tuning step
12+
313
## v2.47.2.post0 (2021-07-01)
414

515
### Documentation Changes

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.47.3.dev0
1+
2.48.1.dev0

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ Steps
115115

116116
.. autoclass:: sagemaker.workflow.steps.ProcessingStep
117117

118+
.. autoclass:: sagemaker.workflow.steps.TuningStep
119+
118120
Utilities
119121
---------
120122

src/sagemaker/huggingface/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.huggingface.estimator import HuggingFace # noqa: F401
17+
from sagemaker.huggingface.model import HuggingFaceModel, HuggingFacePredictor # noqa: F401

src/sagemaker/huggingface/estimator.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
warn_if_parameter_server_with_multi_gpu,
2424
validate_smdistributed,
2525
)
26+
from sagemaker.huggingface.model import HuggingFaceModel
2627
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2728

2829
logger = logging.getLogger("sagemaker")
@@ -233,8 +234,58 @@ def create_model(
233234
dependencies=None,
234235
**kwargs
235236
):
236-
"""Placeholder docstring"""
237-
raise NotImplementedError("Creating model with HuggingFace training job is not supported.")
237+
"""Create a SageMaker ``HuggingFaceModel`` object that can be deployed to an ``Endpoint``.
238+
239+
Args:
240+
model_server_workers (int): Optional. The number of worker processes
241+
used by the inference server. If None, server will use one
242+
worker per vCPU.
243+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
244+
which is also used during transform jobs. If not specified, the
245+
role from the Estimator will be used.
246+
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
247+
the model. Default: use subnets and security groups from this Estimator.
248+
* 'Subnets' (list[str]): List of subnet ids.
249+
* 'SecurityGroupIds' (list[str]): List of security group ids.
250+
entry_point (str): Path (absolute or relative) to the local Python source file which
251+
should be executed as the entry point to training. If ``source_dir`` is specified,
252+
then ``entry_point`` must point to a file located at the root of ``source_dir``.
253+
Defaults to `None`.
254+
source_dir (str): Path (absolute or relative) to a directory with any other serving
255+
source code dependencies aside from the entry point file.
256+
If not specified, the model source directory from training is used.
257+
dependencies (list[str]): A list of paths to directories (absolute or relative) with
258+
any additional libraries that will be exported to the container.
259+
If not specified, the dependencies from training are used.
260+
This is not supported with "local code" in Local Mode.
261+
**kwargs: Additional kwargs passed to the :class:`~sagemaker.huggingface.model.HuggingFaceModel`
262+
constructor.
263+
Returns:
264+
sagemaker.huggingface.model.HuggingFaceModel: A SageMaker ``HuggingFaceModel``
265+
object. See :func:`~sagemaker.huggingface.model.HuggingFaceModel` for full details.
266+
"""
267+
if "image_uri" not in kwargs:
268+
kwargs["image_uri"] = self.image_uri
269+
270+
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
271+
272+
return HuggingFaceModel(
273+
role or self.role,
274+
model_data=self.model_data,
275+
entry_point=entry_point,
276+
transformers_version=self.framework_version,
277+
tensorflow_version=self.tensorflow_version,
278+
pytorch_version=self.pytorch_version,
279+
py_version=self.py_version,
280+
source_dir=(source_dir or self._model_source_dir()),
281+
container_log_level=self.container_log_level,
282+
code_location=self.code_location,
283+
model_server_workers=model_server_workers,
284+
sagemaker_session=self.sagemaker_session,
285+
vpc_config=self.get_vpc_config(vpc_config_override),
286+
dependencies=(dependencies or self.dependencies),
287+
**kwargs
288+
)
238289

239290
@classmethod
240291
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):

0 commit comments

Comments
 (0)