Skip to content

Commit c2ac5f1

Browse files
jessicazhu3navinsoni
authored andcommitted
fix: image_uri does not need to be specified with instance_groups
1 parent 28434fe commit c2ac5f1

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed

src/sagemaker/estimator.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import logging
1818
import os
19+
import re
1920
import uuid
2021
from abc import ABCMeta, abstractmethod
2122
from typing import Any, Dict, Union, Optional, List
@@ -1520,6 +1521,39 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
15201521
init_params["max_wait"] = max_wait
15211522
return init_params
15221523

1524+
def _get_instance_type(self):
1525+
"""Determine the instance type to be used in the training_image_uri function.
1526+
1527+
Returns:
1528+
instance_type: The instance_type to be used.
1529+
"""
1530+
if self.instance_type is not None:
1531+
return self.instance_type
1532+
1533+
if not isinstance(self.instance_groups, list) or len(self.instance_groups) == 0:
1534+
raise ValueError(
1535+
"instance_groups must be set if instance_type is not set and instance_groups "
1536+
"must be a list."
1537+
)
1538+
1539+
for instance_group in self.instance_groups:
1540+
instance_type = instance_group.instance_type
1541+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1542+
1543+
if match:
1544+
family = match[1]
1545+
if family[0] in ("g", "p"):
1546+
return instance_type
1547+
else:
1548+
raise ValueError(
1549+
"Invalid SageMaker instance type for training with heterogeneous clusters: {}. "
1550+
"For options see: https://aws.amazon.com/sagemaker/pricing/instance-types".format(
1551+
instance_type
1552+
)
1553+
)
1554+
1555+
return self.instance_groups[0].instance_type
1556+
15231557
def transformer(
15241558
self,
15251559
instance_count,
@@ -2903,7 +2937,7 @@ def training_image_uri(self, region=None):
29032937
compiler_config=getattr(self, "compiler_config", None),
29042938
tensorflow_version=getattr(self, "tensorflow_version", None),
29052939
pytorch_version=getattr(self, "pytorch_version", None),
2906-
instance_type=self.instance_type,
2940+
instance_type=self._get_instance_type(),
29072941
)
29082942

29092943
@classmethod

tests/unit/test_estimator.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,68 @@ def test_invalid_custom_code_bucket(sagemaker_session):
13341334
assert "Expecting 's3' scheme" in str(error)
13351335

13361336

1337+
def test_get_instance_type_gpu(sagemaker_session):
1338+
estimator = Estimator(
1339+
image_uri="some-image",
1340+
role="some_image",
1341+
instance_groups=[
1342+
InstanceGroup("group1", "ml.c4.xlarge", 1),
1343+
InstanceGroup("group2", "ml.p3.16xlarge", 2),
1344+
],
1345+
sagemaker_session=sagemaker_session,
1346+
base_job_name="base_job_name",
1347+
)
1348+
1349+
assert "ml.p3.16xlarge" == estimator._get_instance_type()
1350+
1351+
1352+
def test_get_instance_type_cpu(sagemaker_session):
1353+
estimator = Estimator(
1354+
image_uri="some-image",
1355+
role="some_image",
1356+
instance_groups=[
1357+
InstanceGroup("group1", "ml.c4.xlarge", 1),
1358+
InstanceGroup("group2", "ml.c5.xlarge", 2),
1359+
],
1360+
sagemaker_session=sagemaker_session,
1361+
base_job_name="base_job_name",
1362+
)
1363+
1364+
assert "ml.c4.xlarge" == estimator._get_instance_type()
1365+
1366+
1367+
def test_get_instance_type_no_instance_groups(sagemaker_session):
1368+
estimator = Estimator(
1369+
image_uri="some-image",
1370+
role="some_image",
1371+
instance_type="ml.c4.xlarge",
1372+
instance_count=1,
1373+
sagemaker_session=sagemaker_session,
1374+
base_job_name="base_job_name",
1375+
)
1376+
1377+
assert "ml.c4.xlarge" == estimator._get_instance_type()
1378+
1379+
1380+
def test_get_instance_type_no_instance_groups_or_instance_type(sagemaker_session):
1381+
estimator = Estimator(
1382+
image_uri="some-image",
1383+
role="some_image",
1384+
instance_type=None,
1385+
instance_count=None,
1386+
instance_groups=None,
1387+
sagemaker_session=sagemaker_session,
1388+
base_job_name="base_job_name",
1389+
)
1390+
with pytest.raises(ValueError) as error:
1391+
estimator._get_instance_type()
1392+
1393+
assert (
1394+
"instance_groups must be set if instance_type is not set and instance_groups must be a list."
1395+
in str(error)
1396+
)
1397+
1398+
13371399
def test_augmented_manifest(sagemaker_session):
13381400
fw = DummyFramework(
13391401
entry_point=SCRIPT_PATH,

0 commit comments

Comments
 (0)