Skip to content

Commit a3323e5

Browse files
chg: Use lists for y case to allow multiple labels
1 parent fbc37d1 commit a3323e5

File tree

3 files changed

+54
-34
lines changed

3 files changed

+54
-34
lines changed

deep_reference_parser/deep_reference_parser.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def __init__(
7272
y_train=None,
7373
y_test=None,
7474
y_valid=None,
75-
digits_word="$NUM$",
76-
ukn_words="out-of-vocabulary",
75+
digits_word="<NUM>",
76+
ukn_words="<OOV>",
7777
padding_style="pre",
7878
output_path="data/model_output",
7979
):
@@ -165,13 +165,11 @@ def prepare_data(self, save=False):
165165
# Compute indexes for words+labels in the training data
166166

167167
self.word2ind, self.ind2word = index_x(self.X_train_merged, self.ukn_words)
168-
self.label2ind, ind2label = index_y(self.y_train)
169168

170-
# NOTE: The original code expected self.ind2label to be a list,
171-
# in case you are training a multi-task model. For this reason,
172-
# self.index2label is wrapped in a list.
169+
y_labels = list(map(index_y, self.y_train))
173170

174-
self.ind2label.append(ind2label)
171+
self.ind2label = [ind2label for _, ind2label in y_labels]
172+
self.label2ind = [label2ind for label2ind, _ in y_labels]
175173

176174
# Convert data into indexes data
177175

@@ -209,21 +207,41 @@ def prepare_data(self, save=False):
209207

210208
# Encode y variables
211209

212-
self.y_train_encoded = encode_y(
213-
self.y_train, self.label2ind, self.max_len, self.padding_style
214-
)
210+
for i, labels in enumerate(self.y_train):
211+
self.y_train_encoded.append(
212+
encode_y(
213+
labels,
214+
self.label2ind[i],
215+
self.max_len,
216+
self.padding_style
217+
)
218+
)
215219

216-
self.y_test_encoded = encode_y(
217-
self.y_test, self.label2ind, self.max_len, self.padding_style
218-
)
220+
for i, labels in enumerate(self.y_test):
221+
self.y_test_encoded.append(
222+
encode_y(
223+
labels,
224+
self.label2ind[i],
225+
self.max_len,
226+
self.padding_style
227+
)
228+
)
219229

220-
self.y_valid_encoded = encode_y(
221-
self.y_valid, self.label2ind, self.max_len, self.padding_style
222-
)
230+
for i, labels in enumerate(self.y_valid):
231+
self.y_valid_encoded.append(
232+
encode_y(
233+
labels,
234+
self.label2ind[i],
235+
self.max_len,
236+
self.padding_style
237+
)
238+
)
239+
240+
241+
logger.debug("Training target dimensions: %s", self.y_train_encoded[0].shape)
242+
logger.debug("Test target dimensions: %s", self.y_test_encoded[0].shape)
243+
logger.debug("Validation target dimensions: %s", self.y_valid_encoded[0].shape)
223244

224-
logger.debug("Training target dimensions: %s", self.y_train_encoded.shape)
225-
logger.debug("Test target dimensions: %s", self.y_test_encoded.shape)
226-
logger.debug("Validation target dimensions: %s", self.y_valid_encoded.shape)
227245

228246
# Create character level data
229247

@@ -456,7 +474,7 @@ def build_model(
456474

457475
self.model = model
458476

459-
# logger.debug(self.model.summary(line_length=150))
477+
logger.debug(self.model.summary(line_length=150))
460478

461479
def train_model(
462480
self, epochs=25, batch_size=100, early_stopping_patience=5, metric="val_f1"
@@ -481,10 +499,8 @@ def train_model(
481499

482500
# Use custom classification scores callback
483501

484-
# NOTE: X lists are important for input here
485-
486502
classification_scores = Classification_Scores(
487-
[self.X_training, [self.y_train_encoded]], self.ind2label, self.weights_path
503+
[self.X_training, self.y_train_encoded], self.ind2label, self.weights_path
488504
)
489505

490506
callbacks.append(classification_scores)
@@ -503,12 +519,12 @@ def train_model(
503519

504520
hist = self.model.fit(
505521
x=self.X_training,
506-
y=[self.y_train_encoded],
507-
validation_data=[self.X_testing, [self.y_test_encoded]],
522+
y=self.y_train_encoded,
523+
validation_data=[self.X_testing, self.y_test_encoded],
508524
epochs=epochs,
509525
batch_size=batch_size,
510526
callbacks=callbacks,
511-
verbose=2,
527+
verbose=1,
512528
)
513529

514530
logger.info(

deep_reference_parser/model_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,9 @@ def encode_y(y, label2ind, max_len, padding_style):
154154

155155
# Encode y (with pad)
156156

157-
# Transform each label into its index in the data
157+
# Transform each label into its index and adding "pre" padding
158158

159-
y_pad = [[0] * (max_len - len(ey)) + [label2ind[c] for c in ey] for ey in y]
159+
y_pad = [[0] * (max_len - len(yi)) + [label2ind[label] for label in yi] for yi in y]
160160

161161
# One-hot-encode label
162162

deep_reference_parser/train.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,17 @@ def train(config_file):
7070

7171
# Load policy data
7272

73-
X_train, y_train = load_tsv(POLICY_TRAIN)
74-
X_test, y_test = load_tsv(POLICY_TEST)
75-
X_valid, y_valid = load_tsv(POLICY_VALID)
73+
train_data = load_tsv(POLICY_TRAIN)
74+
test_data = load_tsv(POLICY_TEST)
75+
valid_data = load_tsv(POLICY_VALID)
7676

77-
logger.info("X_train, y_train examples: %s, %s", len(X_train), len(y_train))
78-
logger.info("X_test, y_test examples: %s, %s", len(X_test), len(y_test))
79-
logger.info("X_valid, y_valid examples: %s, %s", len(X_valid), len(y_valid))
77+
X_train, y_train = train_data[0], train_data[1:]
78+
X_test, y_test = test_data[0], test_data[1:]
79+
X_valid, y_valid = valid_data[0], valid_data[1:]
80+
81+
logger.info("X_train, y_train examples: %s, %s", len(X_train), list(map(len, y_train)))
82+
logger.info("X_test, y_test examples: %s, %s", len(X_test), list(map(len, y_test)))
83+
logger.info("X_valid, y_valid examples: %s, %s", len(X_valid), list(map(len, y_valid)))
8084

8185
drp = DeepReferenceParser(
8286
X_train=X_train,

0 commit comments

Comments
 (0)