Skip to content

Commit b5064d6

Browse files
committed
excutable version, not verified
1 parent 1991274 commit b5064d6

File tree

7 files changed

+372
-55
lines changed

7 files changed

+372
-55
lines changed

agents/bert.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,54 @@
77
from torch import nn
88
from torch.backends import cudnn
99
from torch.autograd import Variable
10+
from torch.utils.data import DataLoader
1011
from tensorboardX import SummaryWriter
1112

1213
from agents.base import BaseAgent
13-
from datasets.bert import BERTDataLoader
14+
from datasets.bert import SentencePairDataset
1415
from graphs.models.bert import BERTModel4Pretrain
1516
from utils.optim import optim4GPU
16-
17+
from utils.tokenization import FullTokenizer
18+
from utils.misc import set_seeds
19+
''
1720
cudnn.benchmark = True
1821

22+
1923
class BERTAgent(BaseAgent):
2024
def __init__(self, config):
2125
super().__init__(config)
2226
self.config = config
23-
27+
set_seeds(self.config.seed)
2428
self.current_epoch = 0
2529
self.global_step = 0
2630
self.best_valid_mean_iou = 0
2731

28-
self.dataloader = BERTDataLoader(self.config)
2932
self.model = BERTModel4Pretrain(self.config)
3033
self.criterion1 = nn.CrossEntropyLoss(reduction='none')
3134
self.criterion2 = nn.CrossEntropyLoss()
3235

3336
self.optimizer = optim4GPU(self.config, self.model)
3437
self.writer = SummaryWriter(log_dir=self.config.log_dir)
3538

39+
tokenizer = FullTokenizer(self.config, do_lower_case=True)
40+
tokenizer.vocab
41+
train_dataset = SentencePairDataset(self.config, tokenizer, 'train')
42+
test_dataset = SentencePairDataset(self.config, tokenizer, 'validate')
43+
44+
45+
a = train_dataset.__getitem__(0)
46+
47+
self.train_dataloader = DataLoader(train_dataset,
48+
batch_size = self.config.batch_size,
49+
num_workers = self.config.data_loader_workers,
50+
pin_memory = self.config.pin_memory
51+
)
52+
53+
self.test_dataloader = DataLoader(test_dataset,
54+
batch_size = self.config.batch_size,
55+
num_workers = self.config.data_loader_workers,
56+
pin_memory = self.config.pin_memory)
57+
3658
def load_checkpoint(self, file_name):
3759
"""
3860
Latest checkpoint loader
@@ -114,12 +136,12 @@ def train_one_epoch(self):
114136
One epoch of training
115137
:return:
116138
"""
117-
iter_bar = tqdm(self.dataloader.train_dataloader,
118-
total=self.dataloader.train_dataset_len,
139+
140+
iter_bar = tqdm(enumerate(self.train_dataloader),
119141
desc="Iter (loss=X.XXX)")
120142

121143
loss_sum = 0. # the sum of iteration losses to get average loss in every epoch
122-
for i, batch in enumerate(iter_bar):
144+
for i, batch in iter_bar:
123145
if self.config.gpu_cpu == 'gpu':
124146
batch = [t.to(self.config.gpu_device) for t in batch]
125147
elif self.config.gpu_cpu == 'cpu':

configs/bert_exp_0.json

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,8 @@
2727
"validate_every": 2,
2828

2929
"train_data_ratio": 0.7,
30-
"test_data_ratio": 0.2,
31-
"validate_data_ratio": 0.1,
32-
3330
"data_loader": "BertDataLoader",
34-
"data_loader_workers": 4,
31+
"data_loader_workers": 0,
3532
"tokenizer": "bpe",
3633

3734
"dim": 768,
@@ -53,6 +50,6 @@
5350
"data_mode": "corpus",
5451
"checkpoint_to_load": "",
5552
"log_dir": "experiments/bert_exp_0/logs",
56-
"data_dir": "./data/sejong_cleaned_nsp.txt",
53+
"data_dir": "./data/sejong_cleaned_nsp_test.txt",
5754
"checkpoint_dir": "./checkpoints/bert/"
5855
}

datasets/bert.py

