Skip to content

Commit 6a244cf

Browse files
chg: Use the max_len sent at init
Don't set it based on maximum sequence length
1 parent 711354f commit 6a244cf

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

deep_reference_parser/deep_reference_parser.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(
7272
y_train=None,
7373
y_test=None,
7474
y_valid=None,
75+
max_len=None,
7576
digits_word="$NUM$",
7677
ukn_words="out-of-vocabulary",
7778
padding_style="pre",
@@ -126,7 +127,7 @@ def __init__(
126127
self.X_validation = list()
127128
self.X_testing = list()
128129

129-
self.max_len = int()
130+
self.max_len = max_len
130131
self.max_char = int()
131132
self.max_words = int()
132133

@@ -156,7 +157,7 @@ def prepare_data(self, save=False):
156157
Save(bool): If True, then data objects will be saved to
157158
`self.output_path`.
158159
"""
159-
self.max_len = max([len(xx) for xx in self.X_train])
160+
#self.max_len = max([len(xx) for xx in self.X_train])
160161

161162
self.X_train_merged, self.X_test_merged, self.X_valid_merged = merge_digits(
162163
[self.X_train, self.X_test, self.X_valid], self.digits_word
@@ -253,7 +254,7 @@ def prepare_data(self, save=False):
253254
self.X_train_char = character_data(
254255
self.X_train,
255256
self.char2ind,
256-
self.max_words,
257+
self.max_len,
257258
self.max_char,
258259
self.digits_word,
259260
self.padding_style,
@@ -262,7 +263,7 @@ def prepare_data(self, save=False):
262263
self.X_test_char = character_data(
263264
self.X_test,
264265
self.char2ind,
265-
self.max_words,
266+
self.max_len,
266267
self.max_char,
267268
self.digits_word,
268269
self.padding_style,
@@ -271,7 +272,7 @@ def prepare_data(self, save=False):
271272
self.X_valid_char = character_data(
272273
self.X_valid,
273274
self.char2ind,
274-
self.max_words,
275+
self.max_len,
275276
self.max_char,
276277
self.digits_word,
277278
self.padding_style,
@@ -370,7 +371,7 @@ def build_model(
370371

371372
if word_embeddings:
372373

373-
word_input = Input((self.max_words,))
374+
word_input = Input((self.max_len,))
374375
inputs.append(word_input)
375376

376377
# TODO: More sensible handling of options for pretrained embedding.
@@ -406,7 +407,7 @@ def build_model(
406407

407408
if self.max_char != 0:
408409

409-
character_input = Input((self.max_words, self.max_char,))
410+
character_input = Input((self.max_len, self.max_char,))
410411

411412
char_embedding = self.character_embedding_layer(
412413
char_embedding_type=char_embedding_type,
@@ -474,7 +475,7 @@ def build_model(
474475

475476
self.model = model
476477

477-
logger.debug(self.model.summary(line_length=150))
478+
#logger.debug(self.model.summary(line_length=150))
478479

479480
def train_model(
480481
self, epochs=25, batch_size=100, early_stopping_patience=5, metric="val_f1"
@@ -970,7 +971,7 @@ def prepare_X_data(self, X):
970971
X_char = character_data(
971972
X,
972973
self.char2ind,
973-
self.max_words,
974+
self.max_len,
974975
self.max_char,
975976
self.digits_word,
976977
self.padding_style,

deep_reference_parser/train.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def train(config_file):
5656
LSTM_HIDDEN = int(cfg["build"]["lstm_hidden"])
5757
WORD_EMBEDDING_SIZE = int(cfg["build"]["word_embedding_size"])
5858
CHAR_EMBEDDING_SIZE = int(cfg["build"]["char_embedding_size"])
59+
MAX_LEN = int(cfg["data"]["line_limit"])
5960

6061
# Train config
6162

@@ -74,6 +75,20 @@ def train(config_file):
7475
X_test, y_test = test_data[0], test_data[1:]
7576
X_valid, y_valid = valid_data[0], valid_data[1:]
7677

78+
import statistics
79+
80+
logger.info("Max token length %s", max([len(i) for i in X_train]))
81+
logger.info("Min token length %s", min([len(i) for i in X_train]))
82+
logger.info("Mean token length %s", statistics.median([len(i) for i in X_train]))
83+
84+
logger.info("Max token length %s", max([len(i) for i in X_test]))
85+
logger.info("Min token length %s", min([len(i) for i in X_test]))
86+
logger.info("Mean token length %s", statistics.median([len(i) for i in X_test]))
87+
88+
logger.info("Max token length %s", max([len(i) for i in X_valid]))
89+
logger.info("Min token length %s", min([len(i) for i in X_valid]))
90+
logger.info("Mean token length %s", statistics.median([len(i) for i in X_valid]))
91+
7792
logger.info("X_train, y_train examples: %s, %s", len(X_train), list(map(len, y_train)))
7893
logger.info("X_test, y_test examples: %s, %s", len(X_test), list(map(len, y_test)))
7994
logger.info("X_valid, y_valid examples: %s, %s", len(X_valid), list(map(len, y_valid)))
@@ -85,6 +100,7 @@ def train(config_file):
85100
y_train=y_train,
86101
y_test=y_test,
87102
y_valid=y_valid,
103+
max_len=MAX_LEN,
88104
output_path=OUTPUT_PATH,
89105
)
90106

0 commit comments

Comments
 (0)