Skip to content

Commit 57e7588

Browse files
chg: Use the max_len sent at init
* Remove confusing max_words parameter * Set max_len in init with default=250
1 parent 711354f commit 57e7588

File tree

4 files changed

+30
-20
lines changed

4 files changed

+30
-20
lines changed

deep_reference_parser/deep_reference_parser.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(
7272
y_train=None,
7373
y_test=None,
7474
y_valid=None,
75+
max_len=250,
7576
digits_word="$NUM$",
7677
ukn_words="out-of-vocabulary",
7778
padding_style="pre",
@@ -126,9 +127,8 @@ def __init__(
126127
self.X_validation = list()
127128
self.X_testing = list()
128129

129-
self.max_len = int()
130+
self.max_len = max_len
130131
self.max_char = int()
131-
self.max_words = int()
132132

133133
# Defined in prepare_data
134134

@@ -156,7 +156,7 @@ def prepare_data(self, save=False):
156156
Save(bool): If True, then data objects will be saved to
157157
`self.output_path`.
158158
"""
159-
self.max_len = max([len(xx) for xx in self.X_train])
159+
#self.max_len = max([len(xx) for xx in self.X_train])
160160

161161
self.X_train_merged, self.X_test_merged, self.X_valid_merged = merge_digits(
162162
[self.X_train, self.X_test, self.X_valid], self.digits_word
@@ -246,14 +246,14 @@ def prepare_data(self, save=False):
246246
# Create character level data
247247

248248
# Create the character level data
249-
self.char2ind, self.max_words, self.max_char = character_index(
249+
self.char2ind, self.max_char = character_index(
250250
self.X_train, self.digits_word
251251
)
252252

253253
self.X_train_char = character_data(
254254
self.X_train,
255255
self.char2ind,
256-
self.max_words,
256+
self.max_len,
257257
self.max_char,
258258
self.digits_word,
259259
self.padding_style,
@@ -262,7 +262,7 @@ def prepare_data(self, save=False):
262262
self.X_test_char = character_data(
263263
self.X_test,
264264
self.char2ind,
265-
self.max_words,
265+
self.max_len,
266266
self.max_char,
267267
self.digits_word,
268268
self.padding_style,
@@ -271,7 +271,7 @@ def prepare_data(self, save=False):
271271
self.X_valid_char = character_data(
272272
self.X_valid,
273273
self.char2ind,
274-
self.max_words,
274+
self.max_len,
275275
self.max_char,
276276
self.digits_word,
277277
self.padding_style,
@@ -292,7 +292,6 @@ def prepare_data(self, save=False):
292292
write_pickle(self.char2ind, "char2ind.pickle", path=self.output_path)
293293

294294
maxes = {
295-
"max_words": self.max_words,
296295
"max_char": self.max_char,
297296
"max_len": self.max_len,
298297
}
@@ -317,11 +316,9 @@ def load_data(self, out_path):
317316

318317
self.max_len = maxes["max_len"]
319318
self.max_char = maxes["max_char"]
320-
self.max_words = maxes["max_words"]
321319

322320
logger.debug("Setting max_len to %s", self.max_len)
323321
logger.debug("Setting max_char to %s", self.max_char)
324-
logger.debug("Setting max_words to %s", self.max_words)
325322

326323
def build_model(
327324
self,
@@ -370,7 +367,7 @@ def build_model(
370367

371368
if word_embeddings:
372369

373-
word_input = Input((self.max_words,))
370+
word_input = Input((self.max_len,))
374371
inputs.append(word_input)
375372

376373
# TODO: More sensible handling of options for pretrained embedding.
@@ -406,7 +403,7 @@ def build_model(
406403

407404
if self.max_char != 0:
408405

409-
character_input = Input((self.max_words, self.max_char,))
406+
character_input = Input((self.max_len, self.max_char,))
410407

411408
char_embedding = self.character_embedding_layer(
412409
char_embedding_type=char_embedding_type,
@@ -474,7 +471,7 @@ def build_model(
474471

475472
self.model = model
476473

477-
logger.debug(self.model.summary(line_length=150))
474+
#logger.debug(self.model.summary(line_length=150))
478475

479476
def train_model(
480477
self, epochs=25, batch_size=100, early_stopping_patience=5, metric="val_f1"
@@ -613,10 +610,6 @@ def evaluate(
613610

614611
# Compute classification report
615612

616-
# Initialise list for storing predictions which will be written
617-
# to tsv file.
618-
619-
620613
for i, y_target in enumerate(self.y_valid_encoded):
621614

622615
# Compute predictions, flatten
@@ -970,7 +963,7 @@ def prepare_X_data(self, X):
970963
X_char = character_data(
971964
X,
972965
self.char2ind,
973-
self.max_words,
966+
self.max_len,
974967
self.max_char,
975968
self.digits_word,
976969
self.padding_style,

deep_reference_parser/model_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,9 @@ def character_index(X, digits_word):
205205

206206
# For padding
207207

208-
max_words = max([len(s) for s in X])
209208
max_char = max([len(w) for s in X for w in s])
210209

211-
return char2ind, max_words, max_char
210+
return char2ind, max_char
212211

213212

214213
def character_data(X, char2ind, max_words, max_char, digits_word, padding_style):

deep_reference_parser/train.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def train(config_file):
5656
LSTM_HIDDEN = int(cfg["build"]["lstm_hidden"])
5757
WORD_EMBEDDING_SIZE = int(cfg["build"]["word_embedding_size"])
5858
CHAR_EMBEDDING_SIZE = int(cfg["build"]["char_embedding_size"])
59+
MAX_LEN = int(cfg["data"]["line_limit"])
5960

6061
# Train config
6162

@@ -74,6 +75,20 @@ def train(config_file):
7475
X_test, y_test = test_data[0], test_data[1:]
7576
X_valid, y_valid = valid_data[0], valid_data[1:]
7677

78+
import statistics
79+
80+
logger.info("Max token length %s", max([len(i) for i in X_train]))
81+
logger.info("Min token length %s", min([len(i) for i in X_train]))
82+
logger.info("Mean token length %s", statistics.median([len(i) for i in X_train]))
83+
84+
logger.info("Max token length %s", max([len(i) for i in X_test]))
85+
logger.info("Min token length %s", min([len(i) for i in X_test]))
86+
logger.info("Mean token length %s", statistics.median([len(i) for i in X_test]))
87+
88+
logger.info("Max token length %s", max([len(i) for i in X_valid]))
89+
logger.info("Min token length %s", min([len(i) for i in X_valid]))
90+
logger.info("Mean token length %s", statistics.median([len(i) for i in X_valid]))
91+
7792
logger.info("X_train, y_train examples: %s, %s", len(X_train), list(map(len, y_train)))
7893
logger.info("X_test, y_test examples: %s, %s", len(X_test), list(map(len, y_test)))
7994
logger.info("X_valid, y_valid examples: %s, %s", len(X_valid), list(map(len, y_valid)))
@@ -85,6 +100,7 @@ def train(config_file):
85100
y_train=y_train,
86101
y_test=y_test,
87102
y_valid=y_valid,
103+
max_len=MAX_LEN,
88104
output_path=OUTPUT_PATH,
89105
)
90106

tests/test_deep_reference_parser.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ def test_DeepReferenceParser_train(tmpdir, cfg):
7777
y_train=y_test,
7878
y_test=y_test,
7979
y_valid=y_test,
80+
max_len=250,
8081
output_path=tmpdir,
82+
8183
)
8284

8385
# Prepare the data

0 commit comments

Comments
 (0)