diff --git a/dspy/clients/lm_local.py b/dspy/clients/lm_local.py index 8329db9786..9e93642f49 100644 --- a/dspy/clients/lm_local.py +++ b/dspy/clients/lm_local.py @@ -211,8 +211,20 @@ def train_sft_locally(model_name, train_data, train_kwargs): ) logger.info(f"Using device: {device}") + USE_QUANTIZATION = train_kwargs.get("use_quantization", False) + quantization_config = None + + if USE_QUANTIZATION: + from transformers import BitsAndBytesConfig + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4" + ) + model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=model_name + pretrained_model_name_or_path=model_name, quantization_config=quantization_config ).to(device) tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)