@@ -72,8 +72,8 @@ def __init__(
72
72
y_train = None ,
73
73
y_test = None ,
74
74
y_valid = None ,
75
- digits_word = "$ NUM$ " ,
76
- ukn_words = "out-of-vocabulary " ,
75
+ digits_word = "< NUM> " ,
76
+ ukn_words = "<OOV> " ,
77
77
padding_style = "pre" ,
78
78
output_path = "data/model_output" ,
79
79
):
@@ -165,13 +165,11 @@ def prepare_data(self, save=False):
165
165
# Compute indexes for words+labels in the training data
166
166
167
167
self .word2ind , self .ind2word = index_x (self .X_train_merged , self .ukn_words )
168
- self .label2ind , ind2label = index_y (self .y_train )
169
168
170
- # NOTE: The original code expected self.ind2label to be a list,
171
- # in case you are training a multi-task model. For this reason,
172
- # self.index2label is wrapped in a list.
169
+ y_labels = list (map (index_y , self .y_train ))
173
170
174
- self .ind2label .append (ind2label )
171
+ self .ind2label = [ind2label for _ , ind2label in y_labels ]
172
+ self .label2ind = [label2ind for label2ind , _ in y_labels ]
175
173
176
174
# Convert data into indexes data
177
175
@@ -209,21 +207,41 @@ def prepare_data(self, save=False):
209
207
210
208
# Encode y variables
211
209
212
- self .y_train_encoded = encode_y (
213
- self .y_train , self .label2ind , self .max_len , self .padding_style
214
- )
210
+ for i , labels in enumerate (self .y_train ):
211
+ self .y_train_encoded .append (
212
+ encode_y (
213
+ labels ,
214
+ self .label2ind [i ],
215
+ self .max_len ,
216
+ self .padding_style
217
+ )
218
+ )
215
219
216
- self .y_test_encoded = encode_y (
217
- self .y_test , self .label2ind , self .max_len , self .padding_style
218
- )
220
+ for i , labels in enumerate (self .y_test ):
221
+ self .y_test_encoded .append (
222
+ encode_y (
223
+ labels ,
224
+ self .label2ind [i ],
225
+ self .max_len ,
226
+ self .padding_style
227
+ )
228
+ )
219
229
220
- self .y_valid_encoded = encode_y (
221
- self .y_valid , self .label2ind , self .max_len , self .padding_style
222
- )
230
+ for i , labels in enumerate (self .y_valid ):
231
+ self .y_valid_encoded .append (
232
+ encode_y (
233
+ labels ,
234
+ self .label2ind [i ],
235
+ self .max_len ,
236
+ self .padding_style
237
+ )
238
+ )
239
+
240
+
241
+ logger .debug ("Training target dimensions: %s" , self .y_train_encoded [0 ].shape )
242
+ logger .debug ("Test target dimensions: %s" , self .y_test_encoded [0 ].shape )
243
+ logger .debug ("Validation target dimensions: %s" , self .y_valid_encoded [0 ].shape )
223
244
224
- logger .debug ("Training target dimensions: %s" , self .y_train_encoded .shape )
225
- logger .debug ("Test target dimensions: %s" , self .y_test_encoded .shape )
226
- logger .debug ("Validation target dimensions: %s" , self .y_valid_encoded .shape )
227
245
228
246
# Create character level data
229
247
@@ -456,7 +474,7 @@ def build_model(
456
474
457
475
self .model = model
458
476
459
- # logger.debug(self.model.summary(line_length=150))
477
+ logger .debug (self .model .summary (line_length = 150 ))
460
478
461
479
def train_model (
462
480
self , epochs = 25 , batch_size = 100 , early_stopping_patience = 5 , metric = "val_f1"
@@ -481,10 +499,8 @@ def train_model(
481
499
482
500
# Use custom classification scores callback
483
501
484
- # NOTE: X lists are important for input here
485
-
486
502
classification_scores = Classification_Scores (
487
- [self .X_training , [ self .y_train_encoded ] ], self .ind2label , self .weights_path
503
+ [self .X_training , self .y_train_encoded ], self .ind2label , self .weights_path
488
504
)
489
505
490
506
callbacks .append (classification_scores )
@@ -503,12 +519,12 @@ def train_model(
503
519
504
520
hist = self .model .fit (
505
521
x = self .X_training ,
506
- y = [ self .y_train_encoded ] ,
507
- validation_data = [self .X_testing , [ self .y_test_encoded ] ],
522
+ y = self .y_train_encoded ,
523
+ validation_data = [self .X_testing , self .y_test_encoded ],
508
524
epochs = epochs ,
509
525
batch_size = batch_size ,
510
526
callbacks = callbacks ,
511
- verbose = 2 ,
527
+ verbose = 1 ,
512
528
)
513
529
514
530
logger .info (
0 commit comments