forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexport_model.py
More file actions
99 lines (85 loc) Β· 4.31 KB
/
export_model.py
File metadata and controls
99 lines (85 loc) Β· 4.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pprint import pprint
import paddle
from paddlenlp.ops import FasterUNIMOText
from paddlenlp.transformers import UNIMOLMHeadModel, UNIMOTokenizer
from paddlenlp.utils.log import logger
# yapf: disable
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", default="checkpoint", type=str, help="The model name to specify the UNIMOText to use. ")
parser.add_argument("--inference_model_dir", default="./export_checkpoint", type=str, help="Path to save inference model of UNIMOText. ")
parser.add_argument("--topk", default=4, type=int, help="The number of candidate to procedure top_k sampling. ")
parser.add_argument("--topp", default=1.0, type=float, help="The probability threshold to procedure top_p sampling. ")
parser.add_argument("--max_dec_len", default=20, type=int, help="Maximum output length. ")
parser.add_argument("--min_dec_len", default=3, type=int, help="Minimum output length. ")
parser.add_argument("--temperature", default=1.0, type=float, help="The temperature to set. ")
parser.add_argument("--num_return_sequences", default=1, type=int, help="The number of returned sequences. ")
parser.add_argument("--use_fp16_decoding", action="store_true", help="Whether to use fp16 decoding to predict. ")
parser.add_argument("--decoding_strategy", default="beam_search", choices=["sampling", "beam_search"], type=str, help="The main strategy to decode. ")
parser.add_argument("--num_beams", default=6, type=int, help="The number of candidate to procedure beam search. ")
parser.add_argument("--diversity_rate", default=0.0, type=float, help="The diversity rate to procedure beam search. ")
parser.add_argument("--length_penalty", default=1.2, type=float, help="The diversity rate to procedure beam search. ")
args = parser.parse_args()
return args
def do_predict(args):
place = "gpu"
place = paddle.set_device(place)
model_name_or_path = args.model_name_or_path
model = UNIMOLMHeadModel.from_pretrained(model_name_or_path)
tokenizer = UNIMOTokenizer.from_pretrained(model_name_or_path)
unimo_text = FasterUNIMOText(model=model,
use_fp16_decoding=args.use_fp16_decoding,
trans_out=True)
# Set evaluate mode
unimo_text.eval()
# Convert dygraph model to static graph model
unimo_text = paddle.jit.to_static(
unimo_text,
input_spec=[
# input_ids
paddle.static.InputSpec(shape=[None, None], dtype="int64"),
# token_type_ids
paddle.static.InputSpec(shape=[None, None], dtype="int64"),
# attention_mask
paddle.static.InputSpec(shape=[None, 1, None, None],
dtype="float32"),
# seq_len
paddle.static.InputSpec(shape=[None], dtype="int64"),
args.max_dec_len,
args.min_dec_len,
args.topk,
args.topp,
args.num_beams, # num_beams. Used for beam_search.
args.decoding_strategy,
tokenizer.cls_token_id, # cls/bos
tokenizer.mask_token_id, # mask/eos
tokenizer.pad_token_id, # pad
args.diversity_rate, # diversity rate. Used for beam search.
args.temperature,
args.num_return_sequences,
args.length_penalty,
])
# Save converted static graph model
paddle.jit.save(unimo_text,
os.path.join(args.inference_model_dir, "unimo_text"))
logger.info("UNIMOText has been saved to {}.".format(
args.inference_model_dir))
if __name__ == "__main__":
args = parse_args()
pprint(args)
do_predict(args)