Skip to content

Commit 5b17587

Browse files
new: Add logic to handle single task case
1 parent b33c2b2 commit 5b17587

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

deep_reference_parser/deep_reference_parser.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1035,11 +1035,16 @@ def predict(self, X, load_weights=False):
10351035

10361036
# Compute the labels for each prediction for each task
10371037

1038+
# If running a single task model, wrap pred_index in a list so that it
1039+
# can use the same logic as multitask models.
1040+
1041+
if len(pred_index) == 1 :
1042+
pred_index = [pred_index]
1043+
10381044
pred_label = []
10391045
for i in range(len(ind2labelNew)):
10401046
out = [[ind2labelNew[i][x] for x in a] for a in pred_index[i]]
10411047
pred_label.append(out)
1042-
10431048
# Flatten data
10441049

10451050
# Remove the padded tokens. This is done by counting the number of

deep_reference_parser/parse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def parse(self, text, verbose=False):
116116

117117
preds = self.drp.predict(tokens, load_weights=True)
118118

119-
flat_predictions = list(itertools.chain.from_iterable(preds))
119+
flat_predictions = list(itertools.chain.from_iterable(preds))[0]
120120
flat_X = list(itertools.chain.from_iterable(tokens))
121121
rows = [i for i in zip(flat_X, flat_predictions)]
122122

deep_reference_parser/split.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def split(self, text, return_tokens=False, verbose=False):
124124

125125
if return_tokens:
126126

127-
flat_predictions = list(itertools.chain.from_iterable(preds))
127+
flat_predictions = list(itertools.chain.from_iterable(preds))[0]
128128
flat_X = list(itertools.chain.from_iterable(tokens))
129129
rows = [i for i in zip(flat_X, flat_predictions)]
130130

@@ -145,7 +145,7 @@ def split(self, text, return_tokens=False, verbose=False):
145145

146146
# Otherwise convert the tokens into references and return
147147

148-
refs = tokens_to_references(tokens, preds)
148+
refs = tokens_to_references(tokens, preds[0])
149149

150150
if verbose:
151151

0 commit comments

Comments
 (0)