Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: dmlc/gluon-nlp
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 37955e355f2798d928f62929a4761cec4b9ad743
Choose a base ref
..
head repository: dmlc/gluon-nlp
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 65c3047e479ca3927452b26536ca77943afe578d
Choose a head ref
Showing with 5,817 additions and 219 deletions.
  1. +14 −5 .github/workflows/unittests.yml
  2. +16 −0 docs/tutorials/deep_learning_compiler/index.rst
  3. +153 −0 docs/tutorials/deep_learning_compiler/tvm_basic.md
  4. +13 −0 docs/tutorials/index.rst
  5. +2 −0 scripts/classification/train_classification.py
  6. +3 −1 scripts/conversion_toolkits/convert_albert.sh
  7. +7 −5 scripts/conversion_toolkits/convert_bert.sh
  8. +24 −0 scripts/conversion_toolkits/convert_bert_torch.sh
  9. +213 −130 scripts/conversion_toolkits/convert_tf_hub_model.py
  10. +49 −0 scripts/pretraining/convert_electra_pretrain_backbone.py
  11. +52 −10 scripts/pretraining/pretraining_utils.py
  12. +9 −44 scripts/pretraining/run_electra.py
  13. +65 −0 scripts/pretraining/torch/bert/README.md
  14. +266 −0 scripts/pretraining/torch/bert/prepare_quickthought.py
  15. +445 −0 scripts/pretraining/torch/bert/run_pretraining.py
  16. +48 −13 setup.py
  17. +1 −0 src/gluonnlp/__init__.py
  18. +14 −0 src/gluonnlp/base.py
  19. +10 −1 src/gluonnlp/data/loading.py
  20. +19 −4 src/gluonnlp/layers.py
  21. +2 −1 src/gluonnlp/models/bert.py
  22. +6 −0 src/gluonnlp/torch/__init__.py
  23. +543 −0 src/gluonnlp/torch/attention_cell.py
  24. +22 −0 src/gluonnlp/torch/clib/amp_C_frontend.cpp
  25. +5 −0 src/gluonnlp/torch/clib/compat.h
  26. +133 −0 src/gluonnlp/torch/clib/multi_tensor_apply.cuh
  27. +443 −0 src/gluonnlp/torch/clib/multi_tensor_l2norm_kernel.cu
  28. +320 −0 src/gluonnlp/torch/clib/multi_tensor_lans.cu
  29. +207 −0 src/gluonnlp/torch/clib/type_shim.h
  30. +1 −0 src/gluonnlp/torch/data/__init__.py
  31. +528 −0 src/gluonnlp/torch/data/batchify.py
  32. +348 −0 src/gluonnlp/torch/layers.py
  33. +2 −0 src/gluonnlp/torch/models/__init__.py
  34. +573 −0 src/gluonnlp/torch/models/bert.py
  35. +622 −0 src/gluonnlp/torch/models/transformer.py
  36. +4 −0 src/gluonnlp/torch/optimizers/__init__.py
  37. +160 −0 src/gluonnlp/torch/optimizers/fused_lans.py
  38. +54 −0 src/gluonnlp/torch/optimizers/schedules.py
  39. +147 −0 src/gluonnlp/torch/utils.py
  40. +1 −0 src/gluonnlp/utils/__init__.py
  41. +48 −3 src/gluonnlp/utils/misc.py
  42. +54 −0 src/gluonnlp/utils/shm.py
  43. +73 −0 tests/torch/test_attention_cell_torch.py
  44. +81 −0 tests/torch/test_bert_torch.py
  45. +15 −0 tests/torch/test_layers_torch.py
  46. +1 −1 tools/docker/ubuntu18.04-cpu.Dockerfile
  47. +1 −1 tools/docker/ubuntu18.04-gpu.Dockerfile
19 changes: 14 additions & 5 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
@@ -38,18 +38,27 @@ jobs:
restore-keys: |
${{ runner.os }}-ccache
- name: Setup python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
architecture: x64

# Install Linux specific dependencies
- name: Install Linux dependencies
if: matrix.os == 'ubuntu-latest'
# TODO https://github.com/apache/incubator-mxnet/issues/18293
run: |
sudo apt-get install -y libopenblas-dev ninja-build libedit-dev libxml2-dev
python -m pip install "torch==1.7.1+cpu" -f https://download.pytorch.org/whl/torch_stable.html
# Install Mac specific dependencies
- name: Install Linux dependencies
if: matrix.os == 'macos-latest'
run: |
python -m pip install torch==1.7.1
- name: Setup python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
architecture: x64
- name: Install Other Dependencies
run: |
python -m pip install --upgrade pip
16 changes: 16 additions & 0 deletions docs/tutorials/deep_learning_compiler/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Compile NLP Models
==================

.. container:: cards

.. card::
:title: Compile and accelerate NLP models via TVM
:link: tvm_basic.html

Basic example of how to use TVM to compile backbone models in GluonNLP.

