Skip to content

Commit 05d0895

Browse files
committed
change: Tests targetting Training Compiler in TensorFlow estimator
1 parent 6f727f7 commit 05d0895

File tree

2 files changed

+517
-19
lines changed

2 files changed

+517
-19
lines changed

tests/integ/test_training_compiler.py

+73-19
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616

1717
import pytest
1818

19-
from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig
19+
from sagemaker.huggingface import HuggingFace
20+
from sagemaker.huggingface import TrainingCompilerConfig as HFTrainingCompilerConfig
21+
from sagemaker.tensorflow import TensorFlow
22+
from sagemaker.tensorflow import TrainingCompilerConfig as TFTrainingCompilerConfig
23+
2024
from tests import integ
2125
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2226
from tests.integ.timeout import timeout
@@ -27,15 +31,15 @@ def gpu_instance_type(request):
2731
return "ml.p3.2xlarge"
2832

2933

34+
@pytest.fixture(scope="module", autouse=True)
35+
def skip_if_incompatible(request):
36+
if integ.test_region() not in integ.TRAINING_COMPILER_SUPPORTED_REGIONS:
37+
pytest.skip("SageMaker Training Compiler is not supported in this region")
38+
if integ.test_region() in integ.TRAINING_NO_P3_REGIONS:
39+
pytest.skip("no ml.p3 instances in this region")
40+
41+
3042
@pytest.mark.release
31-
@pytest.mark.skipif(
32-
integ.test_region() not in integ.TRAINING_COMPILER_SUPPORTED_REGIONS,
33-
reason="SageMaker Training Compiler is not supported in this region",
34-
)
35-
@pytest.mark.skipif(
36-
integ.test_region() in integ.TRAINING_NO_P3_REGIONS,
37-
reason="no ml.p3 instances in this region",
38-
)
3943
def test_huggingface_pytorch(
4044
sagemaker_session,
4145
gpu_instance_type,
@@ -66,7 +70,7 @@ def test_huggingface_pytorch(
6670
environment={"GPU_NUM_DEVICES": "1"},
6771
sagemaker_session=sagemaker_session,
6872
disable_profiler=True,
69-
compiler_config=TrainingCompilerConfig(),
73+
compiler_config=HFTrainingCompilerConfig(),
7074
)
7175

7276
train_input = hf.sagemaker_session.upload_data(
@@ -78,14 +82,6 @@ def test_huggingface_pytorch(
7882

7983

8084
@pytest.mark.release
81-
@pytest.mark.skipif(
82-
integ.test_region() not in integ.TRAINING_COMPILER_SUPPORTED_REGIONS,
83-
reason="SageMaker Training Compiler is not supported in this region",
84-
)
85-
@pytest.mark.skipif(
86-
integ.test_region() in integ.TRAINING_NO_P3_REGIONS,
87-
reason="no ml.p3 instances in this region",
88-
)
8985
def test_huggingface_tensorflow(
9086
sagemaker_session,
9187
gpu_instance_type,
@@ -113,11 +109,69 @@ def test_huggingface_tensorflow(
113109
},
114110
sagemaker_session=sagemaker_session,
115111
disable_profiler=True,
116-
compiler_config=TrainingCompilerConfig(),
112+
compiler_config=HFTrainingCompilerConfig(),
117113
)
118114

119115
train_input = hf.sagemaker_session.upload_data(
120116
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/huggingface/train"
121117
)
122118

123119
hf.fit(train_input)
120+
121+
122+
@pytest.mark.release
123+
def test_tensorflow(
124+
sagemaker_session,
125+
gpu_instance_type,
126+
tensorflow_training_latest_version,
127+
):
128+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
129+
epochs = 10
130+
batch = 256
131+
train_steps = int(10240 * epochs / batch)
132+
steps_per_loop = train_steps // 10
133+
overrides = (
134+
f"runtime.enable_xla=True,"
135+
f"runtime.num_gpus=1,"
136+
f"runtime.distribution_strategy=one_device,"
137+
f"runtime.mixed_precision_dtype=float16,"
138+
f"task.train_data.global_batch_size={batch},"
139+
f"task.train_data.input_path=/opt/ml/input/data/training/validation*,"
140+
f"task.train_data.cache=False,"
141+
f"trainer.train_steps={train_steps},"
142+
f"trainer.steps_per_loop={steps_per_loop},"
143+
f"trainer.summary_interval={steps_per_loop},"
144+
f"trainer.checkpoint_interval={train_steps},"
145+
f"task.model.backbone.type=resnet,"
146+
f"task.model.backbone.resnet.model_id=50"
147+
)
148+
tf = TensorFlow(
149+
py_version="py39",
150+
git_config={
151+
"repo": "https://github.com/tensorflow/models.git",
152+
"branch": "v2.9.2",
153+
},
154+
source_dir=".",
155+
entry_point="official/vision/train.py",
156+
model_dir=False,
157+
role="SageMakerRole",
158+
framework_version=tensorflow_training_latest_version,
159+
instance_count=1,
160+
instance_type=gpu_instance_type,
161+
hyperparameters={
162+
"experiment": "resnet_imagenet",
163+
"config_file": "official/vision/configs/experiments/image_classification/imagenet_resnet50_gpu.yaml",
164+
"mode": "train",
165+
"model_dir": "/opt/ml/model",
166+
"params_override": overrides,
167+
},
168+
sagemaker_session=sagemaker_session,
169+
disable_profiler=True,
170+
compiler_config=TFTrainingCompilerConfig(),
171+
)
172+
173+
tf.fit(
174+
inputs="s3://collection-of-ml-datasets/Imagenet/TFRecords/validation",
175+
logs=True,
176+
wait=True,
177+
)

0 commit comments

Comments
 (0)