Skip to content

Commit c77fd6a

Browse files
author
Dan Choi
committed
add explicit pull for local serve
1 parent 507f2cd commit c77fd6a

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/sagemaker/local/image.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,3 +668,13 @@ def _ecr_login_if_needed(boto_session, image):
668668

669669
cmd = "docker login -u AWS -p %s %s" % (token, ecr_url)
670670
subprocess.check_output(cmd, shell=True)
671+
672+
return True
673+
674+
675+
def _pull_image(image):
676+
pull_image_command = ('docker pull %s' % image).strip()
677+
print('docker command: {}'.format(pull_image_command))
678+
679+
subprocess.check_output(pull_image_command, shell=True)
680+
print('image pulled: {}'.format(image))

tests/unit/test_image.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,13 +513,26 @@ def test_ecr_login_needed(check_output):
513513
}
514514
session_mock.client('ecr').get_authorization_token.return_value = response
515515
image = '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-need:1.1'
516-
sagemaker.local.image._ecr_login_if_needed(session_mock, image)
516+
result = sagemaker.local.image._ecr_login_if_needed(session_mock, image)
517517

518518
expected_command = 'docker login -u AWS -p %s https://520713654638.dkr.ecr.us-east-1.amazonaws.com' % token
519519

520520
check_output.assert_called_with(expected_command, shell=True)
521521
session_mock.client('ecr').get_authorization_token.assert_called_with(registryIds=['520713654638'])
522522

523+
assert result
524+
525+
526+
@patch('subprocess.check_output', return_value=''.encode('utf-8'))
527+
def test_pull_image(check_output):
528+
image = '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-need:1.1'
529+
530+
sagemaker.local.image._pull_image(image)
531+
532+
expected_command = 'docker pull %s' % image
533+
534+
check_output.assert_called_once_with(expected_command, shell=True)
535+
523536

524537
def test__aws_credentials_with_long_lived_credentials():
525538
credentials = Credentials(access_key=_random_string(), secret_key=_random_string(), token=None)

0 commit comments

Comments
 (0)