.. toctree::
:hidden:
:maxdepth: 2

tvm_basic.ipynb
153 changes: 153 additions & 0 deletions docs/tutorials/deep_learning_compiler/tvm_basic.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Convert GluonNLP Models to TVM


```{.python .input}
import mxnet as mx
import numpy as np
from gluonnlp.models import get_backbone
from gluonnlp.utils.lazy_imports import try_import_tvm
from gluonnlp.data.batchify import Pad, Stack
mx.npx.set_np()
ctx = mx.gpu()
```

## Load the ELECTRA-base


```{.python .input}
import os
model_name = 'google_electra_base'
model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(model_name)
model = model_cls.from_cfg(cfg)
model.hybridize()
model.load_parameters(backbone_param_path, ctx=ctx)
```


```{.python .input}
sentences = ['hello world', 'orbit workbench demo via gluon toolkits']
tokens = tokenizer.encode(sentences, int)
tokens = [[tokenizer.vocab.cls_id] + tokens[0] + [tokenizer.vocab.sep_id],
[tokenizer.vocab.cls_id] + tokens[1] + [tokenizer.vocab.sep_id]]
print(tokens)
```


```{.python .input}
token_ids = Pad()(tokens)
valid_length = Stack()(list(map(len, tokens)))
segment_ids = np.zeros_like(token_ids)
print(token_ids)
print(valid_length)
```


```{.python .input}
contextual_embeddings, cls_embedding = model(mx.np.array(token_ids, ctx=ctx),
mx.np.array(segment_ids, ctx=ctx),
mx.np.array(valid_length, ctx=ctx))
```


```{.python .input}
contextual_embeddings
```


```{.python .input}
cls_embedding
```

## Use TVM for Inference


```{.python .input}
_TVM_RT_CACHE = dict()
def compile_tvm_graph_runtime(model, model_name, cfg,
batch_size, seq_length, dtype, instance_type):
layout = cfg.MODEL.layout
compute_layout = cfg.MODEL.compute_layout
key = (model_name, layout, compute_layout, batch_size, seq_length, dtype, instance_type)
if key in _TVM_RT_CACHE:
return _TVM_RT_CACHE[key]
tvm = try_import_tvm()
from tvm import relay
from tvm.contrib import graph_runtime
from gluonnlp.utils.tvm_utils import get_ec2_tvm_flags, update_tvm_convert_map
flags = get_ec2_tvm_flags()[instance_type]
update_tvm_convert_map()
token_ids_shape = (batch_size, seq_length) if layout == 'NT' else (seq_length, batch_size)
valid_length_shape = (batch_size,)
if 'bart' in model_name:
shape_dict = {
'data0': token_ids_shape,
'data1': valid_length_shape,
'data2': token_ids_shape,
'data3': valid_length_shape,
}
dtype_dict = {
'data0': 'int32',
'data1': 'int32',
'data2': 'int32',
'data3': 'int32',
}
elif 'roberta' in model_name or 'xlmr' in model_name:
shape_dict = {
'data0': token_ids_shape,
'data1': valid_length_shape,
}
dtype_dict = {
'data0': 'int32',
'data1': 'int32',
}
else:
shape_dict = {
'data0': token_ids_shape,
'data1': token_ids_shape,
'data2': valid_length_shape,
}
dtype_dict = {
'data0': 'int32',
'data1': 'int32',
'data2': 'int32'
}
sym = model._cached_graph[1]
params = {}
for k, v in model.collect_params().items():
params[v._var_name] = tvm.nd.array(v.data().asnumpy())
mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict, dtype=dtype_dict, arg_params=params)
target = flags['target']
use_gpu = flags['use_gpu']
opt_level = flags['opt_level']
required_pass = flags['required_pass']
with tvm.transform.PassContext(opt_level=opt_level, required_pass=required_pass):
lib = relay.build(mod, target, params=params)
if use_gpu:
ctx = tvm.gpu()
else:
ctx = tvm.cpu()
rt = graph_runtime.GraphModule(lib["default"](ctx))
_TVM_RT_CACHE[key] = rt
return rt
```


```{.python .input}
rt = compile_tvm_graph_runtime(model, model_name, cfg, token_ids.shape[0],
token_ids.shape[1], 'float32', 'g4')
```


```{.python .input}
rt.set_input(data0=token_ids.asnumpy(), data1=segment_ids.asnumpy(), data2=valid_length.asnumpy())
rt.run()
tvm_contextual_embedding = rt.get_output(0)
tvm_cls_embedding = rt.get_output(1)
```


```{.python .input}
tvm_cls_embedding
```
13 changes: 13 additions & 0 deletions docs/tutorials/index.rst
Original file line number Diff line number Diff line change
@@ -77,6 +77,18 @@ Using Pretrained Models
An example of using pretrained models in GluonNLP.


Compiling NLP Models
--------------------

.. container:: cards

