Skip to content

Commit 09d7a76

Browse files
authored
infra: use pytest fixtures in batch transform integ tests to train and upload to S3 only once (#1410)
1 parent bba4b7a commit 09d7a76

File tree

1 file changed

+40
-140
lines changed

1 file changed

+40
-140
lines changed

tests/integ/test_transformer.py

Lines changed: 40 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,13 @@
3636
from tests.integ.timeout import timeout, timeout_and_delete_model_with_transformer
3737
from tests.integ.vpc_test_utils import get_or_create_vpc_resources
3838

39+
MXNET_MNIST_PATH = os.path.join(DATA_DIR, "mxnet_mnist")
3940

40-
@pytest.mark.canary_quick
41-
def test_transform_mxnet(sagemaker_session, mxnet_full_version, cpu_instance_type):
42-
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
43-
script_path = os.path.join(data_path, "mnist.py")
4441

42+
@pytest.fixture(scope="module")
43+
def mxnet_estimator(sagemaker_session, mxnet_full_version, cpu_instance_type):
4544
mx = MXNet(
46-
entry_point=script_path,
45+
entry_point=os.path.join(MXNET_MNIST_PATH, "mnist.py"),
4746
role="SageMakerRole",
4847
train_instance_count=1,
4948
train_instance_type=cpu_instance_type,
@@ -52,29 +51,39 @@ def test_transform_mxnet(sagemaker_session, mxnet_full_version, cpu_instance_typ
5251
)
5352

5453
train_input = mx.sagemaker_session.upload_data(
55-
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
54+
path=os.path.join(MXNET_MNIST_PATH, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
5655
)
5756
test_input = mx.sagemaker_session.upload_data(
58-
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
57+
path=os.path.join(MXNET_MNIST_PATH, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
5958
)
60-
job_name = unique_name_from_base("test-mxnet-transform")
6159

60+
job_name = unique_name_from_base("test-mxnet-transform")
6261
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
6362
mx.fit({"train": train_input, "test": test_input}, job_name=job_name)
6463

65-
transform_input_path = os.path.join(data_path, "transform", "data.csv")
64+
return mx
65+
66+
67+
@pytest.fixture(scope="module")
68+
def mxnet_transform_input(sagemaker_session):
69+
transform_input_path = os.path.join(MXNET_MNIST_PATH, "transform", "data.csv")
6670
transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform"
67-
transform_input = mx.sagemaker_session.upload_data(
71+
return sagemaker_session.upload_data(
6872
path=transform_input_path, key_prefix=transform_input_key_prefix
6973
)
7074

75+
76+
@pytest.mark.canary_quick
77+
def test_transform_mxnet(
78+
mxnet_estimator, mxnet_transform_input, sagemaker_session, cpu_instance_type
79+
):
7180
kms_key_arn = get_or_create_kms_key(sagemaker_session)
7281
output_filter = "$"
7382
input_filter = "$"
7483

7584
transformer = _create_transformer_and_transform_job(
76-
mx,
77-
transform_input,
85+
mxnet_estimator,
86+
mxnet_transform_input,
7887
cpu_instance_type,
7988
kms_key_arn,
8089
input_filter=input_filter,
@@ -197,39 +206,13 @@ def test_transform_pytorch_vpc_custom_model_bucket(
197206
assert custom_bucket_name == model_bucket
198207

199208

200-
def test_transform_mxnet_tags(sagemaker_session, mxnet_full_version, cpu_instance_type):
201-
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
202-
script_path = os.path.join(data_path, "mnist.py")
209+
def test_transform_mxnet_tags(
210+
mxnet_estimator, mxnet_transform_input, sagemaker_session, cpu_instance_type
211+
):
203212
tags = [{"Key": "some-tag", "Value": "value-for-tag"}]
204213

205-
mx = MXNet(
206-
entry_point=script_path,
207-
role="SageMakerRole",
208-
train_instance_count=1,
209-
train_instance_type=cpu_instance_type,
210-
sagemaker_session=sagemaker_session,
211-
framework_version=mxnet_full_version,
212-
)
213-
214-
train_input = mx.sagemaker_session.upload_data(
215-
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
216-
)
217-
test_input = mx.sagemaker_session.upload_data(
218-
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
219-
)
220-
job_name = unique_name_from_base("test-mxnet-transform")
221-
222-
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
223-
mx.fit({"train": train_input, "test": test_input}, job_name=job_name)
224-
225-
transform_input_path = os.path.join(data_path, "transform", "data.csv")
226-
transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform"
227-
transform_input = mx.sagemaker_session.upload_data(
228-
path=transform_input_path, key_prefix=transform_input_key_prefix
229-
)
230-
231-
transformer = mx.transformer(1, cpu_instance_type, tags=tags)
232-
transformer.transform(transform_input, content_type="text/csv")
214+
transformer = mxnet_estimator.transformer(1, cpu_instance_type, tags=tags)
215+
transformer.transform(mxnet_transform_input, content_type="text/csv")
233216

234217
with timeout_and_delete_model_with_transformer(
235218
transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES
@@ -306,86 +289,29 @@ def test_transform_byo_estimator(sagemaker_session, cpu_instance_type):
306289
assert tags == model_tags
307290

308291

309-
def test_single_transformer_multiple_jobs(sagemaker_session, mxnet_full_version, cpu_instance_type):
310-
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
311-
script_path = os.path.join(data_path, "mnist.py")
312-
313-
mx = MXNet(
314-
entry_point=script_path,
315-
role="SageMakerRole",
316-
train_instance_count=1,
317-
train_instance_type=cpu_instance_type,
318-
sagemaker_session=sagemaker_session,
319-
framework_version=mxnet_full_version,
320-
)
321-
322-
train_input = mx.sagemaker_session.upload_data(
323-
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
324-
)
325-
test_input = mx.sagemaker_session.upload_data(
326-
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
327-
)
328-
job_name = unique_name_from_base("test-mxnet-transform")
329-
330-
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
331-
mx.fit({"train": train_input, "test": test_input}, job_name=job_name)
332-
333-
transform_input_path = os.path.join(data_path, "transform", "data.csv")
334-
transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform"
335-
transform_input = mx.sagemaker_session.upload_data(
336-
path=transform_input_path, key_prefix=transform_input_key_prefix
337-
)
338-
339-
transformer = mx.transformer(1, cpu_instance_type)
292+
def test_single_transformer_multiple_jobs(
293+
mxnet_estimator, mxnet_transform_input, sagemaker_session, cpu_instance_type
294+
):
295+
transformer = mxnet_estimator.transformer(1, cpu_instance_type)
340296

341297
job_name = unique_name_from_base("test-mxnet-transform")
342-
transformer.transform(transform_input, content_type="text/csv", job_name=job_name)
298+
transformer.transform(mxnet_transform_input, content_type="text/csv", job_name=job_name)
343299
with timeout_and_delete_model_with_transformer(
344300
transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES
345301
):
346302
assert transformer.output_path == "s3://{}/{}".format(
347303
sagemaker_session.default_bucket(), job_name
348304
)
349305
job_name = unique_name_from_base("test-mxnet-transform")
350-
transformer.transform(transform_input, content_type="text/csv", job_name=job_name)
306+
transformer.transform(mxnet_transform_input, content_type="text/csv", job_name=job_name)
351307
assert transformer.output_path == "s3://{}/{}".format(
352308
sagemaker_session.default_bucket(), job_name
353309
)
354310

355311

356-
def test_stop_transform_job(sagemaker_session, mxnet_full_version, cpu_instance_type):
357-
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
358-
script_path = os.path.join(data_path, "mnist.py")
359-
tags = [{"Key": "some-tag", "Value": "value-for-tag"}]
360-
361-
mx = MXNet(
362-
entry_point=script_path,
363-
role="SageMakerRole",
364-
train_instance_count=1,
365-
train_instance_type=cpu_instance_type,
366-
sagemaker_session=sagemaker_session,
367-
framework_version=mxnet_full_version,
368-
)
369-
370-
train_input = mx.sagemaker_session.upload_data(
371-
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
372-
)
373-
test_input = mx.sagemaker_session.upload_data(
374-
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
375-
)
376-
job_name = unique_name_from_base("test-mxnet-transform")
377-
378-
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
379-
mx.fit({"train": train_input, "test": test_input}, job_name=job_name)
380-
381-
transform_input_path = os.path.join(data_path, "transform", "data.csv")
382-
transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform"
383-
transform_input = mx.sagemaker_session.upload_data(
384-
path=transform_input_path, key_prefix=transform_input_key_prefix
385-
)
386-
387-
transformer = mx.transformer(1, cpu_instance_type, tags=tags)
388-
transformer.transform(transform_input, content_type="text/csv")
312+
def test_stop_transform_job(mxnet_estimator, mxnet_transform_input, cpu_instance_type):
313+
transformer = mxnet_estimator.transformer(1, cpu_instance_type)
314+
transformer.transform(mxnet_transform_input, content_type="text/csv")
389315

390316
time.sleep(15)
391317

@@ -401,39 +327,12 @@ def test_stop_transform_job(sagemaker_session, mxnet_full_version, cpu_instance_
401327
assert desc["TransformJobStatus"] == "Stopped"
402328

403329

404-
def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version, cpu_instance_type):
405-
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
406-
script_path = os.path.join(data_path, "mnist.py")
407-
408-
mx = MXNet(
409-
entry_point=script_path,
410-
role="SageMakerRole",
411-
train_instance_count=1,
412-
train_instance_type=cpu_instance_type,
413-
sagemaker_session=sagemaker_session,
414-
framework_version=mxnet_full_version,
415-
)
416-
417-
train_input = mx.sagemaker_session.upload_data(
418-
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
419-
)
420-
test_input = mx.sagemaker_session.upload_data(
421-
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
422-
)
423-
job_name = unique_name_from_base("test-mxnet-transform")
424-
425-
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
426-
mx.fit({"train": train_input, "test": test_input}, job_name=job_name)
427-
428-
transform_input_path = os.path.join(data_path, "transform", "data.csv")
429-
transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform"
430-
transform_input = mx.sagemaker_session.upload_data(
431-
path=transform_input_path, key_prefix=transform_input_key_prefix
432-
)
433-
330+
def test_transform_mxnet_logs(
331+
mxnet_estimator, mxnet_transform_input, sagemaker_session, cpu_instance_type
332+
):
434333
with timeout(minutes=45):
435334
transformer = _create_transformer_and_transform_job(
436-
mx, transform_input, cpu_instance_type, wait=True, logs=True
335+
mxnet_estimator, mxnet_transform_input, cpu_instance_type, wait=True, logs=True
437336
)
438337

439338
with timeout_and_delete_model_with_transformer(
@@ -462,5 +361,6 @@ def _create_transformer_and_transform_job(
462361
join_source=join_source,
463362
wait=wait,
464363
logs=logs,
364+
job_name=unique_name_from_base("test-transform"),
465365
)
466366
return transformer

0 commit comments

Comments
 (0)