Skip to content

Commit b849bad

Browse files
Update the Quantization and GGUf for model .
1 parent 4a3e499 commit b849bad

File tree

4 files changed

+307
-378
lines changed

4 files changed

+307
-378
lines changed

quantllm/api/high_level.py

Lines changed: 160 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,67 @@
11
from typing import Optional, Dict, Any, Union, Tuple
22
import torch
3-
from transformers import PreTrainedModel, AutoTokenizer, AutoConfig
3+
from transformers import PreTrainedModel, AutoTokenizer, AutoConfig, BitsAndBytesConfig
44
from ..quant.gguf import GGUFQuantizer, SUPPORTED_GGUF_BITS, SUPPORTED_GGUF_TYPES
55
from ..utils.logger import logger
6+
import psutil
7+
import math
8+
9+
def get_gpu_memory():
10+
"""Get available GPU memory in GB."""
11+
if torch.cuda.is_available():
12+
gpu_mem = []
13+
for i in range(torch.cuda.device_count()):
14+
total = torch.cuda.get_device_properties(i).total_memory / (1024**3) # Convert to GB
15+
allocated = torch.cuda.memory_allocated(i) / (1024**3)
16+
gpu_mem.append(total - allocated)
17+
return gpu_mem
18+
return []
19+
20+
def get_system_memory():
21+
"""Get available system memory in GB."""
22+
return psutil.virtual_memory().available / (1024**3)
23+
24+
def estimate_model_size(model_name: Union[str, PreTrainedModel]) -> float:
25+
"""Estimate model size in GB."""
26+
try:
27+
if isinstance(model_name, PreTrainedModel):
28+
params = sum(p.numel() for p in model_name.parameters())
29+
return (params * 2) / (1024**3) # Assuming FP16
30+
else:
31+
config = AutoConfig.from_pretrained(model_name)
32+
if hasattr(config, 'num_parameters'):
33+
return (config.num_parameters * 2) / (1024**3) # Assuming FP16
34+
elif hasattr(config, 'n_params'):
35+
return (config.n_params * 2) / (1024**3) # Assuming FP16
36+
# Estimate based on common architectures
37+
elif hasattr(config, 'hidden_size') and hasattr(config, 'num_hidden_layers'):
38+
# More accurate estimation for transformer models
39+
hidden_size = config.hidden_size
40+
num_layers = config.num_hidden_layers
41+
vocab_size = config.vocab_size if hasattr(config, 'vocab_size') else 32000
42+
43+
# Calculate main components
44+
attention_params = 4 * num_layers * hidden_size * hidden_size # Q,K,V,O matrices
45+
ffn_params = 8 * num_layers * hidden_size * hidden_size # FFN layers
46+
embedding_params = vocab_size * hidden_size # Input embeddings
47+
48+
total_params = attention_params + ffn_params + embedding_params
49+
return (total_params * 2) / (1024**3) # Assuming FP16
50+
51+
# If no size info available, estimate based on model name
52+
if "llama" in model_name.lower():
53+
if "7b" in model_name.lower():
54+
return 13.0
55+
elif "13b" in model_name.lower():
56+
return 24.0
57+
elif "70b" in model_name.lower():
58+
return 130.0
59+
elif "3b" in model_name.lower():
60+
return 6.0
61+
return 7.0 # Default assumption
62+
except Exception as e:
63+
logger.log_warning(f"Error estimating model size: {e}. Using default size.")
64+
return 7.0 # Default assumption
665

