-
Notifications
You must be signed in to change notification settings - Fork 713
Open
Labels
bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)torch.export
Description
🐞Describing the bug
- this bug is quite hard to represent...
- TLDR:
- I create an custom layer, also convert success in coremltools, but show log of warning log on CoreML.framework, like below
To Reproduce
- I wrote a minimal demo to reproduce
import torch
import torch.nn as nn
import torch.nn.functional as F
import coremltools
from collections import OrderedDict
import coremltools.proto.FeatureTypes_pb2 as ft
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.frontend.torch.ops import (
_get_inputs as mil_get_inputs, is_symbolic,_get_scales_from_output_size
)
from coremltools.converters.mil import (
register_torch_op
)
from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op
from coremltools.converters.mil.mil import (
Operation,
types
)
from coremltools.converters.mil.mil.input_type import (
InputSpec,
TensorInputType,
)
@register_torch_op(torch_alias=['grid_sample'], override=True)
def grid_sampler(context, node):
# https://github.com/pytorch/pytorch/blob/00d432a1ed179eff52a9d86a0630f623bf20a37a/aten/src/ATen/native/GridSampler.h#L10-L11
inputs = mil_get_inputs(context, node, expected=5)
x = mb.custom_op(
x=inputs[0],
coordinates=inputs[1],
name=node.name,
)
context.add(x)
@register_op(is_custom_op=True)
class custom_op(Operation):
input_spec = InputSpec(
x=TensorInputType(type_domain="T"),
coordinates=TensorInputType(type_domain="T"),
)
type_domains = {
"T": (types.fp16, types.fp32),
"U": (types.int32,),
}
bindings = {'class_name': 'CustomGridSample',
'input_order': ['coordinates', 'x'],
'description': "custom grid sampler!"
}
def __init__(self, **kwargs):
super(custom_op, self).__init__(**kwargs)
def type_inference(self):
input_shape = self.x.shape
coord_shape = self.coordinates.shape
ret_shape = list(input_shape)
ret_shape[2] = coord_shape[1] # Output height
ret_shape[3] = coord_shape[2] # Output width
return types.tensor(self.x.dtype, ret_shape)
########################################################################
######################## Test ml model ################################
IN_WH = 512
GRID_WH = 256
class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3)
def forward(self, x, grid):
x =F.relu(self.conv1(x))
x = F.grid_sample(x, grid)
x = F.relu(self.conv2(x))
return x
########################################################################
########################################################################
def convert(output_path):
torch_model = TestModel()
# torch_model = torch.jit.load('./flow_480x272_250103.pt', map_location='cpu')
example_input = torch.rand(1, 3, IN_WH, IN_WH)
example_grid = torch.ones(1, GRID_WH, GRID_WH, 2)
# example_input = torch.rand(1, 1, 272, 480)
# traced_model = torch.jit.trace(torch_model, (example_input, example_input))
traced_model = torch.export.export(torch_model, (example_input, example_grid))
mlmodel = coremltools.convert(
traced_model,
inputs=[
coremltools.TensorType(name="input0", shape=example_input.shape),
coremltools.TensorType(name="input1", shape=example_grid.shape),
],
convert_to="neuralnetwork",
# convert_to="milinternal",
# convert_to="mlprogram",
minimum_deployment_target=coremltools.target["iOS13"]
)
print(mlmodel)
mlmodel_path = output_path + ".mlmodel"
mlmodel.save(mlmodel_path)
print(f"Saved to {output_path}")
def main():
convert('test')
if __name__ == "__main__":
main()
using this code can generate an simplest nn net in mlmodel, then loading in objective-c project just the using API
id model = [MLModel modelWithContentsOfURL:modelUrl
error:&error];
will cause this error log dump in console.
I don't know whats wrong on this network infer...Also I cannot judge, this is coremltools bug ?or CoreML framework bug? or some bug in my custom op?
System environment (please complete the following information):
- coremltools version: try 7.2, 8.0,8.1,
- pytorch version: 2.4.0, 2.4.1
- OS : try 14.4, 14.5
Additional context
@YifanShenSZ I'm not sure if there are any bugs in my toy code, but if you have some free time, would you mind reviewing it for me?
Metadata
Metadata
Assignees
Labels
bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)torch.export