Skip to content

Commit 856a949

Browse files
committed
added condition for test to also work with lower transformers version
1 parent cb2d374 commit 856a949

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

tests/data/huggingface/run_tf.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from datasets import load_dataset
88

99
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
10+
import transformers
1011

1112

1213
if __name__ == "__main__":
@@ -57,8 +58,15 @@
5758
)
5859
train_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"])
5960

60-
train_features = {x: train_dataset[x] for x in ["input_ids", "attention_mask"]}
61-
61+
if transformers.__version__ > "4.12.0":
62+
train_features = {x: train_dataset[x] for x in ["input_ids", "attention_mask"]}
63+
else:
64+
train_features = {
65+
x: train_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length])
66+
for x in ["input_ids", "attention_mask"]
67+
}
68+
69+
6270
tf_train_dataset = tf.data.Dataset.from_tensor_slices(
6371
(train_features, train_dataset["label"])
6472
).batch(args.per_device_train_batch_size)
@@ -69,7 +77,14 @@
6977
)
7078
test_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"])
7179

72-
test_features = {x: test_dataset[x] for x in ["input_ids", "attention_mask"]}
80+
if transformers.__version__ > "4.12.0":
81+
test_features = {x: test_dataset[x] for x in ["input_ids", "attention_mask"]}
82+
else:
83+
test_features = {
84+
x: test_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length])
85+
for x in ["input_ids", "attention_mask"]
86+
}
87+
7388

7489
tf_test_dataset = tf.data.Dataset.from_tensor_slices(
7590
(test_features, test_dataset["label"])

0 commit comments

Comments
 (0)