Skip to content

Commit

Permalink
1. DeBERTa v2
Browse files Browse the repository at this point in the history
2. Add DeBERTv2 xlarge, xxlarge and MNLI xlarge-v2, xxlarge-v2 models
3. Fix GLUE data downloading issue.
4. Support plugin tasks
5. Update experiments
  • Loading branch information
BigBird01 committed Feb 9, 2021
1 parent d9e01c6 commit 839e3b4
Show file tree
Hide file tree
Showing 72 changed files with 2,077 additions and 2,095 deletions.
4 changes: 3 additions & 1 deletion DeBERTa/apps/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .task_registry import tasks
import os
# This statement must be executed at the very beginning, i.e. before import torch
os.environ["OMP_NUM_THREADS"] = "1"
3 changes: 3 additions & 0 deletions DeBERTa/apps/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .ner import *
from .multi_choice import *
from .sequence_classification import *
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
from torch.nn import CrossEntropyLoss
import math

from ..deberta import *
from ..utils import *
from ...deberta import *
from ...utils import *
import pdb

__all__ = ['MultiChoiceModel']
class MultiChoiceModel(NNModule):
def __init__(self, config, num_labels = 2, drop_out=None, **kwargs):
super().__init__(config)
self.bert = DeBERTa(config)
self.deberta = DeBERTa(config)
self.num_labels = num_labels
self.classifier = nn.Linear(config.hidden_size, 1)
self.classifier = torch.nn.Linear(config.hidden_size, 1)
drop_out = config.hidden_dropout_prob if drop_out is None else drop_out
self.dropout = StableDropout(drop_out)
self.apply(self.init_weights)
Expand All @@ -39,7 +39,7 @@ def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, positi
position_ids = position_ids.view([-1, position_ids.size(-1)])
if input_mask is not None:
input_mask = input_mask.view([-1, input_mask.size(-1)])
encoder_layers = self.bert(input_ids, token_type_ids=type_ids, attention_mask=input_mask,
encoder_layers = self.deberta(input_ids, token_type_ids=type_ids, attention_mask=input_mask,
position_ids=position_ids, output_all_encoded_layers=True)
seqout = encoder_layers[-1]
cls = seqout[:,:1,:]
Expand Down
2 changes: 1 addition & 1 deletion DeBERTa/apps/ner.py → DeBERTa/apps/models/ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import math
from torch import nn
from torch.nn import CrossEntropyLoss
from ..deberta import DeBERTa,NNModule,ACT2FN,StableDropout
from ...deberta import DeBERTa,NNModule,ACT2FN,StableDropout

__all__ = ['NERModel']

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,35 @@
import torch
from torch.nn import CrossEntropyLoss
import math
import pdb

from ..deberta import *
from ..utils import *
from ...deberta import *
from ...utils import *

__all__= ['SequenceClassificationModel']
class SequenceClassificationModel(NNModule):
def __init__(self, config, num_labels=2, drop_out=None, pre_trained=None):
super().__init__(config)
self.num_labels = num_labels
self.bert = DeBERTa(config, pre_trained=pre_trained)
self._register_load_state_dict_pre_hook(self._pre_load_hook)
self.deberta = DeBERTa(config, pre_trained=pre_trained)
if pre_trained is not None:
self.config = self.bert.config
self.config = self.deberta.config
else:
self.config = config
pool_config = PoolConfig(self.config)
output_dim = self.bert.config.hidden_size
output_dim = self.deberta.config.hidden_size
self.pooler = ContextPooler(pool_config)
output_dim = self.pooler.output_dim()

self.classifier = torch.nn.Linear(output_dim, num_labels)
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
self.dropout = StableDropout(drop_out)
self.apply(self.init_weights)
self.bert.apply_state()
self.deberta.apply_state()

def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, position_ids=None, **kwargs):
encoder_layers = self.bert(input_ids, attention_mask=input_mask, token_type_ids=type_ids,
encoder_layers = self.deberta(input_ids, attention_mask=input_mask, token_type_ids=type_ids,
position_ids=position_ids, output_all_encoded_layers=True)
pooled_output = self.pooler(encoder_layers[-1])
pooled_output = self.dropout(pooled_output)
Expand Down Expand Up @@ -69,3 +71,15 @@ def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, positi
loss = -((log_softmax(logits)*labels).sum(-1)*label_confidence).mean()

return (logits,loss)

def _pre_load_hook(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
new_state = dict()
bert_prefix = prefix + 'bert.'
deberta_prefix = prefix + 'deberta.'
for k in list(state_dict.keys()):
if k.startswith(bert_prefix):
nk = deberta_prefix + k[len(bert_prefix):]
value = state_dict[k]
del state_dict[k]
state_dict[nk] = value
Loading

0 comments on commit 839e3b4

Please sign in to comment.