|
13 | 13 |
|
14 | 14 | import numpy as np
|
15 | 15 |
|
| 16 | + |
| 17 | +from functools import partial |
16 | 18 | import h5py
|
17 | 19 | from keras.engine import saving
|
18 | 20 | from keras.callbacks import EarlyStopping
|
@@ -1019,35 +1021,33 @@ def predict(self, X, load_weights=False):
|
1019 | 1021 |
|
1020 | 1022 | _, X_combined = self.prepare_X_data(X)
|
1021 | 1023 |
|
1022 |
| - pred = self.model.predict(X_combined) |
1023 |
| - |
1024 |
| - pred = np.asarray(pred) |
1025 |
| - |
1026 | 1024 | # Compute validation score
|
1027 | 1025 |
|
| 1026 | + pred = np.asarray(self.model.predict(X_combined)) |
| 1027 | + pred = np.asarray(pred) |
1028 | 1028 | pred_index = np.argmax(pred, axis=-1)
|
1029 | 1029 |
|
1030 |
| - # NOTE: indexing ind2label[0] will only work in the case of making |
1031 |
| - # predictions with a single task model. |
1032 | 1030 |
|
1033 |
| - ind2labelNew = self.ind2label[0].copy() |
| 1031 | + # Add 0 to labels to account for padding |
1034 | 1032 |
|
1035 |
| - # Index 0 in the predictions refers to padding |
| 1033 | + ind2labelNew = self.ind2label.copy() |
| 1034 | + [labels.update({0: "null"}) for labels in ind2labelNew] |
1036 | 1035 |
|
1037 |
| - ind2labelNew.update({0: "null"}) |
| 1036 | + # Compute the labels for each prediction for each task |
1038 | 1037 |
|
1039 |
| - # Compute the labels for each prediction |
1040 |
| - pred_label = [[ind2labelNew[x] for x in a] for a in pred_index] |
| 1038 | + pred_label = [] |
| 1039 | + for i in range(len(ind2labelNew)): |
| 1040 | + out = [[ind2labelNew[i][x] for x in a] for a in pred_index[i]] |
| 1041 | + pred_label.append(out) |
1041 | 1042 |
|
1042 | 1043 | # Flatten data
|
1043 | 1044 |
|
1044 | 1045 | # Remove the padded tokens. This is done by counting the number of
|
1045 | 1046 | # tokens in the input example, and then removing the additional padded
|
1046 |
| - # tokens that are added before this. It has to be done this way because |
1047 |
| - # the model can predict padding tokens, and sometimes it gets it wrong |
1048 |
| - # so if we remove all padding tokens, then we end up with mismatches in |
1049 |
| - # the length of input tokens and the length of predictions. |
| 1047 | + # tokens that are added before this. |
| 1048 | + |
| 1049 | + # This is performed on each set of predictions relating to each task |
1050 | 1050 |
|
1051 |
| - out = remove_padding_from_predictions(X, pred_label, self.padding_style) |
| 1051 | + out = list(map(lambda x: remove_padding_from_predictions(X, x, self.padding_style), pred_label)) |
1052 | 1052 |
|
1053 | 1053 | return out
|
0 commit comments