1
1
# coding: UTF-8
2
2
import torch
3
3
import torch .nn as nn
4
- import torch .nn .functional as F
5
4
# from pytorch_pretrained_bert import BertModel, BertTokenizer
6
- from transformers import BertModel , BertTokenizer
7
-
5
+ from transformers import BertModel , BertTokenizer , BertConfig
6
+ import os
8
7
9
8
class Config (object ):
10
9
11
10
"""配置参数"""
12
11
def __init__ (self , dataset ):
13
12
self .model_name = 'bert'
14
- self .train_path = dataset + '/data/train.txt'
15
- self .dev_path = dataset + '/data/dev.txt'
16
- self .test_path = dataset + '/data/test.txt'
13
+ self .train_path = dataset + '/data/train.txt' # 训练集
14
+ self .dev_path = dataset + '/data/dev.txt' # 验证集
15
+ self .test_path = dataset + '/data/test.txt' # 测试集
17
16
self .class_list = [x .strip () for x in open (
18
- dataset + '/data/class.txt' ).readlines ()]
19
- self .save_path = dataset + '/saved_dict/' + self .model_name + '.ckpt'
20
- self .device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
21
-
22
-
23
- self .num_classes = len (self .class_list )
24
- self .num_epochs = 3
25
- self .batch_size = 128
26
- self .pad_size = 32
27
- self .learning_rate = 5e-5
17
+ dataset + '/data/class.txt' ).readlines ()] # 类别名单
18
+ self .save_path = dataset + '/saved_dict/' + self .model_name + '.ckpt' # 模型训练结果
19
+ self .device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' ) # 设备
20
+
21
+ self . require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练
22
+ self .num_classes = len (self .class_list ) # 类别数
23
+ self .num_epochs = 3 # epoch数
24
+ self .batch_size = 128 # mini-batch大小
25
+ self .pad_size = 32 # 每句话处理成的长度(短填长切)
26
+ self .learning_rate = 5e-5 # 学习率
28
27
self .bert_path = './bert'
29
28
self .tokenizer = BertTokenizer .from_pretrained (self .bert_path )
30
29
self .hidden_size = 768
@@ -34,20 +33,16 @@ class Model(nn.Module):
34
33
35
34
def __init__ (self , config ):
36
35
super (Model , self ).__init__ ()
37
- self .bert = BertModel .from_pretrained (config .bert_path )
36
+ bert_config_file = os .path .join (config .bert_path , f'bert_config.json' )
37
+ bert_config = BertConfig .from_json_file (bert_config_file )
38
+ self .bert = BertModel .from_pretrained (config .bert_path ,config = bert_config )
38
39
for param in self .bert .parameters ():
39
40
param .requires_grad = True
40
41
self .fc = nn .Linear (config .hidden_size , config .num_classes )
41
42
42
-
43
- def forward (self , input_ids ,# 输入的句子
44
- input_mask ,# 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0]
45
- segments_ids
46
- ):
47
- _ , pooled = self .bert (input_ids , attention_mask = input_mask ,token_type_ids = segments_ids )#pooled [batch_size, hidden_size]
43
+ def forward (self , x ):
44
+ context = x [0 ] # 输入的句子
45
+ mask = x [2 ] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0]
46
+ _ , pooled = self .bert (context , attention_mask = mask )
48
47
out = self .fc (pooled )
49
48
return out
50
- def loss (self ,outputs ,labels ):
51
- criterion = F .cross_entropy
52
- loss = criterion (outputs , labels )
53
- return loss
0 commit comments