Skip to content

Commit f392f9f

Browse files
new: Add multitask split_parse command
1 parent 3e48684 commit f392f9f

File tree

4 files changed

+203
-1
lines changed

4 files changed

+203
-1
lines changed

deep_reference_parser/__main__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
from .train import train
1313
from .split import split
1414
from .parse import parse
15+
from .split_parse import split_parse
1516

1617
commands = {
1718
"split": split,
1819
"parse": parse,
1920
"train": train,
21+
"split_parse": split_parse,
2022
}
2123

2224
if len(sys.argv) == 1:

deep_reference_parser/__version__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
__license__ = "MIT"
88
__splitter_model_version__ = "2020.3.6_splitting"
99
__parser_model_version__ = "2020.3.8_parsing"
10+
__multitask_model_version__ = "2020.3.18_multitask"

deep_reference_parser/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from urllib import parse, request
77

88
from .logger import logger
9-
from .__version__ import __splitter_model_version__, __parser_model_version__
9+
from .__version__ import __splitter_model_version__, __parser_model_version__, __multitask_model_version__
1010

1111

1212
def get_path(path):
@@ -15,6 +15,7 @@ def get_path(path):
1515

1616
SPLITTER_CFG = get_path(f"configs/{__splitter_model_version__}.ini")
1717
PARSER_CFG = get_path(f"configs/{__parser_model_version__}.ini")
18+
MULTITASK_CFG = get_path(f"configs/{__multitask_model_version__}.ini")
1819

1920

2021
def download_model_artefact(artefact, s3_slug):

