Skip to content

Commit 0983d5f

Browse files
committed
Merge remote-tracking branch 'refs/remotes/origin/master'
2 parents b78f12d + ff734f8 commit 0983d5f

File tree

7 files changed

+372
-66
lines changed

7 files changed

+372
-66
lines changed

src/sagemaker/jumpstart/artifacts/metric_definitions.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def _retrieve_default_training_metric_definitions(
3434
tolerate_vulnerable_model: bool = False,
3535
tolerate_deprecated_model: bool = False,
3636
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
37+
instance_type: Optional[str] = None,
3738
) -> Optional[List[Dict[str, str]]]:
3839
"""Retrieves the default training metric definitions for the model.
3940
@@ -55,6 +56,8 @@ def _retrieve_default_training_metric_definitions(
5556
object, used for SageMaker interactions. If not
5657
specified, one is created using the default AWS configuration
5758
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
59+
instance_type (str): An instance type to optionally supply in order to get
60+
metric definitions specific for the instance type.
5861
Returns:
5962
list: the default training metric definitions to use for the model or None.
6063
"""
@@ -72,4 +75,29 @@ def _retrieve_default_training_metric_definitions(
7275
sagemaker_session=sagemaker_session,
7376
)
7477

75-
return deepcopy(model_specs.metrics) if model_specs.metrics else None
78+
default_metric_definitions = (
79+
deepcopy(model_specs.metrics) if getattr(model_specs, "metrics") else []
80+
)
81+
82+
instance_specific_metric_definitions = (
83+
model_specs.training_instance_type_variants.get_instance_specific_metric_definitions(
84+
instance_type
85+
)
86+
if instance_type
87+
and getattr(model_specs, "training_instance_type_variants", None) is not None
88+
else []
89+
)
90+
91+
instance_specific_metric_name: str
92+
for instance_specific_metric_definition in instance_specific_metric_definitions:
93+
instance_specific_metric_name = instance_specific_metric_definition["Name"]
94+
default_metric_definitions = list(
95+
filter(
96+
lambda metric_definition: metric_definition["Name"]
97+
!= instance_specific_metric_name,
98+
default_metric_definitions,
99+
)
100+
)
101+
default_metric_definitions.append(instance_specific_metric_definition)
102+
103+
return default_metric_definitions

src/sagemaker/jumpstart/factory/estimator.py

+1
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,7 @@ def _add_metric_definitions_to_kwargs(
633633
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
634634
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
635635
sagemaker_session=kwargs.sagemaker_session,
636+
instance_type=kwargs.instance_type,
636637
)
637638
or []
638639
)

src/sagemaker/jumpstart/types.py

+36
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,42 @@ def to_json(self) -> Dict[str, Any]:
403403
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
404404
return json_obj
405405

