|
14 | 14 |
|
15 | 15 | import json
|
16 | 16 | import logging
|
| 17 | +import os |
17 | 18 | from abc import ABCMeta
|
18 | 19 | from abc import abstractmethod
|
19 | 20 | from six import with_metaclass, string_types
|
20 | 21 |
|
21 |
| -from sagemaker.fw_utils import tar_and_upload_dir |
22 |
| -from sagemaker.fw_utils import parse_s3_url |
23 |
| -from sagemaker.fw_utils import UploadedCode |
24 |
| -from sagemaker.local.local_session import LocalSession, file_input |
| 22 | +from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir |
| 23 | +from sagemaker.local import LocalSession, file_input |
25 | 24 |
|
26 | 25 | from sagemaker.model import Model
|
27 | 26 | from sagemaker.model import (SCRIPT_PARAM_NAME, DIR_PARAM_NAME, CLOUDWATCH_METRICS_PARAM_NAME,
|
|
30 | 29 | from sagemaker.predictor import RealTimePredictor
|
31 | 30 | from sagemaker.session import Session
|
32 | 31 | from sagemaker.session import s3_input
|
33 |
| -from sagemaker.utils import base_name_from_image, name_from_base |
| 32 | +from sagemaker.utils import base_name_from_image, name_from_base, get_config_value |
34 | 33 |
|
35 | 34 |
|
36 | 35 | class EstimatorBase(with_metaclass(ABCMeta, object)):
|
@@ -83,13 +82,10 @@ def __init__(self, role, train_instance_count, train_instance_type,
|
83 | 82 | self.input_mode = input_mode
|
84 | 83 |
|
85 | 84 | if self.train_instance_type in ('local', 'local_gpu'):
|
86 |
| - self.local_mode = True |
87 | 85 | if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
|
88 | 86 | raise RuntimeError("Distributed Training in Local GPU is not supported")
|
89 |
| - |
90 | 87 | self.sagemaker_session = sagemaker_session or LocalSession()
|
91 | 88 | else:
|
92 |
| - self.local_mode = False |
93 | 89 | self.sagemaker_session = sagemaker_session or Session()
|
94 | 90 |
|
95 | 91 | self.base_job_name = base_job_name
|
@@ -158,9 +154,14 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
|
158 | 154 | base_name = self.base_job_name or base_name_from_image(self.train_image())
|
159 | 155 | self._current_job_name = name_from_base(base_name)
|
160 | 156 |
|
161 |
| - # if output_path was specified we use it otherwise initialize here |
| 157 | + # if output_path was specified we use it otherwise initialize here. |
| 158 | + # For Local Mode with local_code=True we don't need an explicit output_path |
162 | 159 | if self.output_path is None:
|
163 |
| - self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket()) |
| 160 | + local_code = get_config_value('local.local_code', self.sagemaker_session.config) |
| 161 | + if self.sagemaker_session.local_mode and local_code: |
| 162 | + self.output_path = '' |
| 163 | + else: |
| 164 | + self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket()) |
164 | 165 |
|
165 | 166 | self.latest_training_job = _TrainingJob.start_new(self, inputs)
|
166 | 167 | if wait:
|
@@ -323,7 +324,7 @@ def start_new(cls, estimator, inputs):
|
323 | 324 | sagemaker.estimator.Framework: Constructed object that captures all information about the started job.
|
324 | 325 | """
|
325 | 326 |
|
326 |
| - local_mode = estimator.local_mode |
| 327 | + local_mode = estimator.sagemaker_session.local_mode |
327 | 328 |
|
328 | 329 | # Allow file:// input only in local mode
|
329 | 330 | if isinstance(inputs, str) and inputs.startswith('file://'):
|
@@ -604,27 +605,54 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
|
604 | 605 | base_name = self.base_job_name or base_name_from_image(self.train_image())
|
605 | 606 | self._current_job_name = name_from_base(base_name)
|
606 | 607 |
|
| 608 | + # validate source dir will raise a ValueError if there is something wrong with the |
| 609 | + # source directory. We are intentionally not handling it because this is a critical error. |
| 610 | + if self.source_dir and not self.source_dir.lower().startswith('s3://'): |
| 611 | + validate_source_dir(self.entry_point, self.source_dir) |
| 612 | + |
| 613 | + # if we are in local mode with local_code=True. We want the container to just |
| 614 | + # mount the source dir instead of uploading to S3. |
| 615 | + local_code = get_config_value('local.local_code', self.sagemaker_session.config) |
| 616 | + if self.sagemaker_session.local_mode and local_code: |
| 617 | + # if there is no source dir, use the directory containing the entry point. |
| 618 | + if self.source_dir is None: |
| 619 | + self.source_dir = os.path.dirname(self.entry_point) |
| 620 | + self.entry_point = os.path.basename(self.entry_point) |
| 621 | + |
| 622 | + code_dir = 'file://' + self.source_dir |
| 623 | + script = self.entry_point |
| 624 | + else: |
| 625 | + self.uploaded_code = self._stage_user_code_in_s3() |
| 626 | + code_dir = self.uploaded_code.s3_prefix |
| 627 | + script = self.uploaded_code.script_name |
| 628 | + |
| 629 | + # Modify hyperparameters in-place to point to the right code directory and script URIs |
| 630 | + self._hyperparameters[DIR_PARAM_NAME] = code_dir |
| 631 | + self._hyperparameters[SCRIPT_PARAM_NAME] = script |
| 632 | + self._hyperparameters[CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics |
| 633 | + self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level |
| 634 | + self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name |
| 635 | + self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name |
| 636 | + super(Framework, self).fit(inputs, wait, logs, self._current_job_name) |
| 637 | + |
| 638 | + def _stage_user_code_in_s3(self): |
| 639 | + """ Upload the user training script to s3 and return the location. |
| 640 | +
|
| 641 | + Returns: s3 uri |
| 642 | +
|
| 643 | + """ |
607 | 644 | if self.code_location is None:
|
608 | 645 | code_bucket = self.sagemaker_session.default_bucket()
|
609 | 646 | code_s3_prefix = '{}/source'.format(self._current_job_name)
|
610 | 647 | else:
|
611 | 648 | code_bucket, key_prefix = parse_s3_url(self.code_location)
|
612 | 649 | code_s3_prefix = '{}/{}/source'.format(key_prefix, self._current_job_name)
|
613 | 650 |
|
614 |
| - self.uploaded_code = tar_and_upload_dir(session=self.sagemaker_session.boto_session, |
615 |
| - bucket=code_bucket, |
616 |
| - s3_key_prefix=code_s3_prefix, |
617 |
| - script=self.entry_point, |
618 |
| - directory=self.source_dir) |
619 |
| - |
620 |
| - # Modify hyperparameters in-place to add the URLs to the uploaded code. |
621 |
| - self._hyperparameters[DIR_PARAM_NAME] = self.uploaded_code.s3_prefix |
622 |
| - self._hyperparameters[SCRIPT_PARAM_NAME] = self.uploaded_code.script_name |
623 |
| - self._hyperparameters[CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics |
624 |
| - self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level |
625 |
| - self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name |
626 |
| - self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_session.region_name |
627 |
| - super(Framework, self).fit(inputs, wait, logs, self._current_job_name) |
| 651 | + return tar_and_upload_dir(session=self.sagemaker_session.boto_session, |
| 652 | + bucket=code_bucket, |
| 653 | + s3_key_prefix=code_s3_prefix, |
| 654 | + script=self.entry_point, |
| 655 | + directory=self.source_dir) |
628 | 656 |
|
629 | 657 | def hyperparameters(self):
|
630 | 658 | """Return the hyperparameters as a dictionary to use for training.
|
|
0 commit comments