|
| 1 | +# Topic model 实验 |
| 2 | + |
| 3 | +<center>PB18071477 敖旭扬</center> |
| 4 | + |
| 5 | +## 原理 |
| 6 | + |
| 7 | +`LDA`的概率图模型如下: |
| 8 | + |
| 9 | + |
| 10 | + |
| 11 | +假定数据集中共含 $K$ 个话题和 $D$ 篇文档,词来自含 $V$ 个词的字典,观测数据为 $D$ 篇文档 |
| 12 | +$$ |
| 13 | +\mathcal{D}=\{\boldsymbol{d}_1,\boldsymbol{d}_2,\dots,\boldsymbol{d}_D\} |
| 14 | +$$ |
| 15 | + |
| 16 | +其中每一篇文档 |
| 17 | +$$ |
| 18 | +\boldsymbol{d}_m=(w_{m,1},w_{m,2},\dots,w_{m,N_m}) |
| 19 | +$$ |
| 20 | +是一个单词序列, $w_{m,n}$ 和 $N_m$是第 $m$ 篇文档的第 $n$ 个单词和长度(单词数) |
| 21 | + |
| 22 | +单词集合为 |
| 23 | +$$ |
| 24 | +\mathcal{W}=\{w_1,w_2,\dots,w_V\} |
| 25 | +$$ |
| 26 | + |
| 27 | +话题数量是提前给定的,话题集合为 |
| 28 | +$$ |
| 29 | +\mathcal{Z}=\{z_1,z_2,\dots,z_K\} |
| 30 | +$$ |
| 31 | + |
| 32 | +`LDA`模型假设 |
| 33 | +$$ |
| 34 | +p(z|\boldsymbol{d}_m)\sim Mult(\boldsymbol{\theta}_m) \tag{1} |
| 35 | +$$ |
| 36 | +$$ |
| 37 | +p(w|z_j)\sim Mult(\boldsymbol{\varphi}_j) \tag{2} |
| 38 | +$$ |
| 39 | +其中 $\boldsymbol{\theta}_m$ 是所有 $K$ 个主题在文档 $\boldsymbol{d}_m$ 中出现的概率, $\boldsymbol{\varphi}_j$ 是所有 $V$ 个单词在主题 $z_j$ 中出现的概率 |
| 40 | + |
| 41 | +同时假设 |
| 42 | +$$ |
| 43 | +\boldsymbol{\theta}_m \sim Dir(\boldsymbol{\alpha}) \tag{3} |
| 44 | +$$ |
| 45 | +$$ |
| 46 | +\boldsymbol{\varphi}_j \sim Dir(\boldsymbol{\beta}) \tag{4} |
| 47 | +$$ |
| 48 | +其中 $\boldsymbol{\alpha} \in [0,1]^K$ 和 $\boldsymbol{\beta} \in [0,1]^V$ 是超参数,提前给定 |
| 49 | + |
| 50 | +生成文档 $\boldsymbol{d}_m$ 的过程: |
| 51 | + |
| 52 | +* 从以 $\boldsymbol{\alpha}$ 为参数的狄利克雷分布中随机采样一个话题分布 $\boldsymbol{\theta}_m$ |
| 53 | +* 根据 $\boldsymbol{\theta}_m$ 进行话题指派,得到文档 $\boldsymbol{d}_m$ 中第 $n$ 词的话题 $z_{m,n}$ |
| 54 | +* 根据指派的话题 $z_{m,n}$ 所对应的的词分布随 $\boldsymbol{\varphi}_j$ 机采样生成词 $w_{m,n}$ |
| 55 | + |
| 56 | +求解模型时有 |
| 57 | +$$ |
| 58 | +\sum_{k=1}^{K}\theta_{mk}=1,\quad \sum_{v=1}^{V}\varphi_{jv}=1 |
| 59 | +$$ |
| 60 | +最后可得 |
| 61 | + |
| 62 | +$$ |
| 63 | +\theta_{mk}=\dfrac{\sigma_{jk}+\alpha_k}{\sum_{i}^{K}\sigma_{ji}+\alpha_i} \tag{5} |
| 64 | +$$ |
| 65 | +$$ |
| 66 | +\varphi_{kv}=\dfrac{\delta_{kv}+\beta_v}{\sum_{r}^{V}\delta_{kr}+\beta_r} \tag{6} |
| 67 | +$$ |
| 68 | +其中 $\sigma_{jk}$ 是第 $m$ 个文档第 $k$ 个主题的单词个数, $\delta_{kv}$ 是第 $k$ 个主题的第 $v$ 个单词个数 |
| 69 | + |
| 70 | +训练过程: |
| 71 | + |
| 72 | +* 选择 $\boldsymbol{\alpha}$ 和 $\boldsymbol{\beta}$ 为全 $1$ 向量, $z$ 随机初始化 |
| 73 | +* 对于每篇文档 $d$ 的每个单词 $w$ ,重复吉布斯采样更新主题编号 $z_{d,w}$ |
| 74 | + |
| 75 | +## 编程实现 |
| 76 | + |
| 77 | +矩阵运算使用`python`的`numpy`库实现,部分文本预处理使用了`nltk`库的功能。 |
| 78 | + |
| 79 | +训练主要算法如下: |
| 80 | + |
| 81 | +```python |
| 82 | +class LDA: |
| 83 | + def gibbs_sampling(self, epoch=100): |
| 84 | + print("吉布斯采样......") |
| 85 | + for _ in tqdm.tqdm(range(epoch)): |
| 86 | + for m in range(self.D): |
| 87 | + for v in range(len(self.Dset[m])): |
| 88 | + self.z[m][v] = self.topic_updated(m, v) |
| 89 | + |
| 90 | + def topic_updated(self, m, v): |
| 91 | + topic_old = int(self.z[m][v]) |
| 92 | + self.delta[topic_old][self.Vset_to_index[self.Dset[m][v]]] -= 1 |
| 93 | + self.n_k[topic_old] -= 1 |
| 94 | + self.sigma[m][topic_old] -= 1 |
| 95 | + self.n_m[m] -= 1 |
| 96 | + p = np.zeros(self.K) |
| 97 | + for k in range(self.K): |
| 98 | + p[k] = (self.sigma[m][k] + self.alpha[k]) / \ |
| 99 | + (self.n_m[m] + np.sum(self.alpha)) * \ |
| 100 | + (self.delta[k][self.Vset_to_index[self.Dset[m][v]]] + |
| 101 | + self.beta[self.Vset_to_index[self.Dset[m][v]]]) / \ |
| 102 | + (self.n_k[k] + np.sum(self.beta)) |
| 103 | + p = p / np.sum(p) |
| 104 | + topic_new = np.argmax(np.random.multinomial(1, p)) |
| 105 | + self.delta[topic_new][self.Vset_to_index[self.Dset[m][v]]] += 1 |
| 106 | + self.n_k[topic_new] += 1 |
| 107 | + self.sigma[m][topic_new] += 1 |
| 108 | + self.n_m[m] += 1 |
| 109 | + return topic_new |
| 110 | + |
| 111 | + def cal_theta_varphi(self): |
| 112 | + for j in range(self.D): |
| 113 | + for k in range(self.K): |
| 114 | + self.theta[j][k] = ( |
| 115 | + self.sigma[j][k] + self.alpha[k]) / \ |
| 116 | + (self.n_m[j] + np.sum(self.alpha)) |
| 117 | + for k in range(self.K): |
| 118 | + for v in range(self.V): |
| 119 | + self.varphi[k][v] = ( |
| 120 | + self.delta[k][v] + self.beta[v]) / \ |
| 121 | + (self.n_k[k] + np.sum(self.beta)) |
| 122 | + |
| 123 | + def train(self, epoch): |
| 124 | + for m in range(self.D): |
| 125 | + self.n_m[m] = len(self.Dset[m]) |
| 126 | + for v in range(len(self.Dset[m])): |
| 127 | + topic = int(np.random.randint(0, self.K)) |
| 128 | + self.z[m][v] = topic |
| 129 | + self.delta[topic][self.Vset_to_index[self.Dset[m][v]]] += 1 |
| 130 | + self.n_k[topic] += 1 |
| 131 | + self.sigma[m][topic] += 1 |
| 132 | + self.gibbs_sampling(epoch) |
| 133 | + self.cal_theta_varphi() |
| 134 | +``` |
| 135 | + |
| 136 | +完整实验源码见压缩包中的[LDA.py](LDA.py)。 |
| 137 | + |
| 138 | +## 运算结果 |
| 139 | + |
| 140 | +### 实例 |
| 141 | + |
| 142 | + |
| 143 | +在主函数中调用下面的实例,其中 $\boldsymbol{\alpha}$ 和 $\boldsymbol{\beta}$ 为全 $1$ 向量 |
| 144 | + |
| 145 | +```python |
| 146 | +texts = np.load("./data/text.npy") |
| 147 | +lda = LDA(texts, K=20, cold_count=8) |
| 148 | + |
| 149 | +# 训练模型 |
| 150 | +lda.train(60) |
| 151 | + |
| 152 | +# 输出20个主题的top10的词 |
| 153 | +top10words = lda.top_words(10) |
| 154 | +for i in range(lda.K): |
| 155 | + print(top10words[i]) |
| 156 | + np.savetxt("top10words.txt", top10words, '%s', delimiter=',') |
| 157 | +``` |
| 158 | + |
| 159 | +某一次运行后文件`top10words.txt`中为 |
| 160 | + |
| 161 | +```text |
| 162 | +use,window,help,one,seem,two,problem,way,order,would |
| 163 | +use,weapon,section,firearm,military,person,mean,shall,carry,license |
| 164 | +would,make,think,say,good,people,god,like,know,one |
| 165 | +copy,magi,new,issue,vote,would,cover,old,must,term |
| 166 | +year,get,bike,one,food,billion,see,would,insurance,much |
| 167 | +hiv,aid,disease,health,care,say,year,child,find,new |
| 168 | +people,say,one,come,get,well,take,could,like,time |
| 169 | +file,send,use,support,include,system,image,also,mail,graphic |
| 170 | +please,point,anyone,well,post,really,every,know,email,time |
| 171 | +israel,would,state,israeli,attack,lebanese,arab,time,true,jew |
| 172 | +earth,space,launch,probe,program,would,mission,orbit,titan,year |
| 173 | +find,point,sphere,level,new,plane,think,normal,sure,define |
| 174 | +key,use,government,chip,one,system,encryption,need,get,make |
| 175 | +drive,problem,hard,make,work,build,year,even,disk,space |
| 176 | +game,play,good,goal,team,win,fan,get,last,blue |
| 177 | +one,church,jesus,god,people,also,believe,christ,say,fact |
| 178 | +greek,turkish,armenian,population,turk,show,jew,cyprus,jewish,child |
| 179 | +card,monitor,thanks,use,need,cache,switch,color,port,anyone |
| 180 | +get,like,car,know,would,thing,one,good,think,problem |
| 181 | +space,power,use,data,april,option,science,flight,test,system |
| 182 | +
|
| 183 | +``` |
| 184 | + |
| 185 | +### 训练结果 |
| 186 | + |
| 187 | +<center> |
| 188 | + <img style="border-radius: 0.3125em; |
| 189 | + box-shadow: 0 2px 4px 0 rgba(34,36,38,.12),0 2px 10px 0 rgba(34,36,38,.08);" |
| 190 | + src="img\result.png"> |
| 191 | + <br> |
| 192 | + <div style="color:orange; border-bottom: 1px solid #d9d9d9; |
| 193 | + display: inline-block; |
| 194 | + color: #999; |
| 195 | + padding: 2px;">主题top10关键词</div> |
| 196 | +</center> |
| 197 | + |
| 198 | + |
0 commit comments