|
4 | 4 | import time
|
5 | 5 |
|
6 | 6 | import tensorflow as tf
|
| 7 | +import transformers |
7 | 8 | from datasets import load_dataset
|
8 |
| - |
9 | 9 | from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
|
10 |
| -import transformers |
| 10 | + |
| 11 | + |
| 12 | +def _get_dataset_features(dataset, tokenizer, columns=[]): |
| 13 | + if transformers.__version__ > "4.12.0": |
| 14 | + features = {x: dataset[x] for x in columns} |
| 15 | + else: |
| 16 | + features = { |
| 17 | + x: dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length]) |
| 18 | + for x in columns |
| 19 | + } |
| 20 | + |
| 21 | + return features |
11 | 22 |
|
12 | 23 |
|
13 | 24 | if __name__ == "__main__":
|
|
58 | 69 | )
|
59 | 70 | train_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"])
|
60 | 71 |
|
61 |
| - if transformers.__version__ > "4.12.0": |
62 |
| - train_features = {x: train_dataset[x] for x in ["input_ids", "attention_mask"]} |
63 |
| - else: |
64 |
| - train_features = { |
65 |
| - x: train_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length]) |
66 |
| - for x in ["input_ids", "attention_mask"] |
67 |
| - } |
| 72 | + train_features = _get_dataset_features( |
| 73 | + train_dataset, tokenizer, columns=["input_ids", "attention_mask"] |
| 74 | + ) |
68 | 75 |
|
69 | 76 | tf_train_dataset = tf.data.Dataset.from_tensor_slices(
|
70 | 77 | (train_features, train_dataset["label"])
|
|
76 | 83 | )
|
77 | 84 | test_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"])
|
78 | 85 |
|
79 |
| - if transformers.__version__ > "4.12.0": |
80 |
| - test_features = {x: test_dataset[x] for x in ["input_ids", "attention_mask"]} |
81 |
| - else: |
82 |
| - test_features = { |
83 |
| - x: test_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length]) |
84 |
| - for x in ["input_ids", "attention_mask"] |
85 |
| - } |
| 86 | + test_features = _get_dataset_features( |
| 87 | + test_dataset, tokenizer, columns=["input_ids", "attention_mask"] |
| 88 | + ) |
86 | 89 |
|
87 | 90 | tf_test_dataset = tf.data.Dataset.from_tensor_slices(
|
88 | 91 | (test_features, test_dataset["label"])
|
|
0 commit comments