Skip to content

OneFlow ONNX如何支持一个新的OP

Xiaoyu Zhang edited this page Sep 13, 2022 · 1 revision

oneflow->onnx目前所有支持的op都实现在这个目录

对于每一个op来说,都会被@flow_op装饰器包装起来完成oneflow的op到ONNX op的映射,@flow_op的实现在这里。我们看一下这个装饰器的初始化参数:

def __init__(
        self,
        name,
        onnx_op=None,
        domain=constants.ONNX_DOMAIN,
        flow_ibns=None,
        flow_obns=None,
        **kwargs
    ):
    ...

这里的name就是oneflow user op的name,onnx_op代表需要把oneflow的user op映射到onnx的op的名字,如果oneflow对应的user op无法直接对应某个onnx op而是对应一系列onnx op构成的子图,那么onnx_op可以不写。domain参数可以不管。flow_ibns和flow_obns可以用来指定输入输出blob的顺序,比如在卷积中设置为["in", "weight"],如果op只有单个输入tensor可以不用设置。

接下来我们看几个转换的例子:

HardSwish的转换:

@flow_op("hardswish", onnx_op="HardSwish")
class HardSwish:
    @classmethod
    def Version_1(cls, ctx, node, **kwargs):
        dtypes = node.output_dtypes
        node1 = ctx.MakeNode(
            "HardSigmoid", [node.input_tensor_names[0]], op_name_scope=node.name, name="hard_sigmoid", dtypes=dtypes, attr={"alpha": 1.0 / 6}
        )
        ctx.RemoveNode(node.name)
        ctx.MakeNode(
            "Mul", [node.input_tensor_names[0], node1.output_tensor_names[0]], outputs=[node.output_tensor_names[0]], op_name_scope=node.name, name="mul"
        )
    
    @classmethod
    def Version_14(cls, ctx, node, **kwargs):
        pass

这里的Version1和Version14和ONNX的Opset Version对应,我们可以在ONNX的官方文档看到每个ONNX Op有哪些版本以及每个版本的语意变化。例如对于下面的Unsqueeze来说,在Version 1时axis是一个Scalar Tensor,而在Version 13时则是Attribute,这是为了兼容各种深度学习框架的补丁。由于ONNX没有HardSwish这个Op,所以我们可以用x * hard_sigmoid(x) 的公式来完成转换,具体的过程就是创建HardSigmoid Op取名node1,然后删除当前的Op并创建一个Mul Op完成node1和输入x的相乘。

Unqueeze的转换:

@flow_op("expand_dims", "Unsqueeze")
class ExpandDimsOp:
    @classmethod
    def Version_1(cls, ctx, node, **kwargs):
        axis = node.attrs.get("axis", None)
        
        axis_node = ctx.MakeConst(
            oneflow._oneflow_internal.UniqueStr("axis"), np.array(axis)
        )

        node.input_tensor_names.append(axis_node.output_tensor_names[0])
    
    @classmethod
    def Version_11(cls, ctx, node, **kwargs):
        # Opset 11 supports negative axis, but core logic is same
        axis = node.attrs.get("axis", None)
        node.attrs["axes"] = [axis]

这里的转换就是对应的修改输入或者Attreibute了。

其它的转换都是类似的。支持了新的op之后我们可以在这里 更新op列表。

Clone this wiki locally