Skip to content

Commit a931ec0

Browse files
authored
feature: add estimator preparation to airflow configuration (#1053)
Also added airflow configuration tests, as those did not exist. This change also contains a number of airflow fixes, as the tests revealed a few bugs.
1 parent 4b0f5af commit a931ec0

File tree

9 files changed

+785
-3
lines changed

9 files changed

+785
-3
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def read_version():
8686
"pandas",
8787
"black==19.3b0 ; python_version >= '3.6'",
8888
"stopit==1.1.2",
89+
"apache-airflow==1.10.5",
8990
]
9091
},
9192
entry_points={"console_scripts": ["sagemaker=sagemaker.cli.main:main"]},

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,21 @@ class constructor
127127
del init_params["image"]
128128
return init_params
129129

130+
def prepare_workflow_for_training(self, records=None, mini_batch_size=None, job_name=None):
131+
"""Calls _prepare_for_training. Used when setting up a workflow.
132+
133+
Args:
134+
records (:class:`~RecordSet`): The records to train this ``Estimator`` on.
135+
mini_batch_size (int or None): The size of each mini-batch to use when
136+
training. If ``None``, a default value will be used.
137+
job_name (str): Name of the training job to be created. If not
138+
specified, one is generated, using the base name given to the
139+
constructor if applicable.
140+
"""
141+
self._prepare_for_training(
142+
records=records, mini_batch_size=mini_batch_size, job_name=job_name
143+
)
144+
130145
def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
131146
"""Set hyperparameters needed for training.
132147

src/sagemaker/estimator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,16 @@ def enable_network_isolation(self):
261261
"""
262262
return False
263263

264+
def prepare_workflow_for_training(self, job_name=None):
265+
"""Calls _prepare_for_training. Used when setting up a workflow.
266+
267+
Args:
268+
job_name (str): Name of the training job to be created. If not
269+
specified, one is generated, using the base name given to the
270+
constructor if applicable.
271+
"""
272+
self._prepare_for_training(job_name=job_name)
273+
264274
def _prepare_for_training(self, job_name=None):
265275
"""Set any values in the estimator that need to be set before training.
266276

src/sagemaker/tensorflow/estimator.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from sagemaker.tensorflow.defaults import TF_VERSION
2828
from sagemaker.tensorflow.model import TensorFlowModel
2929
from sagemaker.tensorflow.serving import Model
30+
from sagemaker.transformer import Transformer
3031
from sagemaker import utils
3132
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
3233

@@ -755,8 +756,31 @@ def transformer(
755756
* 'Subnets' (list[str]): List of subnet ids.
756757
* 'SecurityGroupIds' (list[str]): List of security group ids.
757758
"""
758-
759759
role = role or self.role
760+
761+
if self.latest_training_job is None:
762+
logging.warning(
763+
"No finished training job found associated with this estimator. Please make sure "
764+
"this estimator is only used for building workflow config"
765+
)
766+
return Transformer(
767+
self._current_job_name,
768+
instance_count,
769+
instance_type,
770+
strategy=strategy,
771+
assemble_with=assemble_with,
772+
output_path=output_path,
773+
output_kms_key=output_kms_key,
774+
accept=accept,
775+
max_concurrent_transforms=max_concurrent_transforms,
776+
max_payload=max_payload,
777+
env=env or {},
778+
tags=tags,
779+
base_transform_job_name=self.base_job_name,
780+
volume_kms_key=volume_kms_key,
781+
sagemaker_session=self.sagemaker_session,
782+
)
783+
760784
model = self.create_model(
761785
model_server_workers=model_server_workers,
762786
role=role,

src/sagemaker/workflow/airflow.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,14 @@ def prepare_framework(estimator, s3_operations):
3434
if estimator.code_location is not None:
3535
bucket, key = fw_utils.parse_s3_url(estimator.code_location)
3636
key = os.path.join(key, estimator._current_job_name, "source", "sourcedir.tar.gz")
37+
elif estimator.uploaded_code is not None:
38+
bucket, key = fw_utils.parse_s3_url(estimator.uploaded_code.s3_prefix)
3739
else:
3840
bucket = estimator.sagemaker_session._default_bucket
3941
key = os.path.join(estimator._current_job_name, "source", "sourcedir.tar.gz")
42+
4043
script = os.path.basename(estimator.entry_point)
44+
4145
if estimator.source_dir and estimator.source_dir.lower().startswith("s3://"):
4246
code_dir = estimator.source_dir
4347
estimator.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script)
@@ -96,7 +100,7 @@ def prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size=None):
96100
estimator.mini_batch_size = mini_batch_size
97101

98102

99-
def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=None):
103+
def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=None): # noqa: C901
100104
"""Export Airflow base training config from an estimator
101105
102106
Args:
@@ -134,6 +138,13 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
134138
dict: Training config that can be directly used by
135139
SageMakerTrainingOperator in Airflow.
136140
"""
141+
if isinstance(estimator, sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase):
142+
estimator.prepare_workflow_for_training(
143+
records=inputs, mini_batch_size=mini_batch_size, job_name=job_name
144+
)
145+
else:
146+
estimator.prepare_workflow_for_training(job_name=job_name)
147+
137148
default_bucket = estimator.sagemaker_session.default_bucket()
138149
s3_operations = {}
139150

@@ -528,6 +539,7 @@ def model_config_from_estimator(
528539
model_server_workers=model_server_workers,
529540
role=role,
530541
vpc_config_override=vpc_config_override,
542+
entry_point=estimator.entry_point,
531543
)
532544
else:
533545
raise TypeError(

src/sagemaker/xgboost/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ def create_model(
145145
See :func:`~sagemaker.xgboost.model.XGBoostModel` for full details.
146146
"""
147147
role = role or self.role
148+
149+
# Remove unwanted entry_point kwarg
150+
if "entry_point" in kwargs:
151+
logger.debug("Removing unused entry_point argument: %s", str(kwargs["entry_point"]))
152+
kwargs = {k: v for k, v in kwargs.items() if k != "entry_point"}
153+
148154
return XGBoostModel(
149155
self.model_data,
150156
role,

0 commit comments

Comments
 (0)