Skip to content

Support LSTMCell layer from PyTorch #1344

@dzhelonkin

Description

@dzhelonkin

🌱 Describe your Feature Request

Adding conversion support from PyTorch not only for LSTM, but also for LSTMCell. At least unsafe_chunk op support is required.

Use cases

LSTMCell is needed for streaming(online) processing on device . For example it will be useful for speech processing or any other time series.

Describe alternatives you've considered

Conversion PyTorch->ONNX->CoreML is working, but ONNX->CoreML will be deprecated.

Additional context

Working LSTM conversion

import torch
from torch import nn
import coremltools as ct

lstm = nn.LSTM(3, 3)
lstm.eval()
inputs = torch.randn(12, 1, 3)
traced_model = torch.jit.trace(lstm, inputs)

ct.convert(
    model=traced_model,
    inputs=[
        ct.TensorType(name="sequence", shape=(ct.RangeDim(1, 50), 1, 3))
    ]
)

Failed conversion of LSTMCell with error PyTorch convert function for op 'unsafe_chunk' not implemented.

import torch
from torch import nn
import coremltools as ct

class OneStep(nn.Module):
    def __init__(self):
        super(OneStep, self).__init__()        
        self.lstmcell = nn.LSTMCell(input_size=3, hidden_size=3)
        
    def forward(self, sequence, hidden_state, cell_state):
        return self.lstmcell(sequence, (hidden_state, cell_state))
    
lstm_cell = OneStep()
lstm_cell.eval()
dummy_input = [torch.zeros((1, 3)), torch.zeros((1, 3)), torch.zeros((1, 3))]
traced_model = torch.jit.trace(lstm_cell, dummy_input)

ct.convert(
    model=traced_model,
    inputs=[
        ct.TensorType(name="sequence", shape=(1, 3)), 
        ct.TensorType(name="hidden_state", shape=(1, 3)), 
        ct.TensorType(name="cell_state", shape=(1, 3)), 
    ]
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    LSTM/RNNPyTorch (traced)missing layer typeUnable to convert a layer type from the relevant frameworktriagedReviewed and examined, release as been assigned if applicable (status)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions