diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6afdf8f4c9..3ab41cdb45 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,6 +8,7 @@ CHANGELOG * feature: Estimators: dependencies attribute allows export of additional libraries into the container * feature: Add APIs to export Airflow transform and deploy config * bug-fix: Allow code_location argument to be S3 URI in training_config API +* enhancement: Local Mode: add explicit pull for serving 1.15.0 ====== diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 56196639d7..8416e47048 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -118,7 +118,9 @@ def train(self, input_data_config, output_data_config, hyperparameters, job_name additional_env_vars=training_env_vars) compose_command = self._compose() - _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image) + if _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image): + _pull_image(self.image) + process = subprocess.Popen(compose_command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) try: @@ -164,7 +166,8 @@ def serve(self, model_dir, environment): if parsed_uri.scheme == 'file': volumes.append(_Volume(parsed_uri.path, '/opt/ml/code')) - _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image) + if _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image): + _pull_image(self.image) self._generate_compose_file('serve', additional_env_vars=environment, @@ -656,11 +659,11 @@ def _write_json_file(filename, content): def _ecr_login_if_needed(boto_session, image): # Only ECR images need login if not ('dkr.ecr' in image and 'amazonaws.com' in image): - return + return False # do we have the image? if _check_output('docker images -q %s' % image).strip(): - return + return False if not boto_session: raise RuntimeError('A boto session is required to login to ECR.' @@ -676,3 +679,13 @@ def _ecr_login_if_needed(boto_session, image): cmd = "docker login -u AWS -p %s %s" % (token, ecr_url) subprocess.check_output(cmd, shell=True) + + return True + + +def _pull_image(image): + pull_image_command = ('docker pull %s' % image).strip() + logger.info('docker command: {}'.format(pull_image_command)) + + subprocess.check_output(pull_image_command, shell=True) + logger.info('image pulled: {}'.format(image)) diff --git a/tests/unit/test_image.py b/tests/unit/test_image.py index 1d5f62b8a2..5a2cbb39fe 100644 --- a/tests/unit/test_image.py +++ b/tests/unit/test_image.py @@ -537,9 +537,10 @@ def test_prepare_serving_volumes_with_local_model(get_data_source_instance, sage def test_ecr_login_non_ecr(): session_mock = Mock() - sagemaker.local.image._ecr_login_if_needed(session_mock, 'ubuntu') + result = sagemaker.local.image._ecr_login_if_needed(session_mock, 'ubuntu') session_mock.assert_not_called() + assert result is False @patch('sagemaker.local.image._check_output', return_value='123451324') @@ -547,10 +548,11 @@ def test_ecr_login_image_exists(_check_output): session_mock = Mock() image = '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-have:1.0' - sagemaker.local.image._ecr_login_if_needed(session_mock, image) + result = sagemaker.local.image._ecr_login_if_needed(session_mock, image) session_mock.assert_not_called() _check_output.assert_called() + assert result is False @patch('subprocess.check_output', return_value=''.encode('utf-8')) @@ -577,13 +579,26 @@ def test_ecr_login_needed(check_output): } session_mock.client('ecr').get_authorization_token.return_value = response image = '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-need:1.1' - sagemaker.local.image._ecr_login_if_needed(session_mock, image) + result = sagemaker.local.image._ecr_login_if_needed(session_mock, image) expected_command = 'docker login -u AWS -p %s https://520713654638.dkr.ecr.us-east-1.amazonaws.com' % token check_output.assert_called_with(expected_command, shell=True) session_mock.client('ecr').get_authorization_token.assert_called_with(registryIds=['520713654638']) + assert result is True + + +@patch('subprocess.check_output', return_value=''.encode('utf-8')) +def test_pull_image(check_output): + image = '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-need:1.1' + + sagemaker.local.image._pull_image(image) + + expected_command = 'docker pull %s' % image + + check_output.assert_called_once_with(expected_command, shell=True) + def test__aws_credentials_with_long_lived_credentials(): credentials = Credentials(access_key=_random_string(), secret_key=_random_string(), token=None)