Skip to content

Commit 1007498

Browse files
committed
refactor feature generation
1 parent a5c6012 commit 1007498

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

tests/data/huggingface/run_tf.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,21 @@
44
import time
55

66
import tensorflow as tf
7+
import transformers
78
from datasets import load_dataset
8-
99
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
10-
import transformers
10+
11+
12+
def _get_dataset_features(dataset, tokenizer, columns=[]):
13+
if transformers.__version__ > "4.12.0":
14+
features = {x: dataset[x] for x in columns}
15+
else:
16+
features = {
17+
x: dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length])
18+
for x in columns
19+
}
20+
21+
return features
1122

1223

1324
if __name__ == "__main__":
@@ -58,13 +69,9 @@
5869
)
5970
train_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"])
6071

61-
if transformers.__version__ > "4.12.0":
62-
train_features = {x: train_dataset[x] for x in ["input_ids", "attention_mask"]}
63-
else:
64-
train_features = {
65-
x: train_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length])
66-
for x in ["input_ids", "attention_mask"]
67-
}
72+
train_features = _get_dataset_features(
73+
train_dataset, tokenizer, columns=["input_ids", "attention_mask"]
74+
)
6875

6976
tf_train_dataset = tf.data.Dataset.from_tensor_slices(
7077
(train_features, train_dataset["label"])
@@ -76,13 +83,9 @@
7683
)
7784
test_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"])
7885

79-
if transformers.__version__ > "4.12.0":
80-
test_features = {x: test_dataset[x] for x in ["input_ids", "attention_mask"]}
81-
else:
82-
test_features = {
83-
x: test_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length])
84-
for x in ["input_ids", "attention_mask"]
85-
}
86+
test_features = _get_dataset_features(
87+
test_dataset, tokenizer, columns=["input_ids", "attention_mask"]
88+
)
8689

8790
tf_test_dataset = tf.data.Dataset.from_tensor_slices(
8891
(test_features, test_dataset["label"])

0 commit comments

Comments
 (0)