forked from openvinotoolkit/openvino_notebooks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathov_llm_model.py
369 lines (326 loc) · 13.3 KB
/
ov_llm_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
from transformers import PretrainedConfig, AutoTokenizer
from optimum.utils import NormalizedTextConfig, NormalizedConfigManager
from optimum.intel.openvino import OVModelForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from optimum.intel.openvino.utils import OV_XML_FILE_NAME
from pathlib import Path
from typing import Optional, Union, Dict, Tuple, Any, List
from pathlib import Path
import openvino as ov
import torch
import numpy as np
class OVMPTModel(OVModelForCausalLM):
"""
Optimum intel compatible model wrapper for MPT
"""
def __init__(
self,
model: "Model",
config: "PretrainedConfig" = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path]] = None,
**kwargs,
):
NormalizedConfigManager._conf["mpt"] = NormalizedTextConfig.with_args(
num_layers="n_layers", num_attention_heads="n_heads"
)
super().__init__(
model, config, device, dynamic_shapes, ov_config, model_save_dir, **kwargs
)
def _reshape(self, model: "Model", *args, **kwargs):
shapes = {}
for inputs in model.inputs:
shapes[inputs] = inputs.get_partial_shape()
if shapes[inputs].rank.get_length() in [2, 3]:
shapes[inputs][1] = -1
elif shapes[inputs].rank.get_length() == 1:
continue
else:
if ".key" in inputs.get_any_name():
shapes[inputs][3] = -1
else:
shapes[inputs][2] = -1
model.reshape(shapes)
return model
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
self.compile()
if self.use_cache and past_key_values is not None:
input_ids = input_ids[:, -1:]
batch_size = input_ids.shape[0]
inputs = {}
past_len = 0
if not self.stateful:
if past_key_values is not None:
past_len = past_key_values[0][1].shape[-2]
if self._pkv_precision == Type.bf16:
# numpy does not support bf16, pretending f16, should change to bf16
past_key_values = tuple(
Tensor(past_key_value, past_key_value.shape, Type.bf16)
for pkv_per_layer in past_key_values
for past_key_value in pkv_per_layer
)
else:
# Flatten the past_key_values
past_key_values = tuple(
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
)
# Add the past_key_values to the decoder inputs
inputs = dict(zip(self.key_value_input_names, past_key_values))
# Create empty past_key_values for decoder_with_past first generation step
elif self.use_cache:
for input_name in self.key_value_input_names:
model_inputs = self.model.input(input_name)
shape = model_inputs.get_partial_shape()
if self.config.model_type == 'chatglm':
shape[0] = 0
shape[1] = batch_size
else:
shape[0] = batch_size
if shape[2].is_dynamic:
shape[2] = 0
elif shape.rank.get_length() == 4 and shape[3].is_dynamic:
shape[3] = 0
else:
shape[1] = 0
inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape())
else:
# past_key_values are not used explicitly, instead they are handled inside the model
if past_key_values is None:
# Need a marker to differentiate the first generate iteration from the others in
# the first condition at the function beginning above.
# It should be something that is not None and it should be True when converted to Boolean.
past_key_values = ((),)
# This is the first iteration in a sequence, reset all states
for state in self.request.query_state():
state.reset()
# Set initial value for the next beam_idx input that will be used at the current iteration
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
self.next_beam_idx = np.array(range(batch_size), dtype=int)
inputs["input_ids"] = np.array(input_ids)
# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names or "position_ids" in self.input_names:
if attention_mask is not None:
attention_mask = np.array(attention_mask)
else:
attention_mask = np.ones(
(input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype
)
if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask
if "position_ids" in self.input_names:
if position_ids is not None:
position_ids = np.array(position_ids)
else:
position_ids = np.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
if past_key_values:
position_ids = np.expand_dims(position_ids[:, -1], axis=-1)
inputs["position_ids"] = position_ids
if hasattr(self, 'next_beam_idx'):
inputs['beam_idx'] = self.next_beam_idx
# Run inference
self.request.start_async(inputs, share_inputs=True)
self.request.wait()
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)
if not self.stateful:
if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
past_key_values = tuple(
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
)
else:
past_key_values = None
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str, None]] = None,
revision: Optional[Union[str, None]] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
file_name: Optional[str] = None,
subfolder: str = "",
from_onnx: bool = False,
local_files_only: bool = False,
load_in_8bit: bool = False,
**kwargs,
):
model_path = Path(model_id)
default_file_name = OV_XML_FILE_NAME
file_name = file_name or default_file_name
model_cache_path = cls._cached_file(
model_path=model_path,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
file_name=file_name,
subfolder=subfolder,
local_files_only=local_files_only,
)
model = cls.load_model(model_cache_path, load_in_8bit=load_in_8bit)
init_cls = OVMPTModel
return init_cls(
model=model, config=config, model_save_dir=model_cache_path.parent, **kwargs
)
class OVQWENModel(OVModelForCausalLM):
"""
Optimum intel compatible model wrapper for QWEN
"""
def __init__(
self,
model: "Model",
config: "PretrainedConfig" = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path]] = None,
**kwargs,
):
NormalizedConfigManager._conf["qwen"] = NormalizedTextConfig.with_args(
num_layers="num_hidden_layers",
num_attention_heads="num_attention_heads",
hidden_size="hidden_size",
)
super().__init__(
model, config, device, dynamic_shapes, ov_config, model_save_dir, **kwargs
)
def _reshape(self, model: "Model", *args, **kwargs):
shapes = {}
for inputs in model.inputs:
shapes[inputs] = inputs.get_partial_shape()
if inputs.get_any_name().startswith('beam_idx'):
continue
shapes[inputs][1] = -1
model.reshape(shapes)
return model
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str, None]] = None,
revision: Optional[Union[str, None]] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
file_name: Optional[str] = None,
subfolder: str = "",
from_onnx: bool = False,
local_files_only: bool = False,
load_in_8bit: bool = False,
**kwargs,
):
model_path = Path(model_id)
default_file_name = OV_XML_FILE_NAME
file_name = file_name or default_file_name
model_cache_path = cls._cached_file(
model_path=model_path,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
file_name=file_name,
subfolder=subfolder,
local_files_only=local_files_only,
)
model = cls.load_model(model_cache_path, load_in_8bit=load_in_8bit)
init_cls = OVQWENModel
return init_cls(
model=model, config=config, model_save_dir=model_cache_path.parent, **kwargs
)
class OVCHATGLMModel(OVModelForCausalLM):
"""
Optimum intel compatible model wrapper for CHATGLM2
"""
def __init__(
self,
model: "Model",
config: "PretrainedConfig" = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path]] = None,
**kwargs,
):
NormalizedConfigManager._conf["chatglm"] = NormalizedTextConfig.with_args(
num_layers="num_hidden_layers",
num_attention_heads="num_attention_heads",
hidden_size="hidden_size",
)
super().__init__(
model, config, device, dynamic_shapes, ov_config, model_save_dir, **kwargs
)
def _reshape(
self,
model: "Model",
batch_size: int,
sequence_length: int,
height: int = None,
width: int = None,
):
shapes = {}
for inputs in model.inputs:
shapes[inputs] = inputs.get_partial_shape()
shapes[inputs][0] = -1
input_name = inputs.get_any_name()
if input_name.startswith('beam_idx'):
continue
if input_name.startswith('past_key_values'):
shapes[inputs][1] = -1
shapes[inputs][2] = 2
elif shapes[inputs].rank.get_length() > 1:
shapes[inputs][1] = -1
model.reshape(shapes)
return model
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str, None]] = None,
revision: Optional[Union[str, None]] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
file_name: Optional[str] = None,
subfolder: str = "",
from_onnx: bool = False,
local_files_only: bool = False,
load_in_8bit: bool = False,
**kwargs,
):
model_path = Path(model_id)
default_file_name = OV_XML_FILE_NAME
file_name = file_name or default_file_name
model_cache_path = cls._cached_file(
model_path=model_path,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
file_name=file_name,
subfolder=subfolder,
local_files_only=local_files_only,
)
model = cls.load_model(model_cache_path, load_in_8bit=load_in_8bit)
init_cls = OVCHATGLMModel
return init_cls(
model=model, config=config, model_save_dir=model_cache_path.parent, **kwargs
)
model_classes = {
"mpt": OVMPTModel,
"qwen": OVQWENModel,
"chatglm3": OVCHATGLMModel,
}