Skip to content

Commit d899121

Browse files
Merge pull request #22 from wellcometrust/feature/ivyleavedtoadflax/train_multitask
Fix token/label offset issue
2 parents bb49ede + 34fe5a7 commit d899121

File tree

7 files changed

+1943
-495
lines changed

7 files changed

+1943
-495
lines changed

deep_reference_parser/io/io.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,49 @@
1313

1414
from ..logger import logger
1515

16-
def _split_list_by_linebreaks(tokens):
16+
def _unpack(tuples):
17+
"""Convert list of tuples into the correct format:
18+
19+
From:
20+
21+
[
22+
(
23+
(token0, token1, token2, token3),
24+
(label0, label1, label2, label3),
25+
),
26+
(
27+
(token0, token1, token2),
28+
(label0, label1, label2),
29+
),
30+
)
31+
32+
to:
33+
]
34+
(
35+
(token0, token1, token2, token3),
36+
(token0, token1, token2),
37+
),
38+
(
39+
(label0, label1, label2, label3),
40+
(label0, label1, label2),
41+
),
42+
]
43+
"""
44+
return list(zip(*list(tuples)))
45+
46+
def _split_list_by_linebreaks(rows):
1747
"""Cycle through a list of tokens (or labels) and split them into lists
1848
based on the presence of Nones or more likely math.nan caused by converting
1949
pd.DataFrame columns to lists.
2050
"""
2151
out = []
22-
tokens_gen = iter(tokens)
52+
rows_gen = iter(rows)
2353
while True:
2454
try:
25-
token = next(tokens_gen)
55+
row = next(rows_gen)
56+
token = row[0]
2657
if isinstance(token, str) and token:
27-
out.append(token)
58+
out.append(row)
2859
else:
2960
yield out
3061
out = []
@@ -40,10 +71,8 @@ def load_tsv(filepath, split_char="\t"):
4071
Expects data in the following format (tab separations).
4172
4273
References o o
43-
o o
4474
1 o o
4575
. o o
46-
o o
4776
WHO title b-r
4877
treatment title i-r
4978
guidelines title i-r
@@ -55,8 +84,6 @@ def load_tsv(filepath, split_char="\t"):
5584
, title i-r
5685
2016 title i-r
5786
58-
59-
6087
Args:
6188
filepath (str): Path to the data.
6289
split_char(str): Character to be used to split each line of the
@@ -67,9 +94,16 @@ def load_tsv(filepath, split_char="\t"):
6794
filepath.
6895
6996
"""
70-
7197
df = pd.read_csv(filepath, delimiter=split_char, header=None, skip_blank_lines=False)
72-
out = [list(_split_list_by_linebreaks(column)) for _, column in df.iteritems()]
98+
tuples = _split_list_by_linebreaks(df.to_records(index=False))
99+
100+
# Remove leading empty lists if found
101+
102+
tuples = list(filter(None, tuples))
103+
104+
unpacked_tuples = list(map(_unpack, tuples))
105+
106+
out = _unpack(unpacked_tuples)
73107

74108
logger.info("Loaded %s training examples", len(out[0]))
75109

0 commit comments

Comments
 (0)