From 571f3507cbd7c93c32e6c301ff718a2883045cac Mon Sep 17 00:00:00 2001 From: haitao Date: Tue, 26 Jun 2018 13:56:37 +0800 Subject: [PATCH] 1. update install.md 2. using the NamedParam for operator param definition so that the parameter can be retrievaled by those APIs in applications get_node_param_int()/get_node_param_float()/get_node_param_generic() Former-commit-id: 8659f3b03880fd6786055e67ebef832df2f8dbb3 --- .gitignore | 1 + CMakeLists.txt | 2 + cmake/executor.cmake | 28 +- core/include/operator.hpp | 66 +- core/include/parameter.hpp | 159 +++- core/include/tengine_c_api.h | 46 + core/include/tengine_config.hpp | 4 + core/lib/Makefile | 3 +- core/lib/node.cpp | 24 + core/lib/parameter.cpp | 68 -- core/lib/tengine_c_api.cpp | 42 + core/lib/tengine_config.cpp | 39 +- doc/install.md | 36 +- doc/operator_dev.md | 2 +- examples/yolov2/CMakeLists.txt | 3 + examples/yolov2/yolov2.cpp | 795 +++++++++--------- .../include/operator/batch_norm_param.hpp | 2 +- operator/include/operator/concat_param.hpp | 2 +- operator/include/operator/conv_param.hpp | 4 +- operator/include/operator/deconv_param.hpp | 2 +- .../operator/detection_output_param.hpp | 2 +- operator/include/operator/eltwise.hpp | 2 +- operator/include/operator/eltwise_param.hpp | 2 +- operator/include/operator/fc_param.hpp | 2 +- operator/include/operator/flatten_param.hpp | 2 +- operator/include/operator/lrn_param.hpp | 2 +- operator/include/operator/normalize_param.hpp | 2 +- operator/include/operator/permute_param.hpp | 2 +- operator/include/operator/pool_param.hpp | 2 +- operator/include/operator/pooling.hpp | 2 +- operator/include/operator/priorbox_param.hpp | 2 +- operator/include/operator/region_param.hpp | 4 +- operator/include/operator/relu_param.hpp | 2 +- operator/include/operator/reorg_param.hpp | 2 +- operator/include/operator/reshape_param.hpp | 2 +- operator/include/operator/resize_param.hpp | 2 +- .../include/operator/roi_pooling_param.hpp | 2 +- operator/include/operator/rpn_param.hpp | 2 +- operator/include/operator/scale_param.hpp | 3 +- operator/include/operator/slice_param.hpp | 2 +- operator/include/operator/softmax_param.hpp | 2 +- 41 files changed, 832 insertions(+), 541 deletions(-) delete mode 100644 core/lib/parameter.cpp diff --git a/.gitignore b/.gitignore index da27c9d9d..6b3e3568a 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,4 @@ protobuf/ OpenBLAS/ protobuf_lib/ sysroot/ +android_config.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 2d1af8da6..01b5d9c1b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,6 +34,8 @@ execute_process( COMMAND git rev-parse HEAD STRING(STRIP ${git_commit_id} stripped_commit_id) set(GIT_COMMIT_ID -DGIT_COMMIT_ID="0x${stripped_commit_id}") +message("GIT COMMIT ID: " 0x${stripped_commit_id}) + if (CONFIG_ARCH_ARM64) add_definitions(-DCONFIG_ARCH_ARM64=1) endif() diff --git a/cmake/executor.cmake b/cmake/executor.cmake index 0f9c4357b..64df99d37 100644 --- a/cmake/executor.cmake +++ b/cmake/executor.cmake @@ -1,13 +1,14 @@ include_directories(executor/include executor/operator/include) + +FILE(GLOB_RECURSE COMMON_LIB_CPP_SRCS executor/engine/*.cpp executor/lib/*.cpp executor/plugin/*.cpp) +FILE(GLOB COMMON_CPP_SRCS executor/operator/common/*.cpp executor/operator/common/fused/*.cpp) if(CONFIG_ARCH_BLAS) - FILE(GLOB_RECURSE COMMON_LIB_CPP_SRCS executor/engine/*.cpp executor/lib/*.cpp executor/plugin/*.cpp executor/operator/common/*.cpp ) -else() - FILE(GLOB_RECURSE COMMON_LIB_CPP_SRCS executor/engine/*.cpp executor/lib/*.cpp executor/plugin/*.cpp) - FILE(GLOB COMMON_CPP_SRCS executor/operator/common/*.cpp executor/operator/common/fused/*.cpp) - list(APPEND COMMON_LIB_CPP_SRCS ${COMMON_CPP_SRCS}) + FILE(GLOB COMMON_BLAS_SRCS executor/operator/common/blas/*.cpp) + list(APPEND COMMON_CPP_SRCS ${COMMON_BLAS_SRCS}) endif() list(APPEND TENGINE_LIB_SRCS ${COMMON_LIB_CPP_SRCS}) +list(APPEND TENGINE_LIB_SRCS ${COMMON_CPP_SRCS}) include_directories(driver/cpu) @@ -24,14 +25,25 @@ endif() # Now, handle the .S file if(CONFIG_ARCH_ARM64) - FILE(GLOB_RECURSE ARCH_LIB_CPP_SRCS executor/operator/arm64/*.cpp) + FILE(GLOB_RECURSE ARCH64_LIB_CPP_SRCS executor/operator/arm64/*.cpp) include_directories(executor/operator/arm64/include) - FOREACH(file ${ARCH_LIB_CPP_SRCS}) - list(APPEND TENGINE_LIB_SRCS ${file}) + FOREACH(file ${ARCH64_LIB_CPP_SRCS}) + set(ACL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/executor/operator/arm64/conv/conv_2d_acl") + STRING(REGEX MATCH ${ACL_PREFIX} skip_file2 ${file}) + + if( NOT skip_file2) + list(APPEND ARCH_LIB_CPP_SRCS ${file}) + endif() + endforeach() endif() + +list(APPEND TENGINE_LIB_SRCS ${ARCH_LIB_CPP_SRCS}) + +# Now, handle the .S file + if( CONFIG_ARCH_ARM64) set(src_path executor/operator/arm64) diff --git a/core/include/operator.hpp b/core/include/operator.hpp index ff8bc7df5..ecddd2cef 100644 --- a/core/include/operator.hpp +++ b/core/include/operator.hpp @@ -56,6 +56,9 @@ class Operator: public BaseObject { virtual void ParseParam(void) {}; virtual bool ParamFromStaticOp(StaticOp * s_op) {return true;} + virtual bool GetParamItem(const char * param_name, const std::type_info * type_info, void * val) { return false;} + virtual bool SetParamItem(const char * param_name, const std::type_info * type_info, const void * val) { return false;} + virtual any GetDefParam(void) { return any(); @@ -127,6 +130,7 @@ class Operator: public BaseObject { return *this; } + const std::string& GetDoc(void) const { return doc_;} int GetInputNum(void) const { return inputs_.size();} @@ -233,9 +237,59 @@ class OperatorWithParam: public Operator { } + /* a complicated one, now */ + void ParsePredefinedParam(P& param, Operator * op) + { + auto map=param.GetItemMap(); + + auto ir=map.begin(); + auto end=map.end(); + + while(ir!=end) + { + if(!op->ExistAttr(ir->first)) + { + ir++; + continue; + } + + const any& data=op->GetAttr(ir->first); + + if(param.SetItemFromAny(ir->first,data)) + { + ir++; + continue; + } + + //type mismatch + //possible reason: + // 1. require float, while input is int + // 2. require const char *, while input is std::string + + // otherwise, failed + + const std::type_info & data_type=data.type(); + + if(data_type==typeid(std::string)) + { + const std::string& str=any_cast(data); + + param.SetItemVal(ir->first,&typeid(const char *),str.c_str()); + } + else if(data_type==typeid(int)) + { + float f=(float)any_cast(data); + param.SetItemVal(ir->first,&typeid(float),&f); + } + + ir++; + } + } + + virtual void ParseParam(P& param, Operator * op) { - P::Parse(param,op); + ParsePredefinedParam(param,op); } @@ -255,6 +309,16 @@ class OperatorWithParam: public Operator { return param; } + bool GetParamItem(const char * param_name, const std::type_info * type_info, void * val) override + { + return param_.GetItemVal(param_name,type_info,val); + } + + bool SetParamItem(const char * param_name, const std::type_info * type_info, const void * val) override + { + return param_.SetItemVal(param_name,type_info,val); + } + protected: P param_; diff --git a/core/include/parameter.hpp b/core/include/parameter.hpp index d965010d7..033368103 100644 --- a/core/include/parameter.hpp +++ b/core/include/parameter.hpp @@ -26,46 +26,133 @@ #include +#include -#include "base_object.hpp" +#include "any.hpp" namespace TEngine { - -using entry_parser_t=std::function; - -template -bool ConvertSpecialAny(T& entry, const std::type_info & info, any& data); - - -#define DECLARE_PARSER_STRUCTURE(param) \ - static void Parse(param& param_obj, BaseObject * p_obj)\ - - -#define DECLARE_PARSER_ENTRY(entry) \ - {\ - typedef decltype(param_obj.entry) type0;\ - any& content=(*p_obj)[#entry];\ - if(typeid(type0)== content.type()) \ - param_obj.entry=any_cast(content); \ - else\ - {\ - if(!ConvertSpecialAny(param_obj.entry,content.type(),content))\ - std::cerr<<"cannot parser entry: "<<#entry<(content); \ - else\ - {\ - std::cerr<<"cannot parser entry: "<<#entry<name(),entry.type_info->name()); + return nullptr; + } + + return &entry; + } + + bool GetItemVal(const std::string& name, const std::type_info * val_type, void * val) + { + ItemInfo * entry=FindItem(name,val_type); + + if(entry==nullptr) + return false; + + entry->cpy_func(val,(char *)this+entry->data); + + return true; + } + + bool SetItemVal(const std::string& name, const std::type_info * val_type, const void * val) + { + ItemInfo * entry=FindItem(name,val_type); + + if(entry==nullptr) + return false; + + + entry->cpy_func((char *)this+entry->data,val); + + return true; + } + + bool SetItemCompatibleAny(const std::string& name, const any& n) + { + if(item_map_.count(name)==0) + return false; + + ItemInfo& entry=item_map_.at(name); + const std::type_info * item_type=entry.type_info; + const std::type_info& any_type=n.type(); + + /* several special cases */ + if(*item_type== typeid(const char *) && any_type==typeid(std::string)) + { + const char ** ptr=(const char **)((char *) this+entry.data); + const std::string& str=any_cast(n); + + ptr[0]=str.c_str(); //unsafe, since any may be destroyed soon + + return true; + } + + if(*item_type==typeid(std::string) && any_type==typeid(const char *)) + { + std::string * p_str=(std::string*)((char *)this+entry.data); + const char * ptr=any_cast(n); + + *p_str=ptr; + + return true; + } + + return false; + + } + + bool SetItemFromAny(const std::string& name, const any& n) + { + + ItemInfo * entry=FindItem(name,&n.type()); + + if(entry==nullptr) + return SetItemCompatibleAny(name,n); + + entry->cpy_any((char*)this+entry->data,n); + + return true; + } + + const std::unordered_map & GetItemMap(void) { return item_map_;} + + +protected: + std::unordered_map item_map_; +}; + + +#define DECLARE_PARSER_STRUCTURE(s) \ + s(void) + +#define DECLARE_PARSER_ENTRY(e) \ +{\ + typedef decltype(e) T ;\ + ItemInfo info; \ + info.type_info=&typeid(T);\ + info.data=(char*)&e -(char *)this; \ + info.cpy_func=[](void * data, const void * v){ *(T*)data=*(const T*)v;}; \ + info.cpy_any=[](void * data, const any& n){ *(T*)data=any_cast(n);};\ + item_map_[# e]=info; \ +} diff --git a/core/include/tengine_c_api.h b/core/include/tengine_c_api.h index ad35a57b0..f799b7f2d 100644 --- a/core/include/tengine_c_api.h +++ b/core/include/tengine_c_api.h @@ -533,6 +533,52 @@ const char * get_tensor_name(tensor_t tensor); */ node_t get_graph_node(graph_t graph, const char * node_name); + +/*! +* @brief get the param value (int) of a node +* +* @param node, the target node +* @param param_name, the name of the param to be retrieval +* @param param_val, pointer to the int val to be saved +* +* @return 0, retrieval value successfully; +* <0, failed; probably the name does not exist or the type mismatch +*/ + +int get_node_param_int(node_t node, const char * param_name, int * param_val); + +/*! +* @brief get the param value (float) of a node +* +* @param node, the target node +* @param param_name, the name of the param to be retrieval +* @param param_val, pointer to the float val to be saved +* +* @return 0, retrieval value successfully; +* <0, failed; probably the name does not exist or the type mismatch +*/ + +int get_node_param_float(node_t node, const char * param_name, float * param_val); + +/*! +* @brief get the param value of a node, the data type is indicated by type_info +* this interface only works in c++, as type_info refers std::type_info +* +* @param node, the target node +* @param param_name, the name of the param to be retrieval +* @param type_info, pointer to the std::type_info of wanted type +* @param param_val, pointer to the float val to be saved +* +* @return 0, retrieval value successfully; +* <0, failed; probably the name does not exist or the type mismatch +*/ + +int get_node_param_generic(node_t node, const char * param_name, const void * type_info, void * param_val); + +int set_node_param_int(node_t node, const char * param_name, const int * param_val); +int set_node_param_float(node_t node, const char * param_name, const float * param_val); +int set_node_param_generic(node_t node, const char * param_name, const void * type_info, const void * param_val); + /*! * @brief initialize resource for graph execution * diff --git a/core/include/tengine_config.hpp b/core/include/tengine_config.hpp index cd85fe3a3..d4cb7715a 100644 --- a/core/include/tengine_config.hpp +++ b/core/include/tengine_config.hpp @@ -36,6 +36,10 @@ namespace TEngine { +template +bool ConvertSpecialAny(T& entry, const std::type_info & info, any& data); + + struct TEngineConfig { static bool tengine_mt_mode; // multithread mode diff --git a/core/lib/Makefile b/core/lib/Makefile index f0c90d3cf..7f9c67971 100644 --- a/core/lib/Makefile +++ b/core/lib/Makefile @@ -2,7 +2,6 @@ obj-y+=data_type.o obj-y+=data_layout.o obj-y+=tensor.o obj-y+=tensor_shape.o -obj-y+=parameter.o obj-y+=graph.o obj-y+=node.o obj-y+=static_graph.o @@ -19,6 +18,6 @@ obj-y+=tengine_plugin.o obj-y+=logger/ #graph_executor_CXXFLAGS+=-I../../executor/include -tengine_config_CXXFLAGS+=-DGIT_COMMIT_ID=$(GIT_COMMIT_ID) +tengine_config_CXXFLAGS+=-DGIT_COMMIT_ID=\"$(GIT_COMMIT_ID)\" diff --git a/core/lib/node.cpp b/core/lib/node.cpp index 692be4298..680144730 100644 --- a/core/lib/node.cpp +++ b/core/lib/node.cpp @@ -149,6 +149,30 @@ void Node::MergeAttr(Node * orig) } +int NodeGetParamGeneric(void * node, const char * param_name, const void * type_info, void * param_val) +{ + Node * real_node=(Node *)node; + + Operator *op=real_node->GetOp(); + + if(op->GetParamItem(param_name,(const std::type_info *)type_info,param_val)) + return 0; + else + return -1; + +} +int NodeSetParamGeneric(void * node, const char * param_name, const void * type_info, const void * param_val) +{ + Node * real_node=(Node *)node; + + Operator *op=real_node->GetOp(); + + if(op->SetParamItem(param_name,(const std::type_info *)type_info,param_val)) + return 0; + else + return -1; + +} } //namespace TEngine diff --git a/core/lib/parameter.cpp b/core/lib/parameter.cpp deleted file mode 100644 index 56fec5d7c..000000000 --- a/core/lib/parameter.cpp +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * Copyright (c) 2017, Open AI Lab - * Author: haitao@openailab.com - */ -#include "parameter.hpp" - - -namespace TEngine { - - - - -template<> -bool ConvertSpecialAny(int& entry, const std::type_info& info, any& data) -{ - if(info == typeid(double)) - { - entry=any_cast(data); - return true; - } - - return false; -} - -template<> -bool ConvertSpecialAny(float& entry, const std::type_info& info, any& data) -{ - if(info == typeid(double)) - { - entry=any_cast(data); - return true; - } - - return false; -} - -template<> -bool ConvertSpecialAny(std::string& entry, const std::type_info& info, any& data) -{ - if(info == typeid(const char *)) - { - entry=any_cast(data); - return true; - } - - return false; -} - -} //namespace TEngine diff --git a/core/lib/tengine_c_api.cpp b/core/lib/tengine_c_api.cpp index 86772c955..cbc51a584 100644 --- a/core/lib/tengine_c_api.cpp +++ b/core/lib/tengine_c_api.cpp @@ -767,7 +767,49 @@ void * get_graph_node(graph_t graph, const char * node_name) return executor->FindNode(node_name); } + +int get_node_param_int(node_t node, const char * param_name, int * param_val) +{ + return get_node_param_generic(node,param_name,&typeid(int),param_val); +} + +int get_node_param_float(node_t node, const char * param_name, float * param_val) +{ + return get_node_param_generic(node,param_name,&typeid(float),param_val); +} + +/* a temporary solution: + * Define an intermidate function + * NodeGetParamGeneric(): defined in node.cpp + * + */ +namespace TEngine { + +extern int NodeGetParamGeneric(void * node, const char * param_name, const void * type_info, void * param_val); +extern int NodeSetParamGeneric(void * node, const char * param_name, const void * type_info, const void * param_val); + +} + +int get_node_param_generic(node_t node, const char * param_name, const void * type_info, void * param_val) +{ + return NodeGetParamGeneric(node,param_name,type_info,param_val); +} + +int set_node_param_int(node_t node, const char * param_name, const int * param_val) +{ + return set_node_param_generic(node,param_name,&typeid(int),param_val); +} + +int set_node_param_float(node_t node, const char * param_name, const float * param_val) +{ + return set_node_param_generic(node,param_name,&typeid(float),param_val); +} + +int set_node_param_generic(node_t node, const char * param_name, const void * type_info, const void * param_val) +{ + return NodeSetParamGeneric(node,param_name,type_info,param_val); +} int prerun_graph(graph_t graph) { diff --git a/core/lib/tengine_config.cpp b/core/lib/tengine_config.cpp index 3158bb4cc..2116a1527 100644 --- a/core/lib/tengine_config.cpp +++ b/core/lib/tengine_config.cpp @@ -28,7 +28,7 @@ namespace TEngine { using ConfManager = Attribute; const std::string TEngineConfig::version("0.5.0"); -const char * TEngine_git_commit_id="@ ## GIT_COMMIT_ID ## @"; +const char * TEngine_git_commit_id="@" GIT_COMMIT_ID "@"; bool TEngineConfig::tengine_mt_mode = true; char TEngineConfig::delim_ch = '='; @@ -163,4 +163,41 @@ bool GetSyncRunMode(void) } +template<> +bool ConvertSpecialAny(int& entry, const std::type_info& info, any& data) +{ + if(info == typeid(double)) + { + entry=any_cast(data); + return true; + } + + return false; +} + +template<> +bool ConvertSpecialAny(float& entry, const std::type_info& info, any& data) +{ + if(info == typeid(double)) + { + entry=any_cast(data); + return true; + } + + return false; +} + +template<> +bool ConvertSpecialAny(std::string& entry, const std::type_info& info, any& data) +{ + if(info == typeid(const char *)) + { + entry=any_cast(data); + return true; + } + + return false; +} + + } //end of namespace TEngine diff --git a/doc/install.md b/doc/install.md index e46f47f4f..d9fcf5e31 100644 --- a/doc/install.md +++ b/doc/install.md @@ -74,31 +74,31 @@ make Tengine also provides some example programs for tests, and you can easily validate whether your Tengine is successfully built by running these test programs and inspecting the results. ### 3.1 Run SqueezeNet - ``` - ./build/tests/bin/bench_sqz -r1 - ``` + + ./build/tests/bin/bench_sqz -r1 - `-r1` means repeat one time. + `-r1` means repeat one time. Output message: - 0.2763 - "n02123045 tabby, tabby cat" - 0.2673 - "n02123159 tiger cat" - 0.1766 - "n02119789 kit fox, Vulpes macrotis" - 0.0827 - "n02124075 Egyptian cat" - 0.0777 - "n02085620 Chihuahua" + 0.2763 - "n02123045 tabby, tabby cat" + 0.2673 - "n02123159 tiger cat" + 0.1766 - "n02119789 kit fox, Vulpes macrotis" + 0.0827 - "n02124075 Egyptian cat" + 0.0777 - "n02085620 Chihuahua" ### 3.2 Run MobileNet - ``` - ./build/tests/bin/bench_sqz -r1 - ``` + + ./build/tests/bin/bench_mobilenet -r1 Output message: - 8.5976 - "n02123159 tiger cat" - 7.9550 - "n02119022 red fox, Vulpes vulpes" - 7.8679 - "n02119789 kit fox, Vulpes macrotis" - 7.4274 - "n02113023 Pembroke, Pembroke Welsh corgi" - 6.3647 - "n02123045 tabby, tabby cat" + 8.5976 - "n02123159 tiger cat" + 7.9550 - "n02119022 red fox, Vulpes vulpes" + 7.8679 - "n02119789 kit fox, Vulpes macrotis" + 7.4274 - "n02113023 Pembroke, Pembroke Welsh corgi" + 6.3647 - "n02123045 tabby, tabby cat" + +For more information about the performance test of Tengine, please refer to the documentation of **[benchmark](benchmark.md)**. -For more information about the performance test of Tengine, please refer to the documentation of [benchmark](benchmark.md). +Please visit **[exmaples](../examples/readme.md)** for applications on classification/detection etc. diff --git a/doc/operator_dev.md b/doc/operator_dev.md index a4aecdc3e..8c3b5f296 100644 --- a/doc/operator_dev.md +++ b/doc/operator_dev.md @@ -51,7 +51,7 @@ Please refer to: [operator/include/operator/relu.hpp](../operator/include/operat ### 2. Operator with Parameter First, a separate parameter definition file should be created. In order to facilitate the parameter parsing, it is suggested to define the parameter structure following the example below: ```c++ -struct ConvParam { +struct ConvParam : public NamedParam { int kernel_h; int kernel_w; diff --git a/examples/yolov2/CMakeLists.txt b/examples/yolov2/CMakeLists.txt index 4f104d68f..e73c9fb78 100644 --- a/examples/yolov2/CMakeLists.txt +++ b/examples/yolov2/CMakeLists.txt @@ -22,6 +22,9 @@ endif() set( CODE_SRCS yolov2.cpp ../common/common.cpp) set( BIN_EXE YOLOV2) + +set(CMAKE_CXX_FLAGS "-std=c++11 -O3 -g -Wall") + #opencv find_package(OpenCV REQUIRED) diff --git a/examples/yolov2/yolov2.cpp b/examples/yolov2/yolov2.cpp index d6d68b283..ba9b94171 100644 --- a/examples/yolov2/yolov2.cpp +++ b/examples/yolov2/yolov2.cpp @@ -42,443 +42,476 @@ using namespace TEngine; struct Box { - float x; - float y; - float w; - float h; + float x; + float y; + float w; + float h; }; struct Sbox { - int index; - int class_id; - float **probs; + int index; + int class_id; + float **probs; }; static int nms_comparator(const void *pa, const void *pb) { - Sbox a = *(Sbox *)pa; - Sbox b = *(Sbox *)pb; - float diff = a.probs[a.index][b.class_id] - b.probs[b.index][b.class_id]; - if (diff < 0) - return 1; - else if (diff > 0) - return -1; - return 0; + Sbox a = *(Sbox *)pa; + Sbox b = *(Sbox *)pb; + float diff = a.probs[a.index][b.class_id] - b.probs[b.index][b.class_id]; + if (diff < 0) + return 1; + else if (diff > 0) + return -1; + return 0; } int entry_index(int n,int loc, int entry, int hw, int classes) { - int coords = 4; - return n * hw * (coords + classes + 1) + entry * hw + loc; + int coords = 4; + return n * hw * (coords + classes + 1) + entry * hw + loc; } void get_region_box(Box &b, float *x, std::vector &biases, - int n, int index, int i, int j, int w, int h, int stride) + int n, int index, int i, int j, int w, int h, int stride) { - b.x = (i + x[index + 0 * stride]) / w; - b.y = (j + x[index + 1 * stride]) / h; - b.w = exp(x[index + 2 * stride]) * biases[2 * n] / w; - b.h = exp(x[index + 3 * stride]) * biases[2 * n + 1] / h; + b.x = (i + x[index + 0 * stride]) / w; + b.y = (j + x[index + 1 * stride]) / h; + b.w = exp(x[index + 2 * stride]) * biases[2 * n] / w; + b.h = exp(x[index + 3 * stride]) * biases[2 * n + 1] / h; } void correct_region_boxes(std::vector &boxes, int n, int w, int h, - int netw, int neth) + int netw, int neth) { - int i; - int new_w = 0; - int new_h = 0; - if (((float)netw / w) < ((float)neth / h)) - { - new_w = netw; - new_h = (h * netw) / w; - } - else - { - new_h = neth; - new_w = (w * neth) / h; - } - for (i = 0; i < n; ++i) - { - Box b = boxes[i]; - b.x = (b.x - (netw - new_w) / 2. / netw) / ((float)new_w / netw); - b.y = (b.y - (neth - new_h) / 2. / neth) / ((float)new_h / neth); - b.w *= (float)netw / new_w; - b.h *= (float)neth / new_h; - boxes[i] = b; - } + int i; + int new_w = 0; + int new_h = 0; + if (((float)netw / w) < ((float)neth / h)) + { + new_w = netw; + new_h = (h * netw) / w; + } + else + { + new_h = neth; + new_w = (w * neth) / h; + } + for (i = 0; i < n; ++i) + { + Box b = boxes[i]; + b.x = (b.x - (netw - new_w) / 2. / netw) / ((float)new_w / netw); + b.y = (b.y - (neth - new_h) / 2. / neth) / ((float)new_h / neth); + b.w *= (float)netw / new_w; + b.h *= (float)neth / new_h; + boxes[i] = b; + } } void get_region_boxes(float *output, std::vector &biases, - int neth,int netw, - int h, int w, - int img_w, int img_h, - int num_box, int num_classes, float thresh, - float **probs, - std::vector &boxes) + int neth,int netw, + int h, int w, + int img_w, int img_h, + int num_box, int num_classes, float thresh, + float **probs, + std::vector &boxes) { - int coords = 4; - int hw = h * w; - int i, j, n; - float *predictions = output; - - for (i = 0; i < hw; ++i) - { - int row = i / w; - int col = i % w; - for (n = 0; n < num_box; ++n) - { - int index = n * hw + i; - for (j = 0; j < num_classes; ++j) - { - probs[index][j] = 0; - } - int obj_index = entry_index(n, i, coords, hw, num_classes); - int box_index = entry_index(n, i, 0, hw, num_classes); - float scale = predictions[obj_index]; - get_region_box(boxes[index],predictions, biases, n, box_index, col, row, w, h, hw); - - float max = 0; - for (j = 0; j < num_classes; ++j) - { - int class_index = entry_index( n, i, coords + 1 + j, hw, num_classes); - float prob = scale * predictions[class_index]; - probs[index][j] = (prob > thresh) ? prob : 0; - if (prob > max) - max = prob; - } - probs[index][num_classes] = max; - } - } - - correct_region_boxes(boxes, hw * num_box, img_w, img_h, netw, neth); + int coords = 4; + int hw = h * w; + int i, j, n; + float *predictions = output; + + for (i = 0; i < hw; ++i) + { + int row = i / w; + int col = i % w; + for (n = 0; n < num_box; ++n) + { + int index = n * hw + i; + for (j = 0; j < num_classes; ++j) + { + probs[index][j] = 0; + } + int obj_index = entry_index(n, i, coords, hw, num_classes); + int box_index = entry_index(n, i, 0, hw, num_classes); + float scale = predictions[obj_index]; + get_region_box(boxes[index],predictions, biases, n, box_index, col, row, w, h, hw); + + float max = 0; + for (j = 0; j < num_classes; ++j) + { + int class_index = entry_index( n, i, coords + 1 + j, hw, num_classes); + float prob = scale * predictions[class_index]; + probs[index][j] = (prob > thresh) ? prob : 0; + if (prob > max) + max = prob; + } + probs[index][num_classes] = max; + } + } + + correct_region_boxes(boxes, hw * num_box, img_w, img_h, netw, neth); } float overlap(float x1, float w1, float x2, float w2) { - float l1 = x1 - w1 / 2; - float l2 = x2 - w2 / 2; - float left = l1 > l2 ? l1 : l2; - float r1 = x1 + w1 / 2; - float r2 = x2 + w2 / 2; - float right = r1 < r2 ? r1 : r2; - return right - left; + float l1 = x1 - w1 / 2; + float l2 = x2 - w2 / 2; + float left = l1 > l2 ? l1 : l2; + float r1 = x1 + w1 / 2; + float r2 = x2 + w2 / 2; + float right = r1 < r2 ? r1 : r2; + return right - left; } float box_intersection(Box &a, Box &b) { - float w = overlap(a.x, a.w, b.x, b.w); - float h = overlap(a.y, a.h, b.y, b.h); - if (w < 0 || h < 0) - return 0; - float area = w * h; - return area; + float w = overlap(a.x, a.w, b.x, b.w); + float h = overlap(a.y, a.h, b.y, b.h); + if (w < 0 || h < 0) + return 0; + float area = w * h; + return area; } float box_union(Box &a, Box &b) { - float i = box_intersection(a, b); - float u = a.w * a.h + b.w * b.h - i; - return u; + float i = box_intersection(a, b); + float u = a.w * a.h + b.w * b.h - i; + return u; } float box_iou(Box &a, Box &b) { - return box_intersection(a, b) / box_union(a, b); + return box_intersection(a, b) / box_union(a, b); } void do_nms_sort(std::vector &boxes, - float **probs, - int total, int classes, float thresh) + float **probs, + int total, int classes, float thresh) { - int i, j, k; - Sbox *s = (Sbox *)malloc(sizeof(Sbox) * total); - - for (i = 0; i < total; ++i) - { - s[i].index = i; - s[i].class_id = 0; - s[i].probs = probs; - } - - for (k = 0; k < classes; ++k) - { - for (i = 0; i < total; ++i) - { - s[i].class_id = k; - } - qsort(s, total, sizeof(Sbox), nms_comparator); - for (i = 0; i < total; ++i) - { - if (probs[s[i].index][k] == 0) - continue; - Box a = boxes[s[i].index]; - for (j = i + 1; j < total; ++j) - { - Box b = boxes[s[j].index]; - if (box_iou(a, b) > thresh) - { - probs[s[j].index][k] = 0; - } - } - } - } - free(s); + int i, j, k; + Sbox *s = (Sbox *)malloc(sizeof(Sbox) * total); + + for (i = 0; i < total; ++i) + { + s[i].index = i; + s[i].class_id = 0; + s[i].probs = probs; + } + + for (k = 0; k < classes; ++k) + { + for (i = 0; i < total; ++i) + { + s[i].class_id = k; + } + qsort(s, total, sizeof(Sbox), nms_comparator); + for (i = 0; i < total; ++i) + { + if (probs[s[i].index][k] == 0) + continue; + Box a = boxes[s[i].index]; + for (j = i + 1; j < total; ++j) + { + Box b = boxes[s[j].index]; + if (box_iou(a, b) > thresh) + { + probs[s[j].index][k] = 0; + } + } + } + } + free(s); } void draw_detections(std::string &image_file, std::string &save_name, int num, float thresh, std::vector &boxes, - float **probs, int classes) + float **probs, int classes) { - const char *class_names[] = {"background", - "aeroplane", "bicycle", "bird", "boat", - "bottle", "bus", "car", "cat", "chair", - "cow", "diningtable", "dog", "horse", - "motorbike", "person", "pottedplant", - "sheep", "sofa", "train", "tvmonitor"}; - cv::Mat img = cv::imread(image_file); - int img_h = img.size().height; - int img_w = img.size().width; - int line_width=img_w*0.005; - int i, j; - for (i = 0; i < num; ++i) - { - int class_id = -1; - for (j = 0; j < classes; ++j) - { - if (probs[i][j] > thresh) - { - if (class_id < 0) - { - class_id = j; - } - printf("%s\t:%.0f%%\n", class_names[class_id + 1], probs[i][j] * 100); - Box b = boxes[i]; - int left = (b.x - b.w / 2.) * img_w; - int right = (b.x + b.w / 2.) * img_w; - int top = (b.y - b.h / 2.) * img_h; - int bot = (b.y + b.h / 2.) * img_h; - if (left < 0) - left = 0; - if (right > img_w - 1) - right = img_w - 1; - if (top < 0) - top = 0; - if (bot > img_h - 1) - bot = img_h - 1; - printf("BOX:( %d , %d ),( %d , %d )\n",left,top,right,bot); - cv::rectangle(img, cv::Rect(left, top, (right - left), (bot - top)), cv::Scalar(0, 255, 255),line_width); - std::ostringstream score_str; - score_str << probs[i][j]; - std::string label = std::string(class_names[class_id + 1]) + ": " + score_str.str(); - int baseLine = 0; - cv::Size label_size = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine); - cv::rectangle(img, - cv::Rect(cv::Point(left, top - label_size.height), - cv::Size(label_size.width, label_size.height + baseLine)), - cv::Scalar(0, 255, 255), - CV_FILLED); - cv::putText(img, label, cv::Point(left, top), - cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0)); - } - } - } - - cv::imwrite(save_name, img); - std::cout<<"======================================\n"; - std::cout<<"[DETECTED IMAGE SAVED]:\t"<< save_name<<"\n"; - std::cout<<"======================================\n"; + const char *class_names[] = {"background", + "aeroplane", "bicycle", "bird", "boat", + "bottle", "bus", "car", "cat", "chair", + "cow", "diningtable", "dog", "horse", + "motorbike", "person", "pottedplant", + "sheep", "sofa", "train", "tvmonitor"}; + cv::Mat img = cv::imread(image_file); + int img_h = img.size().height; + int img_w = img.size().width; + int line_width=img_w*0.005; + int i, j; + for (i = 0; i < num; ++i) + { + int class_id = -1; + for (j = 0; j < classes; ++j) + { + if (probs[i][j] > thresh) + { + if (class_id < 0) + { + class_id = j; + } + printf("%s\t:%.0f%%\n", class_names[class_id + 1], probs[i][j] * 100); + Box b = boxes[i]; + int left = (b.x - b.w / 2.) * img_w; + int right = (b.x + b.w / 2.) * img_w; + int top = (b.y - b.h / 2.) * img_h; + int bot = (b.y + b.h / 2.) * img_h; + if (left < 0) + left = 0; + if (right > img_w - 1) + right = img_w - 1; + if (top < 0) + top = 0; + if (bot > img_h - 1) + bot = img_h - 1; + printf("BOX:( %d , %d ),( %d , %d )\n",left,top,right,bot); + cv::rectangle(img, cv::Rect(left, top, (right - left), (bot - top)), cv::Scalar(0, 255, 255),line_width); + std::ostringstream score_str; + score_str << probs[i][j]; + std::string label = std::string(class_names[class_id + 1]) + ": " + score_str.str(); + int baseLine = 0; + cv::Size label_size = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine); + cv::rectangle(img, + cv::Rect(cv::Point(left, top - label_size.height), + cv::Size(label_size.width, label_size.height + baseLine)), + cv::Scalar(0, 255, 255), + CV_FILLED); + cv::putText(img, label, cv::Point(left, top), + cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0)); + } + } + } + + cv::imwrite(save_name, img); + std::cout<<"======================================\n"; + std::cout<<"[DETECTED IMAGE SAVED]:\t"<< save_name<<"\n"; + std::cout<<"======================================\n"; } void preprocess_yolov2(std::string &image_file, float *input_data, int img_h, int img_w, int *raw_h, int *raw_w) { - cv::Mat img = cv::imread(image_file, -1); - if (img.empty()) - { - std::cerr << "failed to read image file " << image_file << "\n"; - return; - } - - *raw_h = img.rows; - *raw_w = img.cols; - - int new_w = img.cols; - int new_h = img.rows; - if (((float)img_w / img.cols) < ((float)img_h / img.rows)) - { - new_w = img_w; - new_h = (img.rows * img_w) / img.cols; - } - else - { - new_h = img_h; - new_w = (img.cols * img_h) / img.rows; - } - - img.convertTo(img, CV_32FC3); - img = img.mul(0.00392156862745098f); - - std::vector channels; - cv::split(img, channels); - cv::Mat temp = channels[2]; - channels[2] = channels[0]; - channels[0] = temp; - cv::merge(channels, img); - cv::resize(img, img, cv::Size(new_w, new_h)); - - int delta_h = (img_h - new_h) * 0.5f; - int delta_w = (img_w - new_w) * 0.5f; - cv::copyMakeBorder(img, img, delta_h, delta_h, delta_w, delta_w, cv::BORDER_CONSTANT, cv::Scalar(0.5f)); - - float *img_data = (float *)img.data; - int hw = img_h * img_w; - for (int h = 0; h < img_h; h++) - { - for (int w = 0; w < img_w; w++) - { - for (int c = 0; c < 3; c++) - { - input_data[c * hw + h * img_w + w] = *img_data; - img_data++; - } - } - } + cv::Mat img = cv::imread(image_file, -1); + if (img.empty()) + { + std::cerr << "failed to read image file " << image_file << "\n"; + return; + } + + *raw_h = img.rows; + *raw_w = img.cols; + + int new_w = img.cols; + int new_h = img.rows; + if (((float)img_w / img.cols) < ((float)img_h / img.rows)) + { + new_w = img_w; + new_h = (img.rows * img_w) / img.cols; + } + else + { + new_h = img_h; + new_w = (img.cols * img_h) / img.rows; + } + + img.convertTo(img, CV_32FC3); + img = img.mul(0.00392156862745098f); + + std::vector channels; + cv::split(img, channels); + cv::Mat temp = channels[2]; + channels[2] = channels[0]; + channels[0] = temp; + cv::merge(channels, img); + cv::resize(img, img, cv::Size(new_w, new_h)); + + int delta_h = (img_h - new_h) * 0.5f; + int delta_w = (img_w - new_w) * 0.5f; + cv::copyMakeBorder(img, img, delta_h, delta_h, delta_w, delta_w, cv::BORDER_CONSTANT, cv::Scalar(0.5f)); + + float *img_data = (float *)img.data; + int hw = img_h * img_w; + for (int h = 0; h < img_h; h++) + { + for (int w = 0; w < img_w; w++) + { + for (int c = 0; c < 3; c++) + { + input_data[c * hw + h * img_w + w] = *img_data; + img_data++; + } + } + } } int main(int argc, char **argv) { - const std::string root_path = get_root_path(); - std::string proto_file; - std::string model_file; - std::string image_file; - std::string save_name="save.jpg"; - - // this thresh can be tuned for higher/lower confidence boxes - float thresh=0.24; - - int res; - while( ( res=getopt(argc,argv,"p:m:i"))!= -1) - { - switch(res) - { - case 'p': - proto_file=optarg; - break; - case 'm': - model_file=optarg; - break; - case 'i': - image_file=optarg; - break; - default: - break; - } - } - - - // init tengine - init_tengine_library(); - if (request_tengine_version("0.1") < 0) - return 1; - - // load model - const char *model_name = "yolov2"; - if(proto_file.empty()) - { - proto_file = root_path + DEF_PROTO; - std::cout<< "proto file not specified,using "<(((Node *)node)->GetOp())->GetParam(); - int num_box = param->num_box; - int num_class = param->num_classes; - int total = out_dim[2] * out_dim[3] * num_box; - //init box and probs - std::vector boxes(total); - float **probs = (float **)calloc(total, sizeof(float *)); - for (int j = 0; j < total; ++j) - { - probs[j] = (float *)calloc(num_class + 1, sizeof(float *)); - } - - get_region_boxes(output, param->biases, - img_h,img_w, - out_dim[2], out_dim[3], - raw_w, raw_h, num_box, - num_class, thresh, - probs, boxes); - - float nms_thresh = 0.3; - do_nms_sort(boxes, probs, total, num_class, nms_thresh); - // if repeat_count=1, print output - if (repeat_count==1) - draw_detections(image_file, save_name, total, thresh, boxes, probs, num_class); - free(probs); - } - std::cout << "--------------------------------------\n"; - std::cout << "repeat " << repeat_count << " times, avg time per run is " << avg_time / repeat_count << " ms\n"; - free(input_data); - postrun_graph(graph); - destroy_runtime_graph(graph); - remove_model(model_name); - return 0; + const std::string root_path = get_root_path(); + std::string proto_file; + std::string model_file; + std::string image_file; + std::string save_name="save.jpg"; + + // this thresh can be tuned for higher/lower confidence boxes + float thresh=0.24; + + int res; + while( ( res=getopt(argc,argv,"p:m:i"))!= -1) + { + switch(res) + { + case 'p': + proto_file=optarg; + break; + case 'm': + model_file=optarg; + break; + case 'i': + image_file=optarg; + break; + default: + break; + } + } + + + // init tengine + init_tengine_library(); + if (request_tengine_version("0.1") < 0) + return 1; + + // load model + const char *model_name = "yolov2"; + if(proto_file.empty()) + { + proto_file = root_path + DEF_PROTO; + std::cout<< "proto file not specified,using "<(((Node *)node)->GetOp())->GetParam(); + int num_box = param->num_box; + int num_class = param->num_classes; + std::vector param_biases=parm->biases; +#else + int num_box=0; + int num_class=0; + + if(get_node_param_int(node,"num_box",&num_box)<0) + { + std::cerr<<"cannot get num box setting\n"; + return 1; + } + + + if(get_node_param_int(node,"num_classes",&num_class)<0) + { + std::cerr<<"cannot get num class setting\n"; + return 1; + } + + std::vector param_biases; + + if(get_node_param_generic(node,"biases",&typeid(std::vector),¶m_biases)<0) + { + std::cout<<"cannot get bias settings\n"; + return 1; + } + + +#endif + printf("num box: %d\n",num_box); + printf("num class: %d\n",num_class); + + + int total = out_dim[2] * out_dim[3] * num_box; + //init box and probs + std::vector boxes(total); + float **probs = (float **)calloc(total, sizeof(float *)); + for (int j = 0; j < total; ++j) + { + probs[j] = (float *)calloc(num_class + 1, sizeof(float *)); + } + + get_region_boxes(output, param_biases, + img_h,img_w, + out_dim[2], out_dim[3], + raw_w, raw_h, num_box, + num_class, thresh, + probs, boxes); + + float nms_thresh = 0.3; + do_nms_sort(boxes, probs, total, num_class, nms_thresh); + // if repeat_count=1, print output + if (repeat_count==1) + draw_detections(image_file, save_name, total, thresh, boxes, probs, num_class); + free(probs); + } + std::cout << "--------------------------------------\n"; + std::cout << "repeat " << repeat_count << " times, avg time per run is " << avg_time / repeat_count << " ms\n"; + free(input_data); + postrun_graph(graph); + destroy_runtime_graph(graph); + remove_model(model_name); + return 0; } diff --git a/operator/include/operator/batch_norm_param.hpp b/operator/include/operator/batch_norm_param.hpp index 64f023b94..cd45f4f7f 100644 --- a/operator/include/operator/batch_norm_param.hpp +++ b/operator/include/operator/batch_norm_param.hpp @@ -30,7 +30,7 @@ namespace TEngine { -struct BatchNormParam { +struct BatchNormParam : public NamedParam { float rescale_factor; float eps; int caffe_flavor; diff --git a/operator/include/operator/concat_param.hpp b/operator/include/operator/concat_param.hpp index e43db69dc..9109839a1 100644 --- a/operator/include/operator/concat_param.hpp +++ b/operator/include/operator/concat_param.hpp @@ -31,7 +31,7 @@ namespace TEngine { -struct ConcatParam { +struct ConcatParam : public NamedParam { int axis; diff --git a/operator/include/operator/conv_param.hpp b/operator/include/operator/conv_param.hpp index 7d4e095ad..9243a0ebf 100644 --- a/operator/include/operator/conv_param.hpp +++ b/operator/include/operator/conv_param.hpp @@ -24,12 +24,14 @@ #ifndef __CONVOLUTION_PARAM_HPP__ #define __CONVOLUTION_PARAM_HPP__ +#include + #include "parameter.hpp" namespace TEngine { -struct ConvParam { +struct ConvParam : public NamedParam { int kernel_h; int kernel_w; diff --git a/operator/include/operator/deconv_param.hpp b/operator/include/operator/deconv_param.hpp index 194b9f0b6..7d6561dc0 100644 --- a/operator/include/operator/deconv_param.hpp +++ b/operator/include/operator/deconv_param.hpp @@ -29,7 +29,7 @@ namespace TEngine { -struct DeconvParam { +struct DeconvParam : public NamedParam { int kernel_size; int stride; diff --git a/operator/include/operator/detection_output_param.hpp b/operator/include/operator/detection_output_param.hpp index 0dea17638..04f6eff1e 100644 --- a/operator/include/operator/detection_output_param.hpp +++ b/operator/include/operator/detection_output_param.hpp @@ -31,7 +31,7 @@ namespace TEngine { -struct DetectionOutputParam { +struct DetectionOutputParam: public NamedParam { int num_classes; int keep_top_k; diff --git a/operator/include/operator/eltwise.hpp b/operator/include/operator/eltwise.hpp index d53124d1f..66e7ead52 100644 --- a/operator/include/operator/eltwise.hpp +++ b/operator/include/operator/eltwise.hpp @@ -50,7 +50,7 @@ class Eltwise: public OperatorWithParam { } void ParseParam(EltwiseParam & param, Operator * op) override { - EltwiseParam::Parse(param,op); + ParsePredefinedParam(param,op); MethodToType(param); } void SetSchema(void) override; diff --git a/operator/include/operator/eltwise_param.hpp b/operator/include/operator/eltwise_param.hpp index 18b5fab70..2ea7503fa 100644 --- a/operator/include/operator/eltwise_param.hpp +++ b/operator/include/operator/eltwise_param.hpp @@ -42,7 +42,7 @@ enum EltType { namespace TEngine { -struct EltwiseParam { +struct EltwiseParam : public NamedParam { std::string method; EltType type; diff --git a/operator/include/operator/fc_param.hpp b/operator/include/operator/fc_param.hpp index f23d42902..b7a784b6b 100644 --- a/operator/include/operator/fc_param.hpp +++ b/operator/include/operator/fc_param.hpp @@ -31,7 +31,7 @@ namespace TEngine { -struct FCParam +struct FCParam : public NamedParam { int num_output; diff --git a/operator/include/operator/flatten_param.hpp b/operator/include/operator/flatten_param.hpp index 88638a490..b82201188 100644 --- a/operator/include/operator/flatten_param.hpp +++ b/operator/include/operator/flatten_param.hpp @@ -31,7 +31,7 @@ namespace TEngine { -struct FlattenParam { +struct FlattenParam : public NamedParam { int axis; int end_axis; diff --git a/operator/include/operator/lrn_param.hpp b/operator/include/operator/lrn_param.hpp index 81cde9921..75d127c97 100644 --- a/operator/include/operator/lrn_param.hpp +++ b/operator/include/operator/lrn_param.hpp @@ -31,7 +31,7 @@ namespace TEngine { #define LRN_ACROSS_CHANNELS 0 #define LRN_WITHIN_CHANNEL 1 -struct LRNParam { +struct LRNParam : public NamedParam { int local_size; float alpha; float beta; diff --git a/operator/include/operator/normalize_param.hpp b/operator/include/operator/normalize_param.hpp index 50b231113..0bd6e5062 100644 --- a/operator/include/operator/normalize_param.hpp +++ b/operator/include/operator/normalize_param.hpp @@ -29,7 +29,7 @@ namespace TEngine { -struct NormalizeParam +struct NormalizeParam : public NamedParam { int across_spatial; int channel_shared; diff --git a/operator/include/operator/permute_param.hpp b/operator/include/operator/permute_param.hpp index d37e7e652..830585d50 100644 --- a/operator/include/operator/permute_param.hpp +++ b/operator/include/operator/permute_param.hpp @@ -30,7 +30,7 @@ namespace TEngine { -struct PermuteParam { +struct PermuteParam : public NamedParam { int flag; int order0; diff --git a/operator/include/operator/pool_param.hpp b/operator/include/operator/pool_param.hpp index 19d8ebb89..8af793971 100644 --- a/operator/include/operator/pool_param.hpp +++ b/operator/include/operator/pool_param.hpp @@ -36,7 +36,7 @@ enum PoolArg { namespace TEngine { -struct PoolParam { +struct PoolParam : public NamedParam { std::string method; PoolArg alg; diff --git a/operator/include/operator/pooling.hpp b/operator/include/operator/pooling.hpp index 1f6099322..6b7b5d072 100644 --- a/operator/include/operator/pooling.hpp +++ b/operator/include/operator/pooling.hpp @@ -60,7 +60,7 @@ class Pooling: public OperatorWithParam { void ParseParam(PoolParam & param, Operator * op) override { - PoolParam::Parse(param,op); + ParsePredefinedParam(param,op); MethodToAlg(param); /* translate to onnx parameters */ diff --git a/operator/include/operator/priorbox_param.hpp b/operator/include/operator/priorbox_param.hpp index 7ac46ae0e..472551a05 100644 --- a/operator/include/operator/priorbox_param.hpp +++ b/operator/include/operator/priorbox_param.hpp @@ -31,7 +31,7 @@ namespace TEngine { -struct PriorBoxParam { +struct PriorBoxParam : public NamedParam { std::vector min_size; std::vector max_size; diff --git a/operator/include/operator/region_param.hpp b/operator/include/operator/region_param.hpp index 193ca757a..ecd187dca 100644 --- a/operator/include/operator/region_param.hpp +++ b/operator/include/operator/region_param.hpp @@ -31,7 +31,7 @@ namespace TEngine { -struct RegionParam { +struct RegionParam: public NamedParam { int num_classes; int side; @@ -42,7 +42,9 @@ float nms_threshold; std::vectorbiases; DECLARE_PARSER_STRUCTURE(RegionParam) { + DECLARE_PARSER_ENTRY(num_box); DECLARE_PARSER_ENTRY(num_classes); + DECLARE_PARSER_ENTRY(biases); } }; diff --git a/operator/include/operator/relu_param.hpp b/operator/include/operator/relu_param.hpp index c48dc50e5..eb8826390 100644 --- a/operator/include/operator/relu_param.hpp +++ b/operator/include/operator/relu_param.hpp @@ -31,7 +31,7 @@ namespace TEngine { -struct ReLuParam { +struct ReLuParam : public NamedParam { float negative_slope ; diff --git a/operator/include/operator/reorg_param.hpp b/operator/include/operator/reorg_param.hpp index 7072009b3..f4d1db204 100644 --- a/operator/include/operator/reorg_param.hpp +++ b/operator/include/operator/reorg_param.hpp @@ -30,7 +30,7 @@ namespace TEngine { -struct ReorgParam { +struct ReorgParam : public NamedParam { int stride; diff --git a/operator/include/operator/reshape_param.hpp b/operator/include/operator/reshape_param.hpp index 2441525b4..b49a01dc5 100644 --- a/operator/include/operator/reshape_param.hpp +++ b/operator/include/operator/reshape_param.hpp @@ -31,7 +31,7 @@ namespace TEngine { -struct ReshapeParam { +struct ReshapeParam : public NamedParam { std::vector dims; int axis; diff --git a/operator/include/operator/resize_param.hpp b/operator/include/operator/resize_param.hpp index 5d95bdde2..b2d908ded 100644 --- a/operator/include/operator/resize_param.hpp +++ b/operator/include/operator/resize_param.hpp @@ -31,7 +31,7 @@ namespace TEngine { struct StaticOp; -struct ResizeParam { +struct ResizeParam : public NamedParam { float scale_x; float scale_y; diff --git a/operator/include/operator/roi_pooling_param.hpp b/operator/include/operator/roi_pooling_param.hpp index c1f80184c..d82426466 100644 --- a/operator/include/operator/roi_pooling_param.hpp +++ b/operator/include/operator/roi_pooling_param.hpp @@ -30,7 +30,7 @@ namespace TEngine { -struct ROIPoolingParam { +struct ROIPoolingParam : public NamedParam { int pooled_h; int pooled_w; diff --git a/operator/include/operator/rpn_param.hpp b/operator/include/operator/rpn_param.hpp index f1d0761d4..63b6384ee 100644 --- a/operator/include/operator/rpn_param.hpp +++ b/operator/include/operator/rpn_param.hpp @@ -46,7 +46,7 @@ struct Box namespace TEngine { -struct RPNParam { +struct RPNParam : public NamedParam { std::vector ratios; std::vector anchor_scales; diff --git a/operator/include/operator/scale_param.hpp b/operator/include/operator/scale_param.hpp index 5b316fdfd..a64a77d0d 100644 --- a/operator/include/operator/scale_param.hpp +++ b/operator/include/operator/scale_param.hpp @@ -29,7 +29,8 @@ namespace TEngine { -struct ScaleParam { +struct ScaleParam :public NamedParam +{ int axis; int num_axes; int bias_term; diff --git a/operator/include/operator/slice_param.hpp b/operator/include/operator/slice_param.hpp index 2909c596a..4948dee4a 100644 --- a/operator/include/operator/slice_param.hpp +++ b/operator/include/operator/slice_param.hpp @@ -30,7 +30,7 @@ namespace TEngine { -struct SliceParam { +struct SliceParam: public NamedParam { int axis; diff --git a/operator/include/operator/softmax_param.hpp b/operator/include/operator/softmax_param.hpp index 50a2fd8c8..8c915689f 100644 --- a/operator/include/operator/softmax_param.hpp +++ b/operator/include/operator/softmax_param.hpp @@ -30,7 +30,7 @@ namespace TEngine { -struct SoftmaxParam { +struct SoftmaxParam: public NamedParam { int axis;