Skip to content

Commit 1f71808

Browse files
author
Chuyang Deng
committed
creating a unique object for checkpointing
1 parent 8ebcf7b commit 1f71808

File tree

1 file changed

+10
-26
lines changed

1 file changed

+10
-26
lines changed

tests/integ/test_tf_script_mode.py

+10-26
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pytest
2020

2121
from sagemaker.tensorflow import TensorFlow
22-
from sagemaker.utils import unique_name_from_base
22+
from sagemaker.utils import unique_name_from_base, sagemaker_timestamp
2323

2424
import tests.integ
2525
from tests.integ import timeout
@@ -39,7 +39,11 @@
3939
TAGS = [{"Key": "some-key", "Value": "some-value"}]
4040

4141

42-
def test_mnist(sagemaker_session, instance_type):
42+
def test_mnist_with_checkpoint_config(sagemaker_session, instance_type):
43+
checkpoint_s3_uri = "s3://{}/tf-{}".format(
44+
sagemaker_session.default_bucket(), sagemaker_timestamp()
45+
)
46+
checkpoint_local_path = "/test/checkpoint/path"
4347
estimator = TensorFlow(
4448
entry_point=SCRIPT,
4549
role="SageMakerRole",
@@ -50,13 +54,16 @@ def test_mnist(sagemaker_session, instance_type):
5054
framework_version=TensorFlow.LATEST_VERSION,
5155
py_version=tests.integ.PYTHON_VERSION,
5256
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
57+
checkpoint_s3_uri=checkpoint_s3_uri,
58+
checkpoint_local_path=checkpoint_local_path
5359
)
5460
inputs = estimator.sagemaker_session.upload_data(
5561
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist"
5662
)
5763

64+
training_job_name = unique_name_from_base("test-tf-sm-mnist")
5865
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
59-
estimator.fit(inputs=inputs, job_name=unique_name_from_base("test-tf-sm-mnist"))
66+
estimator.fit(inputs=inputs, job_name=training_job_name)
6067
assert_s3_files_exist(
6168
sagemaker_session,
6269
estimator.model_dir,
@@ -65,29 +72,6 @@ def test_mnist(sagemaker_session, instance_type):
6572
df = estimator.training_job_analytics.dataframe()
6673
assert df.size > 0
6774

68-
69-
def test_checkpoint_config(sagemaker_session, instance_type):
70-
checkpoint_s3_uri = "s3://{}".format(sagemaker_session.default_bucket())
71-
checkpoint_local_path = "/test/checkpoint/path"
72-
estimator = TensorFlow(
73-
entry_point=SCRIPT,
74-
role="SageMakerRole",
75-
train_instance_count=1,
76-
train_instance_type=instance_type,
77-
sagemaker_session=sagemaker_session,
78-
script_mode=True,
79-
framework_version=TensorFlow.LATEST_VERSION,
80-
py_version=tests.integ.PYTHON_VERSION,
81-
checkpoint_s3_uri=checkpoint_s3_uri,
82-
checkpoint_local_path=checkpoint_local_path,
83-
)
84-
inputs = estimator.sagemaker_session.upload_data(
85-
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="script/mnist"
86-
)
87-
training_job_name = unique_name_from_base("test-tf-sm-checkpoint")
88-
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
89-
estimator.fit(inputs=inputs, job_name=training_job_name)
90-
9175
expected_training_checkpoint_config = {
9276
"S3Uri": checkpoint_s3_uri,
9377
"LocalPath": checkpoint_local_path,

0 commit comments

Comments
 (0)