Skip to content

Commit 975def8

Browse files
committed
chore: update integ test
1 parent 1da64e7 commit 975def8

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

src/sagemaker/jumpstart/artifacts.py

+3
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,9 @@ def _retrieve_default_instance_type(
437437
specified region due to lack of supported computing instances.
438438
"""
439439

440+
if region is None:
441+
region = JUMPSTART_DEFAULT_REGION_NAME
442+
440443
model_specs = verify_model_region_and_return_specs(
441444
model_id=model_id,
442445
version=model_version,

tests/integ/sagemaker/jumpstart/script_mode_class/test_inference.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from __future__ import absolute_import
1414
import os
1515

16-
from sagemaker import image_uris, model_uris, script_uris
16+
from sagemaker import image_uris, instance_types, model_uris, script_uris
1717
from sagemaker.jumpstart.constants import INFERENCE_ENTRY_POINT_SCRIPT_NAME
1818
from sagemaker.model import Model
1919
from tests.integ.sagemaker.jumpstart.constants import (
@@ -31,8 +31,12 @@
3131

3232
def test_jumpstart_inference_model_class(setup):
3333

34-
model_id, model_version = "catboost-classification-model", "1.0.0"
35-
instance_type, instance_count = "ml.m5.xlarge", 1
34+
model_id, model_version = "catboost-classification-model", "1.2.7"
35+
36+
instance_type = instance_types.retrieve_default(
37+
model_id=model_id, model_version=model_version, scope="inference"
38+
)
39+
instance_count = 1
3640

3741
print("Starting inference...")
3842

tests/integ/sagemaker/jumpstart/script_mode_class/test_transfer_learning.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@
1313
from __future__ import absolute_import
1414
import os
1515

16-
from sagemaker import hyperparameters, metric_definitions, image_uris, model_uris, script_uris
16+
from sagemaker import (
17+
hyperparameters,
18+
instance_types,
19+
metric_definitions,
20+
image_uris,
21+
model_uris,
22+
script_uris,
23+
)
1724
from sagemaker.estimator import Estimator
1825
from sagemaker.jumpstart.constants import (
1926
INFERENCE_ENTRY_POINT_SCRIPT_NAME,
@@ -36,8 +43,13 @@
3643
def test_jumpstart_transfer_learning_estimator_class(setup):
3744

3845
model_id, model_version = "huggingface-spc-bert-base-cased", "1.2.3"
39-
training_instance_type = "ml.p3.2xlarge"
40-
inference_instance_type = "ml.p2.xlarge"
46+
47+
inference_instance_type = instance_types.retrieve_default(
48+
model_id=model_id, model_version=model_version, scope="inference"
49+
)
50+
training_instance_type = instance_types.retrieve_default(
51+
model_id=model_id, model_version=model_version, scope="training"
52+
)
4153
instance_count = 1
4254

4355
print("Starting training...")

0 commit comments

Comments
 (0)