Skip to content

Commit 618cdca

Browse files
author
chongjiu.jin
committed
update to transformers 2.3.0
1 parent 0f83d3f commit 618cdca

File tree

3 files changed

+69
-4
lines changed

3 files changed

+69
-4
lines changed

pytorch-bert-code/bert/README.md

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11

22
update to transformer 2.3.0
3+
### 如何将bert model 的Tensorflow模型 转换为pytorch模型
4+
5+
convert_bert_original_tf_checkpoint_to_pytorch.py
36

4-
转换工具已经失效
7+
运行脚本run.sh
8+
9+
后生成对应pytorch_model.bin
510

11+
---
612
chinese bert
713

814
https://github.com/ymcui/Chinese-BERT-wwm/blob/master/README_EN.md
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# coding=utf-8
2+
# Copyright 2018 The HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Convert BERT checkpoint."""
16+
17+
18+
import argparse
19+
import logging
20+
21+
import torch
22+
23+
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
24+
25+
26+
logging.basicConfig(level=logging.INFO)
27+
28+
29+
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
30+
# Initialise PyTorch model
31+
config = BertConfig.from_json_file(bert_config_file)
32+
print("Building PyTorch model from configuration: {}".format(str(config)))
33+
model = BertForPreTraining(config)
34+
35+
# Load weights from tf checkpoint
36+
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
37+
38+
# Save pytorch-model
39+
print("Save PyTorch model to {}".format(pytorch_dump_path))
40+
torch.save(model.state_dict(), pytorch_dump_path)
41+
42+
43+
if __name__ == "__main__":
44+
parser = argparse.ArgumentParser()
45+
# Required parameters
46+
parser.add_argument(
47+
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
48+
)
49+
parser.add_argument(
50+
"--bert_config_file",
51+
default=None,
52+
type=str,
53+
required=True,
54+
help="The config json file corresponding to the pre-trained BERT model. \n"
55+
"This specifies the model architecture.",
56+
)
57+
parser.add_argument(
58+
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
59+
)
60+
args = parser.parse_args()
61+
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

pytorch-bert-code/bert/run.sh

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
1-
export BERT_BASE_DIR=./
2-
3-
transformers bert $BERT_BASE_DIR/bert_model.ckpt $BERT_BASE_DIR/bert_config.json $BERT_BASE_DIR/pytorch_model.bin
1+
python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path bert_model.ckpt --bert_config_file bert_config.json --pytorch_dump_path bert_model.bin

0 commit comments

Comments
 (0)