Skip to content

Commit

Permalink
Add ArgMax ONNX operator (neoml-lib#642)
Browse files Browse the repository at this point in the history
Signed-off-by: Valeriy Fedyunin <valery.fedyunin@abbyy.com>
  • Loading branch information
Valeriy Fedyunin committed May 28, 2022
1 parent afba542 commit 0867ea7
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NeoOnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ target_sources( ${PROJECT_NAME} PRIVATE
LayerUtils.cpp
Operator.cpp
Operators/ActivationOperator.cpp
Operators/ArgMaxOperator.cpp
Operators/BatchNormalizationOperator.cpp
Operators/CastOperator.cpp
Operators/ConcatOperator.cpp
Expand Down Expand Up @@ -73,6 +74,7 @@ target_sources( ${PROJECT_NAME} PRIVATE
NeoOnnxCheck.h
Operator.h
Operators/ActivationOperator.h
Operators/ArgMaxOperator.h
Operators/BatchNormalizationOperator.h
Operators/CastOperator.h
Operators/ConcatOperator.h
Expand Down
2 changes: 2 additions & 0 deletions NeoOnnx/src/Operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "NeoOnnxCheck.h"

#include "Operators/ActivationOperator.h"
#include "Operators/ArgMaxOperator.h"
#include "Operators/BatchNormalizationOperator.h"
#include "Operators/CastOperator.h"
#include "Operators/ConcatOperator.h"
Expand Down Expand Up @@ -114,6 +115,7 @@ namespace {
// Register all operators
REGISTER_OPERATOR( CAbsOperator, "Abs" )
REGISTER_OPERATOR( CAddOperator, "Add" )
REGISTER_OPERATOR( CArgMaxOperator, "ArgMax" )
REGISTER_OPERATOR( CAveragePoolOperator, "AveragePool" )
REGISTER_OPERATOR( CBatchNormalizationOperator, "BatchNormalization" )
REGISTER_OPERATOR( CCastOperator, "Cast" )
Expand Down
69 changes: 69 additions & 0 deletions NeoOnnx/src/Operators/ArgMaxOperator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/* Copyright © 2017-2022 ABBYY Production LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--------------------------------------------------------------------------------------------------------------*/

#include "../common.h"
#pragma hdrstop

#include "onnx.pb.h"

#include "ArgMaxOperator.h"
#include "NeoOnnxCheck.h"

namespace NeoOnnx {

CArgMaxOperator::CArgMaxOperator( const onnx::NodeProto& argMax, int opsetVersion ) :
CLayerOperator( argMax, opsetVersion )
{
// v1 - original
// v11 - negative axis attribute values are supported
// v12 - select_last_index attribute is added
// v13 - bfloat16 data type is supported
CheckOnnxProtocol( InputCount() == 1, "operator must have 1 input", *this );
CheckOnnxProtocol( OutputCount() == 1, "operator must have 1 output", *this );
}

void CArgMaxOperator::AddLayers( const CTensorArray& inputs, CDnn& dnn, CTensorArray& outputs ) const
{
CheckOnnxProtocol( inputs[0] != nullptr, "input can't be optional" );
CPtr<const CUserTensor> input = AsUserTensor( *inputs[0], Name() + "_data", dnn );

// In ONNX ArgMax supports any data type
// In NeoML CArgMaxLayer supports only float input
CBaseLayer* outputLayer = Cast( CT_Float )( Name() + "_cast", CDnnLayerLink( input->Layer(), input->OutputIndex() ) );

int axis = 0;
GetAttribute( "axis", axis );
if( axis < 0 ) {
axis += input->DimCount();
}
outputLayer = Argmax( input->Layout()[axis] )( Name(), outputLayer );

CTensorShape outputShape;
input->Shape().CopyTo( outputShape );
outputShape[axis] = 1;

CTensorLayout outputLayout = input->Layout();

int keepDims = 1;
GetAttribute( "keepdims", keepDims );
if( keepDims == 0 ) {
outputLayout.DeleteAt( axis );
outputShape.DeleteAt( axis );
}

outputs.Add( new CUserTensor( outputShape, outputLayout, CLayerOutput( outputLayer, 0 ) ) );
}

} // namespace NeoOnnx
32 changes: 32 additions & 0 deletions NeoOnnx/src/Operators/ArgMaxOperator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/* Copyright © 2017-2022 ABBYY Production LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--------------------------------------------------------------------------------------------------------------*/

#pragma once

#include "../LayerOperator.h"

namespace NeoOnnx {

// ArgMax operator
class CArgMaxOperator : public CLayerOperator {
public:
CArgMaxOperator( const onnx::NodeProto& argMax, int opsetVersion );

protected:
// CLayerOperator methods
void AddLayers( const CTensorArray& input, CDnn& dnn, CTensorArray& outputs ) const override;
};

} // namespace NeoOnnx

0 comments on commit 0867ea7

Please sign in to comment.