Skip to content

Commit db80817

Browse files
Update the Quant for the llama.cpp .
1 parent b849bad commit db80817

File tree

6 files changed

+388
-481
lines changed

6 files changed

+388
-481
lines changed

quantllm/api/high_level.py

Lines changed: 53 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ..utils.logger import logger
66
import psutil
77
import math
8+
import os
89

910
def get_gpu_memory():
1011
"""Get available GPU memory in GB."""
@@ -219,7 +220,7 @@ def quantize_from_pretrained(
219220
if auto_device:
220221
if torch.cuda.is_available() and gpu_mem:
221222
max_gpu_mem = max(gpu_mem)
222-
if model_size_gb * 1.5 > max_gpu_mem: # Need 1.5x for safe loading
223+
if model_size_gb > max_gpu_mem:
223224
logger.log_info("Insufficient GPU memory. Using CPU offloading.")
224225
device = "cpu"
225226
cpu_offload = True
@@ -262,6 +263,7 @@ def quantize_from_pretrained(
262263
)
263264
logger.log_info(f"Selected quantization type: {quant_type} ({bits}-bit)")
264265

266+
# Create and store quantizer
265267
quantizer = GGUFQuantizer(
266268
model_name=model_name,
267269
bits=bits,
@@ -279,6 +281,9 @@ def quantize_from_pretrained(
279281
cpu_offload=cpu_offload
280282
)
281283

284+
# Store quantizer instance in model for later use
285+
quantizer.model._quantizer = quantizer
286+
282287
return quantizer.model
283288

284289
except Exception as e:
@@ -297,78 +302,68 @@ def save_quantized_model(
297302
):
298303
"""Save a quantized model in GGUF format."""
299304
try:
300-
logger.log_info("\n" + "="*60)
301-
logger.log_info("Starting GGUF Export Process")
302-
logger.log_info("="*60)
305+
logger.log_info("\n" + "="*80)
306+
logger.log_info("Starting GGUF Export Process".center(80))
307+
logger.log_info("="*80 + "\n")
303308

304309
# Log model details
305310
total_params = sum(p.numel() for p in model.parameters())
306311
model_size_gb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**3)
307312

308-
logger.log_info(f"\nModel Information:")
309-
logger.log_info(f"Architecture: {model.config.model_type}")
310-
logger.log_info(f"Total Parameters: {total_params:,}")
311-
logger.log_info(f"Model Size: {model_size_gb:.2f} GB")
313+
logger.log_info("📊 Model Information:")
314+
logger.log_info("-"*40)
315+
logger.log_info(f"• Architecture: {model.config.model_type}")
316+
logger.log_info(f"• Total Parameters: {total_params:,}")
317+
logger.log_info(f"• Model Size: {model_size_gb:.2f} GB")
318+
logger.log_info("")
312319

313320
# Get quantization info
314-
if hasattr(model.config, 'quantization_config'):
315-
config_dict = model.config.quantization_config
316-
if isinstance(config_dict, BitsAndBytesConfig):
317-
# Handle BitsAndBytesConfig
318-
bits = 4 if config_dict.load_in_4bit else (8 if config_dict.load_in_8bit else 16)
321+
if not quant_config:
322+
if hasattr(model.config, 'quantization_config'):
323+
config_dict = model.config.quantization_config
324+
if isinstance(config_dict, BitsAndBytesConfig):
325+
# Handle BitsAndBytesConfig
326+
bits = 4 if config_dict.load_in_4bit else (8 if config_dict.load_in_8bit else 16)
327+
quant_config = {
328+
'bits': bits,
329+
'group_size': 128, # Default group size
330+
'quant_type': f"Q{bits}_K_M" if bits <= 8 else "F16"
331+
}
332+
logger.log_info("📊 Quantization Configuration:")
333+
logger.log_info("-"*40)
334+
logger.log_info(f"• Bits: {bits}")
335+
logger.log_info(f"• Quantization Type: {quant_config['quant_type']}")
336+
if config_dict.load_in_4bit:
337+
logger.log_info(f"• 4-bit Type: {config_dict.bnb_4bit_quant_type}")
338+
logger.log_info(f"• Compute dtype: {config_dict.bnb_4bit_compute_dtype}")
339+
else:
340+
quant_config = config_dict
341+
else:
342+
logger.log_info("\nUsing default 4-bit quantization settings")
319343
quant_config = {
320-
'bits': bits,
321-
'group_size': 128, # Default group size
322-
'quant_type': f"Q{bits}_K_M" if bits <= 8 else "F16"
344+
'bits': 4,
345+
'group_size': 128,
346+
'quant_type': "Q4_K_M"
323347
}
324-
logger.log_info(f"\nQuantization Configuration:")
325-
logger.log_info(f"Bits: {bits}")
326-
logger.log_info(f"Quantization Type: {quant_config['quant_type']}")
327-
if config_dict.load_in_4bit:
328-
logger.log_info(f"4-bit Type: {config_dict.bnb_4bit_quant_type}")
329-
logger.log_info(f"Compute dtype: {config_dict.bnb_4bit_compute_dtype}")
330-
else:
331-
quant_config = config_dict
332348

333-
if not quant_config:
334-
logger.log_info("\nUsing default 4-bit quantization settings")
335-
quant_config = {
336-
'bits': 4,
337-
'group_size': 128,
338-
'quant_type': "Q4_K_M"
339-
}
349+
# Create output directory
350+
output_dir = os.path.dirname(output_path) or "."
351+
os.makedirs(output_dir, exist_ok=True)
340352

341-
# Create quantizer with config
342-
logger.log_info("\nInitializing GGUF quantizer...")
343-
quantizer = GGUFQuantizer(
344-
model_name=model,
353+
# Convert to GGUF using the new converter
354+
from ..quant.llama_cpp_utils import LlamaCppConverter
355+
356+
converter = LlamaCppConverter()
357+
gguf_path = converter.convert_to_gguf(
358+
model=model,
359+
output_dir=output_dir,
345360
bits=quant_config['bits'],
346361
group_size=quant_config.get('group_size', 128),
347-
quant_type=quant_config.get('quant_type'),
348-
use_packed=quant_config.get('use_packed', True)
362+
save_tokenizer=save_tokenizer
349363
)
350364

351-
# Convert to GGUF
352-
logger.log_info("\nConverting model to GGUF format...")
353-
quantizer.convert_to_gguf(output_path)
354-
logger.log_info("GGUF conversion completed successfully")
355-
356-
# Save tokenizer if requested
357-
if save_tokenizer and hasattr(model, 'config'):
358-
if hasattr(model.config, '_name_or_path'):
359-
try:
360-
tokenizer = AutoTokenizer.from_pretrained(
361-
model.config._name_or_path,
362-
trust_remote_code=True
363-
)
364-
tokenizer_path = output_path.rsplit('.', 1)[0] + "_tokenizer"
365-
tokenizer.save_pretrained(tokenizer_path)
366-
logger.log_info(f"Tokenizer saved to: {tokenizer_path}")
367-
except Exception as e:
368-
logger.log_warning(f"Failed to save tokenizer: {e}")
369-
370-
logger.log_info("\nModel export completed successfully!")
371-
logger.log_info("="*60)
365+
logger.log_info("\n✨ Model export completed successfully!")
366+
logger.log_info("="*80)
372367

373368
except Exception as e:
374369
logger.log_error(f"Failed to save model: {str(e)}")

quantllm/quant/formats.py

Lines changed: 0 additions & 143 deletions
This file was deleted.

0 commit comments

Comments
 (0)