Skip to content

Commit 995496d

Browse files
committed
Update Files
1 parent 10c00ff commit 995496d

18 files changed

+669
-0
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
FROM python:3
2+
EXPOSE 8000
3+
4+
RUN apt-get update && apt-get install -y apache2 \
5+
apache2-dev \
6+
vim \
7+
&& apt-get clean \
8+
&& apt-get autoremove \
9+
&& rm -rf /var/lib/apt/lists/*
10+
11+
WORKDIR /var/www/NERmodel/
12+
13+
COPY ./NERmodel.py /var/www/NERmodel/NERmodel.py
14+
COPY ./NERmodel.wsgi /var/www/NERmodel/NERmodel.wsgi
15+
COPY ./models /var/www/NERmodel/models/
16+
COPY ./model.py /var/www/NERmodel/model.py
17+
COPY ./utils.py /var/www/NERmodel/utils.py
18+
COPY ./requirements.txt /var/www/NERmodel/requirements.txt
19+
COPY ./config.yml /var/www/NERmodel/config.yml
20+
21+
RUN pip install --upgrade pip
22+
RUN pip install --upgrade --ignore-installed PyYAML
23+
RUN pip install torch==1.5.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
24+
RUN pip install -r requirements.txt
25+
RUN mod_wsgi-express install-module
26+
RUN mod_wsgi-express setup-server NERmodel.wsgi --port=8000 \
27+
--user www-data --group www-data \
28+
--server-root=/etc/mod_wsgi-express-80
29+
30+
CMD /etc/mod_wsgi-express-80/apachectl start -D FOREGROUND
+181
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
# import necessary Python Packages
5+
import re
6+
import pickle
7+
import torch
8+
import yaml
9+
from model import BiLSTMCRF
10+
from utils import *
11+
import warnings
12+
import numpy as np
13+
from flask import Flask, request
14+
from flasgger import Swagger
15+
16+
17+
warnings.filterwarnings("ignore")
18+
device = torch.device("cpu")
19+
20+
21+
app = Flask(__name__)
22+
swagger = Swagger(app)
23+
24+
25+
def load_params(path: str):
26+
"""
27+
Load the parameters (data)
28+
"""
29+
with open(path + "data.pkl", "rb") as fopen:
30+
data_map = pickle.load(fopen)
31+
return data_map
32+
33+
34+
def strQ2B(ustring):
35+
rstring = ""
36+
for uchar in ustring:
37+
inside_code=ord(uchar)
38+
if inside_code == 12288:
39+
inside_code = 32
40+
elif inside_code >= 65281 and inside_code <= 65374:
41+
inside_code -= 65248
42+
rstring += chr(inside_code)
43+
return rstring
44+
45+
46+
def cut_text(text, length):
47+
textArr = re.findall('.{' + str(length) + '}', text)
48+
textArr.append(text[(len(textArr) * length):])
49+
return textArr
50+
51+
52+
def load_config():
53+
"""
54+
Load hyper-parameters from the YAML file
55+
"""
56+
fopen = open("config.yml")
57+
config = yaml.load(fopen, Loader=yaml.FullLoader)
58+
fopen.close()
59+
return config
60+
61+
62+
class ChineseNER:
63+
def __init__(self, entry="train"):
64+
# Load some Hyper-parameters
65+
config = load_config()
66+
self.embedding_size = config.get("embedding_size")
67+
self.hidden_size = config.get("hidden_size")
68+
self.batch_size = config.get("batch_size")
69+
self.model_path = config.get("model_path")
70+
self.dropout = config.get("dropout")
71+
self.tags = config.get("tags")
72+
self.learning_rate = config.get("learning_rate")
73+
self.epochs = config.get("epochs")
74+
self.weight_decay = config.get("weight_decay")
75+
self.transfer_learning = config.get("transfer_learning")
76+
self.lr_decay_step = config.get("lr_decay_step")
77+
self.lr_decay_rate = config.get("lr_decay_rate")
78+
self.max_length = config.get("max_length")
79+
80+
# Model Initialization
81+
self.main_model(entry)
82+
83+
def main_model(self, entry):
84+
"""
85+
Model Initialization
86+
"""
87+
# The Testing & Inference Process
88+
if entry == "predict":
89+
data_map = load_params(path=self.model_path)
90+
input_size = data_map.get("input_size")
91+
self.tag_map = data_map.get("tag_map")
92+
self.vocab = data_map.get("vocab")
93+
self.model = BiLSTMCRF(
94+
tag_map=self.tag_map,
95+
vocab_size=input_size,
96+
dropout=0.0,
97+
embedding_dim=self.embedding_size,
98+
hidden_dim=self.hidden_size,
99+
max_length=self.max_length
100+
)
101+
self.restore_model()
102+
103+
def restore_model(self):
104+
"""
105+
Restore the model if there is one
106+
"""
107+
try:
108+
self.model.load_state_dict(torch.load(self.model_path + "params.pkl"))
109+
print("Model Successfully Restored!")
110+
except Exception as error:
111+
print("Model Failed to restore! {}".format(error))
112+
113+
def predict(self, input_str):
114+
"""
115+
Prediction & Inference Stage
116+
:param input_str: Input Chinese sentence
117+
:return entities: Predicted entities
118+
"""
119+
if len(input_str) != 0:
120+
# Full-width to half-width
121+
input_str = strQ2B(input_str)
122+
input_str = re.sub(pattern='。', repl='.', string=input_str)
123+
text = cut_text(text=input_str, length=self.max_length)
124+
125+
cut_out = []
126+
for cuttext in text:
127+
# Get the embedding vector (Input Vector) from vocab
128+
input_vec = [self.vocab.get(i, 0) for i in cuttext]
129+
130+
# convert it to tensor and run the model
131+
sentences = torch.tensor(input_vec).view(1, -1)
132+
133+
length = np.expand_dims(np.shape(sentences)[1], axis=0)
134+
length = torch.tensor(length, dtype=torch.int64, device=device)
135+
136+
_, paths = self.model(sentences=sentences, real_length=length, lengths=None)
137+
138+
# Get the entities from the model
139+
entities = []
140+
for tag in self.tags:
141+
tags = get_tags(paths[0], tag, self.tag_map)
142+
entities += format_result(tags, cuttext, tag)
143+
144+
# Get all the entities
145+
all_start = []
146+
for entity in entities:
147+
start = entity.get('start')
148+
all_start.append([start, entity])
149+
150+
# Sort the results by the "start" index
151+
sort_d = [value for index, value in sorted(enumerate(all_start), key=lambda all_start: all_start[1])]
152+
153+
if len(sort_d) == 0:
154+
return print("There was no entity in this sentence!!")
155+
else:
156+
sort_d = np.reshape(np.array(sort_d)[:, 1], [np.shape(sort_d)[0], 1])
157+
cut_out.append(sort_d)
158+
return cut_out
159+
else:
160+
return print('Invalid input! Please re-input!!\n')
161+
162+
163+
@app.route('/predict', methods=["GET"])
164+
def predict_iris_file():
165+
"""Named Entity Recognition (NER) Prediction for Medical Services
166+
---
167+
parameters:
168+
- name: input_str
169+
in: query
170+
type: string
171+
required: true
172+
"""
173+
input_str = request.args.get("input_str")
174+
cn = ChineseNER("predict")
175+
prediction = cn.predict(input_str)
176+
return str(prediction)
177+
178+
179+
# main function
180+
if __name__ == "__main__":
181+
app.run(host='0.0.0.0')
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
import sys
5+
import os
6+
from NERmodel import app as application
7+
8+
sys.path.insert(0, "/var/www/NERmodel")
9+
sys.path.insert(0, '/usr/local/lib/python3.8/site-packages')
10+
sys.path.insert(0, "/usr/local/lib/python3.8/bin/")
11+
12+
os.environ['PYTHONPATH'] = '/usr/local/lib/python3.8/bin/python'
+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
embedding_size: 50 # 30 ~ 50 dimensionality for ~ M corpus
2+
hidden_size: 256
3+
model_path: models/
4+
dataset_path: data/
5+
batch_size: 16
6+
dropout: 0.50
7+
learning_rate: 0.001
8+
lr_decay_step: 5
9+
lr_decay_rate: 0.90
10+
epochs: 1000
11+
weight_decay: 0.0005
12+
max_length: 120
13+
transfer_learning: False
14+
tags:
15+
- E95f2a617
16+
- E320ca3f6
17+
- E340ca71c
18+
- E1ceb2bd7
19+
- E1deb2d6a
20+
- E370cabd5
21+
- E360caa42
22+
- E310ca263
23+
- E300ca0d0
24+
- E18eb258b
25+
- E3c0cb3b4
26+
- E1beb2a44
27+
- E3d0cb547
28+
- E8ff29ca5
29+
- E330ca589
30+
- E1eeb2efd
31+
- E17eb23f8
32+
- E94f2a484

0 commit comments

Comments
 (0)