Skip to content

Commit

Permalink
Bug fixes for argmin/argmax, maxglobalpooling and squeeze (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinch-nv authored Nov 18, 2019
1 parent 2066f53 commit 8716c9b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
2 changes: 1 addition & 1 deletion builtin_op_importers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ DEFINE_BUILTIN_OP_IMPORTER(GlobalMaxPool)
static_cast<nvinfer1::Dims>(nvinfer1::Dims2{dims.d[2], dims.d[3]}) :
static_cast<nvinfer1::Dims>(nvinfer1::Dims3{dims.d[2], dims.d[3], dims.d[4]});
ASSERT(!isDynamic(kernelSize) && "Cannot run GlobalMaxPool on an input with dynamic spatial dimensions!", ErrorCode::kUNSUPPORTED_NODE);
RETURN_FIRST_OUTPUT(ctx->network()->addPoolingNd(tensor, nvinfer1::PoolingType::kAVERAGE, kernelSize));
RETURN_FIRST_OUTPUT(ctx->network()->addPoolingNd(tensor, nvinfer1::PoolingType::kMAX, kernelSize));
}

DEFINE_BUILTIN_OP_IMPORTER(HardSigmoid)
Expand Down
33 changes: 27 additions & 6 deletions onnx2trt_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,26 +229,45 @@ Status applyLegacyBinaryOpBroadcasting(IImporterContext* ctx,
NodeImportResult argMinMaxHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node,
std::vector<TensorOrWeights>& inputs, nvinfer1::TopKOperation op)
{
nvinfer1::ITensor& tensor = convertToTensor(inputs.at(0), ctx);
ASSERT(tensor.getType() != nvinfer1::DataType::kINT32, ErrorCode::kUNSUPPORTED_NODE);
nvinfer1::ITensor* tensorPtr = &convertToTensor(inputs.at(0), ctx);
ASSERT(tensorPtr->getType() != nvinfer1::DataType::kINT32, ErrorCode::kUNSUPPORTED_NODE);

// Support 1D argMin/argMax
bool needToExpandDims = (tensorPtr->getDimensions().nbDims == 1);
if (needToExpandDims)
{
// Expand dims from 1D to 2D
std::vector<int> axes{1};
tensorPtr = unsqueezeTensor(ctx, *tensorPtr, axes);
ASSERT(tensorPtr, ErrorCode::kUNSUPPORTED_NODE);
}
// Get attributes.
OnnxAttrs attrs(node);
int keepdims = attrs.get("keepdims", 1);
int axis = attrs.get("axis", 0);

// Insert a TopK layer with k set to 1.
int nbDims = tensor.getDimensions().nbDims;
int nbDims = tensorPtr->getDimensions().nbDims;
TRT_CHECK(convert_axis(axis, nbDims));

uint32_t axisMask = 1 << axis;
nvinfer1::ITopKLayer* layer = ctx->network()->addTopK(tensor, op, 1, axisMask);
nvinfer1::ITopKLayer* layer = ctx->network()->addTopK(*tensorPtr, op, 1, axisMask);
ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE);
// We don't care about the TopK values, just the indices.
nvinfer1::ITensor* indices = layer->getOutput(1);
indices->setType(nvinfer1::DataType::kINT32);

// Squeeze back to 1D if applicable
if (needToExpandDims)
{
std::vector<int> axes{1};
indices = squeezeTensor(ctx, *indices, axes);
ASSERT(indices, ErrorCode::kUNSUPPORTED_NODE);
}

// The default behavior of the TopK layer is to keepdims.
if (keepdims)
{
// The default behavior of the TopK layer is to keepdims.
return {{indices}};
}
else
Expand Down Expand Up @@ -1177,9 +1196,11 @@ nvinfer1::ITensor* squeezeStaticTensor(IImporterContext* ctx, nvinfer1::ITensor&
std::set<int> axesSet(axes.begin(), axes.end());
std::vector<int> shape{dims.d, dims.d + dims.nbDims};

int axisCount = 0;
for (const auto& axis : axesSet)
{
shape.erase(shape.begin() + axis);
shape.erase(shape.begin() + axis - axisCount);
axisCount++;
}

nvinfer1::Dims newShape{dims.nbDims - static_cast<int>(axesSet.size())};
Expand Down

0 comments on commit 8716c9b

Please sign in to comment.