Skip to content

Commit 1894226

Browse files
committed
Update Files
1 parent 1ddeac1 commit 1894226

File tree

7 files changed

+642
-0
lines changed

7 files changed

+642
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
FROM python:3.6-slim-buster
2+
3+
EXPOSE 8000
4+
5+
WORKDIR /var/www/NERmodel/
6+
7+
COPY ./NERmodel.py /var/www/NERmodel/NERmodel.py
8+
COPY ./models /var/www/NERmodel/models/
9+
COPY ./model.py /var/www/NERmodel/model.py
10+
COPY ./utils.py /var/www/NERmodel/utils.py
11+
COPY ./config.yml /var/www/NERmodel/config.yml
12+
13+
# Add Python Packages
14+
RUN pip install --upgrade --no-cache-dir pip
15+
RUN pip install --upgrade --ignore-installed --no-cache-dir PyYAML
16+
RUN pip install --upgrade --no-cache-dir flask
17+
RUN pip install --upgrade --no-cache-dir flasgger==0.8.1
18+
RUN pip install --upgrade --no-cache-dir torch==1.2.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
19+
20+
CMD python /var/www/NERmodel/NERmodel.py
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
# import necessary Python Packages
5+
import re
6+
import pickle
7+
import torch
8+
import yaml
9+
from model import BiLSTMCRF
10+
from utils import *
11+
import warnings
12+
import numpy as np
13+
from flask import Flask, request
14+
from flasgger import Swagger
15+
16+
17+
warnings.filterwarnings("ignore")
18+
device = torch.device("cpu")
19+
20+
21+
app = Flask(__name__)
22+
swagger = Swagger(app)
23+
24+
25+
def load_params(path: str):
26+
"""
27+
Load the parameters (data)
28+
"""
29+
with open(path + "data.pkl", "rb") as fopen:
30+
data_map = pickle.load(fopen)
31+
return data_map
32+
33+
34+
def strQ2B(ustring):
35+
rstring = ""
36+
for uchar in ustring:
37+
inside_code=ord(uchar)
38+
if inside_code == 12288:
39+
inside_code = 32
40+
elif inside_code >= 65281 and inside_code <= 65374:
41+
inside_code -= 65248
42+
rstring += chr(inside_code)
43+
return rstring
44+
45+
46+
def cut_text(text, length):
47+
textArr = re.findall('.{' + str(length) + '}', text)
48+
textArr.append(text[(len(textArr) * length):])
49+
return textArr
50+
51+
52+
def load_config():
53+
"""
54+
Load hyper-parameters from the YAML file
55+
"""
56+
fopen = open("config.yml")
57+
config = yaml.load(fopen, Loader=yaml.FullLoader)
58+
fopen.close()
59+
return config
60+
61+
62+
class ChineseNER:
63+
def __init__(self, entry="train"):
64+
# Load some Hyper-parameters
65+
config = load_config()
66+
self.embedding_size = config.get("embedding_size")
67+
self.hidden_size = config.get("hidden_size")
68+
self.batch_size = config.get("batch_size")
69+
self.model_path = config.get("model_path")
70+
self.dropout = config.get("dropout")
71+
self.tags = config.get("tags")
72+
self.learning_rate = config.get("learning_rate")
73+
self.epochs = config.get("epochs")
74+
self.weight_decay = config.get("weight_decay")
75+
self.transfer_learning = config.get("transfer_learning")
76+
self.lr_decay_step = config.get("lr_decay_step")
77+
self.lr_decay_rate = config.get("lr_decay_rate")
78+
self.max_length = config.get("max_length")
79+
80+
# Model Initialization
81+
self.main_model(entry)
82+
83+
def main_model(self, entry):
84+
"""
85+
Model Initialization
86+
"""
87+
# The Testing & Inference Process
88+
if entry == "predict":
89+
data_map = load_params(path=self.model_path)
90+
input_size = data_map.get("input_size")
91+
self.tag_map = data_map.get("tag_map")
92+
self.vocab = data_map.get("vocab")
93+
self.model = BiLSTMCRF(
94+
tag_map=self.tag_map,
95+
vocab_size=input_size,
96+
dropout=0.0,
97+
embedding_dim=self.embedding_size,
98+
hidden_dim=self.hidden_size,
99+
max_length=self.max_length
100+
)
101+
self.restore_model()
102+
103+
def restore_model(self):
104+
"""
105+
Restore the model if there is one
106+
"""
107+
try:
108+
self.model.load_state_dict(torch.load(self.model_path + "params.pkl"))
109+
print("Model Successfully Restored!")
110+
except Exception as error:
111+
print("Model Failed to restore! {}".format(error))
112+
113+
def predict(self, input_str):
114+
"""
115+
Prediction & Inference Stage
116+
:param input_str: Input Chinese sentence
117+
:return entities: Predicted entities
118+
"""
119+
if len(input_str) != 0:
120+
# Full-width to half-width
121+
input_str = strQ2B(input_str)
122+
input_str = re.sub(pattern='。', repl='.', string=input_str)
123+
text = cut_text(text=input_str, length=self.max_length)
124+
125+
cut_out = []
126+
for cuttext in text:
127+
# Get the embedding vector (Input Vector) from vocab
128+
input_vec = [self.vocab.get(i, 0) for i in cuttext]
129+
130+
# convert it to tensor and run the model
131+
sentences = torch.tensor(input_vec).view(1, -1)
132+
133+
length = np.expand_dims(np.shape(sentences)[1], axis=0)
134+
length = torch.tensor(length, dtype=torch.int64, device=device)
135+
136+
_, paths = self.model(sentences=sentences, real_length=length, lengths=None)
137+
138+
# Get the entities from the model
139+
entities = []
140+
for tag in self.tags:
141+
tags = get_tags(paths[0], tag, self.tag_map)
142+
entities += format_result(tags, cuttext, tag)
143+
144+
# Get all the entities
145+
all_start = []
146+
for entity in entities:
147+
start = entity.get('start')
148+
all_start.append([start, entity])
149+
150+
# Sort the results by the "start" index
151+
sort_d = [value for index, value in sorted(enumerate(all_start), key=lambda all_start: all_start[1])]
152+
153+
if len(sort_d) == 0:
154+
return print("There was no entity in this sentence!!")
155+
else:
156+
sort_d = np.reshape(np.array(sort_d)[:, 1], [np.shape(sort_d)[0], 1])
157+
cut_out.append(sort_d)
158+
return cut_out
159+
else:
160+
return print('Invalid input! Please re-input!!\n')
161+
162+
163+
@app.route('/predict', methods=["GET"])
164+
def predict_iris_file():
165+
"""Named Entity Recognition (NER) Prediction for Medical Services
166+
---
167+
parameters:
168+
- name: input_str
169+
in: query
170+
type: string
171+
required: true
172+
"""
173+
input_str = request.args.get("input_str")
174+
cn = ChineseNER("predict")
175+
prediction = cn.predict(input_str)
176+
return str(prediction)
177+
178+
179+
# main function
180+
if __name__ == "__main__":
181+
app.run(host='0.0.0.0', port=8000)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
embedding_size: 50 # 30 ~ 50 dimensionality for ~ M corpus
2+
hidden_size: 256
3+
model_path: models/
4+
dataset_path: data/
5+
batch_size: 16
6+
dropout: 0.50
7+
learning_rate: 0.001
8+
lr_decay_step: 5
9+
lr_decay_rate: 0.90
10+
epochs: 1000
11+
weight_decay: 0.0005
12+
max_length: 120
13+
transfer_learning: False
14+
tags:
15+
- E95f2a617
16+
- E320ca3f6
17+
- E340ca71c
18+
- E1ceb2bd7
19+
- E1deb2d6a
20+
- E370cabd5
21+
- E360caa42
22+
- E310ca263
23+
- E300ca0d0
24+
- E18eb258b
25+
- E3c0cb3b4
26+
- E1beb2a44
27+
- E3d0cb547
28+
- E8ff29ca5
29+
- E330ca589
30+
- E1eeb2efd
31+
- E17eb23f8
32+
- E94f2a484

0 commit comments

Comments
 (0)