Skip to content

Commit 12ea6d3

Browse files
Deleted outdated files and regenerated weights_tf_free files
1 parent 1b8c94b commit 12ea6d3

File tree

2 files changed

+5
-9
lines changed

2 files changed

+5
-9
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ Models/Burmese_model4_version2/
1515
Models/Other/
1616
*~
1717
venv/
18+
convert_weights.py

lstm_word_segmentation/word_segmenter.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from pathlib import Path
22
import numpy as np
3-
import h5py
43
import json
54
from icu import Char
65
from keras.models import Sequential
@@ -605,13 +604,10 @@ def save_model(self):
605604
model_path = (Path.joinpath(Path(__file__).parent.parent.absolute(), "Models/" + self.name))
606605
tf.saved_model.save(self.model, model_path)
607606

608-
# Inlining weight saving directly into HDF5 format
609-
weights_file = Path.joinpath(Path(__file__).parent.parent.absolute(), "Models/" + self.name + "/weights.h5")
610-
with h5py.File(str(weights_file), 'w') as hdf5_file:
611-
# Iterate over the model weights and save each one as a dataset in the HDF5 file
612-
for i, weight in enumerate(self.model.weights):
613-
weight_name = f"weight_{i+1}"
614-
hdf5_file.create_dataset(weight_name, data=weight.numpy()) # Save weight tensor directly
607+
# Save one np array that holds all weights
608+
file = Path.joinpath(Path(__file__).parent.parent.absolute(), "Models/" + self.name + "/weights")
609+
np.save(str(file), self.model.weights)
610+
615611
# Save the model in json format, that has both weights and grapheme clusters dictionary
616612
json_file = Path.joinpath(Path(__file__).parent.parent.absolute(), "Models/" + self.name + "/weights.json")
617613
with open(str(json_file), 'w') as wfile:
@@ -640,7 +636,6 @@ def save_model(self):
640636
dic_model["data"] = serial_mat
641637
output["mat{}".format(i+1)] = dic_model
642638
json.dump(output, wfile)
643-
print(f"Model, weights in .h5, and weights metadata in .json saved successfully!")
644639

645640
def set_model(self, input_model):
646641
"""

0 commit comments

Comments
 (0)