diff --git a/operator/include/operator/identity.hpp b/operator/include/operator/identity.hpp new file mode 100644 index 0000000..deb5b56 --- /dev/null +++ b/operator/include/operator/identity.hpp @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright (c) 2021, OPEN AI LAB + * Author: ycyang@openailab.com + */ +#ifndef __IDENTITY_HPP__ +#define __IDENTITY_HPP__ + +#include "operator.hpp" + +namespace TEngine { + +class Identity : public OperatorNoParam +{ +public: + Identity() + { + name_ = "Identity"; + } + Identity(const Identity& src) = default; + virtual ~Identity(){}; + + void SetSchema(void) override; +}; + +} // namespace TEngine + +#endif diff --git a/operator/operator/eltwise.cpp b/operator/operator/eltwise.cpp index ade9258..d43b283 100644 --- a/operator/operator/eltwise.cpp +++ b/operator/operator/eltwise.cpp @@ -40,18 +40,44 @@ bool Eltwise::InferShape(const std::vector& ishape, std::vector& return false; } - int i0_size = ishape[0].GetSize(); - int i1_size = ishape[1].GetSize(); + TShape input_shape0 = ishape[0]; + TShape input_shape1 = ishape[1]; + auto& dim0 = input_shape0.GetDim(); + auto& dim1 = input_shape1.GetDim(); - if (i0_size >= i1_size) + int dim_num = dim0.size() >= dim1.size() ? dim0.size():dim1.size(); + std::vector out_dims; + if (dim0.size() >= dim1.size()){ + for (int i=0; i= dim1[i] ? dim0[dim0.size()-dim1.size()+i] : dim1[i]); + } + + } + else{ + for (int i=0; i= dim0[i] ? dim1[dim1.size()-dim0.size()+i] : dim0[i]); + } + } + +/* if (i0_size >= i1_size) { oshape[0] = ishape[0]; } else if (i0_size < i1_size) { oshape[0] = ishape[1]; - } + } */ + TShape shape; + shape.SetDim(out_dims); + shape.SetDataLayout(layout); + oshape[0] = shape; return true; } diff --git a/operator/operator/gather.cpp b/operator/operator/gather.cpp index ec623ca..04b01f1 100644 --- a/operator/operator/gather.cpp +++ b/operator/operator/gather.cpp @@ -8,14 +8,26 @@ namespace TEngine { bool Gather::InferShape(const std::vector& ishape, std::vector& oshape, int layout) { const TShape& input = ishape[0]; + const TShape& input2 = ishape[1]; + std::vector input_dim = input.GetDim(); + std::vector input_dim2 = input2.GetDim(); std::vector output_dim; - + printf ("input2_num: %d\n", input_dim2[0]); if (param_.axis > ( int )input_dim.size()) { return false; } - int indices_size = param_.indices_num; + int indices_size; + if (param_.indices_num != 0){ + + indices_size = param_.indices_num; + + } + else { + indices_size = input_dim2[0]; + } + /* printf("gather input dims: "); for(int i =0; i<(int)input_dim.size(); i++){ diff --git a/operator/operator/identity.cpp b/operator/operator/identity.cpp new file mode 100644 index 0000000..a807104 --- /dev/null +++ b/operator/operator/identity.cpp @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright (c) 2021, OPEN AI LAB + * Author: ycyang@openailab.com + */ +#include "operator/identity.hpp" + +namespace TEngine { + +void Identity::SetSchema(void) +{ + Input({"input:float32"}).Output({"output:float32"}).SetDoc(R"DOC(Identity Operator)DOC"); +} + +} // namespace TEngine diff --git a/operator/plugin/init.cpp b/operator/plugin/init.cpp index 120bffc..e3e709a 100644 --- a/operator/plugin/init.cpp +++ b/operator/plugin/init.cpp @@ -134,6 +134,7 @@ #include "operator/mish.hpp" #include "operator/softplus.hpp" #include "operator/reciprocal.hpp" +#include "operator/identity.hpp" #include "operator/spatialtransformer.hpp" #include "operator/nms.hpp" @@ -251,6 +252,7 @@ int operator_plugin_init(void) RegisterOp("Mish"); RegisterOp("Softplus"); RegisterOp("Reciprocal"); + RegisterOp("Identity"); RegisterOp("SpatialTransformer"); RegisterOp("NMS"); // std::cout<<"OPERATOR PLUGIN INITED\n"; diff --git a/serializer/include/tengine/v2/tm2_format.h b/serializer/include/tengine/v2/tm2_format.h index 5730ccb..881e902 100644 --- a/serializer/include/tengine/v2/tm2_format.h +++ b/serializer/include/tengine/v2/tm2_format.h @@ -150,6 +150,7 @@ typedef uint8_t tm_bool_t; /* bool is 1-byte unsigned integer */ #define TM2_OPSTR_L2NORMALIZATION "L2Normalization" #define TM2_OPSTR_SOFTPLUS "Softplus" #define TM2_OPSTR_RECIPROCAL "Reciprocal" +#define TM2_OPSTR_IDENTITY "Identity" #define TM2_OPSTR_NMS "NMS" #define TM2_OPSTR_SPATIALTRANSFORMER "SpatialTransformer" /* Operator types */ @@ -259,7 +260,9 @@ typedef uint8_t tm_bool_t; /* bool is 1-byte unsigned integer */ #define TM2_OPTYPE_RECIPROCAL 103 #define TM2_OPTYPE_NMS 104 #define TM2_OPTYPE_SPATIALTRANSFORMER 105 -#define TM2_OPTYPE_NUM 106 +#define TM2_OPTYPE_IDENTITY 106 +#define TM2_OPTYPE_NUM 107 + /* --------------------- -------- TM objects -------------------------------- */ diff --git a/serializer/include/tengine/v2/tm2_op_serializer.hpp b/serializer/include/tengine/v2/tm2_op_serializer.hpp index 384d500..fef93cb 100644 --- a/serializer/include/tengine/v2/tm2_op_serializer.hpp +++ b/serializer/include/tengine/v2/tm2_op_serializer.hpp @@ -122,6 +122,7 @@ #include "operator/relu1.hpp" #include "operator/softplus.hpp" #include "operator/reciprocal.hpp" +#include "operator/identity.hpp" #include "operator/nms.hpp" #include "operator/spatialtransformer.hpp" @@ -317,6 +318,7 @@ bool LoadTmLogSoftmaxOp(StaticGraph* graph, StaticNode* node, void* const start_ bool LoadTmReLU1Op(StaticGraph* graph, StaticNode* node, void* const start_ptr, const TM2_Operator* tm_op); bool LoadTmSoftplusOp(StaticGraph* graph, StaticNode* node, void* const start_ptr, const TM2_Operator* tm_op); bool LoadTmReciprocalOp(StaticGraph* graph, StaticNode* node, void* const start_ptr, const TM2_Operator* tm_op); +bool LoadTmIdentityOp(StaticGraph* graph, StaticNode* node, void* const start_ptr, const TM2_Operator* tm_op); bool LoadTmNMSOp(StaticGraph* graph, StaticNode* node, void* const start_ptr, const TM2_Operator* tm_op); bool LoadTmSpatialTransformerOp(StaticGraph* graph, StaticNode* node, void* const start_ptr, const TM2_Operator* tm_op); @@ -424,6 +426,7 @@ tm_uoffset_t SaveTmLogSoftmaxOp(void* const start_ptr, tm_uoffset_t* cur_pos, Op tm_uoffset_t SaveTmReLU1Op(void* const start_ptr, tm_uoffset_t* cur_pos, Operator* op); tm_uoffset_t SaveTmSoftplusOp(void* const start_ptr, tm_uoffset_t* cur_pos, Operator* op); tm_uoffset_t SaveTmReciprocalOp(void* const start_ptr, tm_uoffset_t* cur_pos, Operator* op); +tm_uoffset_t SaveTmIdentityOp(void* const start_ptr, tm_uoffset_t* cur_pos, Operator* op); tm_uoffset_t SaveTmNMSop(void* const start_ptr, tm_uoffset_t* cur_pos, Operator* op); tm_uoffset_t SaveTmSpatialTransformerOp(void* const start_ptr, tm_uoffset_t* cur_pos, Operator* op); diff --git a/serializer/tengine/v2/tm2_op_load.cpp b/serializer/tengine/v2/tm2_op_load.cpp index 502447b..adcbfa9 100644 --- a/serializer/tengine/v2/tm2_op_load.cpp +++ b/serializer/tengine/v2/tm2_op_load.cpp @@ -1659,6 +1659,13 @@ bool LoadTmReciprocalOp(StaticGraph* graph, StaticNode* node, void* const start_ SetNodeOp(node, op); return true; } +bool LoadTmIdentityOp(StaticGraph* graph, StaticNode* node, void* const start_ptr, const TM2_Operator* tm_op) +{ + StaticOp* op = CreateStaticOp(graph, TM2_OPSTR_IDENTITY); + SetNodeOp(node, op); + return true; +} + bool LoadTmNMSOp(StaticGraph* graph, StaticNode* node, void* const start_ptr, const TM2_Operator* tm_op) { const std::string& op_str = TM2_OPSTR_NMS; @@ -1910,6 +1917,8 @@ op_load_t LoadTmOpFunc(uint32_t op_type) return LoadTmSoftplusOp; case TM2_OPTYPE_RECIPROCAL: return LoadTmReciprocalOp; + case TM2_OPTYPE_IDENTITY: + return LoadTmIdentityOp; case TM2_OPTYPE_NMS: return LoadTmNMSOp; case TM2_OPTYPE_SPATIALTRANSFORMER: @@ -2146,6 +2155,8 @@ std::string GetOpStr(uint32_t op_type) return std::string(TM2_OPSTR_SOFTPLUS); case TM2_OPTYPE_RECIPROCAL: return std::string(TM2_OPSTR_RECIPROCAL); + case TM2_OPTYPE_IDENTITY: + return std::string(TM2_OPSTR_IDENTITY); case TM2_OPTYPE_NMS: return std::string(TM2_OPSTR_NMS); case TM2_OPTYPE_SPATIALTRANSFORMER: diff --git a/serializer/tengine/v2/tm2_op_save.cpp b/serializer/tengine/v2/tm2_op_save.cpp index d69b567..25066dc 100644 --- a/serializer/tengine/v2/tm2_op_save.cpp +++ b/serializer/tengine/v2/tm2_op_save.cpp @@ -1725,6 +1725,14 @@ tm_uoffset_t SaveTmReciprocalOp(void* const start_ptr, tm_uoffset_t* cur_pos, Op SetTmOperator(&tm_op, TM2_OPTYPE_RECIPROCAL, TM2_NOT_SET); return WriteTmObject(start_ptr, cur_pos, &tm_op, sizeof(TM2_Operator)); } + +tm_uoffset_t SaveTmIdentityOp(void* const start_ptr, tm_uoffset_t* cur_pos, Operator* op) +{ + TM2_Operator tm_op; + memset(&tm_op, 0, sizeof(TM2_Operator)); + SetTmOperator(&tm_op, TM2_OPTYPE_IDENTITY, TM2_NOT_SET); + return WriteTmObject(start_ptr, cur_pos, &tm_op, sizeof(TM2_Operator)); +} tm_uoffset_t SaveTmNMSOp(void* const start_ptr, tm_uoffset_t* cur_pos, Operator* op) { NMSParam* p = (dynamic_cast(op))->GetParam(); @@ -1982,6 +1990,8 @@ op_save_t SaveTmOpFunc(uint32_t op_type) return SaveTmSoftplusOp; case TM2_OPTYPE_RECIPROCAL: return SaveTmReciprocalOp; + case TM2_OPTYPE_IDENTITY: + return SaveTmIdentityOp; case TM2_OPTYPE_NMS: return SaveTmNMSOp; case TM2_OPTYPE_SPATIALTRANSFORMER: diff --git a/tools/convert_model_to_tm.cpp b/tools/convert_model_to_tm.cpp index 560a7dc..6d6a270 100644 --- a/tools/convert_model_to_tm.cpp +++ b/tools/convert_model_to_tm.cpp @@ -199,6 +199,7 @@ int main(int argc, char* argv[]) std::cout << "prerun failed\n"; return -1; } + //dump_graph(graph); } // Save the tengine model file diff --git a/tools/onnx/onnx_serializer.cpp b/tools/onnx/onnx_serializer.cpp index 4118520..d14c182 100644 --- a/tools/onnx/onnx_serializer.cpp +++ b/tools/onnx/onnx_serializer.cpp @@ -1253,7 +1253,7 @@ static bool LoadOnnxMul(StaticGraph* graph, StaticNode* node, const onnx::NodePr param.type = ELT_PROD; - for(int i = 0; i < onnx_node.input().size(); ++i) +/* for(int i = 0; i < onnx_node.input().size(); ++i) { StaticTensor* tensor = FindTensor(graph, onnx_node.input(i)); std::vector dims = tensor->dims; @@ -1263,7 +1263,7 @@ static bool LoadOnnxMul(StaticGraph* graph, StaticNode* node, const onnx::NodePr new_dims.push_back(1); SetTensorDim(tensor, new_dims); } - } + } */ StaticOp* op = CreateStaticOp(graph, "Eltwise"); @@ -1521,7 +1521,8 @@ static bool LoadOnnxGather(StaticGraph* graph, StaticNode* node, const onnx::Nod { GatherParam param = any_cast(OpManager::GetOpDefParam("Gather")); StaticTensor* indices_tensor = FindTensor(graph, onnx_node.input(1)); - + printf ("input_size: %d\n",onnx_node.input_size()); + printf("gather_tensor_size: %d\n", indices_tensor->dims.size()); for (int k = 0; k < onnx_node.attribute_size(); k++) { const onnx::AttributeProto& attr = onnx_node.attribute(k); @@ -1530,8 +1531,14 @@ static bool LoadOnnxGather(StaticGraph* graph, StaticNode* node, const onnx::Nod param.axis = attr.i(); } } - int64_t* data = ( int64_t* )GetConstTensorBuffer(indices_tensor); - param.indices_num = *data; + if (indices_tensor->dims.size() != 0){ + int64_t* data = ( int64_t* )GetConstTensorBuffer(indices_tensor); + param.indices_num = *data; + } + else { + param.indices_num = 0; + } + param.is_onnx = true; StaticOp* op = CreateStaticOp(graph, "Gather"); @@ -2891,7 +2898,13 @@ static bool LoadOnnxReciprocal(StaticGraph* graph, StaticNode* node, const onnx: return true; } +static bool LoadOnnxIdentity(StaticGraph* graph, StaticNode* node, const onnx::NodeProto& onnx_node) +{ + StaticOp* op = CreateStaticOp(graph, "Identity"); + SetNodeOp(node, op); + return true; +} static bool LoadOnnxResize(StaticGraph* graph, StaticNode* node, const onnx::NodeProto& onnx_node) { StaticOp* op = CreateStaticOp(graph, "Interp"); @@ -3077,6 +3090,7 @@ bool OnnxSerializerRegisterOpLoader(void) p_onnx->RegisterOpLoadMethod("Sqrt", op_load_t(LoadOnnxSqrt)); p_onnx->RegisterOpLoadMethod("Resize", op_load_t(LoadOnnxResize)); p_onnx->RegisterOpLoadMethod("Reciprocal", op_load_t(LoadOnnxReciprocal)); + p_onnx->RegisterOpLoadMethod("Identity", op_load_t(LoadOnnxIdentity)); p_onnx->RegisterOpLoadMethod("InstanceNormalization", op_load_t(LoadOnnxInstanceNormalization)); return true;