Skip to content

Commit 53bf53a

Browse files
ChoiByungWookjesterhazy
authored andcommitted
add explicit pull for local serve (#455)
* add explicit pull for local serve
1 parent beece5a commit 53bf53a

File tree

3 files changed

+36
-7
lines changed

3 files changed

+36
-7
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ CHANGELOG
88
* feature: Estimators: dependencies attribute allows export of additional libraries into the container
99
* feature: Add APIs to export Airflow transform and deploy config
1010
* bug-fix: Allow code_location argument to be S3 URI in training_config API
11+
* enhancement: Local Mode: add explicit pull for serving
1112

1213
1.15.0
1314
======

src/sagemaker/local/image.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ def train(self, input_data_config, output_data_config, hyperparameters, job_name
118118
additional_env_vars=training_env_vars)
119119
compose_command = self._compose()
120120

121-
_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image)
121+
if _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image):
122+
_pull_image(self.image)
123+
122124
process = subprocess.Popen(compose_command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
123125

124126
try:
@@ -164,7 +166,8 @@ def serve(self, model_dir, environment):
164166
if parsed_uri.scheme == 'file':
165167
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))
166168

167-
_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image)
169+
if _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image):
170+
_pull_image(self.image)
168171

169172
self._generate_compose_file('serve',
170173
additional_env_vars=environment,
@@ -656,11 +659,11 @@ def _write_json_file(filename, content):
656659
def _ecr_login_if_needed(boto_session, image):
657660
# Only ECR images need login
658661
if not ('dkr.ecr' in image and 'amazonaws.com' in image):
659-
return
662+
return False
660663

661664
# do we have the image?
662665
if _check_output('docker images -q %s' % image).strip():
663-
return
666+
return False
664667

665668
if not boto_session:
666669
raise RuntimeError('A boto session is required to login to ECR.'
@@ -676,3 +679,13 @@ def _ecr_login_if_needed(boto_session, image):
676679

677680
cmd = "docker login -u AWS -p %s %s" % (token, ecr_url)
678681
subprocess.check_output(cmd, shell=True)
682+
683+
return True
684+
685+
686+
def _pull_image(image):
687+
pull_image_command = ('docker pull %s' % image).strip()
688+
logger.info('docker command: {}'.format(pull_image_command))
689+
690+
subprocess.check_output(pull_image_command, shell=True)
691+
logger.info('image pulled: {}'.format(image))

tests/unit/test_image.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,20 +537,22 @@ def test_prepare_serving_volumes_with_local_model(get_data_source_instance, sage
537537

538538
def test_ecr_login_non_ecr():
539539
session_mock = Mock()
540-
sagemaker.local.image._ecr_login_if_needed(session_mock, 'ubuntu')
540+
result = sagemaker.local.image._ecr_login_if_needed(session_mock, 'ubuntu')
541541

542542
session_mock.assert_not_called()
543+
assert result is False
543544

544545

545546
@patch('sagemaker.local.image._check_output', return_value='123451324')
546547
def test_ecr_login_image_exists(_check_output):
547548
session_mock = Mock()
548549

549550
image = '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-have:1.0'
550-
sagemaker.local.image._ecr_login_if_needed(session_mock, image)
551+
result = sagemaker.local.image._ecr_login_if_needed(session_mock, image)
551552

552553
session_mock.assert_not_called()
553554
_check_output.assert_called()
555+
assert result is False
554556

555557

556558
@patch('subprocess.check_output', return_value=''.encode('utf-8'))
@@ -577,13 +579,26 @@ def test_ecr_login_needed(check_output):
577579
}
578580
session_mock.client('ecr').get_authorization_token.return_value = response
579581
image = '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-need:1.1'
580-
sagemaker.local.image._ecr_login_if_needed(session_mock, image)
582+
result = sagemaker.local.image._ecr_login_if_needed(session_mock, image)
581583

582584
expected_command = 'docker login -u AWS -p %s https://520713654638.dkr.ecr.us-east-1.amazonaws.com' % token
583585

584586
check_output.assert_called_with(expected_command, shell=True)
585587
session_mock.client('ecr').get_authorization_token.assert_called_with(registryIds=['520713654638'])
586588

589+
assert result is True
590+
591+
592+
@patch('subprocess.check_output', return_value=''.encode('utf-8'))
593+
def test_pull_image(check_output):
594+
image = '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-need:1.1'
595+
596+
sagemaker.local.image._pull_image(image)
597+
598+
expected_command = 'docker pull %s' % image
599+
600+
check_output.assert_called_once_with(expected_command, shell=True)
601+
587602

588603
def test__aws_credentials_with_long_lived_credentials():
589604
credentials = Credentials(access_key=_random_string(), secret_key=_random_string(), token=None)

0 commit comments

Comments
 (0)