14
14
from __future__ import absolute_import
15
15
16
16
import logging
17
+ from typing import Optional , Union , List , Dict
17
18
18
19
import sagemaker
19
- from sagemaker import image_uris
20
+ from sagemaker import image_uris , ModelMetrics
20
21
from sagemaker .deserializers import JSONDeserializer
22
+ from sagemaker .drift_check_baselines import DriftCheckBaselines
21
23
from sagemaker .fw_utils import (
22
24
model_code_key_prefix ,
23
25
validate_version_or_image_args ,
24
26
)
27
+ from sagemaker .metadata_properties import MetadataProperties
25
28
from sagemaker .model import FrameworkModel , MODEL_SERVER_WORKERS_PARAM_NAME
26
29
from sagemaker .predictor import Predictor
27
30
from sagemaker .serializers import JSONSerializer
28
31
from sagemaker .session import Session
32
+ from sagemaker .utils import to_string
33
+ from sagemaker .workflow import is_pipeline_variable
34
+ from sagemaker .workflow .entities import PipelineVariable
29
35
30
36
logger = logging .getLogger ("sagemaker" )
31
37
@@ -100,16 +106,16 @@ class HuggingFaceModel(FrameworkModel):
100
106
101
107
def __init__ (
102
108
self ,
103
- role ,
104
- model_data = None ,
105
- entry_point = None ,
106
- transformers_version = None ,
107
- tensorflow_version = None ,
108
- pytorch_version = None ,
109
- py_version = None ,
110
- image_uri = None ,
111
- predictor_cls = HuggingFacePredictor ,
112
- model_server_workers = None ,
109
+ role : str ,
110
+ model_data : Optional [ Union [ str , PipelineVariable ]] = None ,
111
+ entry_point : Optional [ str ] = None ,
112
+ transformers_version : Optional [ str ] = None ,
113
+ tensorflow_version : Optional [ str ] = None ,
114
+ pytorch_version : Optional [ str ] = None ,
115
+ py_version : Optional [ str ] = None ,
116
+ image_uri : Optional [ Union [ str , PipelineVariable ]] = None ,
117
+ predictor_cls : callable = HuggingFacePredictor ,
118
+ model_server_workers : Optional [ Union [ int , PipelineVariable ]] = None ,
113
119
** kwargs ,
114
120
):
115
121
"""Initialize a HuggingFaceModel.
@@ -299,27 +305,27 @@ def deploy(
299
305
300
306
def register (
301
307
self ,
302
- content_types ,
303
- response_types ,
304
- inference_instances = None ,
305
- transform_instances = None ,
306
- model_package_name = None ,
307
- model_package_group_name = None ,
308
- image_uri = None ,
309
- model_metrics = None ,
310
- metadata_properties = None ,
311
- marketplace_cert = False ,
312
- approval_status = None ,
313
- description = None ,
314
- drift_check_baselines = None ,
315
- customer_metadata_properties = None ,
316
- domain = None ,
317
- sample_payload_url = None ,
318
- task = None ,
319
- framework = None ,
320
- framework_version = None ,
321
- nearest_model_name = None ,
322
- data_input_configuration = None ,
308
+ content_types : List [ Union [ str , PipelineVariable ]] ,
309
+ response_types : List [ Union [ str , PipelineVariable ]] ,
310
+ inference_instances : Optional [ List [ Union [ str , PipelineVariable ]]] = None ,
311
+ transform_instances : Optional [ List [ Union [ str , PipelineVariable ]]] = None ,
312
+ model_package_name : Optional [ Union [ str , PipelineVariable ]] = None ,
313
+ model_package_group_name : Optional [ Union [ str , PipelineVariable ]] = None ,
314
+ image_uri : Optional [ Union [ str , PipelineVariable ]] = None ,
315
+ model_metrics : Optional [ ModelMetrics ] = None ,
316
+ metadata_properties : Optional [ MetadataProperties ] = None ,
317
+ marketplace_cert : bool = False ,
318
+ approval_status : Optional [ Union [ str , PipelineVariable ]] = None ,
319
+ description : Optional [ str ] = None ,
320
+ drift_check_baselines : Optional [ DriftCheckBaselines ] = None ,
321
+ customer_metadata_properties : Optional [ Dict [ str , Union [ str , PipelineVariable ]]] = None ,
322
+ domain : Optional [ Union [ str , PipelineVariable ]] = None ,
323
+ sample_payload_url : Optional [ Union [ str , PipelineVariable ]] = None ,
324
+ task : Optional [ Union [ str , PipelineVariable ]] = None ,
325
+ framework : Optional [ Union [ str , PipelineVariable ]] = None ,
326
+ framework_version : Optional [ Union [ str , PipelineVariable ]] = None ,
327
+ nearest_model_name : Optional [ Union [ str , PipelineVariable ]] = None ,
328
+ data_input_configuration : Optional [ Union [ str , PipelineVariable ]] = None ,
323
329
):
324
330
"""Creates a model package for creating SageMaker models or listing on Marketplace.
325
331
@@ -377,6 +383,13 @@ def register(
377
383
region_name = self .sagemaker_session .boto_session .region_name ,
378
384
instance_type = instance_type ,
379
385
)
386
+ if not is_pipeline_variable (framework ):
387
+ framework = (
388
+ framework
389
+ or fetch_framework_and_framework_version (
390
+ self .tensorflow_version , self .pytorch_version
391
+ )[0 ]
392
+ ).upper ()
380
393
return super (HuggingFaceModel , self ).register (
381
394
content_types ,
382
395
response_types ,
@@ -395,12 +408,7 @@ def register(
395
408
domain = domain ,
396
409
sample_payload_url = sample_payload_url ,
397
410
task = task ,
398
- framework = (
399
- framework
400
- or fetch_framework_and_framework_version (
401
- self .tensorflow_version , self .pytorch_version
402
- )[0 ]
403
- ).upper (),
411
+ framework = framework ,
404
412
framework_version = framework_version
405
413
or fetch_framework_and_framework_version (self .tensorflow_version , self .pytorch_version )[
406
414
1
@@ -449,7 +457,9 @@ def prepare_container_def(
449
457
deploy_env .update (self ._script_mode_env_vars ())
450
458
451
459
if self .model_server_workers :
452
- deploy_env [MODEL_SERVER_WORKERS_PARAM_NAME .upper ()] = str (self .model_server_workers )
460
+ deploy_env [MODEL_SERVER_WORKERS_PARAM_NAME .upper ()] = to_string (
461
+ self .model_server_workers
462
+ )
453
463
return sagemaker .container_def (
454
464
deploy_image , self .repacked_model_data or self .model_data , deploy_env
455
465
)
0 commit comments