Skip to content

Problems to use a converted Py Torch Vanilla NN in a Pipeline when it has hidden layers with different shapes from input layer #2422

@sjordine

Description

@sjordine

🐞Describing the bug

  • I have create a PyTorch very simple vanilla neural network:
class SimpleModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.linear1 = nn.Linear(3, 4)
    #self.activation1 = nn.ReLU()
    #self.linear2 = nn.Linear(4, 3)
    self.linear3 = nn.Linear(4, 1)

  def forward(self, x):
     x = self.linear1(x)
     #x = self.activation1(x)
     #x = self.linear2(x)
     x = self.linear3(x)
     return x

And converted it to a Core ML package using Core ML Tools

input = torch.rand(1,3)
model.eval()
traced_model = torch.jit.trace(model, input)
mlmodel = ct.convert(
    traced_model,
    inputs=[ct.TensorType(name="input", shape=input.shape)],
)
mlmodel.save("Test.mlpackage")

This model can be open in Xcode with no problem

But, if I try to use it in a Pipeline:

pipeline_network = pipeline.Pipeline (
   input_features = [("input",datatypes.Array(1,3))],
   output_features=[("linear_1",datatypes.Array(1,1))]
)
pipeline_network.add_model(mlmodel)
pipeline_spec = pipeline_network.spec
ct.utils.convert_double_to_float_multiarray_type(pipeline_spec)
ct.utils.save_spec(pipeline_spec, "Test-Pipeline.mlpackage")

It does not open on XCode and give me the following error:
Screenshot 2024-12-19 at 4 47 36 PM

This problem does not occur if self.linear1 = nn.Linear(3, 3) and self.linear3 = nn.Linear(3, 1).

To Reproduce

  • Here is my full example:
!pip install torch==2.1.2
!pip install --upgrade coremltools

!rm -Rf /content/Test.mlpackage
!rm -Rf /content/Test-Pipeline.mlpackage

!rm -Rf /content/simple-export.zip
!rm -Rf /content/pipeline-export.zip

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.linear1 = nn.Linear(3, 4) #changing 4 to 3 here and in linear3 make it work
    #self.activation1 = nn.ReLU()
    #self.linear2 = nn.Linear(4, 3)
    self.linear3 = nn.Linear(4, 1) #changing 4 to 3 here and in linear1 make it work

  def forward(self, x):
     x = self.linear1(x)
     #x = self.activation1(x)
     #x = self.linear2(x)
     x = self.linear3(x)
     return x

input = torch.rand(1,3)

model = SimpleModel()

model.eval()

traced_model = torch.jit.trace(model, input)

import coremltools as ct
import numpy as np

mlmodel = ct.convert(
    traced_model,
    inputs=[ct.TensorType(name="input", shape=input.shape)],
    outputs=[ct.TensorType(name="linear_1")],
    compute_precision = ct.precision.FLOAT32,
    compute_units=ct.ComputeUnit.CPU_AND_GPU
    #minimum_deployment_target = ct.target.iOS17,
    #convert_to="mlprogram"
)

mlmodel.save("Test.mlpackage")

!zip -r /content/simple-export.zip /content/Test.mlpackage

from coremltools.models import pipeline
from coremltools.models import datatypes

import coremltools.models as models

pipeline_network = pipeline.Pipeline (
   input_features = [("input",datatypes.Array(1,3))],
   output_features=[("linear_1",datatypes.Array(1,1))]
)

pipeline_network.add_model(mlmodel)

pipeline_spec = pipeline_network.spec

ct.utils.convert_double_to_float_multiarray_type(pipeline_spec)

ct.utils.save_spec(pipeline_spec, "Test-Pipeline.mlpackage")

!zip -r /content/pipeline-export.zip /content/Test-Pipeline.mlpackage

Test.mlpackage will open without problem in Xcode, but Test-Pipeline.mlpackage will present the outputSchema error...

System environment

  • coremltools version: 8.1
  • Colab environment
  • Py torch version: 2.1.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugUnexpected behaviour that should be corrected (type)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions