Skip to content

Commit 3e1b20b

Browse files
chg: Use lower level weight loading
keras_contrib.utils.load_save_utils is a soft wrapper around keras.saving which is not required in this relatively simple case.
1 parent f392f9f commit 3e1b20b

File tree

1 file changed

+7
-15
lines changed

1 file changed

+7
-15
lines changed

deep_reference_parser/deep_reference_parser.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
import numpy as np
1515

16+
import h5py
17+
from keras.engine import saving
1618
from keras.callbacks import EarlyStopping
1719
from keras.layers import (
1820
LSTM,
@@ -30,7 +32,6 @@
3032
from keras.models import Model
3133
from keras.optimizers import Adam, RMSprop
3234
from keras_contrib.layers import CRF
33-
from keras_contrib.utils import save_load_utils
3435
from sklearn_crfsuite import metrics
3536

3637
from deep_reference_parser.logger import logger
@@ -976,26 +977,17 @@ def load_weights(self):
976977

977978
if not self.model:
978979

979-
# Assumes that model has been buit with build_model!
980-
981980
logger.exception(
982981
"No model. you must build the model first with build_model"
983982
)
984983

985-
# NOTE: This is not required if incldue_optimizer is set to false in
986-
# load_all_weights.
987-
988-
# Run the model for one epoch to initialise network weights. Then load
989-
# full trained weights
990-
991-
# self.model.fit(x=self.X_testing, y=self.y_test_encoded,
992-
# batch_size=2500, epochs=1)
993-
994984
logger.debug("Loading weights from %s", self.weights_path)
995985

996-
save_load_utils.load_all_weights(
997-
self.model, self.weights_path, include_optimizer=False
998-
)
986+
with h5py.File(self.weights_path, mode='r') as f:
987+
saving.load_weights_from_hdf5_group(
988+
f['model_weights'], self.model.layers
989+
)
990+
999991

1000992
def predict(self, X, load_weights=False):
1001993
"""

0 commit comments

Comments
 (0)