deep_reference_parser/split_parse.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
#!/usr/bin/env python3
2+
# coding: utf-8
3+
"""
4+
Run predictions from a pre-trained model
5+
"""
6+
7+
import itertools
8+
import json
9+
import os
10+
11+
import en_core_web_sm
12+
import plac
13+
import spacy
14+
import wasabi
15+
16+
import warnings
17+
18+
with warnings.catch_warnings():
19+
warnings.filterwarnings("ignore", category=DeprecationWarning)
20+
21+
from deep_reference_parser import __file__
22+
from deep_reference_parser.__version__ import __splitter_model_version__
23+
from deep_reference_parser.common import MULTITASK_CFG, download_model_artefact
24+
from deep_reference_parser.deep_reference_parser import DeepReferenceParser
25+
from deep_reference_parser.logger import logger
26+
from deep_reference_parser.model_utils import get_config
27+
from deep_reference_parser.reference_utils import break_into_chunks
28+
from deep_reference_parser.tokens_to_references import tokens_to_references
29+
30+
msg = wasabi.Printer(icons={"check": "\u2023"})
31+
32+
33+
class SplitParser:
34+
def __init__(self, config_file):
35+
36+
msg.info(f"Using config file: {config_file}")
37+
38+
cfg = get_config(config_file)
39+
40+
msg.info(
41+
f"Attempting to download model artefacts if they are not found locally in {cfg['build']['output_path']}. This may take some time..."
42+
)
43+
44+
# Build config
45+
46+
OUTPUT_PATH = cfg["build"]["output_path"]
47+
S3_SLUG = cfg["data"]["s3_slug"]
48+
49+
# Check whether the necessary artefacts exists and download them if
50+
# not.
51+
52+
artefacts = [
53+
"indices.pickle",
54+
"weights.h5",
55+
]
56+
57+
for artefact in artefacts:
58+
with msg.loading(f"Could not find {artefact} locally, downloading..."):
59+
try:
60+
artefact = os.path.join(OUTPUT_PATH, artefact)
61+
download_model_artefact(artefact, S3_SLUG)
62+
msg.good(f"Found {artefact}")
63+
except:
64+
msg.fail(f"Could not download {S3_SLUG}{artefact}")
65+
logger.exception("Could not download %s%s", S3_SLUG, artefact)
66+
67+
# Check on word embedding and download if not exists
68+
69+
WORD_EMBEDDINGS = cfg["build"]["word_embeddings"]
70+
71+
with msg.loading(f"Could not find {WORD_EMBEDDINGS} locally, downloading..."):
72+
try:
73+
download_model_artefact(WORD_EMBEDDINGS, S3_SLUG)
74+
msg.good(f"Found {WORD_EMBEDDINGS}")
75+
except:
76+
msg.fail(f"Could not download {S3_SLUG}{WORD_EMBEDDINGS}")
77+
logger.exception("Could not download %s", WORD_EMBEDDINGS)
78+
79+
OUTPUT = cfg["build"]["output"]
80+
PRETRAINED_EMBEDDING = cfg["build"]["pretrained_embedding"]
81+
DROPOUT = float(cfg["build"]["dropout"])
82+
LSTM_HIDDEN = int(cfg["build"]["lstm_hidden"])
83+
WORD_EMBEDDING_SIZE = int(cfg["build"]["word_embedding_size"])
84+
CHAR_EMBEDDING_SIZE = int(cfg["build"]["char_embedding_size"])
85+
86+
self.MAX_WORDS = int(cfg["data"]["line_limit"])
87+
88+
# Evaluate config
89+
90+
self.drp = DeepReferenceParser(output_path=OUTPUT_PATH)
91+
92+
# Encode data and load required mapping dicts. Note that the max word and
93+
# max char lengths will be loaded in this step.
94+
95+
self.drp.load_data(OUTPUT_PATH)
96+
97+
# Build the model architecture
98+
99+
self.drp.build_model(
100+
output=OUTPUT,
101+
word_embeddings=WORD_EMBEDDINGS,
102+
pretrained_embedding=PRETRAINED_EMBEDDING,
103+
dropout=DROPOUT,
104+
lstm_hidden=LSTM_HIDDEN,
105+
word_embedding_size=WORD_EMBEDDING_SIZE,
106+
char_embedding_size=CHAR_EMBEDDING_SIZE,
107+
)
108+
109+
def split_parse(self, text, return_tokens=False, verbose=False):
110+
111+
nlp = en_core_web_sm.load()
112+
doc = nlp(text)
113+
chunks = break_into_chunks(doc, max_words=self.MAX_WORDS)
114+
tokens = [[token.text for token in chunk] for chunk in chunks]
115+
116+
preds = self.drp.predict(tokens, load_weights=True)
117+
118+
return preds
119+
120+
# If tokens argument passed, return the labelled tokens
121+
122+
#if return_tokens:
123+
124+
# flat_predictions = list(itertools.chain.from_iterable(preds))
125+
# flat_X = list(itertools.chain.from_iterable(tokens))
126+
# rows = [i for i in zip(flat_X, flat_predictions)]
127+
128+
# if verbose:
129+
130+
# msg.divider("Token Results")
131+
132+
# header = ("token", "label")
133+
# aligns = ("r", "l")
134+
# formatted = wasabi.table(
135+
# rows, header=header, divider=True, aligns=aligns
136+
# )
137+
# print(formatted)
138+
139+
# out = rows
140+
141+
#else:
142+
143+
# # Otherwise convert the tokens into references and return
144+
145+
# refs = tokens_to_references(tokens, preds)
146+
147+
# if verbose:
148+
149+
# msg.divider("Results")
150+
151+
# if refs:
152+
153+
# msg.good(f"Found {len(refs)} references.")
154+
# msg.info("Printing found references:")
155+
156+
# for ref in refs:
157+
# msg.text(ref, icon="check", spaced=True)
158+
159+
# else:
160+
161+
# msg.fail("Failed to find any references.")
162+
163+
# out = refs
164+
165+
#return out
166+
167+
168+
@plac.annotations(
169+
text=("Plaintext from which to extract references", "positional", None, str),
170+
config_file=("Path to config file", "option", "c", str),
171+
tokens=("Output tokens instead of complete references", "flag", "t", str),
172+
outfile=("Path to json file to which results will be written", "option", "o", str),
173+
)
174+
def split_parse(text, config_file=MULTITASK_CFG, tokens=False, outfile=None):
175+
"""
176+
Runs the default splitting model and pretty prints results to console unless
177+
--outfile is parsed with a path. Files output to the path specified in
178+
--outfile will be a valid json. Can output either tokens (with -t|--tokens)
179+
or split naively into references based on the b-r tag (default).
180+
181+
NOTE: that this function is provided for examples only and should not be used
182+
in production as the model is instantiated each time the command is run. To
183+
use in a production setting, a more sensible approach would be to replicate
184+
the split or parse functions within your own logic.
185+
"""
186+
mt = SplitParser(config_file)
187+
if outfile:
188+
out = mt.split_parse(text, return_tokens=tokens, verbose=False)
189+
190+
try:
191+
with open(outfile, "w") as fb:
192+
json.dump(out, fb)
193+
msg.good(f"Wrote model output to {outfile}")
194+
except:
195+
msg.fail(f"Failed to write output to {outfile}")
196+
197+
else:
198+
out = mt.split_parse(text, return_tokens=tokens, verbose=True)

0 commit comments

Comments
 (0)