Skip to content

Commit 4b0871e

Browse files
committed
Release the code.
1 parent e1d0092 commit 4b0871e

File tree

155 files changed

+20110
-1
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

155 files changed

+20110
-1
lines changed

README.md

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,58 @@ This is the official PyTorch implementation of [QLLM: Accurate and Efficient Low
77

88
By [Jing Liu](https://jing-liu.com/), [Ruihao Gong](https://xhplus.github.io/), [Xiuying Wei](https://wimh966.github.io/), [Zhiwei Dong](https://zwdong.com.cn/), [Jianfei Cai](https://jianfei-cai.github.io/), and [Bohan Zhuang](https://bohanzhuang.github.io/).
99

10-
Code will be released soon!
10+
![qllm](imgs/qllm.png)
11+
12+
we propose QLLM, an accurate and efficient low-bitwidth post-training quantization method designed for LLMs.
13+
14+
## 📰 News
15+
- [10-03-2024] Release the code!🌟
16+
- [17-01-2024] QLLM is accepted by ICLR 2024! 👏
17+
18+
## 📖 Contents
19+
- [Install](#🛠-install)
20+
- [Usage](#⚙️-usage)
21+
- [Results](#📋-results)
22+
- [Citation](#📝-citation)
23+
- [License](#🧾-license)
24+
- [Acknowledgement](#🙏-acknowledgement)
25+
26+
## 🛠 Install
27+
```
28+
conda create -n qllm python=3.10 -y
29+
conda activate qllm
30+
git clone https://github.com/ModelTC/QLLM
31+
cd QLLM
32+
pip install --upgrade pip
33+
pip install -e .
34+
```
35+
36+
## ⚙️ Usage
37+
We provide the training scripts in `scripts` folder. For example, to perform W4A8 quantization for LLaMA-7B, run
38+
```
39+
sh scripts/llama-7b/w4a4.sh
40+
```
41+
Remember to change the path of model `model` and output path `output_dir`.
42+
43+
## 📋 Results
44+
* QLLM achieve SoTA performance in weight-activation quantization
45+
46+
![weight_activation_llama_1](imgs/llama_1_results.png)
47+
![weight_activation_llama_2](imgs/llama_2_results.png)
48+
49+
## 📝 Citation
50+
If you use our `QLLM`` useful in your research, please consider to cite the following related papers:
51+
```
52+
@inproceedings{liu2024qllm,
53+
title = {{QLLM}: Accurate and Efficient Low-Bitwidth Quantization for Large Language Models},
54+
author = {Liu, Jing and Gong, Ruihao and Wei, Xiuying and Dong, Zhiwei and Cai, Jianfei and Zhuang, Bohan},
55+
booktitle = {International Conference on Learning Representations (ICLR)},
56+
year = {2024},
57+
}
58+
```
59+
60+
## 🧾 License
61+
This repository is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file.
62+
63+
## 🙏 Acknowledgement
64+
This repository is built upon [OmniQuant](https://github.com/OpenGVLab/OmniQuant). We thank the authors for their open-sourced code.

assembly/ca_module.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
def do_nothing(x):
6+
return x
7+
8+
9+
def bipartite_soft_matching_x_w(x, w, r, scaling_factors):
10+
# We can only reduce by a maximum of 50% channels
11+
# metric shape: [B * N, C]
12+
b = x.shape[0]
13+
t = x.shape[1]
14+
r = min(r, t // 2)
15+
16+
if r <= 0:
17+
return do_nothing, do_nothing
18+
19+
with torch.no_grad():
20+
# xa, xb shape: [B * N, C/2]
21+
xa, xb = x[..., ::2], x[..., 1::2]
22+
# wa, wb shape: [cout, C/2]
23+
wa, wb = w[..., ::2], w[..., 1::2]
24+
xa_c, xb_c = xa.shape[1], xb.shape[1]
25+
26+
# shape: [C/2, C/2]
27+
# fast version
28+
# xdist = (xa.t().reshape(xa_c, b, 1) - xb.reshape(1, b, xb_c)).sum(1)
29+
# score_ij = wi (yi - yj) / 2 + wj (yj - yi) / 2
30+
# score_a[i, :] = w:i xdist:
31+
32+
xdist = torch.cdist(xa.t(), xb.t(), p=2.0)
33+
scores_a = torch.zeros(xa_c, xb_c, device=x.device)
34+
scores_b = torch.zeros(xa_c, xb_c, device=x.device)
35+
scores_fast = torch.zeros(xa_c, xb_c, device=x.device)
36+
for i in range(xb_c):
37+
scores_a[i, :] = (wa[:, i].unsqueeze(1) * xdist[i]).sum(0)
38+
for j in range(xb_c):
39+
scores_b[:, j] = (wb[:, j].unsqueeze(1) * (xdist[:, j])).sum(0)
40+
scores_fast = (scores_a + scores_b).pow(2)
41+
scores = scores_fast
42+
43+
if scaling_factors is not None:
44+
split_mask = scaling_factors != 1.0
45+
split_mask_a = split_mask[::2]
46+
split_index_a = split_mask_a.nonzero().squeeze()
47+
48+
split_mask_b = split_mask[1::2]
49+
split_index_b = split_mask_b.nonzero().squeeze()
50+
scores.index_fill_(
51+
dim=0, index=split_index_a, value=torch.finfo(scores.dtype).max
52+
)
53+
scores.index_fill_(
54+
dim=1, index=split_index_b, value=torch.finfo(scores.dtype).max
55+
)
56+
57+
# scores_a = torch.zeros(xa_c, xb_c, device=x.device)
58+
# for i in range(xb_c):
59+
# scores_a[i, :] = (wa[:, i].unsqueeze(1) * xa[:, i]).mean(0)
60+
61+
# slow version
62+
# scores = torch.zeros(t // 2, t // 2, device=x.device)
63+
# for i in range(t // 2):
64+
# for j in range(t // 2):
65+
# scores[i, j] = (
66+
# (
67+
# (wa[..., i] * (xa[..., i] - xb[..., j]).sum())
68+
# + (wb[..., j] * (xb[..., j] - xa[..., i]).sum())
69+
# )
70+
# .mean(0)
71+
# .pow(2)
72+
# )
73+
74+
# node max, node_idx shape: [C/2], index of b
75+
# Draw one edge from each token in A to its most similar token in B.
76+
node_min, node_idx = scores.min(dim=-1)
77+
# edge_idx shape: [C/2]
78+
# Keep the r most similar edges. index of a
79+
edge_idx = node_min.argsort(dim=-1, descending=False)
80+
81+
# unm_idx shape: [C/2 -r]
82+
# unm_idx = edge_idx[r:] # Unassembled Channels
83+
# src_idx shape: [r]
84+
src_idx = edge_idx[:r] # Assembled Channels
85+
dst_idx = node_idx[src_idx]
86+
return src_idx, dst_idx, scores[src_idx, dst_idx]
87+
88+
89+
def assembly(x, src_idx, dst_idx, r, mode="mean") -> torch.Tensor:
90+
# shape of src dst: [B, N, C]
91+
B, N, C = x.shape
92+
93+
ori_src_idx = torch.arange(0, C, 2, device=x.device)
94+
ori_dst_idx = torch.arange(1, C, 2, device=x.device)
95+
src, dst = x[..., ori_src_idx], x[..., ori_dst_idx]
96+
src_C = src.shape[-1]
97+
dst_C = dst.shape[-1]
98+
99+
# we set mask to 0 when channel is assembled
100+
channel_mask = torch.ones(C, device=x.device, dtype=x.dtype)
101+
m_idx = ori_src_idx[src_idx]
102+
channel_mask[m_idx] = 0.0
103+
104+
n, t1, c = src.shape
105+
sub_src = src.gather(dim=-1, index=src_idx.expand(n, t1, r))
106+
dst = dst.scatter_reduce(-1, dst_idx.expand(n, t1, r), sub_src, reduce=mode)
107+
src = src.view(B, N, src_C, 1)
108+
dst = dst.view(B, N, dst_C, 1)
109+
if src_C == dst_C:
110+
assembled_x = torch.cat([src, dst], dim=-1).view(B, N, C)
111+
else:
112+
assembled_x = torch.cat([src[..., :-1, :], dst], dim=-1).view(
113+
B, N, src_C + dst_C - 1
114+
)
115+
assembled_x = torch.cat(
116+
[assembled_x, src[..., -1, :].reshape(B, N, 1)], dim=-1
117+
).view(B, N, src_C + dst_C)
118+
assembled_x = assembled_x.index_select(-1, (channel_mask != 0).nonzero().squeeze())
119+
return assembled_x
120+
121+
122+
class CAModule(nn.Module):
123+
def __init__(self, num_assembled_channels):
124+
super().__init__()
125+
self.num_assembled_channels = num_assembled_channels
126+
self.have_assembled = False
127+
self.src_idx = None
128+
self.dst_idx = None
129+
self.num_disassembly = None
130+
self.scaling_factors = None
131+
132+
def find_similar_channels(self, x, fcs):
133+
B, N, C = x.shape
134+
x = x.view(B * N, C)
135+
136+
fc_weight = []
137+
if not isinstance(fcs, list):
138+
fcs = [fcs]
139+
for fc in fcs:
140+
fc_weight.append(fc.weight)
141+
fc_weight = torch.cat(fc_weight, dim=0)
142+
143+
x = x.float()
144+
src_idx, dst_idx, scores = bipartite_soft_matching_x_w(
145+
x, fc_weight, self.num_assembled_channels, self.scaling_factors
146+
)
147+
del self.src_idx
148+
del self.dst_idx
149+
print("Score: {}".format(scores))
150+
self.register_buffer("src_idx", src_idx)
151+
self.register_buffer("dst_idx", dst_idx)
152+
self.have_assembled = True
153+
154+
def forward(self, x):
155+
# only perform assembly after find_similar_channels
156+
if self.have_assembled:
157+
B, N, C = x.shape
158+
# if size is None:
159+
# size = torch.ones_like(x[0, 0])
160+
# size = size.view(1, 1, C)
161+
162+
# x = assembly(
163+
# x * size,
164+
# self.src_idx,
165+
# self.dst_idx,
166+
# self.num_assembled_channels,
167+
# mode="sum",
168+
# )
169+
# size = assembly(
170+
# size,
171+
# self.src_idx,
172+
# self.dst_idx,
173+
# self.num_assembled_channels,
174+
# mode="sum",
175+
# )
176+
# x = x / size
177+
x = assembly(
178+
x,
179+
self.src_idx,
180+
self.dst_idx,
181+
self.num_assembled_channels,
182+
mode="mean",
183+
)
184+
return x

categories.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
subcategories = {
2+
"abstract_algebra": ["math"],
3+
"anatomy": ["health"],
4+
"astronomy": ["physics"],
5+
"business_ethics": ["business"],
6+
"clinical_knowledge": ["health"],
7+
"college_biology": ["biology"],
8+
"college_chemistry": ["chemistry"],
9+
"college_computer_science": ["computer science"],
10+
"college_mathematics": ["math"],
11+
"college_medicine": ["health"],
12+
"college_physics": ["physics"],
13+
"computer_security": ["computer science"],
14+
"conceptual_physics": ["physics"],
15+
"econometrics": ["economics"],
16+
"electrical_engineering": ["engineering"],
17+
"elementary_mathematics": ["math"],
18+
"formal_logic": ["philosophy"],
19+
"global_facts": ["other"],
20+
"high_school_biology": ["biology"],
21+
"high_school_chemistry": ["chemistry"],
22+
"high_school_computer_science": ["computer science"],
23+
"high_school_european_history": ["history"],
24+
"high_school_geography": ["geography"],
25+
"high_school_government_and_politics": ["politics"],
26+
"high_school_macroeconomics": ["economics"],
27+
"high_school_mathematics": ["math"],
28+
"high_school_microeconomics": ["economics"],
29+
"high_school_physics": ["physics"],
30+
"high_school_psychology": ["psychology"],
31+
"high_school_statistics": ["math"],
32+
"high_school_us_history": ["history"],
33+
"high_school_world_history": ["history"],
34+
"human_aging": ["health"],
35+
"human_sexuality": ["culture"],
36+
"international_law": ["law"],
37+
"jurisprudence": ["law"],
38+
"logical_fallacies": ["philosophy"],
39+
"machine_learning": ["computer science"],
40+
"management": ["business"],
41+
"marketing": ["business"],
42+
"medical_genetics": ["health"],
43+
"miscellaneous": ["other"],
44+
"moral_disputes": ["philosophy"],
45+
"moral_scenarios": ["philosophy"],
46+
"nutrition": ["health"],
47+
"philosophy": ["philosophy"],
48+
"prehistory": ["history"],
49+
"professional_accounting": ["other"],
50+
"professional_law": ["law"],
51+
"professional_medicine": ["health"],
52+
"professional_psychology": ["psychology"],
53+
"public_relations": ["politics"],
54+
"security_studies": ["politics"],
55+
"sociology": ["culture"],
56+
"us_foreign_policy": ["politics"],
57+
"virology": ["health"],
58+
"world_religions": ["philosophy"],
59+
}
60+
61+
categories = {
62+
"STEM": [
63+
"physics",
64+
"chemistry",
65+
"biology",
66+
"computer science",
67+
"math",
68+
"engineering",
69+
],
70+
"humanities": ["history", "philosophy", "law"],
71+
"social sciences": ["politics", "culture", "economics", "geography", "psychology"],
72+
"other (business, health, misc.)": ["other", "business", "health"],
73+
}

0 commit comments

Comments
 (0)