Skip to content

Commit cd88696

Browse files
Merge pull request #29 from 0saurabh0/compatibility_k3
resolved the compatibility issue with keras 3
2 parents 5ebd0c6 + 49fe8be commit cd88696

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

lstm_word_segmentation/word_segmenter.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from keras.models import Sequential
66
from keras.layers import LSTM, Dense, TimeDistributed, Bidirectional, Embedding, Dropout
77
from tensorflow import keras
8+
import tensorflow as tf
89

910
from . import constants
1011
from .helpers import sigmoid
@@ -600,7 +601,9 @@ def save_model(self):
600601
This function saves the current trained model of this word_segmenter instance.
601602
"""
602603
# Save the model using Keras
603-
self.model.save(Path.joinpath(Path(__file__).parent.parent.absolute(), "Models/" + self.name))
604+
model_path = (Path.joinpath(Path(__file__).parent.parent.absolute(), "Models/" + self.name))
605+
tf.saved_model.save(self.model, model_path)
606+
604607
# Save one np array that holds all weights
605608
file = Path.joinpath(Path(__file__).parent.parent.absolute(), "Models/" + self.name + "/weights")
606609
np.save(str(file), self.model.weights)
@@ -652,7 +655,7 @@ def pick_lstm_model(model_name, embedding, train_data, eval_data):
652655
eval_data: the data set to test the model. Often, it should have the same structure as training data set.
653656
"""
654657
file = Path.joinpath(Path(__file__).parent.parent.absolute(), 'Models/' + model_name)
655-
model = keras.models.load_model(file)
658+
model = keras.layers.TFSMLayer(file, call_endpoint='serving_default')
656659

657660
# Figuring out name of the model
658661
language = None

0 commit comments

Comments
 (0)