Skip to content

Commit 28bbf40

Browse files
authored
[VLM]fix bs and grad reset (#344)
Signed-off-by: n1ck-guo <[email protected]>
1 parent 61e04a8 commit 28bbf40

File tree

8 files changed

+39
-22
lines changed

8 files changed

+39
-22
lines changed

auto_round/autoround.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def quantize(self):
294294
accelerate.hooks.remove_hook_from_submodules(self.model) ##self.model.hf_device_map has not been changed
295295
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
296296
logger.info("caching done")
297+
pbar = tqdm(range(0, sum([len(i) for i in all_blocks]), self.nblocks))
297298
for block_names in all_blocks:
298299
inputs = all_inputs[block_names[0]]
299300
all_inputs.pop(block_names[0])
@@ -318,6 +319,7 @@ def quantize(self):
318319
block_names,
319320
nblocks=self.nblocks,
320321
device=self.device,
322+
pbar=pbar
321323
)
322324

323325
self.quant_layers(layer_names, all_inputs)
@@ -1124,6 +1126,7 @@ def quant_blocks(
11241126
block_names,
11251127
nblocks=1,
11261128
device=torch.device("cpu"),
1129+
pbar=None
11271130
):
11281131
"""Quantize and dequantize the weights of the specified blocks in the model.
11291132
@@ -1162,8 +1165,10 @@ def quant_blocks(
11621165
to_dtype(input_others[key][i], tmp_dtype)
11631166
quant_block = compile_func(self.quant_block, device, self.enable_torch_compile)
11641167

1165-
pbar = tqdm(range(0, len(block_names), nblocks))
1166-
for i in pbar:
1168+
if pbar is None:
1169+
pbar = tqdm(range(0, len(block_names), nblocks))
1170+
# for i in pbar:
1171+
for i in range(len(block_names)):
11671172
if nblocks == 1:
11681173
n = block_names[i]
11691174
pbar.set_description(f"Quantizing {n}")
@@ -1184,6 +1189,7 @@ def quant_blocks(
11841189
q_input=q_input,
11851190
device=device,
11861191
)
1192+
pbar.update(1)
11871193

11881194
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
11891195

auto_round/mllm/autoround_mllm.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __init__(
120120
low_gpu_mem_usage: bool = False,
121121
low_cpu_mem_usage: bool = False,
122122
iters: int = 200,
123-
seqlen: int = 2048,
123+
seqlen: int = None,
124124
nsamples: int = 128,
125125
sampler: str = "rand",
126126
seed: int = 42,
@@ -136,7 +136,7 @@ def __init__(
136136
act_dynamic: bool = True,
137137
to_quant_block_names: Union[str, list] = None,
138138
enable_norm_bias_tuning: bool = False,
139-
truncation: bool = False,
139+
truncation: bool = None,
140140
enable_torch_compile: bool = None,
141141
**kwargs,
142142
):
@@ -152,10 +152,6 @@ def __init__(
152152

153153
dataset = self.template.default_dataset if dataset is None else dataset
154154

155-
if nsamples % batch_size != 0:
156-
nsamples = (nsamples // batch_size + 1) * batch_size
157-
logger.warning(f"'nsamples' is not divisible by 'batch_size', will adjusted to {nsamples}")
158-
159155
from ..calib_dataset import CALIB_DATASETS
160156
from .mllm_dataset import MLLM_DATASET
161157
if isinstance(dataset, str):
@@ -170,17 +166,31 @@ def __init__(
170166

171167
if dataset in MLLM_DATASET.keys():
172168
truncation = False
173-
batch_size = 1
174169
seqlen = 512 if seqlen is None else seqlen
170+
if batch_size != 1:
171+
logger.warning(
172+
f"rest batch_size({batch_size}) to 1 and "
173+
f"gradient_accumulate_steps({gradient_accumulate_steps}) "
174+
f"to {batch_size * gradient_accumulate_steps}, "
175+
f"cause batch_size={batch_size} cannot be used for {dataset}")
176+
gradient_accumulate_steps = batch_size * gradient_accumulate_steps
177+
batch_size = 1
175178
if quant_nontext_module and batch_size != 1:
176-
logger.warning(f"batch_size({batch_size}) cannot be used for calibrating non-text modules,"
177-
"reset to 1")
179+
logger.warning(
180+
f"rest batch_size({batch_size}) to 1 and "
181+
f"gradient_accumulate_steps({gradient_accumulate_steps}) "
182+
f"to {batch_size * gradient_accumulate_steps}, "
183+
f"cause batch_size={batch_size} cannot be used for calibrating non-text modules.")
178184
gradient_accumulate_steps = batch_size * gradient_accumulate_steps
179185
batch_size = 1
180186
seqlen = 2048 if seqlen is None else seqlen
181187
truncation = True if truncation is None else truncation
182188
self.truncation = truncation
183189

190+
if nsamples % batch_size != 0:
191+
nsamples = (nsamples // batch_size + 1) * batch_size
192+
logger.warning(f"'nsamples' is not divisible by 'batch_size', will adjusted to {nsamples}")
193+
184194
super(AutoRoundMLLM, self).__init__(
185195
model=model,
186196
tokenizer=tokenizer,
@@ -259,7 +269,7 @@ def calib(self, nsamples, bs):
259269
m = m.to(self.device)
260270

261271
total = nsamples if not hasattr(self.dataloader, "len") else min(nsamples, len(self.dataloader))
262-
with tqdm(range(1, total + 1), desc="calib") as pbar:
272+
with tqdm(range(1, total + 1), desc="cache block inputs") as pbar:
263273
for data in self.dataloader:
264274
if data is None:
265275
pbar.update(1)
@@ -337,7 +347,7 @@ def calib(self, nsamples, bs):
337347
exit(-1)
338348
elif total_cnt < nsamples:
339349
logger.warning(
340-
f"Insufficient number of samples collected may affect the quantification. "
350+
f"Insufficient number of samples collected may affect the quantization. "
341351
f"target samples count is {nsamples}, while valid samples count is {total_cnt}"
342352
)
343353
if total_cnt < self.batch_size:

auto_round/script/llm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def setup_parser():
156156
parser.add_argument("--group_size", default=128, type=int,
157157
help="group size")
158158

159-
parser.add_argument("--batch_size", "--train_bs", default=8, type=int,
159+
parser.add_argument("--batch_size", "--train_bs", "--bs", default=8, type=int,
160160
help="train batch size")
161161

162162
parser.add_argument("--iters", "--iter", default=200, type=int,
@@ -178,7 +178,7 @@ def setup_best_parser():
178178
parser.add_argument("--group_size", default=128, type=int,
179179
help="group size")
180180

181-
parser.add_argument("--batch_size", "--train_bs", default=8, type=int,
181+
parser.add_argument("--batch_size", "--train_bs", "--bs", default=8, type=int,
182182
help="train batch size")
183183

184184
parser.add_argument("--iters", "--iter", default=1000, type=int,
@@ -202,7 +202,7 @@ def setup_fast_parser():
202202
parser.add_argument("--group_size", default=128, type=int,
203203
help="group size")
204204

205-
parser.add_argument("--batch_size", "--train_bs", default=4, type=int,
205+
parser.add_argument("--batch_size", "--train_bs", "--bs", default=4, type=int,
206206
help="train batch size")
207207

208208
parser.add_argument("--iters", default=200, type=int,

auto_round/script/mllm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def setup_parser():
170170
parser.add_argument("--group_size", default=128, type=int,
171171
help="group size")
172172

173-
parser.add_argument("--batch_size", "--train_bs", default=8, type=int,
173+
parser.add_argument("--batch_size", "--train_bs", "--bs", default=8, type=int,
174174
help="train batch size")
175175

176176
parser.add_argument("--iters", "--iter", default=200, type=int,
@@ -450,6 +450,7 @@ def setup_lmms_parser():
450450
)
451451
parser.add_argument(
452452
"--batch_size",
453+
"--bs",
453454
"-b",
454455
type=str,
455456
default=1,

examples/language-modeling/main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
parser.add_argument("--group_size", default=128, type=int,
2121
help="group size")
2222

23-
parser.add_argument("--batch_size", "--train_bs", default=8, type=int,
23+
parser.add_argument("--batch_size", "--bs", "--train_bs", default=8, type=int,
2424
help="train batch size")
2525

2626
parser.add_argument("--eval_bs", default=None, type=int,
@@ -136,7 +136,7 @@
136136
args = parser.parse_args()
137137

138138
print(
139-
"Warning, examples/language-modeling/main.py is deprecated, please use auto-round cmd line instead. The file will be deleted in the V0.4.1 release ")
139+
"Warning, examples/language-modeling/main.py is deprecated, please use auto-round cmd line instead. The file will be deleted in the V0.4.2 release ")
140140

141141
if args.enable_minmax_tuning:
142142
print(

examples/multimodal-modeling/Common_model/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat
277277
parser.add_argument("--group_size", default=128, type=int,
278278
help="group size")
279279

280-
parser.add_argument("--batch_size", default=1, type=int,
280+
parser.add_argument("--batch_size", "--bs", default=1, type=int,
281281
help="train batch size")
282282

283283
parser.add_argument("--eval_bs", default=4, type=int,

examples/multimodal-modeling/Llava/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def save_tower(model, save_path, quant_nontext_module: bool = False, max_shard_s
113113
parser.add_argument("--group_size", default=128, type=int,
114114
help="group size")
115115

116-
parser.add_argument("--batch_size", default=1, type=int,
116+
parser.add_argument("--batch_size", "--bs", default=1, type=int,
117117
help="train batch size")
118118

119119
parser.add_argument("--eval_bs", default=4, type=int,

examples/multimodal-modeling/Phi-3-vision/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def create_data_loader(dataset, batch_size=1, data_collator=None):
167167
parser.add_argument("--group_size", default=128, type=int,
168168
help="group size")
169169

170-
parser.add_argument("--batch_size", default=1, type=int,
170+
parser.add_argument("--batch_size", "--bs", default=1, type=int,
171171
help="train batch size")
172172

173173
parser.add_argument("--eval_bs", default=4, type=int,

0 commit comments

Comments
 (0)