Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Cannot export Flair model to ONNX #3517

Closed
malteos opened this issue Aug 1, 2024 · 3 comments · Fixed by #3530
Closed

[Bug]: Cannot export Flair model to ONNX #3517

malteos opened this issue Aug 1, 2024 · 3 comments · Fixed by #3530
Labels
bug Something isn't working

Comments

@malteos
Copy link

malteos commented Aug 1, 2024

Describe the bug

Hi all,

thanks for your work on Flair!

I want to export a Flair model to ONNX and follow the corresponding tutorial (apart from a different dataset). However, the export fails with a Rank must be 0 or 1, not None error. I already tried different flair/torch/transformers/onnx versions but nothing fixed this error.

Best,
Malte

To Reproduce

from flair.models import SequenceTagger
from flair.embeddings import TransformerWordEmbeddings, TransformerDocumentEmbeddings

from flair.datasets import NER_GERMAN_LEGAL

model = SequenceTagger.load("ner-large")
assert isinstance(model.embeddings, (TransformerWordEmbeddings, TransformerDocumentEmbeddings))

sentences = list(NER_GERMAN_LEGAL().test)[:5]

model.embeddings = model.embeddings.export_onnx("flert-embeddings.onnx", sentences, providers=["CUDAExecutionProvider", "CPUExecutionProvider"], session_options={})

model.predict(sentences)

Expected behavior

ONNX model is exported without error.

Logs and Stack traces

python convert_flair_to_onnx.py 
2024-08-01 12:07:47,787 SequenceTagger predicts: Dictionary with 20 tags: <unk>, O, S-ORG, S-MISC, B-PER, E-PER, S-LOC, B-ORG, E-ORG, I-PER, S-PER, B-MISC, I-MISC, E-MISC, I-ORG, B-LOC, E-LOC, I-LOC, <START>, <STOP>
2024-08-01 12:07:47,860 Reading data from /Users/my-user/.flair/datasets/ner_german_legal
2024-08-01 12:07:47,860 Train: /Users/my-user/.flair/datasets/ner_german_legal/ler_train.conll
2024-08-01 12:07:47,860 Dev: /Users/my-user/.flair/datasets/ner_german_legal/ler_dev.conll
2024-08-01 12:07:47,860 Test: /Users/my-user/.flair/datasets/ner_german_legal/ler_test.conll
/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/jit/annotations.py:389: UserWarning: TorchScript will treat type annotations of Tensor dtype-specific subtypes as if they are normal Tensors. dtype constraints are not enforced in compilation either.
  warnings.warn(
Traceback (most recent call last):
  File "/Users/my-user/repos/tools-team/ner/performance-testing/convert_flair_to_onnx.py", line 77, in <module>
    model.embeddings = model.embeddings.export_onnx("flert-embeddings.onnx", sentences, providers=["CUDAExecutionProvider", "CPUExecutionProvider"], session_options={})
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/flair/embeddings/transformer.py", line 1497, in export_onnx
    return self.onnx_cls.export_from_embedding(path, self, example_sentences, **kwargs)
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/flair/embeddings/transformer.py", line 868, in export_from_embedding
    torch.onnx.export(
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/onnx/utils.py", line 551, in export
    _export(
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/onnx/utils.py", line 1648, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/onnx/utils.py", line 1174, in _model_to_graph
    graph = _optimize_graph(
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/onnx/utils.py", line 714, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/onnx/utils.py", line 1997, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py", line 6812, in prim_loop
    torch._C._jit_pass_onnx_block(
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/onnx/utils.py", line 1997, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py", line 6904, in prim_if
    torch._C._jit_pass_onnx_block(
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/onnx/utils.py", line 1997, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py", line 6982, in onnx_placeholder
    return torch._C._jit_onnx_convert_pattern_from_subblock(
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/onnx/utils.py", line 1997, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/onnx/symbolic_opset10.py", line 571, in slice
    return symbolic_helper._slice_helper(
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 749, in _slice_helper
    return _slice10(g, input, axes, starts, ends, steps)
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/onnx/symbolic_opset10.py", line 551, in _slice
    ends = to_slice_input(ends, default_value=_constants.INT64_MAX)
  File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/torch/onnx/symbolic_opset10.py", line 530, in to_slice_input
    raise errors.SymbolicValueError(
torch.onnx.errors.SymbolicValueError: Rank must be 0 or 1, not None  [Caused by the value '3159 defined in (%3159 : LongTensor(device=cpu)[] = onnx::Gather[axis=0](%3156, %3158) # /Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/flair/embeddings/transformer.py:101:50
)' (type 'List[Tensor]') in the TorchScript graph. The containing node has kind 'onnx::Gather'.] 
    (node defined in   File "/Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/flair/embeddings/transformer.py", line 101
            end_part = selected_sentences[selected_sentences.shape[0] - 1, max_length - half_stride - 1 :]
            sentence_hidden_state = torch.cat((start_part, mid_part, end_part), dim=0)
            sentence_hidden_states[sentence_id, : sentence_hidden_state.shape[0]] = torch.cat(
                                                  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
                (start_part, mid_part, end_part), dim=0
            )
)

    Inputs:
        #0: 3156 defined in (%3156 : int[] = onnx::Shape(%sentence_hidden_state) # <string>:13:9
    )  (type 'List[int]')
        #1: 3158 defined in (%3158 : Long(device=cpu) = onnx::Constant[value={0}]() # /Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/flair/embeddings/transformer.py:101:50
    )  (type 'Tensor')
    Outputs:
        #0: 3159 defined in (%3159 : LongTensor(device=cpu)[] = onnx::Gather[axis=0](%3156, %3158) # /Users/my-user/miniconda3/envs/triton-server-experimentation/lib/python3.10/site-packages/flair/embeddings/transformer.py:101:50
    )  (type 'List[Tensor]')

Screenshots

No response

Additional Context

No response

Environment

Versions:

Flair

0.14.0

Pytorch

2.4.0

Transformers

4.43.3

GPU

False

pip freeze|grep onnx
onnx==1.16.1
onnxconverter-common==1.13.0
onnxruntime==1.18.1
onnxruntime-tools==1.7.0
@malteos malteos added the bug Something isn't working label Aug 1, 2024
@stefan-it
Copy link
Member

Hi @helpmefindaname could you please have a look at this issue, any help is highly appreciated 😇

@helpmefindaname
Copy link
Collaborator

Hi @malteos @stefan-it

can you check if #3530 fixes the issue for you?

@malteos
Copy link
Author

malteos commented Aug 21, 2024

Thanks, this fixes the issue!

@malteos malteos closed this as completed Aug 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants