-
Notifications
You must be signed in to change notification settings - Fork 8
OneFlow ONNX如何支持一个新的OP
oneflow->onnx目前所有支持的op都实现在这个目录
对于每一个op来说,都会被@flow_op装饰器包装起来完成oneflow的op到ONNX op的映射,@flow_op的实现在这里。我们看一下这个装饰器的初始化参数:
def __init__(
self,
name,
onnx_op=None,
domain=constants.ONNX_DOMAIN,
flow_ibns=None,
flow_obns=None,
**kwargs
):
...
这里的name就是oneflow user op的name,onnx_op代表需要把oneflow的user op映射到onnx的op的名字,如果oneflow对应的user op无法直接对应某个onnx op而是对应一系列onnx op构成的子图,那么onnx_op可以不写。domain参数可以不管。flow_ibns和flow_obns可以用来指定输入输出blob的顺序,比如在卷积中设置为["in", "weight"],如果op只有单个输入tensor可以不用设置。
接下来我们看几个转换的例子:
HardSwish的转换:
@flow_op("hardswish", onnx_op="HardSwish")
class HardSwish:
@classmethod
def Version_1(cls, ctx, node, **kwargs):
dtypes = node.output_dtypes
node1 = ctx.MakeNode(
"HardSigmoid", [node.input_tensor_names[0]], op_name_scope=node.name, name="hard_sigmoid", dtypes=dtypes, attr={"alpha": 1.0 / 6}
)
ctx.RemoveNode(node.name)
ctx.MakeNode(
"Mul", [node.input_tensor_names[0], node1.output_tensor_names[0]], outputs=[node.output_tensor_names[0]], op_name_scope=node.name, name="mul"
)
@classmethod
def Version_14(cls, ctx, node, **kwargs):
pass
这里的Version1和Version14和ONNX的Opset Version对应,我们可以在ONNX的官方文档看到每个ONNX Op有哪些版本以及每个版本的语意变化。例如对于下面的Unsqueeze来说,在Version 1时axis是一个Scalar Tensor,而在Version 13时则是Attribute,这是为了兼容各种深度学习框架的补丁。由于ONNX没有HardSwish这个Op,所以我们可以用x * hard_sigmoid(x) 的公式来完成转换,具体的过程就是创建HardSigmoid Op取名node1,然后删除当前的Op并创建一个Mul Op完成node1和输入x的相乘。
Unqueeze的转换:
@flow_op("expand_dims", "Unsqueeze")
class ExpandDimsOp:
@classmethod
def Version_1(cls, ctx, node, **kwargs):
axis = node.attrs.get("axis", None)
axis_node = ctx.MakeConst(
oneflow._oneflow_internal.UniqueStr("axis"), np.array(axis)
)
node.input_tensor_names.append(axis_node.output_tensor_names[0])
@classmethod
def Version_11(cls, ctx, node, **kwargs):
# Opset 11 supports negative axis, but core logic is same
axis = node.attrs.get("axis", None)
node.attrs["axes"] = [axis]
这里的转换就是对应的修改输入或者Attreibute了。
其它的转换都是类似的。支持了新的op之后我们可以在这里 更新op列表。