Skip to content

change: removing unnecessary tests cases #951

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions tests/integ/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,12 @@
import sagemaker.utils
import tests.integ as integ
from sagemaker.tensorflow import TensorFlow
from tests.integ import test_region, timeout, HOSTING_NO_P3_REGIONS
from tests.integ import timeout

horovod_dir = os.path.join(os.path.dirname(__file__), "..", "data", "horovod")


@pytest.fixture(
scope="session",
params=[
"ml.c4.xlarge",
pytest.param(
"ml.p3.2xlarge",
marks=pytest.mark.skipif(
test_region() in HOSTING_NO_P3_REGIONS, reason="no ml.p3 instances in this region"
),
),
],
)
@pytest.fixture(scope="session", params=["ml.c4.xlarge"])
def instance_type(request):
return request.param

Expand Down
38 changes: 0 additions & 38 deletions tests/integ/test_pytorch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
from __future__ import absolute_import

import os
import time

import numpy
import pytest
import tests.integ
from tests.integ import DATA_DIR, PYTHON_VERSION, TRAINING_DEFAULT_TIMEOUT_MINUTES
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name

Expand Down Expand Up @@ -80,42 +78,6 @@ def test_deploy_model(pytorch_training_job, sagemaker_session):
assert output.shape == (batch_size, 10)


@pytest.mark.skipif(
tests.integ.test_region() in tests.integ.HOSTING_NO_P2_REGIONS
or tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS,
reason="no ml.p2 instances in these regions",
)
def test_async_fit_deploy(sagemaker_session, pytorch_full_version):
training_job_name = ""
# TODO: add tests against local mode when it's ready to be used
instance_type = "ml.p2.xlarge"

with timeout(minutes=10):
pytorch = _get_pytorch_estimator(sagemaker_session, pytorch_full_version, instance_type)

pytorch.fit({"training": _upload_training_data(pytorch)}, wait=False)
training_job_name = pytorch.latest_training_job.name

print("Waiting to re-attach to the training job: %s" % training_job_name)
time.sleep(20)

if not _is_local_mode(instance_type):
endpoint_name = "test-pytorch-async-fit-attach-deploy-{}".format(sagemaker_timestamp())

with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
print("Re-attaching now to: %s" % training_job_name)
estimator = PyTorch.attach(
training_job_name=training_job_name, sagemaker_session=sagemaker_session
)
predictor = estimator.deploy(1, instance_type, endpoint_name=endpoint_name)

batch_size = 100
data = numpy.random.rand(batch_size, 1, 28, 28).astype(numpy.float32)
output = predictor.predict(data)

assert output.shape == (batch_size, 10)


def _upload_training_data(pytorch):
return pytorch.sagemaker_session.upload_data(
path=os.path.join(MNIST_DIR, "training"),
Expand Down
83 changes: 0 additions & 83 deletions tests/integ/test_tf_cifar.py

This file was deleted.

6 changes: 1 addition & 5 deletions tests/integ/test_tf_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
@pytest.mark.skipif(
tests.integ.PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2."
)
@pytest.mark.skipif(
tests.integ.test_region() in tests.integ.HOSTING_NO_P2_REGIONS,
reason="no ml.p2 instances in these regions",
)
def test_keras(sagemaker_session):
script_path = os.path.join(tests.integ.DATA_DIR, "cifar_10", "source")
dataset_path = os.path.join(tests.integ.DATA_DIR, "cifar_10", "data")
Expand Down Expand Up @@ -60,7 +56,7 @@ def test_keras(sagemaker_session):

endpoint_name = estimator.latest_training_job.name
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
predictor = estimator.deploy(initial_instance_count=1, instance_type="ml.p2.xlarge")
predictor = estimator.deploy(initial_instance_count=1, instance_type="ml.c4.xlarge")

data = np.random.randn(32, 32, 3)
predict_response = predictor.predict(data)
Expand Down
29 changes: 12 additions & 17 deletions tests/integ/test_tf_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,7 @@
TAGS = [{"Key": "some-key", "Value": "some-value"}]


@pytest.fixture(
scope="session",
params=[
"ml.c4.xlarge",
pytest.param(
"ml.p2.xlarge",
marks=pytest.mark.skipif(
tests.integ.test_region() in tests.integ.HOSTING_NO_P2_REGIONS
or tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS,
reason="no ml.p2 instances in this region",
),
),
],
)
@pytest.fixture(scope="session", params=["ml.c4.xlarge"])
def instance_type(request):
return request.param

Expand Down Expand Up @@ -228,8 +215,15 @@ def _assert_s3_files_exist(s3_url, files):
raise ValueError("File {} is not found under {}".format(f, s3_url))


def _assert_tags_match(sagemaker_client, resource_arn, tags):
actual_tags = sagemaker_client.list_tags(ResourceArn=resource_arn)["Tags"]
def _assert_tags_match(sagemaker_client, resource_arn, tags, retries=1):
actual_tags = None
for _ in range(retries):
actual_tags = sagemaker_client.list_tags(ResourceArn=resource_arn)["Tags"]
if actual_tags:
break
else:
# endpoint tags might take minutes to propagate. Sleeping.
time.sleep(30)
assert actual_tags == tags


Expand All @@ -240,7 +234,8 @@ def _assert_model_tags_match(sagemaker_client, model_name, tags):

def _assert_endpoint_tags_match(sagemaker_client, endpoint_name, tags):
endpoint_description = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
_assert_tags_match(sagemaker_client, endpoint_description["EndpointArn"], tags)

_assert_tags_match(sagemaker_client, endpoint_description["EndpointArn"], tags, retries=10)


def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags):
Expand Down