Lines changed: 6 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ def __getitem__(self, idx):
5959
# candidate positions of masked tokens
6060
cand_pos = [i for i, token in enumerate(tokens)
6161
if token != '[CLS]' and token != '[SEP]']
62-
shuffle(cand_pos)
62+
random.shuffle(cand_pos)
6363
for pos in cand_pos[:n_pred]:
6464
masked_tokens.append(tokens[pos])
6565
masked_pos.append(pos)
66-
if rand() < 0.8: # 80%
66+
if random.random() < 0.8: # 80%
6767
tokens[pos] = '[MASK]'
68-
elif rand() < 0.5: # 10%
68+
elif random.random() < 0.5: # 10%
6969
tokens[pos] = get_random_word(self.vocab)
7070
# when n_pred < max_pred, we only calculate loss within n_pred
7171
masked_weights = [1]*len(masked_tokens)
@@ -88,13 +88,9 @@ def __getitem__(self, idx):
8888
masked_weights.extend([0]*n_pad)
8989

9090
batch = (input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, is_next)
91-
batch_tensors = [torch.tesnor(x, dtype=torch.long) for x in zip(*batch)]
91+
batch_tensors = [torch.tensor(x, dtype=torch.long) for x in batch]
9292
return batch_tensors
9393

94-
95-
96-
97-
9894
def random_sent(self, idx):
9995
t1, t2 = self.get_corpus_line(idx)
10096

@@ -104,38 +100,10 @@ def random_sent(self, idx):
104100
else:
105101
return t1, self.get_random_line(), 0
106102

107-
def get_corpus_line(self):
103+
def get_corpus_line(self, idx):
108104
return self.lines[idx][0], self.lines[idx][1]
109105

110-
def get_random_line(self, ):
106+
def get_random_line(self):
111107
return random.choice(self.lines)[1]
112-
113-
114-
class BERTDataLoader:
115-
def __init__(self, config):
116-
self.config = config
117-
tokenizer = FullTokenizer(self.config, do_lower_case=True)
118-
119-
if self.config.mode == "pretrain":
120-
train_dataset = SentencePairDataset(self.config, tokenizer, 'train')
121-
validate_dataset = SentencePairDataset(self.config, tokenizer, 'validate')
122-
123-
self.train_dataset_len = len(train_dataset)
124-
self.validate_dataset_len = len(validate_dataset)
125-
126-
self.train_dataloader = DataLoader(train_dataset,
127-
batch_size = self.config.batch_size,
128-
num_workers = self.config.data_loader_workers,
129-
pin_memory = self.config.pin_memory)
130-
131-
self.validate_dataloader = DataLoader(validate_dataset,
132-
batch_size = self.config.batch_size,
133-
num_workers = self.config.data_loader_workers,
134-
pin_memory = self.config.pin_memory)
135-
136-
137-
138-
139-
140108

141109

graphs/models/bert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class Embeddings(nn.Module):
3737
"The embedding module from word, position and token_type embeddings."
3838
def __init__(self, config):
3939
super().__init__()
40-
self.tok_embed = nn.Embedding(config.vocab_size, config.dim)
40+
self.tok_embed = nn.Embedding(config.vocab_size+3, config.dim)
4141
self.pos_embed = nn.Embedding(config.max_len, config.dim)
4242
self.seg_embed = nn.Embedding(config.n_segments + 1, config.dim)
4343

@@ -76,11 +76,11 @@ def forward(self, x, mask):
7676
# (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
7777
scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
7878
if mask is not None:
79-
mask = mask[:, None, None, :].flaot()
79+
mask = mask[:, None, None, :].float()
8080
scores -= 10000.0 * (1.0 - mask)
8181
scores = self.drop(F.softmax(scores, dim=-1))
8282
# (B, H, S, S) @ (B, H, S, W) - > (B, H, S, W) -trans -> (B, S, H, W)
83-
h = (scores @ V).transpose(1, 2).contiguous()
83+
h = (scores @ v).transpose(1, 2).contiguous()
8484
# -merge-> (B, S, D)
8585
h = merge_last(h, 2)
8686
self.scores = scores

run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
#export CUDA_DEVICE_ORDER=PCI_BUS_ID
77
#export CUDA_VISIBLE_DEVICES=1
88

9-
python main.py configs/bert_exp_0.json
9+
python main.py configs/bert_exp_0.json

utils/misc.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import time
22
import logging
3+
import random
4+
import numpy as np
5+
import torch
36

47

58
def timeit(f):
@@ -47,4 +50,11 @@ def truncate_tokens_pair(tokens_a, tokens_b, max_len):
4750

4851
def get_random_word(vocab_words):
4952
i = random.randint(0, len(vocab_words)-1)
50-
return vocab_words[i]
53+
return list(vocab_words.keys())[i]
54+
55+
def set_seeds(seed):
56+
"set random seeds"
57+
random.seed(seed)
58+
np.random.seed(seed)
59+
torch.manual_seed(seed)
60+
torch.cuda.manual_seed_all(seed)

0 commit comments

Comments
 (0)