Skip to content

Commit fb7269d

Browse files
committed
Fixing container path for source files when using local_session
Copying env variable instead Adding tests forl local code
1 parent 6555f09 commit fb7269d

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ CHANGELOG
1414
* feature: Local Mode: Add support for intermediate output to a local directory.
1515
* bug-fix: Update PyYAML version to avoid conflicts with docker-compose
1616
* doc-fix: Correct the numbered list in the table of contents
17+
* bug-fix: Local Mode: No longer requires s3 permissions to run local entry point file
1718
* doc-fix: Add Airflow API documentation
1819

1920
1.16.1.post1

src/sagemaker/local/image.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ def train(self, input_data_config, output_data_config, hyperparameters, job_name
106106
data_dir = self._create_tmp_folder()
107107
volumes = self._prepare_training_volumes(data_dir, input_data_config, output_data_config,
108108
hyperparameters)
109-
109+
# If local, source directory needs to be updated to mounted /opt/ml/code path
110+
hyperparameters = self._update_local_src_path(hyperparameters, key=sagemaker.estimator.DIR_PARAM_NAME)
111+
110112
# Create the configuration files for each container that we will create
111113
# Each container will map the additional local volumes (if any).
112114
for host in self.hosts:
@@ -169,6 +171,9 @@ def serve(self, model_dir, environment):
169171
parsed_uri = urlparse(script_dir)
170172
if parsed_uri.scheme == 'file':
171173
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))
174+
# Update path to mount location
175+
environment = environment.copy()
176+
environment[sagemaker.estimator.DIR_PARAM_NAME.upper()] = '/opt/ml/code'
172177

173178
if _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image):
174179
_pull_image(self.image)
@@ -302,7 +307,7 @@ def _prepare_training_volumes(self, data_dir, input_data_config, output_data_con
302307
volumes.append(_Volume(data_source.get_root_dir(), channel=channel_name))
303308

304309
# If there is a training script directory and it is a local directory,
305-
# mount it to the container.
310+
# mount it to the container.
306311
if sagemaker.estimator.DIR_PARAM_NAME in hyperparameters:
307312
training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME])
308313
parsed_uri = urlparse(training_dir)
@@ -321,6 +326,16 @@ def _prepare_training_volumes(self, data_dir, input_data_config, output_data_con
321326

322327
return volumes
323328

329+
def _update_local_src_path(self, params, key):
330+
if key in params:
331+
src_dir = json.loads(params[key])
332+
parsed_uri = urlparse(src_dir)
333+
if parsed_uri.scheme == 'file':
334+
new_params = params.copy()
335+
new_params[key] = '/opt/ml/code'
336+
return new_params
337+
return params
338+
324339
def _prepare_serving_volumes(self, model_location):
325340
volumes = []
326341
host = self.hosts[0]

tests/unit/test_image.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -388,12 +388,18 @@ def test_train_local_code(tmpdir, sagemaker_session):
388388
with open(docker_compose_file, 'r') as f:
389389
config = yaml.load(f)
390390
assert len(config['services']) == instance_count
391-
for h in sagemaker_container.hosts:
392-
assert config['services'][h]['image'] == image
393-
assert config['services'][h]['command'] == 'train'
394-
volumes = config['services'][h]['volumes']
395-
assert '%s:/opt/ml/code' % '/tmp/code' in volumes
396-
assert '%s:/opt/ml/shared' % shared_folder_path in volumes
391+
392+
for h in sagemaker_container.hosts:
393+
assert config['services'][h]['image'] == image
394+
assert config['services'][h]['command'] == 'train'
395+
volumes = config['services'][h]['volumes']
396+
assert '%s:/opt/ml/code' % '/tmp/code' in volumes
397+
assert '%s:/opt/ml/shared' % shared_folder_path in volumes
398+
399+
config_file_root = os.path.join(sagemaker_container.container_root, h, 'input', 'config')
400+
hyperparameters_file = os.path.join(config_file_root, 'hyperparameters.json')
401+
hyperparameters_data = json.load(open(hyperparameters_file))
402+
assert hyperparameters_data['sagemaker_submit_directory'] == '/opt/ml/code'
397403

398404

399405
@patch('sagemaker.local.local_session.LocalSession', Mock())
@@ -506,6 +512,7 @@ def test_serve_local_code(tmpdir, sagemaker_session):
506512

507513
volumes = config['services'][h]['volumes']
508514
assert '%s:/opt/ml/code' % '/tmp/code' in volumes
515+
assert 'SAGEMAKER_SUBMIT_DIRECTORY=/opt/ml/code' in config['services'][h]['environment']
509516

510517

511518
@patch('sagemaker.local.image._HostingContainer.run', Mock())

0 commit comments

Comments
 (0)