Skip to content

Commit 7511dee

Browse files
authored
Merge branch 'reinvent' into latest
2 parents 036def9 + 3988a09 commit 7511dee

File tree

99 files changed

+49870
-576
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

99 files changed

+49870
-576
lines changed

.flake8

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
[flake8]
22
application_import_names = sagemaker, tests
33
import-order-style = google
4+
per-file-ignores =
5+
tests/unit/test_tuner.py: F405

buildspec.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ version: 0.2
33
phases:
44
build:
55
commands:
6+
# TODO-reinvent-2019 [akarpur]: Remove this (adding internal boto models)
7+
- aws configure add-model --service-model file://./tests/data/boto_models/sagemaker/2017-07-24/normal.json --service-name sagemaker
8+
69
- IGNORE_COVERAGE=-
710

811
# run integration tests
@@ -16,7 +19,7 @@ phases:
1619
- start_time=`date +%s`
1720
- |
1821
if has-matching-changes "tests/" "src/*.py" "setup.py" "setup.cfg" "buildspec.yml"; then
19-
tox -e py36 -- tests/integ -m "not local_mode" -n 48 --reruns 3 --reruns-delay 5 --durations 50 --boto-config '{"region_name": "us-east-1"}'
22+
tox -e py36 -- tests/integ -m "not local_mode" -n 48 --reruns 3 --reruns-delay 5 --durations 50 --boto-config '{"region_name": "us-east-2"}'
2023
fi
2124
- ./ci-scripts/displaytime.sh 'py36 tests/integ' $start_time
2225

doc/analytics.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,8 @@ Analytics
1515
:members:
1616
:undoc-members:
1717
:show-inheritance:
18+
19+
.. autoclass:: sagemaker.analytics.ExperimentAnalytics
20+
:members:
21+
:undoc-members:
22+
:show-inheritance:

doc/conf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def __getattr__(cls, name):
4141
"tensorflow.python.framework",
4242
"tensorflow_serving",
4343
"tensorflow_serving.apis",
44-
"numpy",
4544
"scipy",
4645
"scipy.sparse",
4746
]

doc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,5 @@ SageMaker APIs to export configurations for creating and managing Airflow workfl
164164
:maxdepth: 2
165165

166166
sagemaker.workflow.airflow
167+
168+

src/sagemaker/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from sagemaker.model import Model, ModelPackage # noqa: F401
5151
from sagemaker.pipeline import PipelineModel # noqa: F401
5252
from sagemaker.predictor import RealTimePredictor # noqa: F401
53+
from sagemaker.processing import Processor, ScriptProcessor # noqa: F401
5354
from sagemaker.session import Session # noqa: F401
5455
from sagemaker.session import container_def, pipeline_container_def # noqa: F401
5556
from sagemaker.session import production_variant # noqa: F401

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from sagemaker.model import NEO_IMAGE_ACCOUNT
2828
from sagemaker.session import s3_input
2929
from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
30+
from sagemaker.xgboost.defaults import XGBOOST_VERSION_1, XGBOOST_SUPPORTED_VERSIONS
3031
from sagemaker.xgboost.estimator import get_xgboost_image_uri
31-
from sagemaker.xgboost.defaults import XGBOOST_LATEST_VERSION
3232

3333
logger = logging.getLogger(__name__)
3434

@@ -178,7 +178,15 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
178178
self.feature_dim = feature_dim
179179
self.mini_batch_size = mini_batch_size
180180

181-
def fit(self, records, mini_batch_size=None, wait=True, logs=True, job_name=None):
181+
def fit(
182+
self,
183+
records,
184+
mini_batch_size=None,
185+
wait=True,
186+
logs=True,
187+
job_name=None,
188+
experiment_config=None,
189+
):
182190
"""Fit this Estimator on serialized Record objects, stored in S3.
183191
184192
``records`` should be an instance of :class:`~RecordSet`. This
@@ -206,10 +214,16 @@ def fit(self, records, mini_batch_size=None, wait=True, logs=True, job_name=None
206214
job_name (str): Training job name. If not specified, the estimator
207215
generates a default job name, based on the training image name
208216
and current timestamp.
217+
experiment_config (dict[str, str]): Experiment management configuration.
218+
Dictionary contains three optional keys, 'ExperimentName',
219+
'TrialName', and 'TrialComponentName'
220+
(default: ``None``).
209221
"""
210222
self._prepare_for_training(records, job_name=job_name, mini_batch_size=mini_batch_size)
211223

212-
self.latest_training_job = _TrainingJob.start_new(self, records)
224+
self.latest_training_job = _TrainingJob.start_new(
225+
self, records, experiment_config=experiment_config
226+
)
213227
if wait:
214228
self.latest_training_job.wait(logs=logs)
215229

@@ -559,13 +573,23 @@ def get_image_uri(region_name, repo_name, repo_version=1):
559573
"""
560574
if repo_name == "xgboost":
561575
if repo_version in ["0.90", "0.90-1", "0.90-1-cpu-py3"]:
562-
return get_xgboost_image_uri(region_name, XGBOOST_LATEST_VERSION)
576+
return get_xgboost_image_uri(region_name, XGBOOST_VERSION_1)
577+
578+
supported_version = [
579+
version
580+
for version in XGBOOST_SUPPORTED_VERSIONS
581+
if repo_version in (version, version + "-cpu-py3")
582+
]
583+
if supported_version:
584+
return get_xgboost_image_uri(region_name, supported_version[0])
585+
563586
logging.warning(
564-
"There is a more up to date SageMaker XGBoost image."
587+
"There is a more up to date SageMaker XGBoost image. "
565588
"To use the newer image, please set 'repo_version'="
566-
"'0.90-1. For example:\n"
589+
"'%s'. For example:\n"
567590
"\tget_image_uri(region, 'xgboost', '%s').",
568-
XGBOOST_LATEST_VERSION,
591+
XGBOOST_VERSION_1,
592+
XGBOOST_VERSION_1,
569593
)
570594
repo = "{}:{}".format(repo_name, repo_version)
571595
return "{}/{}".format(registry(region_name, repo_name), repo)

src/sagemaker/amazon/linear_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def __init__(
366366
self.balance_multiclass_weights = balance_multiclass_weights
367367

368368
if self.predictor_type == "multiclass_classifier" and (
369-
num_classes is None or num_classes < 3
369+
num_classes is None or int(num_classes) < 3
370370
):
371371
raise ValueError(
372372
"For predictor_type 'multiclass_classifier', 'num_classes' should be set to a "

0 commit comments

Comments
 (0)