Skip to content

Commit cea55a7

Browse files
Merge pull request #23 from wellcometrust/revert-22-feature/ivyleavedtoadflax/train_multitask
Revert "Fix token/label offset issue"
2 parents d899121 + 95509c3 commit cea55a7

File tree

7 files changed

+495
-1943
lines changed

7 files changed

+495
-1943
lines changed

deep_reference_parser/io/io.py

Lines changed: 10 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,49 +13,18 @@
1313

1414
from ..logger import logger
1515

16-
def _unpack(tuples):
17-
"""Convert list of tuples into the correct format:
18-
19-
From:
20-
21-
[
22-
(
23-
(token0, token1, token2, token3),
24-
(label0, label1, label2, label3),
25-
),
26-
(
27-
(token0, token1, token2),
28-
(label0, label1, label2),
29-
),
30-
)
31-
32-
to:
33-
]
34-
(
35-
(token0, token1, token2, token3),
36-
(token0, token1, token2),
37-
),
38-
(
39-
(label0, label1, label2, label3),
40-
(label0, label1, label2),
41-
),
42-
]
43-
"""
44-
return list(zip(*list(tuples)))
45-
46-
def _split_list_by_linebreaks(rows):
16+
def _split_list_by_linebreaks(tokens):
4717
"""Cycle through a list of tokens (or labels) and split them into lists
4818
based on the presence of Nones or more likely math.nan caused by converting
4919
pd.DataFrame columns to lists.
5020
"""
5121
out = []
52-
rows_gen = iter(rows)
22+
tokens_gen = iter(tokens)
5323
while True:
5424
try:
55-
row = next(rows_gen)
56-
token = row[0]
25+
token = next(tokens_gen)
5726
if isinstance(token, str) and token:
58-
out.append(row)
27+
out.append(token)
5928
else:
6029
yield out
6130
out = []
@@ -71,8 +40,10 @@ def load_tsv(filepath, split_char="\t"):
7140
Expects data in the following format (tab separations).
7241
7342
References o o
43+
o o
7444
1 o o
7545
. o o
46+
o o
7647
WHO title b-r
7748
treatment title i-r
7849
guidelines title i-r
@@ -84,6 +55,8 @@ def load_tsv(filepath, split_char="\t"):
8455
, title i-r
8556
2016 title i-r
8657
58+
59+
8760
Args:
8861
filepath (str): Path to the data.
8962
split_char(str): Character to be used to split each line of the
@@ -94,16 +67,9 @@ def load_tsv(filepath, split_char="\t"):
9467
filepath.
9568
9669
"""
97-
df = pd.read_csv(filepath, delimiter=split_char, header=None, skip_blank_lines=False)
98-
tuples = _split_list_by_linebreaks(df.to_records(index=False))
9970

100-
# Remove leading empty lists if found
101-
102-
tuples = list(filter(None, tuples))
103-
104-
unpacked_tuples = list(map(_unpack, tuples))
105-
106-
out = _unpack(unpacked_tuples)
71+
df = pd.read_csv(filepath, delimiter=split_char, header=None, skip_blank_lines=False)
72+
out = [list(_split_list_by_linebreaks(column)) for _, column in df.iteritems()]
10773

10874
logger.info("Loaded %s training examples", len(out[0]))
10975

0 commit comments

Comments
 (0)