-
Notifications
You must be signed in to change notification settings - Fork 713
Open
Labels
LSTM/RNNPyTorch (traced)missing layer typeUnable to convert a layer type from the relevant frameworkUnable to convert a layer type from the relevant frameworktriagedReviewed and examined, release as been assigned if applicable (status)Reviewed and examined, release as been assigned if applicable (status)
Description
🌱 Describe your Feature Request
Adding conversion support from PyTorch not only for LSTM, but also for LSTMCell. At least unsafe_chunk
op support is required.
Use cases
LSTMCell is needed for streaming(online) processing on device . For example it will be useful for speech processing or any other time series.
Describe alternatives you've considered
Conversion PyTorch->ONNX->CoreML is working, but ONNX->CoreML will be deprecated.
Additional context
Working LSTM conversion
import torch
from torch import nn
import coremltools as ct
lstm = nn.LSTM(3, 3)
lstm.eval()
inputs = torch.randn(12, 1, 3)
traced_model = torch.jit.trace(lstm, inputs)
ct.convert(
model=traced_model,
inputs=[
ct.TensorType(name="sequence", shape=(ct.RangeDim(1, 50), 1, 3))
]
)
Failed conversion of LSTMCell with error PyTorch convert function for op 'unsafe_chunk' not implemented.
import torch
from torch import nn
import coremltools as ct
class OneStep(nn.Module):
def __init__(self):
super(OneStep, self).__init__()
self.lstmcell = nn.LSTMCell(input_size=3, hidden_size=3)
def forward(self, sequence, hidden_state, cell_state):
return self.lstmcell(sequence, (hidden_state, cell_state))
lstm_cell = OneStep()
lstm_cell.eval()
dummy_input = [torch.zeros((1, 3)), torch.zeros((1, 3)), torch.zeros((1, 3))]
traced_model = torch.jit.trace(lstm_cell, dummy_input)
ct.convert(
model=traced_model,
inputs=[
ct.TensorType(name="sequence", shape=(1, 3)),
ct.TensorType(name="hidden_state", shape=(1, 3)),
ct.TensorType(name="cell_state", shape=(1, 3)),
]
)
Metadata
Metadata
Assignees
Labels
LSTM/RNNPyTorch (traced)missing layer typeUnable to convert a layer type from the relevant frameworkUnable to convert a layer type from the relevant frameworktriagedReviewed and examined, release as been assigned if applicable (status)Reviewed and examined, release as been assigned if applicable (status)