|
23 | 23 | warn_if_parameter_server_with_multi_gpu,
|
24 | 24 | validate_smdistributed,
|
25 | 25 | )
|
| 26 | +from sagemaker.huggingface.model import HuggingFaceModel |
26 | 27 | from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
|
27 | 28 |
|
28 | 29 | logger = logging.getLogger("sagemaker")
|
@@ -233,8 +234,58 @@ def create_model(
|
233 | 234 | dependencies=None,
|
234 | 235 | **kwargs
|
235 | 236 | ):
|
236 |
| - """Placeholder docstring""" |
237 |
| - raise NotImplementedError("Creating model with HuggingFace training job is not supported.") |
| 237 | + """Create a SageMaker ``HuggingFaceModel`` object that can be deployed to an ``Endpoint``. |
| 238 | +
|
| 239 | + Args: |
| 240 | + model_server_workers (int): Optional. The number of worker processes |
| 241 | + used by the inference server. If None, server will use one |
| 242 | + worker per vCPU. |
| 243 | + role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, |
| 244 | + which is also used during transform jobs. If not specified, the |
| 245 | + role from the Estimator will be used. |
| 246 | + vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on |
| 247 | + the model. Default: use subnets and security groups from this Estimator. |
| 248 | + * 'Subnets' (list[str]): List of subnet ids. |
| 249 | + * 'SecurityGroupIds' (list[str]): List of security group ids. |
| 250 | + entry_point (str): Path (absolute or relative) to the local Python source file which |
| 251 | + should be executed as the entry point to training. If ``source_dir`` is specified, |
| 252 | + then ``entry_point`` must point to a file located at the root of ``source_dir``. |
| 253 | + Defaults to `None`. |
| 254 | + source_dir (str): Path (absolute or relative) to a directory with any other serving |
| 255 | + source code dependencies aside from the entry point file. |
| 256 | + If not specified, the model source directory from training is used. |
| 257 | + dependencies (list[str]): A list of paths to directories (absolute or relative) with |
| 258 | + any additional libraries that will be exported to the container. |
| 259 | + If not specified, the dependencies from training are used. |
| 260 | + This is not supported with "local code" in Local Mode. |
| 261 | + **kwargs: Additional kwargs passed to the :class:`~sagemaker.huggingface.model.HuggingFaceModel` |
| 262 | + constructor. |
| 263 | + Returns: |
| 264 | + sagemaker.huggingface.model.HuggingFaceModel: A SageMaker ``HuggingFaceModel`` |
| 265 | + object. See :func:`~sagemaker.huggingface.model.HuggingFaceModel` for full details. |
| 266 | + """ |
| 267 | + if "image_uri" not in kwargs: |
| 268 | + kwargs["image_uri"] = self.image_uri |
| 269 | + |
| 270 | + kwargs["name"] = self._get_or_create_name(kwargs.get("name")) |
| 271 | + |
| 272 | + return HuggingFaceModel( |
| 273 | + role or self.role, |
| 274 | + model_data=self.model_data, |
| 275 | + entry_point=entry_point, |
| 276 | + transformers_version=self.framework_version, |
| 277 | + tensorflow_version=self.tensorflow_version, |
| 278 | + pytorch_version=self.pytorch_version, |
| 279 | + py_version=self.py_version, |
| 280 | + source_dir=(source_dir or self._model_source_dir()), |
| 281 | + container_log_level=self.container_log_level, |
| 282 | + code_location=self.code_location, |
| 283 | + model_server_workers=model_server_workers, |
| 284 | + sagemaker_session=self.sagemaker_session, |
| 285 | + vpc_config=self.get_vpc_config(vpc_config_override), |
| 286 | + dependencies=(dependencies or self.dependencies), |
| 287 | + **kwargs |
| 288 | + ) |
238 | 289 |
|
239 | 290 | @classmethod
|
240 | 291 | def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
|
|
0 commit comments