1
1
from typing import Optional , Dict , Any , Union , Tuple
2
2
import torch
3
- from transformers import PreTrainedModel , AutoTokenizer , AutoConfig
3
+ from transformers import PreTrainedModel , AutoTokenizer , AutoConfig , BitsAndBytesConfig
4
4
from ..quant .gguf import GGUFQuantizer , SUPPORTED_GGUF_BITS , SUPPORTED_GGUF_TYPES
5
5
from ..utils .logger import logger
6
+ import psutil
7
+ import math
8
+
9
+ def get_gpu_memory ():
10
+ """Get available GPU memory in GB."""
11
+ if torch .cuda .is_available ():
12
+ gpu_mem = []
13
+ for i in range (torch .cuda .device_count ()):
14
+ total = torch .cuda .get_device_properties (i ).total_memory / (1024 ** 3 ) # Convert to GB
15
+ allocated = torch .cuda .memory_allocated (i ) / (1024 ** 3 )
16
+ gpu_mem .append (total - allocated )
17
+ return gpu_mem
18
+ return []
19
+
20
+ def get_system_memory ():
21
+ """Get available system memory in GB."""
22
+ return psutil .virtual_memory ().available / (1024 ** 3 )
23
+
24
+ def estimate_model_size (model_name : Union [str , PreTrainedModel ]) -> float :
25
+ """Estimate model size in GB."""
26
+ try :
27
+ if isinstance (model_name , PreTrainedModel ):
28
+ params = sum (p .numel () for p in model_name .parameters ())
29
+ return (params * 2 ) / (1024 ** 3 ) # Assuming FP16
30
+ else :
31
+ config = AutoConfig .from_pretrained (model_name )
32
+ if hasattr (config , 'num_parameters' ):
33
+ return (config .num_parameters * 2 ) / (1024 ** 3 ) # Assuming FP16
34
+ elif hasattr (config , 'n_params' ):
35
+ return (config .n_params * 2 ) / (1024 ** 3 ) # Assuming FP16
36
+ # Estimate based on common architectures
37
+ elif hasattr (config , 'hidden_size' ) and hasattr (config , 'num_hidden_layers' ):
38
+ # More accurate estimation for transformer models
39
+ hidden_size = config .hidden_size
40
+ num_layers = config .num_hidden_layers
41
+ vocab_size = config .vocab_size if hasattr (config , 'vocab_size' ) else 32000
42
+
43
+ # Calculate main components
44
+ attention_params = 4 * num_layers * hidden_size * hidden_size # Q,K,V,O matrices
45
+ ffn_params = 8 * num_layers * hidden_size * hidden_size # FFN layers
46
+ embedding_params = vocab_size * hidden_size # Input embeddings
47
+
48
+ total_params = attention_params + ffn_params + embedding_params
49
+ return (total_params * 2 ) / (1024 ** 3 ) # Assuming FP16
50
+
51
+ # If no size info available, estimate based on model name
52
+ if "llama" in model_name .lower ():
53
+ if "7b" in model_name .lower ():
54
+ return 13.0
55
+ elif "13b" in model_name .lower ():
56
+ return 24.0
57
+ elif "70b" in model_name .lower ():
58
+ return 130.0
59
+ elif "3b" in model_name .lower ():
60
+ return 6.0
61
+ return 7.0 # Default assumption
62
+ except Exception as e :
63
+ logger .log_warning (f"Error estimating model size: { e } . Using default size." )
64
+ return 7.0 # Default assumption
6
65
7
66
class QuantLLM :
8
67
"""High-level API for GGUF model quantization."""
@@ -80,11 +139,11 @@ def get_recommended_quant_type(
80
139
if model_size_gb <= 2 :
81
140
bits , qtype = (5 , "Q5_1" ) if priority == "quality" else (4 , "Q4_K_M" )
82
141
elif model_size_gb <= 7 :
83
- bits , qtype = (4 , "Q4_K_M " ) if priority != "speed " else (4 , "Q4_K_S " )
142
+ bits , qtype = (5 , "Q5_1 " ) if priority == "quality " else (4 , "Q4_K_M " )
84
143
elif model_size_gb <= 13 :
85
- bits , qtype = (3 , "Q3_K_M " ) if priority != "speed" else (3 , "Q3_K_S " )
144
+ bits , qtype = (4 , "Q4_K_M " ) if priority != "speed" else (4 , "Q4_K_S " )
86
145
else :
87
- bits , qtype = (2 , "Q2_K " )
146
+ bits , qtype = (3 , "Q3_K_M " )
88
147
89
148
return bits , qtype
90
149
@@ -108,10 +167,11 @@ def quantize_from_pretrained(
108
167
offload_state_dict : bool = False ,
109
168
torch_dtype : Optional [torch .dtype ] = torch .float16 ,
110
169
auto_device : bool = True ,
111
- optimize_for : str = "balanced"
170
+ optimize_for : str = "balanced" ,
171
+ cpu_offload : bool = False
112
172
) -> PreTrainedModel :
113
173
"""
114
- Quantize a model using GGUF format with BitsAndBytes and Accelerate for efficient loading .
174
+ Quantize a model using GGUF format with optimized resource handling .
115
175
116
176
Args:
117
177
model_name: Model identifier or instance
@@ -133,6 +193,7 @@ def quantize_from_pretrained(
133
193
torch_dtype: Default torch dtype
134
194
auto_device: Automatically determine optimal device
135
195
optimize_for: Optimization priority ("speed", "quality", or "balanced")
196
+ cpu_offload: Whether to use CPU offloading
136
197
137
198
Returns:
138
199
Quantized model
@@ -145,42 +206,56 @@ def quantize_from_pretrained(
145
206
if quant_type and quant_type not in SUPPORTED_GGUF_TYPES .get (bits , {}):
146
207
raise ValueError (f"Unsupported quant_type: { quant_type } for { bits } bits" )
147
208
148
- # Auto-determine device if requested
149
- if auto_device and device is None :
150
- if torch .cuda .is_available ():
151
- # Check available GPU memory
152
- gpu_mem = torch .cuda .get_device_properties (0 ).total_memory
153
- model_size = 0
154
- if isinstance (model_name , PreTrainedModel ):
155
- model_size = sum (p .numel () * p .element_size () for p in model_name .parameters ())
156
-
157
- # If model is too large for GPU, use CPU offloading
158
- if model_size > gpu_mem * 0.7 : # Leave 30% margin
159
- logger .log_info ("Model too large for GPU memory. Using CPU offloading." )
209
+ # Estimate model size and available resources
210
+ model_size_gb = estimate_model_size (model_name )
211
+ gpu_mem = get_gpu_memory ()
212
+ system_mem = get_system_memory ()
213
+
214
+ logger .log_info (f"Estimated model size: { model_size_gb :.2f} GB" )
215
+ logger .log_info (f"Available GPU memory: { gpu_mem } " )
216
+ logger .log_info (f"Available system memory: { system_mem :.2f} GB" )
217
+
218
+ # Auto-configure resources
219
+ if auto_device :
220
+ if torch .cuda .is_available () and gpu_mem :
221
+ max_gpu_mem = max (gpu_mem )
222
+ if model_size_gb * 1.5 > max_gpu_mem : # Need 1.5x for safe loading
223
+ logger .log_info ("Insufficient GPU memory. Using CPU offloading." )
160
224
device = "cpu"
225
+ cpu_offload = True
161
226
device_map = "cpu"
162
227
max_memory = None
163
228
else :
164
229
device = "cuda"
230
+ # Calculate memory distribution
231
+ if device_map == "auto" :
232
+ max_memory = {
233
+ i : f"{ int (mem * 0.8 )} GB" # Use 80% of available memory
234
+ for i , mem in enumerate (gpu_mem )
235
+ }
236
+ max_memory ["cpu" ] = f"{ int (system_mem * 0.5 )} GB" # Use 50% of system RAM
165
237
else :
166
238
device = "cpu"
239
+ cpu_offload = True
167
240
device_map = "cpu"
168
241
max_memory = None
169
242
logger .log_info (f"Auto-selected device: { device } " )
243
+
244
+ # Configure BitsAndBytes for 4-bit quantization
245
+ if load_in_4bit :
246
+ compute_dtype = bnb_4bit_compute_dtype or torch .float16
247
+ bnb_config = BitsAndBytesConfig (
248
+ load_in_4bit = True ,
249
+ bnb_4bit_quant_type = bnb_4bit_quant_type ,
250
+ bnb_4bit_compute_dtype = compute_dtype ,
251
+ bnb_4bit_use_double_quant = bnb_4bit_use_double_quant ,
252
+ llm_int8_enable_fp32_cpu_offload = cpu_offload
253
+ )
254
+ else :
255
+ bnb_config = None
170
256
171
- # If no quant_type specified, use recommended type based on optimization priority
257
+ # If no quant_type specified, use recommended type
172
258
if not quant_type :
173
- if isinstance (model_name , PreTrainedModel ):
174
- model_size_gb = sum (p .numel () * p .element_size () for p in model_name .parameters ()) / (1024 ** 3 )
175
- else :
176
- # Estimate model size based on common architectures
177
- config = AutoConfig .from_pretrained (model_name )
178
- params = config .n_params if hasattr (config , 'n_params' ) else None
179
- if params :
180
- model_size_gb = (params * 2 ) / (1024 ** 3 ) # Assuming FP16
181
- else :
182
- model_size_gb = 7 # Default assumption
183
-
184
259
bits , quant_type = QuantLLM .get_recommended_quant_type (
185
260
model_size_gb = model_size_gb ,
186
261
priority = optimize_for
@@ -194,17 +269,14 @@ def quantize_from_pretrained(
194
269
quant_type = quant_type ,
195
270
use_packed = use_packed ,
196
271
device = device ,
197
- load_in_8bit = load_in_8bit ,
198
- load_in_4bit = load_in_4bit ,
199
- bnb_4bit_quant_type = bnb_4bit_quant_type ,
200
- bnb_4bit_compute_dtype = bnb_4bit_compute_dtype ,
201
- bnb_4bit_use_double_quant = bnb_4bit_use_double_quant ,
272
+ quantization_config = bnb_config ,
202
273
use_gradient_checkpointing = use_gradient_checkpointing ,
203
274
device_map = device_map ,
204
275
max_memory = max_memory ,
205
276
offload_folder = offload_folder ,
206
277
offload_state_dict = offload_state_dict ,
207
- torch_dtype = torch_dtype
278
+ torch_dtype = torch_dtype ,
279
+ cpu_offload = cpu_offload
208
280
)
209
281
210
282
return quantizer .model
@@ -223,32 +295,61 @@ def save_quantized_model(
223
295
save_tokenizer : bool = True ,
224
296
quant_config : Optional [Dict [str , Any ]] = None
225
297
):
226
- """
227
- Save a quantized model in GGUF format.
228
-
229
- Args:
230
- model: Quantized model to save
231
- output_path: Path to save the model
232
- save_tokenizer: Whether to save the tokenizer
233
- quant_config: Optional quantization configuration
234
- """
298
+ """Save a quantized model in GGUF format."""
235
299
try :
236
- logger .log_info (f"Converting model to GGUF format: { output_path } " )
300
+ logger .log_info ("\n " + "=" * 60 )
301
+ logger .log_info ("Starting GGUF Export Process" )
302
+ logger .log_info ("=" * 60 )
303
+
304
+ # Log model details
305
+ total_params = sum (p .numel () for p in model .parameters ())
306
+ model_size_gb = sum (p .numel () * p .element_size () for p in model .parameters ()) / (1024 ** 3 )
307
+
308
+ logger .log_info (f"\n Model Information:" )
309
+ logger .log_info (f"Architecture: { model .config .model_type } " )
310
+ logger .log_info (f"Total Parameters: { total_params :,} " )
311
+ logger .log_info (f"Model Size: { model_size_gb :.2f} GB" )
312
+
313
+ # Get quantization info
314
+ if hasattr (model .config , 'quantization_config' ):
315
+ config_dict = model .config .quantization_config
316
+ if isinstance (config_dict , BitsAndBytesConfig ):
317
+ # Handle BitsAndBytesConfig
318
+ bits = 4 if config_dict .load_in_4bit else (8 if config_dict .load_in_8bit else 16 )
319
+ quant_config = {
320
+ 'bits' : bits ,
321
+ 'group_size' : 128 , # Default group size
322
+ 'quant_type' : f"Q{ bits } _K_M" if bits <= 8 else "F16"
323
+ }
324
+ logger .log_info (f"\n Quantization Configuration:" )
325
+ logger .log_info (f"Bits: { bits } " )
326
+ logger .log_info (f"Quantization Type: { quant_config ['quant_type' ]} " )
327
+ if config_dict .load_in_4bit :
328
+ logger .log_info (f"4-bit Type: { config_dict .bnb_4bit_quant_type } " )
329
+ logger .log_info (f"Compute dtype: { config_dict .bnb_4bit_compute_dtype } " )
330
+ else :
331
+ quant_config = config_dict
237
332
238
- # Get quantization config from model if not provided
239
- if not quant_config and hasattr (model .config , 'quantization_config' ):
240
- quant_config = model .config .quantization_config
333
+ if not quant_config :
334
+ logger .log_info ("\n Using default 4-bit quantization settings" )
335
+ quant_config = {
336
+ 'bits' : 4 ,
337
+ 'group_size' : 128 ,
338
+ 'quant_type' : "Q4_K_M"
339
+ }
241
340
242
- # Create quantizer with existing or default config
341
+ # Create quantizer with config
342
+ logger .log_info ("\n Initializing GGUF quantizer..." )
243
343
quantizer = GGUFQuantizer (
244
344
model_name = model ,
245
- bits = quant_config . get ( 'bits' , 4 ) if quant_config else 4 ,
246
- group_size = quant_config .get ('group_size' , 128 ) if quant_config else 128 ,
247
- quant_type = quant_config .get ('quant_type' , None ) if quant_config else None ,
248
- use_packed = quant_config .get ('use_packed' , True ) if quant_config else True
345
+ bits = quant_config [ 'bits' ] ,
346
+ group_size = quant_config .get ('group_size' , 128 ),
347
+ quant_type = quant_config .get ('quant_type' ) ,
348
+ use_packed = quant_config .get ('use_packed' , True )
249
349
)
250
350
251
351
# Convert to GGUF
352
+ logger .log_info ("\n Converting model to GGUF format..." )
252
353
quantizer .convert_to_gguf (output_path )
253
354
logger .log_info ("GGUF conversion completed successfully" )
254
355
@@ -260,12 +361,14 @@ def save_quantized_model(
260
361
model .config ._name_or_path ,
261
362
trust_remote_code = True
262
363
)
263
- tokenizer .save_pretrained (output_path )
264
- logger .log_info ("Tokenizer saved successfully" )
364
+ tokenizer_path = output_path .rsplit ('.' , 1 )[0 ] + "_tokenizer"
365
+ tokenizer .save_pretrained (tokenizer_path )
366
+ logger .log_info (f"Tokenizer saved to: { tokenizer_path } " )
265
367
except Exception as e :
266
368
logger .log_warning (f"Failed to save tokenizer: { e } " )
267
369
268
- logger .log_info ("Model saved successfully" )
370
+ logger .log_info ("\n Model export completed successfully!" )
371
+ logger .log_info ("=" * 60 )
269
372
270
373
except Exception as e :
271
374
logger .log_error (f"Failed to save model: { str (e )} " )
0 commit comments