@@ -106,7 +106,9 @@ def train(self, input_data_config, output_data_config, hyperparameters, job_name
106
106
data_dir = self ._create_tmp_folder ()
107
107
volumes = self ._prepare_training_volumes (data_dir , input_data_config , output_data_config ,
108
108
hyperparameters )
109
-
109
+ # If local, source directory needs to be updated to mounted /opt/ml/code path
110
+ hyperparameters = self ._update_local_src_path (hyperparameters , key = sagemaker .estimator .DIR_PARAM_NAME )
111
+
110
112
# Create the configuration files for each container that we will create
111
113
# Each container will map the additional local volumes (if any).
112
114
for host in self .hosts :
@@ -169,6 +171,9 @@ def serve(self, model_dir, environment):
169
171
parsed_uri = urlparse (script_dir )
170
172
if parsed_uri .scheme == 'file' :
171
173
volumes .append (_Volume (parsed_uri .path , '/opt/ml/code' ))
174
+ # Update path to mount location
175
+ environment = environment .copy ()
176
+ environment [sagemaker .estimator .DIR_PARAM_NAME .upper ()] = '/opt/ml/code'
172
177
173
178
if _ecr_login_if_needed (self .sagemaker_session .boto_session , self .image ):
174
179
_pull_image (self .image )
@@ -302,7 +307,7 @@ def _prepare_training_volumes(self, data_dir, input_data_config, output_data_con
302
307
volumes .append (_Volume (data_source .get_root_dir (), channel = channel_name ))
303
308
304
309
# If there is a training script directory and it is a local directory,
305
- # mount it to the container.
310
+ # mount it to the container.
306
311
if sagemaker .estimator .DIR_PARAM_NAME in hyperparameters :
307
312
training_dir = json .loads (hyperparameters [sagemaker .estimator .DIR_PARAM_NAME ])
308
313
parsed_uri = urlparse (training_dir )
@@ -321,6 +326,16 @@ def _prepare_training_volumes(self, data_dir, input_data_config, output_data_con
321
326
322
327
return volumes
323
328
329
+ def _update_local_src_path (self , params , key ):
330
+ if key in params :
331
+ src_dir = json .loads (params [key ])
332
+ parsed_uri = urlparse (src_dir )
333
+ if parsed_uri .scheme == 'file' :
334
+ new_params = params .copy ()
335
+ new_params [key ] = '/opt/ml/code'
336
+ return new_params
337
+ return params
338
+
324
339
def _prepare_serving_volumes (self , model_location ):
325
340
volumes = []
326
341
host = self .hosts [0 ]
0 commit comments