Skip to content

Commit 04f8abf

Browse files
move sagemaker_s3_output to model class (#560)
1 parent 2ad6c1d commit 04f8abf

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ CHANGELOG
77

88
* bug-fix: Append retry id to default Airflow job name to avoid name collisions in retry
99
* bug-fix: Local Mode: No longer requires s3 permissions to run local entry point file
10+
* bug-fix: Local Mode: Move dependency on sagemaker_s3_output from rl.estimator to model
1011

1112
1.16.2
1213
======

src/sagemaker/local/image.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,7 @@ def _prepare_training_volumes(self, data_dir, input_data_config, output_data_con
317317
volumes.append(_Volume(shared_dir, '/opt/ml/shared'))
318318

319319
parsed_uri = urlparse(output_data_config['S3OutputPath'])
320-
if parsed_uri.scheme == 'file' \
321-
and sagemaker.rl.estimator.SAGEMAKER_OUTPUT_LOCATION in hyperparameters:
320+
if parsed_uri.scheme == 'file' and sagemaker.model.SAGEMAKER_OUTPUT_LOCATION in hyperparameters:
322321
intermediate_dir = os.path.join(parsed_uri.path, 'output', 'intermediate')
323322
if not os.path.exists(intermediate_dir):
324323
os.makedirs(intermediate_dir)

src/sagemaker/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
257257
JOB_NAME_PARAM_NAME = 'sagemaker_job_name'
258258
MODEL_SERVER_WORKERS_PARAM_NAME = 'sagemaker_model_server_workers'
259259
SAGEMAKER_REGION_PARAM_NAME = 'sagemaker_region'
260+
SAGEMAKER_OUTPUT_LOCATION = 'sagemaker_s3_output'
260261

261262

262263
class FrameworkModel(Model):

src/sagemaker/rl/estimator.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from sagemaker.estimator import Framework
2020
import sagemaker.fw_utils as fw_utils
21-
from sagemaker.model import FrameworkModel
21+
from sagemaker.model import FrameworkModel, SAGEMAKER_OUTPUT_LOCATION
2222
from sagemaker.mxnet.model import MXNetModel
2323
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2424

@@ -27,7 +27,6 @@
2727

2828
SAGEMAKER_ESTIMATOR = 'sagemaker_estimator'
2929
SAGEMAKER_ESTIMATOR_VALUE = 'RLEstimator'
30-
SAGEMAKER_OUTPUT_LOCATION = 'sagemaker_s3_output'
3130
PYTHON_VERSION = 'py3'
3231
TOOLKIT_FRAMEWORK_VERSION_MAP = {
3332
'coach': {

0 commit comments

Comments
 (0)