Skip to content

Commit e643212

Browse files
n1ck-guowenhuach21pre-commit-ci[bot]
authored
[Experimental Feature]support for common hf multimodel (#276)
Signed-off-by: n1ck-guo <[email protected]> Signed-off-by: n1ck-guo <[email protected]> Co-authored-by: wenhuach21 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4f22871 commit e643212

File tree

25 files changed

+1326
-127
lines changed

25 files changed

+1326
-127
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ AutoRound
44
===========================
55
<h3> Advanced Quantization Algorithm for LLMs</h3>
66

7-
[![python](https://img.shields.io/badge/python-3.8%2B-blue)](https://github.com/intel/auto-round)
7+
[![python](https://img.shields.io/badge/python-3.9%2B-blue)](https://github.com/intel/auto-round)
88
[![version](https://img.shields.io/badge/release-0.3.1-green)](https://github.com/intel/auto-round)
99
[![license](https://img.shields.io/badge/license-Apache%202-blue)](https://github.com/intel/auto-round/blob/main/LICENSE)
1010
---

auto_round/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from .autoround import AutoRound, AutoAdamRound, AutoOPTRound
14+
from .autoround import AutoRound, AutoRoundAdam, AutoRoundOPT
15+
from .mllm import AutoRoundMLLM
1516
from .auto_quantizer import AutoHfQuantizer,AutoRoundConfig
1617
from .version import __version__

auto_round/__main__.py

Lines changed: 279 additions & 92 deletions
Large diffs are not rendered by default.

auto_round/autoround.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from transformers import set_seed
2323
from torch import autocast
2424
from tqdm import tqdm
25+
import accelerate
2526

2627
from .quantizer import WrapperMultiblock, wrapper_block, unwrapper_block, WrapperLinear, unwrapper_layer, \
2728
WrapperTransformerConv1d
@@ -48,10 +49,9 @@
4849
get_layer_names_in_block,
4950
mv_module_from_gpu,
5051
unsupport_meta_device, detect_device_count, clear_memory,
52+
get_multimodal_block_names,
5153
)
52-
5354
from .low_cpu_mem.utils import get_layers_before_block
54-
import accelerate
5555

5656

5757
class AutoRound(object):
@@ -529,11 +529,10 @@ def calib(self, nsamples, bs):
529529
for key in data.keys():
530530
data_new[key] = to_device(data[key], self.model.device)
531531
if key == 'images':
532-
data_new[key] = to_dtype(data[key], self.model.dtype)
532+
data_new[key] = to_dtype(data_new[key], self.model.dtype)
533533
input_ids = data_new["input_ids"]
534534
if input_ids.shape[-1] < self.seqlen:
535535
continue
536-
537536
try:
538537
if isinstance(data_new, torch.Tensor):
539538
self.model(data_new)
@@ -544,7 +543,7 @@ def calib(self, nsamples, bs):
544543
except NotImplementedError:
545544
pass
546545
except Exception as error:
547-
logger.error(error)
546+
raise error
548547
total_cnt += input_ids.shape[0] if len(input_ids.shape) > 1 else 1
549548
if total_cnt >= nsamples:
550549
break
@@ -595,18 +594,21 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
595594
)
596595
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
597596
clear_memory()
598-
except:
599-
logger.info("switch to cpu to cache inputs")
600-
if "lm_head" in self.layer_config and self.layer_config["lm_head"]["bits"] < 8:
601-
logger.warning(f"we strongly recommend using additional CUDA/HPU devices,e.g. "
602-
f"'CUDA_VISIBLE_DEVICES=0,1 python xxx',"
603-
f" for optimal performance during calibration when enabling lm-head quantization. "
604-
f"Otherwise, the process may be significantly slower.")
605-
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
606-
clear_memory()
607-
all_inputs = self.cache_inter_data(
608-
block_names, nsamples, layer_names=layer_names, last_cache_name=last_cache_name
609-
)
597+
except RuntimeError as e:
598+
if "CUDA out of memory" in str(e):
599+
logger.info("switch to cpu to cache inputs")
600+
if "lm_head" in self.layer_config and self.layer_config["lm_head"]["bits"] < 8:
601+
logger.warning(f"we strongly recommend using additional CUDA/HPU devices,e.g. "
602+
f"'CUDA_VISIBLE_DEVICES=0,1 python xxx',"
603+
f" for optimal performance during calibration when enabling lm-head quantization. "
604+
f"Otherwise, the process may be significantly slower.")
605+
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
606+
clear_memory()
607+
all_inputs = self.cache_inter_data(
608+
block_names, nsamples, layer_names=layer_names, last_cache_name=last_cache_name
609+
)
610+
else:
611+
raise
610612
return all_inputs
611613

612614
@torch.no_grad()
@@ -1330,7 +1332,7 @@ def step(self, scaler, optimizer, lr_schedule):
13301332
lr_schedule.step()
13311333

13321334

1333-
class AutoOPTRound(AutoRound):
1335+
class AutoRoundOPT(AutoRound):
13341336
"""Class for automatic rounding-based quantization with optimizers like adamw of a PyTorch model.
13351337
13361338
Args:
@@ -1413,7 +1415,7 @@ def __init__(
14131415
optimizer="AdamW",
14141416
**kwargs,
14151417
):
1416-
super(AutoOPTRound, self).__init__(
1418+
super(AutoRoundOPT, self).__init__(
14171419
model=model,
14181420
tokenizer=tokenizer,
14191421
bits=bits,
@@ -1493,7 +1495,7 @@ def step(self, scaler, optimizer, lr_schedule):
14931495
htcore.mark_step()
14941496

14951497

1496-
class AutoAdamRound(AutoOPTRound):
1498+
class AutoRoundAdam(AutoRoundOPT):
14971499
"""Class for automatic rounding-based quantization with optimizers like adamw of a PyTorch model.
14981500
The default lr has been changed.
14991501
@@ -1577,7 +1579,7 @@ def __init__(
15771579
optimizer="AdamW",
15781580
**kwargs,
15791581
):
1580-
super(AutoAdamRound, self).__init__(
1582+
super(AutoRoundAdam, self).__init__(
15811583
model=model,
15821584
tokenizer=tokenizer,
15831585
bits=bits,

auto_round/mllm/README.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# AutoRound for MLLMs
2+
## API Usage (Gaudi2/CPU/GPU)
3+
```python
4+
from auto_round import AutoRoundMLLM
5+
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoTokenizer
6+
7+
model_name = "Qwen/Qwen2-VL-2B-Instruct"
8+
tokenizer = AutoTokenizer.from_pretrained(model_name)
9+
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code)
10+
tokenizer.processor = processor
11+
model = Qwen2VLForConditionalGeneration.from_pretrained(
12+
model_name, trust_remote_code=True)
13+
dataset = "/path/to/llava.json"
14+
extra_data_dir = "/path/to/images/dir"
15+
16+
bits, group_size = 4, 128
17+
autoround = AutoRoundMLLM(model, tokenizer, bits=bits, group_size=group_size, dataset=dataset, extra_data_dir=extra_data_dir)
18+
19+
autoround.quantize()
20+
output_dir = "./tmp_autoround"
21+
autoround.save_quantized(output_dir, format='auto_round', inplace=True)
22+
```
23+
24+
## Template
25+
For autoround MLLMs, using Template to customize different operations for different models. User can add a custom chat template through json file as below.
26+
```json
27+
{
28+
"model_type": "qwen2_vl",
29+
"format_user": "<|im_start|>user\n{{content}}<|im_end|>\n",
30+
"format_assistant": "<|im_start|>assistant\n{{content}}<|im_end|>\n",
31+
"format_system": "<|im_start|>system\n{{content}}<|im_end|>\n",
32+
"format_observation": "<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n",
33+
"format_separator": "\n",
34+
"default_system": "You are a helpful assistant.",
35+
"replace_tokens": ["<image>", "<|vision_start|><|image_pad|><|vision_end|>"],
36+
"processor": "qwen2_vl" }
37+
```
38+
The special token ```{{content}}``` is a placeholder to tell the preprocessor where to fill in the corresponding dialogue content.
39+
40+
```format_*```: Add specific token to chat content depends on different role names.
41+
42+
For example, the input conversations:<br>
43+
```[{'role': 'user', 'value': '<image>\nWhat are the colors of the bus in the image?'}, {'role': 'assistant', 'value': 'The bus in the image is white and red.'}]```
44+
45+
Using the above template, the input will be converted to the specified format required by Qwen2-vl as below: <br>
46+
```'<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>\nWhat are the colors of the bus in the image?<|im_end|>\n<|im_start|>assistant\nThe bus in the image is white and red.<|im_end|>\n<|im_start|>user\nWhat feature can be seen on the back of the bus?<|im_end|>\n<|im_start|>assistant\nThe back of the bus features an advertisement.<|im_end|>\n<|im_start|>user\nIs the bus driving down the street or pulled off to the side?<|im_end|>\n<|im_start|>assistant\nThe bus is driving down the street, which is crowded with people and other vehicles.<|im_end|>\n'```.
47+
48+
## Processor
49+
Processor is callback interface for calling different processors, such as texts or images processors, for MLLMs. User can define own processor and use registration function to declare. For more information, please refer to the relevant code in ```auto_round/mllm/processor.py```.

auto_round/mllm/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .mllm_dataset import get_mllm_dataloader
16+
from .template import Template, get_template, TEMPLATES
17+
from .autoround_mllm import AutoRoundMLLM

0 commit comments

Comments
 (0)