406+
def get_instance_specific_metric_definitions(
407+
self, instance_type: str
408+
) -> List[JumpStartHyperparameter]:
409+
"""Returns instance specific metric definitions.
410+
411+
Returns empty list if a model, instance type tuple does not have specific
412+
metric definitions.
413+
"""
414+
415+
if self.variants is None:
416+
return []
417+
418+
instance_specific_metric_definitions: List[Dict[str, Union[str, Any]]] = (
419+
self.variants.get(instance_type, {}).get("properties", {}).get("metrics", [])
420+
)
421+
422+
instance_type_family = get_instance_type_family(instance_type)
423+
424+
instance_family_metric_definitions: List[Dict[str, Union[str, Any]]] = (
425+
self.variants.get(instance_type_family, {}).get("properties", {}).get("metrics", [])
426+
if instance_type_family not in {"", None}
427+
else []
428+
)
429+
430+
instance_specific_metric_names = {
431+
metric_definition["Name"] for metric_definition in instance_specific_metric_definitions
432+
}
433+
434+
metric_definitions_to_return = deepcopy(instance_specific_metric_definitions)
435+
436+
for instance_family_metric_definition in instance_family_metric_definitions:
437+
if instance_family_metric_definition["Name"] not in instance_specific_metric_names:
438+
metric_definitions_to_return.append(instance_family_metric_definition)
439+
440+
return metric_definitions_to_return
441+
406442
def get_instance_specific_prepacked_artifact_key(self, instance_type: str) -> Optional[str]:
407443
"""Returns instance specific model artifact key.
408444

src/sagemaker/metric_definitions.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def retrieve_default(
2929
region: Optional[str] = None,
3030
model_id: Optional[str] = None,
3131
model_version: Optional[str] = None,
32+
instance_type: Optional[str] = None,
3233
tolerate_vulnerable_model: bool = False,
3334
tolerate_deprecated_model: bool = False,
3435
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -42,6 +43,8 @@ def retrieve_default(
4243
retrieve the default training metric definitions. (Default: None).
4344
model_version (str): The version of the model for which to retrieve the
4445
default training metric definitions. (Default: None).
46+
instance_type (str): An instance type to optionally supply in order to get
47+
metric definitions specific for the instance type.
4548
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4649
specifications should be tolerated (exception not raised). If False, raises an
4750
exception if the script used by this version of the model has dependencies with known
@@ -66,10 +69,11 @@ def retrieve_default(
6669
)
6770

6871
return artifacts._retrieve_default_training_metric_definitions(
69-
model_id,
70-
model_version,
71-
region,
72-
tolerate_vulnerable_model,
73-
tolerate_deprecated_model,
72+
model_id=model_id,
73+
model_version=model_version,
74+
instance_type=instance_type,
75+
region=region,
76+
tolerate_vulnerable_model=tolerate_vulnerable_model,
77+
tolerate_deprecated_model=tolerate_deprecated_model,
7478
sagemaker_session=sagemaker_session,
7579
)

tests/unit/sagemaker/jumpstart/constants.py

+102-55
Original file line numberDiff line numberDiff line change
@@ -190,59 +190,8 @@
190190
"framework_version": "1.5.0",
191191
"py_version": "py3",
192192
},
193-
"hosting_instance_type_variants": {
194-
"regional_aliases": {
195-
"us-west-2": {
196-
"gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/"
197-
"huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04",
198-
"cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah",
199-
"inf_model_package_arn": "us-west-2/blah/blah/blah/inf",
200-
"gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu",
201-
}
202-
},
203-
"variants": {
204-
"p2": {
205-
"regional_properties": {
206-
"image_uri": "$gpu_image_uri",
207-
"model_package_arn": "$gpu_model_package_arn",
208-
}
209-
},
210-
"p3": {
211-
"regional_properties": {
212-
"image_uri": "$gpu_image_uri",
213-
"model_package_arn": "$gpu_model_package_arn",
214-
}
215-
},
216-
"p4": {
217-
"regional_properties": {
218-
"image_uri": "$gpu_image_uri",
219-
"model_package_arn": "$gpu_model_package_arn",
220-
}
221-
},
222-
"g4dn": {
223-
"regional_properties": {
224-
"image_uri": "$gpu_image_uri",
225-
"model_package_arn": "$gpu_model_package_arn",
226-
}
227-
},
228-
"m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
229-
"c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
230-
"ml.g5.48xlarge": {
231-
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}}
232-
},
233-
"ml.g5.12xlarge": {
234-
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}}
235-
},
236-
"inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}},
237-
"inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}},
238-
},
239-
},
240-
"training_ecr_specs": {
241-
"framework": "pytorch",
242-
"framework_version": "1.5.0",
243-
"py_version": "py3",
244-
},
245193
"training_instance_type_variants": {
194+
"regional_aliases": {},
246195
"variants": {
247196
"ml.p2.12xlarge": {
248197
"properties": {
@@ -305,9 +254,28 @@
305254
"scope": "algorithm",
306255
},
307256
],
257+
"metrics": [
258+
{
259+
"Name": "huggingface-textgeneration:instance-typemetric-loss",
260+
"Regex": "'eval_loss': ([0-9]+\\.[0-9]+)",
261+
},
262+
{
263+
"Name": "huggingface-textgeneration:eval-loss",
264+
"Regex": "'eval_loss': ([0-9]+\\.[0-9]+)",
265+
},
266+
{
267+
"Name": "huggingface-textgeneration:train-loss",
268+
"Regex": "'instance type specific': ([0-9]+\\.[0-9]+)",
269+
},
270+
{
271+
"Name": "huggingface-textgeneration:noneyourbusiness-loss",
272+
"Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)",
273+
},
274+
],
308275
}
309276
},
310277
"p2": {
278+
"regional_properties": {"image_uri": "$gpu_ecr_uri_2"},
311279
"properties": {
312280
"hyperparameters": [
313281
{
@@ -372,10 +340,80 @@
372340
"default": "20",
373341
"scope": "container",
374342
},
375-
]
343+
],
344+
"metrics": [
345+
{
346+
"Name": "huggingface-textgeneration:wtafigo",
347+
"Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)",
348+
},
349+
{
350+
"Name": "huggingface-textgeneration:eval-loss",
351+
"Regex": "'eval_loss': ([0-9]+\\.[0-9]+)",
352+
},
353+
{
354+
"Name": "huggingface-textgeneration:train-loss",
355+
"Regex": "'instance family specific': ([0-9]+\\.[0-9]+)",
356+
},
357+
{
358+
"Name": "huggingface-textgeneration:noneyourbusiness-loss",
359+
"Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)",
360+
},
361+
],
362+
},
363+
},
364+
},
365+
},
366+
"hosting_instance_type_variants": {
367+
"regional_aliases": {
368+
"us-west-2": {
369+
"gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/"
370+
"huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04",
371+
"cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah",
372+
"inf_model_package_arn": "us-west-2/blah/blah/blah/inf",
373+
"gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu",
374+
}
375+
},
376+
"variants": {
377+
"p2": {
378+
"regional_properties": {
379+
"image_uri": "$gpu_image_uri",
380+
"model_package_arn": "$gpu_model_package_arn",
376381
}
377382
},
378-
}
383+
"p3": {
384+
"regional_properties": {
385+
"image_uri": "$gpu_image_uri",
386+
"model_package_arn": "$gpu_model_package_arn",
387+
}
388+
},
389+
"p4": {
390+
"regional_properties": {
391+
"image_uri": "$gpu_image_uri",
392+
"model_package_arn": "$gpu_model_package_arn",
393+
}
394+
},
395+
"g4dn": {
396+
"regional_properties": {
397+
"image_uri": "$gpu_image_uri",
398+
"model_package_arn": "$gpu_model_package_arn",
399+
}
400+
},
401+
"m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
402+
"c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
403+
"ml.g5.48xlarge": {
404+
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}}
405+
},
406+
"ml.g5.12xlarge": {
407+
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}}
408+
},
409+
"inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}},
410+
"inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}},
411+
},
412+
},
413+
"training_ecr_specs": {
414+
"framework": "pytorch",
415+
"framework_version": "1.5.0",
416+
"py_version": "py3",
379417
},
380418
"hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
381419
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
@@ -515,7 +553,16 @@
515553
"ml.c5.2xlarge",
516554
],
517555
"hosting_use_script_uri": True,
518-
"metrics": [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}],
556+
"metrics": [
557+
{
558+
"Name": "huggingface-textgeneration:train-loss",
559+
"Regex": "'loss default': ([0-9]+\\.[0-9]+)",
560+
},
561+
{
562+
"Name": "huggingface-textgeyyyuyuyuyneration:train-loss",
563+
"Regex": "'loss default': ([0-9]+\\.[0-9]+)",
564+
},
565+
],
519566
"model_kwargs": {"some-model-kwarg-key": "some-model-kwarg-value"},
520567
"deploy_kwargs": {"some-model-deploy-kwarg-key": "some-model-deploy-kwarg-value"},
521568
"estimator_kwargs": {

0 commit comments

Comments
 (0)