Skip to content

Commit 20afa75

Browse files
chg: Update predict function for multitask scenario
1 parent 3e1b20b commit 20afa75

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

deep_reference_parser/deep_reference_parser.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
import numpy as np
1515

16+
17+
from functools import partial
1618
import h5py
1719
from keras.engine import saving
1820
from keras.callbacks import EarlyStopping
@@ -1019,35 +1021,33 @@ def predict(self, X, load_weights=False):
10191021

10201022
_, X_combined = self.prepare_X_data(X)
10211023

1022-
pred = self.model.predict(X_combined)
1023-
1024-
pred = np.asarray(pred)
1025-
10261024
# Compute validation score
10271025

1026+
pred = np.asarray(self.model.predict(X_combined))
1027+
pred = np.asarray(pred)
10281028
pred_index = np.argmax(pred, axis=-1)
10291029

1030-
# NOTE: indexing ind2label[0] will only work in the case of making
1031-
# predictions with a single task model.
10321030

1033-
ind2labelNew = self.ind2label[0].copy()
1031+
# Add 0 to labels to account for padding
10341032

1035-
# Index 0 in the predictions refers to padding
1033+
ind2labelNew = self.ind2label.copy()
1034+
[labels.update({0: "null"}) for labels in ind2labelNew]
10361035

1037-
ind2labelNew.update({0: "null"})
1036+
# Compute the labels for each prediction for each task
10381037

1039-
# Compute the labels for each prediction
1040-
pred_label = [[ind2labelNew[x] for x in a] for a in pred_index]
1038+
pred_label = []
1039+
for i in range(len(ind2labelNew)):
1040+
out = [[ind2labelNew[i][x] for x in a] for a in pred_index[i]]
1041+
pred_label.append(out)
10411042

10421043
# Flatten data
10431044

10441045
# Remove the padded tokens. This is done by counting the number of
10451046
# tokens in the input example, and then removing the additional padded
1046-
# tokens that are added before this. It has to be done this way because
1047-
# the model can predict padding tokens, and sometimes it gets it wrong
1048-
# so if we remove all padding tokens, then we end up with mismatches in
1049-
# the length of input tokens and the length of predictions.
1047+
# tokens that are added before this.
1048+
1049+
# This is performed on each set of predictions relating to each task
10501050

1051-
out = remove_padding_from_predictions(X, pred_label, self.padding_style)
1051+
out = list(map(lambda x: remove_padding_from_predictions(X, x, self.padding_style), pred_label))
10521052

10531053
return out

0 commit comments

Comments
 (0)