Skip to content

Commit faa5c67

Browse files
committed
2 parents cb2d374 + 8be2914 commit faa5c67

File tree

2 files changed

+5
-15
lines changed

2 files changed

+5
-15
lines changed

tests/data/huggingface/run_tf.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,8 @@
4444

4545
# Load dataset
4646
train_dataset, test_dataset = load_dataset("imdb", split=["train", "test"])
47-
train_dataset = train_dataset.shuffle().select(
48-
range(5000)
49-
) # smaller the size for train dataset to 5k
50-
test_dataset = test_dataset.shuffle().select(
51-
range(500)
52-
) # smaller the size for test dataset to 500
47+
train_dataset = train_dataset.shuffle().select(range(5000)) # smaller the size for train dataset to 5k
48+
test_dataset = test_dataset.shuffle().select(range(500)) # smaller the size for test dataset to 500
5349

5450
# Preprocess train dataset
5551
train_dataset = train_dataset.map(
@@ -82,9 +78,7 @@
8278
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
8379

8480
start_train_time = time.time()
85-
train_results = model.fit(
86-
tf_train_dataset, epochs=args.epochs, batch_size=args.per_device_train_batch_size
87-
)
81+
train_results = model.fit(tf_train_dataset, epochs=args.epochs, batch_size=args.per_device_train_batch_size)
8882
end_train_time = time.time() - start_train_time
8983

9084
logger.info("*** Train ***")

tests/integ/test_huggingface.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ def test_huggingface_training(
107107

108108

109109
@pytest.mark.release
110-
@pytest.mark.skipif(
111-
integ.test_region() in integ.TRAINING_NO_P2_REGIONS, reason="no ml.p2 instances in this region"
112-
)
110+
@pytest.mark.skipif(integ.test_region() in integ.TRAINING_NO_P2_REGIONS, reason="no ml.p2 instances in this region")
113111
def test_huggingface_training_tf(
114112
sagemaker_session,
115113
gpu_instance_type,
@@ -172,9 +170,7 @@ def test_huggingface_inference(
172170
pytorch_version=huggingface_inference_pytorch_latest_version,
173171
)
174172
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
175-
model.deploy(
176-
instance_type=gpu_instance_type, initial_instance_count=1, endpoint_name=endpoint_name
177-
)
173+
model.deploy(instance_type=gpu_instance_type, initial_instance_count=1, endpoint_name=endpoint_name)
178174

179175
predictor = HuggingFacePredictor(endpoint_name=endpoint_name)
180176
data = {

0 commit comments

Comments
 (0)