Skip to content

Commit 81b4f98

Browse files
SunMarcMekkCyber
andauthored
transformers serve quantization docs + some api fixes for bitsandbytes (#41253)
* doc * fix api * fix * fix * fix * fix args * minor doc fix * fix * style * rm check for now * fix * style * Update docs/source/en/serving.md Co-authored-by: Mohamed Mekkouri <[email protected]> * add log and update value --------- Co-authored-by: Mohamed Mekkouri <[email protected]>
1 parent 2a3f66d commit 81b4f98

File tree

2 files changed

+39
-31
lines changed

2 files changed

+39
-31
lines changed

docs/source/en/serving.md

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,30 @@ transformers serve \
383383
--attn_implementation "sdpa"
384384
```
385385

386+
### Quantization
387+
388+
transformers serve is compatible with all [quantization methods](https://huggingface.co/docs/transformers/main/quantization/overview) supported in transformers. Quantization can significantly reduce memory usage and improve inference speed, with two main workflows: pre-quantized models and on-the-fly quantization.
389+
390+
#### Pre-quantized Models
391+
392+
For models that are already quantized (e.g., GPTQ, AWQ, bitsandbytes), simply choose a quantized model name for serving.
393+
Make sure to install the required libraries listed in the quantization documentation.
394+
395+
> [!TIP]
396+
> Pre-quantized models generally provide the best balance of performance and accuracy.
397+
398+
#### On the fly quantization
399+
400+
If you want to quantize a model at runtime, you can specify the --quantization flag in the CLI. Note that not all quantization methods support on-the-fly conversion. The full list of supported methods is available in the quantization [overview](https://huggingface.co/docs/transformers/main/quantization/overview).
401+
402+
Currently, with transformers serve, we only supports some methods: ["bnb-4bit", "bnb-8bit"]
403+
404+
For example, to enable 4-bit quantization with bitsandbytes, you need to pass add `--quantization bnb-4bit`:
405+
406+
```sh
407+
transformers serve --quantization bnb-4bit
408+
```
409+
386410
### Performance tips
387411

388412
- Use an efficient attention backend when available:
@@ -397,6 +421,4 @@ transformers serve \
397421
398422
- `--dtype {bfloat16|float16}` typically improve throughput and memory use vs. `float32`
399423

400-
- `--load_in_4bit`/`--load_in_8bit` can reduce memory footprint for LoRA setups
401-
402424
- `--force-model <repo_id>` avoids per-request model hints and helps produce stable, repeatable runs

src/transformers/cli/serve.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -377,14 +377,10 @@ def __init__(
377377
help="Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`."
378378
),
379379
] = None,
380-
load_in_8bit: Annotated[
381-
bool, typer.Option(help="Whether to use 8 bit precision for the base model - works only with LoRA.")
382-
] = False,
383-
load_in_4bit: Annotated[
384-
bool, typer.Option(help="Whether to use 4 bit precision for the base model - works only with LoRA.")
385-
] = False,
386-
bnb_4bit_quant_type: Annotated[str, typer.Option(help="Quantization type.")] = "nf4",
387-
use_bnb_nested_quant: Annotated[bool, typer.Option(help="Whether to use nested quantization.")] = False,
380+
quantization: Annotated[
381+
Optional[str],
382+
typer.Option(help="Which quantization method to use. choices: 'bnb-4bit', 'bnb-8bit'"),
383+
] = None,
388384
host: Annotated[str, typer.Option(help="Interface the server will listen to.")] = "localhost",
389385
port: Annotated[int, typer.Option(help="Port the server will listen to.")] = 8000,
390386
model_timeout: Annotated[
@@ -424,10 +420,7 @@ def __init__(
424420
self.dtype = dtype
425421
self.trust_remote_code = trust_remote_code
426422
self.attn_implementation = attn_implementation
427-
self.load_in_8bit = load_in_8bit
428-
self.load_in_4bit = load_in_4bit
429-
self.bnb_4bit_quant_type = bnb_4bit_quant_type
430-
self.use_bnb_nested_quant = use_bnb_nested_quant
423+
self.quantization = quantization
431424
self.host = host
432425
self.port = port
433426
self.model_timeout = model_timeout
@@ -1688,22 +1681,20 @@ def get_quantization_config(self) -> Optional["BitsAndBytesConfig"]:
16881681
Returns:
16891682
`Optional[BitsAndBytesConfig]`: The quantization config.
16901683
"""
1691-
if self.load_in_4bit:
1684+
if self.quantization == "bnb-4bit":
16921685
quantization_config = BitsAndBytesConfig(
16931686
load_in_4bit=True,
1694-
# For consistency with model weights, we use the same value as `dtype`
1695-
bnb_4bit_compute_dtype=self.dtype,
1696-
bnb_4bit_quant_type=self.bnb_4bit_quant_type,
1697-
bnb_4bit_use_double_quant=self.use_bnb_nested_quant,
1698-
bnb_4bit_quant_storage=self.dtype,
1699-
)
1700-
elif self.load_in_8bit:
1701-
quantization_config = BitsAndBytesConfig(
1702-
load_in_8bit=True,
1687+
bnb_4bit_quant_type="nf4",
1688+
bnb_4bit_use_double_quant=True,
17031689
)
1690+
elif self.quantization == "bnb-8bit":
1691+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
17041692
else:
17051693
quantization_config = None
17061694

1695+
if quantization_config is not None:
1696+
logger.info(f"Quantization applied with the following config: {quantization_config}")
1697+
17071698
return quantization_config
17081699

17091700
def process_model_name(self, model_id: str) -> str:
@@ -1750,27 +1741,22 @@ def _load_model_and_data_processor(self, model_id_and_revision: str):
17501741
revision=revision,
17511742
trust_remote_code=self.trust_remote_code,
17521743
)
1753-
17541744
dtype = self.dtype if self.dtype in ["auto", None] else getattr(torch, self.dtype)
17551745
quantization_config = self.get_quantization_config()
17561746

17571747
model_kwargs = {
17581748
"revision": revision,
17591749
"attn_implementation": self.attn_implementation,
17601750
"dtype": dtype,
1761-
"device_map": "auto",
1751+
"device_map": self.device,
17621752
"trust_remote_code": self.trust_remote_code,
1753+
"quantization_config": quantization_config,
17631754
}
1764-
if quantization_config is not None:
1765-
model_kwargs["quantization_config"] = quantization_config
17661755

17671756
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
17681757
architecture = getattr(transformers, config.architectures[0])
17691758
model = architecture.from_pretrained(model_id, **model_kwargs)
17701759

1771-
if getattr(model, "hf_device_map", None) is None:
1772-
model = model.to(self.device)
1773-
17741760
has_default_max_length = (
17751761
model.generation_config.max_new_tokens is None and model.generation_config.max_length == 20
17761762
)

0 commit comments

Comments
 (0)