.. card::
:title: Compile and accelerate NLP models via TVM
:link: deep_learning_compiler/tvm_basic.html

Basic example of how to use TVM to compile backbone models in GluonNLP.


.. toctree::
:hidden:
:maxdepth: 2
@@ -86,3 +98,4 @@ Using Pretrained Models
question_answering/index
tokenization/index
pretrained_models/index
deep_learning_compiler/index
2 changes: 2 additions & 0 deletions scripts/classification/train_classification.py
Original file line number Diff line number Diff line change
@@ -238,6 +238,8 @@ def train(args):
#random seed
set_seed(args.seed)
level = logging.INFO
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
detail_dir = os.path.join(args.output_dir, args.task_name)
if not os.path.exists(detail_dir):
os.mkdir(detail_dir)
4 changes: 3 additions & 1 deletion scripts/conversion_toolkits/convert_albert.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
set -ex

python3 -m pip install tensorflow==1.15 --upgrade --user
python3 -m pip install tensorflow_hub --upgrade --user
export TF_FORCE_GPU_ALLOW_GROWTH="true"
for model in base large xlarge xxlarge
do
hub_directory="google_albert_${model}_v2"
mkdir ${hub_directory}
mkdir -p ${hub_directory}
wget "https://tfhub.dev/google/albert_${model}/3?tf-hub-format=compressed" -O "${hub_directory}.tar.gz"
tar -xvf ${hub_directory}.tar.gz --directory ${hub_directory}
python3 convert_tf_hub_model.py --tf_hub_model_path ${hub_directory} --model_type albert --test
12 changes: 7 additions & 5 deletions scripts/conversion_toolkits/convert_bert.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
python3 -m pip install tensorflow==2.3.0 --upgrade --user
set -ex

python3 -m pip install 'tensorflow<3' --upgrade --user
python3 -m pip install tensorflow_hub --upgrade --user
export TF_FORCE_GPU_ALLOW_GROWTH="true"

@@ -8,7 +10,7 @@ do
for case in cased uncased
do
hub_directory="google_en_${case}_bert_${model}"
mkdir ${hub_directory}
mkdir -p ${hub_directory}
if [ ${model} == base ];then
url="https://tfhub.dev/google/bert_${case}_L-12_H-768_A-12/1?tf-hub-format=compressed"
else
@@ -24,7 +26,7 @@ done
# Conversion for Chinese Models
url="https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/2?tf-hub-format=compressed"
hub_directory="google_zh_bert_base"
mkdir ${hub_directory}
mkdir -p ${hub_directory}
wget ${url} -O "${hub_directory}.tar.gz"
tar -xvf ${hub_directory}.tar.gz --directory ${hub_directory}
cp bert_base_config.json ${hub_directory}/assets/
@@ -33,7 +35,7 @@ python3 convert_tf_hub_model.py --tf_hub_model_path ${hub_directory} --model_typ
# Conversion for Multi-lingual Models
url="https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/2?tf-hub-format=compressed"
hub_directory="google_multi_cased_bert_base"
mkdir ${hub_directory}
mkdir -p ${hub_directory}
wget ${url} -O "${hub_directory}.tar.gz"
tar -xvf ${hub_directory}.tar.gz --directory ${hub_directory}
cp bert_base_config.json ${hub_directory}/assets/
@@ -43,7 +45,7 @@ python3 convert_tf_hub_model.py --tf_hub_model_path ${hub_directory} --model_typ
for case in cased uncased
do
hub_directory="google_en_${case}_bert_wwm_large"
mkdir ${hub_directory}
mkdir -p ${hub_directory}
url="https://tfhub.dev/tensorflow/bert_en_wwm_${case}_L-24_H-1024_A-16/2?tf-hub-format=compressed"
wget ${url} -O "${hub_directory}.tar.gz"
tar -xvf ${hub_directory}.tar.gz --directory ${hub_directory}
24 changes: 24 additions & 0 deletions scripts/conversion_toolkits/convert_bert_torch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
set -ex

python3 -m pip install 'tensorflow<3' --upgrade --user
python3 -m pip install tensorflow_hub --upgrade --user
export TF_FORCE_GPU_ALLOW_GROWTH="true"

# Conversion for English Models
for model in base large
do
for case in cased uncased
do
hub_directory="google_en_${case}_bert_${model}"
mkdir -p ${hub_directory}
if [ ${model} == base ];then
url="https://tfhub.dev/google/bert_${case}_L-12_H-768_A-12/1?tf-hub-format=compressed"
else
url="https://tfhub.dev/google/bert_${case}_L-24_H-1024_A-16/1?tf-hub-format=compressed"
fi
wget ${url} -O "${hub_directory}.tar.gz"
tar -xvf ${hub_directory}.tar.gz --directory ${hub_directory}
cp bert_${model}_config.json ${hub_directory}/assets/
python3 convert_tf_hub_model.py --tf_hub_model_path ${hub_directory} --model_type bert --test --torch
done
done
Loading