19
19
import pytest
20
20
21
21
from sagemaker .tensorflow import TensorFlow
22
- from sagemaker .utils import unique_name_from_base
22
+ from sagemaker .utils import unique_name_from_base , sagemaker_timestamp
23
23
24
24
import tests .integ
25
25
from tests .integ import timeout
39
39
TAGS = [{"Key" : "some-key" , "Value" : "some-value" }]
40
40
41
41
42
- def test_mnist (sagemaker_session , instance_type ):
42
+ def test_mnist_with_checkpoint_config (sagemaker_session , instance_type ):
43
+ checkpoint_s3_uri = "s3://{}/tf-{}" .format (
44
+ sagemaker_session .default_bucket (), sagemaker_timestamp ()
45
+ )
46
+ checkpoint_local_path = "/test/checkpoint/path"
43
47
estimator = TensorFlow (
44
48
entry_point = SCRIPT ,
45
49
role = "SageMakerRole" ,
@@ -50,13 +54,16 @@ def test_mnist(sagemaker_session, instance_type):
50
54
framework_version = TensorFlow .LATEST_VERSION ,
51
55
py_version = tests .integ .PYTHON_VERSION ,
52
56
metric_definitions = [{"Name" : "train:global_steps" , "Regex" : r"global_step\/sec:\s(.*)" }],
57
+ checkpoint_s3_uri = checkpoint_s3_uri ,
58
+ checkpoint_local_path = checkpoint_local_path
53
59
)
54
60
inputs = estimator .sagemaker_session .upload_data (
55
61
path = os .path .join (MNIST_RESOURCE_PATH , "data" ), key_prefix = "scriptmode/mnist"
56
62
)
57
63
64
+ training_job_name = unique_name_from_base ("test-tf-sm-mnist" )
58
65
with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
59
- estimator .fit (inputs = inputs , job_name = unique_name_from_base ( "test-tf-sm-mnist" ) )
66
+ estimator .fit (inputs = inputs , job_name = training_job_name )
60
67
assert_s3_files_exist (
61
68
sagemaker_session ,
62
69
estimator .model_dir ,
@@ -65,29 +72,6 @@ def test_mnist(sagemaker_session, instance_type):
65
72
df = estimator .training_job_analytics .dataframe ()
66
73
assert df .size > 0
67
74
68
-
69
- def test_checkpoint_config (sagemaker_session , instance_type ):
70
- checkpoint_s3_uri = "s3://{}" .format (sagemaker_session .default_bucket ())
71
- checkpoint_local_path = "/test/checkpoint/path"
72
- estimator = TensorFlow (
73
- entry_point = SCRIPT ,
74
- role = "SageMakerRole" ,
75
- train_instance_count = 1 ,
76
- train_instance_type = instance_type ,
77
- sagemaker_session = sagemaker_session ,
78
- script_mode = True ,
79
- framework_version = TensorFlow .LATEST_VERSION ,
80
- py_version = tests .integ .PYTHON_VERSION ,
81
- checkpoint_s3_uri = checkpoint_s3_uri ,
82
- checkpoint_local_path = checkpoint_local_path ,
83
- )
84
- inputs = estimator .sagemaker_session .upload_data (
85
- path = os .path .join (MNIST_RESOURCE_PATH , "data" ), key_prefix = "script/mnist"
86
- )
87
- training_job_name = unique_name_from_base ("test-tf-sm-checkpoint" )
88
- with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
89
- estimator .fit (inputs = inputs , job_name = training_job_name )
90
-
91
75
expected_training_checkpoint_config = {
92
76
"S3Uri" : checkpoint_s3_uri ,
93
77
"LocalPath" : checkpoint_local_path ,
0 commit comments