766
class QuantLLM:
867
"""High-level API for GGUF model quantization."""
@@ -80,11 +139,11 @@ def get_recommended_quant_type(
80139
if model_size_gb <= 2:
81140
bits, qtype = (5, "Q5_1") if priority == "quality" else (4, "Q4_K_M")
82141
elif model_size_gb <= 7:
83-
bits, qtype = (4, "Q4_K_M") if priority != "speed" else (4, "Q4_K_S")
142+
bits, qtype = (5, "Q5_1") if priority == "quality" else (4, "Q4_K_M")
84143
elif model_size_gb <= 13:
85-
bits, qtype = (3, "Q3_K_M") if priority != "speed" else (3, "Q3_K_S")
144+
bits, qtype = (4, "Q4_K_M") if priority != "speed" else (4, "Q4_K_S")
86145
else:
87-
bits, qtype = (2, "Q2_K")
146+
bits, qtype = (3, "Q3_K_M")
88147

89148
return bits, qtype
90149

@@ -108,10 +167,11 @@ def quantize_from_pretrained(
108167
offload_state_dict: bool = False,
109168
torch_dtype: Optional[torch.dtype] = torch.float16,
110169
auto_device: bool = True,
111-
optimize_for: str = "balanced"
170+
optimize_for: str = "balanced",
171+
cpu_offload: bool = False
112172
) -> PreTrainedModel:
113173
"""
114-
Quantize a model using GGUF format with BitsAndBytes and Accelerate for efficient loading.
174+
Quantize a model using GGUF format with optimized resource handling.
115175
116176
Args:
117177
model_name: Model identifier or instance
@@ -133,6 +193,7 @@ def quantize_from_pretrained(
133193
torch_dtype: Default torch dtype
134194
auto_device: Automatically determine optimal device
135195
optimize_for: Optimization priority ("speed", "quality", or "balanced")
196+
cpu_offload: Whether to use CPU offloading
136197
137198
Returns:
138199
Quantized model
@@ -145,42 +206,56 @@ def quantize_from_pretrained(
145206
if quant_type and quant_type not in SUPPORTED_GGUF_TYPES.get(bits, {}):
146207
raise ValueError(f"Unsupported quant_type: {quant_type} for {bits} bits")
147208

148-
# Auto-determine device if requested
149-
if auto_device and device is None:
150-
if torch.cuda.is_available():
151-
# Check available GPU memory
152-
gpu_mem = torch.cuda.get_device_properties(0).total_memory
153-
model_size = 0
154-
if isinstance(model_name, PreTrainedModel):
155-
model_size = sum(p.numel() * p.element_size() for p in model_name.parameters())
156-
157-
# If model is too large for GPU, use CPU offloading
158-
if model_size > gpu_mem * 0.7: # Leave 30% margin
159-
logger.log_info("Model too large for GPU memory. Using CPU offloading.")
209+
# Estimate model size and available resources
210+
model_size_gb = estimate_model_size(model_name)
211+
gpu_mem = get_gpu_memory()
212+
system_mem = get_system_memory()
213+
214+
logger.log_info(f"Estimated model size: {model_size_gb:.2f} GB")
215+
logger.log_info(f"Available GPU memory: {gpu_mem}")
216+
logger.log_info(f"Available system memory: {system_mem:.2f} GB")
217+
218+
# Auto-configure resources
219+
if auto_device:
220+
if torch.cuda.is_available() and gpu_mem:
221+
max_gpu_mem = max(gpu_mem)
222+
if model_size_gb * 1.5 > max_gpu_mem: # Need 1.5x for safe loading
223+
logger.log_info("Insufficient GPU memory. Using CPU offloading.")
160224
device = "cpu"
225+
cpu_offload = True
161226
device_map = "cpu"
162227
max_memory = None
163228
else:
164229
device = "cuda"
230+
# Calculate memory distribution
231+
if device_map == "auto":
232+
max_memory = {
233+
i: f"{int(mem * 0.8)}GB" # Use 80% of available memory
234+
for i, mem in enumerate(gpu_mem)
235+
}
236+
max_memory["cpu"] = f"{int(system_mem * 0.5)}GB" # Use 50% of system RAM
165237
else:
166238
device = "cpu"
239+
cpu_offload = True
167240
device_map = "cpu"
168241
max_memory = None
169242
logger.log_info(f"Auto-selected device: {device}")
243+
244+
# Configure BitsAndBytes for 4-bit quantization
245+
if load_in_4bit:
246+
compute_dtype = bnb_4bit_compute_dtype or torch.float16
247+
bnb_config = BitsAndBytesConfig(
248+
load_in_4bit=True,
249+
bnb_4bit_quant_type=bnb_4bit_quant_type,
250+
bnb_4bit_compute_dtype=compute_dtype,
251+
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
252+
llm_int8_enable_fp32_cpu_offload=cpu_offload
253+
)
254+
else:
255+
bnb_config = None
170256

171-
# If no quant_type specified, use recommended type based on optimization priority
257+
# If no quant_type specified, use recommended type
172258
if not quant_type:
173-
if isinstance(model_name, PreTrainedModel):
174-
model_size_gb = sum(p.numel() * p.element_size() for p in model_name.parameters()) / (1024**3)
175-
else:
176-
# Estimate model size based on common architectures
177-
config = AutoConfig.from_pretrained(model_name)
178-
params = config.n_params if hasattr(config, 'n_params') else None
179-
if params:
180-
model_size_gb = (params * 2) / (1024**3) # Assuming FP16
181-
else:
182-
model_size_gb = 7 # Default assumption
183-
184259
bits, quant_type = QuantLLM.get_recommended_quant_type(
185260
model_size_gb=model_size_gb,
186261
priority=optimize_for
@@ -194,17 +269,14 @@ def quantize_from_pretrained(
194269
quant_type=quant_type,
195270
use_packed=use_packed,
196271
device=device,
197-
load_in_8bit=load_in_8bit,
198-
load_in_4bit=load_in_4bit,
199-
bnb_4bit_quant_type=bnb_4bit_quant_type,
200-
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
201-
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
272+
quantization_config=bnb_config,
202273
use_gradient_checkpointing=use_gradient_checkpointing,
203274
device_map=device_map,
204275
max_memory=max_memory,
205276
offload_folder=offload_folder,
206277
offload_state_dict=offload_state_dict,
207-
torch_dtype=torch_dtype
278+
torch_dtype=torch_dtype,
279+
cpu_offload=cpu_offload
208280
)
209281

210282
return quantizer.model
@@ -223,32 +295,61 @@ def save_quantized_model(
223295
save_tokenizer: bool = True,
224296
quant_config: Optional[Dict[str, Any]] = None
225297
):
226-
"""
227-
Save a quantized model in GGUF format.
228-
229-
Args:
230-
model: Quantized model to save
231-
output_path: Path to save the model
232-
save_tokenizer: Whether to save the tokenizer
233-
quant_config: Optional quantization configuration
234-
"""
298+
"""Save a quantized model in GGUF format."""
235299
try:
236-
logger.log_info(f"Converting model to GGUF format: {output_path}")
300+
logger.log_info("\n" + "="*60)
301+
logger.log_info("Starting GGUF Export Process")
302+
logger.log_info("="*60)
303+
304+
# Log model details
305+
total_params = sum(p.numel() for p in model.parameters())
306+
model_size_gb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**3)
307+
308+
logger.log_info(f"\nModel Information:")
309+
logger.log_info(f"Architecture: {model.config.model_type}")
310+
logger.log_info(f"Total Parameters: {total_params:,}")
311+
logger.log_info(f"Model Size: {model_size_gb:.2f} GB")
312+
313+
# Get quantization info
314+
if hasattr(model.config, 'quantization_config'):
315+
config_dict = model.config.quantization_config
316+
if isinstance(config_dict, BitsAndBytesConfig):
317+
# Handle BitsAndBytesConfig
318+
bits = 4 if config_dict.load_in_4bit else (8 if config_dict.load_in_8bit else 16)
319+
quant_config = {
320+
'bits': bits,
321+
'group_size': 128, # Default group size
322+
'quant_type': f"Q{bits}_K_M" if bits <= 8 else "F16"
323+
}
324+
logger.log_info(f"\nQuantization Configuration:")
325+
logger.log_info(f"Bits: {bits}")
326+
logger.log_info(f"Quantization Type: {quant_config['quant_type']}")
327+
if config_dict.load_in_4bit:
328+
logger.log_info(f"4-bit Type: {config_dict.bnb_4bit_quant_type}")
329+
logger.log_info(f"Compute dtype: {config_dict.bnb_4bit_compute_dtype}")
330+
else:
331+
quant_config = config_dict
237332

238-
# Get quantization config from model if not provided
239-
if not quant_config and hasattr(model.config, 'quantization_config'):
240-
quant_config = model.config.quantization_config
333+
if not quant_config:
334+
logger.log_info("\nUsing default 4-bit quantization settings")
335+
quant_config = {
336+
'bits': 4,
337+
'group_size': 128,
338+
'quant_type': "Q4_K_M"
339+
}
241340

242-
# Create quantizer with existing or default config
341+
# Create quantizer with config
342+
logger.log_info("\nInitializing GGUF quantizer...")
243343
quantizer = GGUFQuantizer(
244344
model_name=model,
245-
bits=quant_config.get('bits', 4) if quant_config else 4,
246-
group_size=quant_config.get('group_size', 128) if quant_config else 128,
247-
quant_type=quant_config.get('quant_type', None) if quant_config else None,
248-
use_packed=quant_config.get('use_packed', True) if quant_config else True
345+
bits=quant_config['bits'],
346+
group_size=quant_config.get('group_size', 128),
347+
quant_type=quant_config.get('quant_type'),
348+
use_packed=quant_config.get('use_packed', True)
249349
)
250350

251351
# Convert to GGUF
352+
logger.log_info("\nConverting model to GGUF format...")
252353
quantizer.convert_to_gguf(output_path)
253354
logger.log_info("GGUF conversion completed successfully")
254355

@@ -260,12 +361,14 @@ def save_quantized_model(
260361
model.config._name_or_path,
261362
trust_remote_code=True
262363
)
263-
tokenizer.save_pretrained(output_path)
264-
logger.log_info("Tokenizer saved successfully")
364+
tokenizer_path = output_path.rsplit('.', 1)[0] + "_tokenizer"
365+
tokenizer.save_pretrained(tokenizer_path)
366+
logger.log_info(f"Tokenizer saved to: {tokenizer_path}")
265367
except Exception as e:
266368
logger.log_warning(f"Failed to save tokenizer: {e}")
267369

268-
logger.log_info("Model saved successfully")
370+
logger.log_info("\nModel export completed successfully!")
371+
logger.log_info("="*60)
269372

270373
except Exception as e:
271374
logger.log_error(f"Failed to save model: {str(e)}")

0 commit comments

Comments
 (0)