-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issues with ONNX model ReduceMax API not supported by ONNXRT TensorRT EP (but ok with CPU and Cuda EP) #16886
Comments
Thanks for providing these details. |
You need to load nvidia onnx graphsurgeon for tensorrt run as : graph_surgeon_faster_rcnn.py --onnx faster_rcnn_fpn.onnx --output onnx_output Before that , you need to install the following from nvidia # INSTALL PYTHON MODULES polygraphy, tensorrt, onnx_graphsurgeon
python -m pip install colored polygraph --trusted-host pypi.ngc.nvidia.com --extra-index-url https://pypi.ngc.nvidia.com
# and build onnx-graphsurgeon since the pip installer did not work for me
# pip install onnx_graphsurgeon>=0.3.21 --trusted-host pypi.ngc.nvidia.com --extra-index-url=https://pypi.ngc.nvidia.com
git clone https://github.com/NVIDIA/TensorRT.git
cd TensorRT/tools/onnx-graphsurgeon && make build && python3 -m pip install dist/onnx_graphsurgeon-*-py2.py3-none-any.whl here is the graph_surgeon_faster_rcnn.py (tested on python 3.10) #! /usr/bin/env python3
## ---------------------------------------------------------------------------
##
## File: graph_surgeon_faster_rcnn.py for detectron2 faster-rcnn model zoo
##
## Created by Zhijin Li
## E-mail: <zhijinl@nvidia.com>
##
## Started on Mon Jul 17 15:53:26 2023 Zhijin Li
## Last update Tue Jul 18 02:08:41 2023 Zhijin Li
## Modified Sun Jul 30 2023 JC Datin
## ---------------------------------------------------------------------------
import onnx
import onnx_graphsurgeon as gs
import os
import urllib
import argparse
import numpy as np
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.utils.logger import setup_logger
def parse_arguments():
"""
Parse input command line arguments.
Returns
----------
Parsed argument object.
"""
parser = argparse.ArgumentParser(
description='graph surgeon an ONNX model for tensorRT conversion.')
parser.add_argument(
'--onnx',
type=str,
required=True,
help='Path to ONNX model')
parser.add_argument(
'--output',
type=str,
default='./output',
help='output directory for the graph surgeoned model')
parser.add_argument(
'opts',
help='Modify config options using the command-line',
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
return args
def sanitize(graph):
"""
Sanitize the graph by cleaning any unconnected nodes, do a topological resort, and fold constant inputs values.
When possible, run shape inference on the ONNX graph to determine tensor shapes.
"""
for i in range(3):
count_before = len(graph.nodes)
graph.cleanup().toposort()
try:
for node in graph.nodes:
for o in node.outputs:
o.shape = None
model = gs.export_onnx(graph)
model = shape_inference.infer_shapes(model)
graph = gs.import_onnx(model)
except Exception as e:
log.info("Shape inference could not be performed at this time:\n{}".format(e))
try:
graph.fold_constants(fold_shapes=True)
except TypeError as e:
log.error("This version of ONNX GraphSurgeon does not support folding shapes, please upgrade your "
"onnx_graphsurgeon module. Error:\n{}".format(e))
raise
count_after = len(graph.nodes)
if count_before == count_after:
# No new folding occurred in this iteration, so we can stop for now.
break
if __name__ == '__main__':
logger = setup_logger()
args = parse_arguments()
# Load graph
logger.info("loading ONNX model: {}".format(args.onnx))
original_model = args.onnx
graph = gs.import_onnx(onnx.load(original_model))
surgeoned_model_path = original_model.rsplit('.', 1)[0]
surgeoned_model_name = surgeoned_model_path.rsplit('/', 1)[-1]
surgeoned_model= surgeoned_model_name + "-graph_surgeon" + ".onnx"
print(surgeoned_model_path)
print(surgeoned_model_name)
print(surgeoned_model)
graph = graph.cleanup().toposort()
graph = graph.fold_constants().cleanup()
# print(graph)
# for node in graph.nodes:
# print(node)
for node in graph.nodes:
for o in node.outputs:
o.shape = None
model = gs.export_onnx(graph)
model = onnx.shape_inference.infer_shapes(model)
graph = gs.import_onnx(model)
# print(graph)
###### FIX: /model/roi_heads/pooler/level_poolers.0/If
for node in graph.nodes:
if node.name == '/model/roi_heads/box_pooler/level_poolers.0/If':
else_branch_graph = node.attrs['else_branch']
squeeze_axes = gs.Constant(
name='axes',
values=np.array([1]))
squeeze_node = gs.Node(
"Squeeze",
name="squeeze_fix")
squeeze_node.inputs = [*else_branch_graph.outputs, squeeze_axes]
squeeze_node.outputs = [gs.Variable('squeeze_fix_output', dtype=np.float32)]
node.attrs['else_branch'].nodes = [*node.attrs['else_branch'].nodes, squeeze_node]
node.attrs['else_branch'].outputs = squeeze_node.outputs
if node.name == '/model/roi_heads/box_pooler/level_poolers.1/If':
else_branch_graph = node.attrs['else_branch']
squeeze_axes = gs.Constant(
name='axes',
values=np.array([1]))
squeeze_node = gs.Node(
"Squeeze",
name="squeeze_fix")
squeeze_node.inputs = [*else_branch_graph.outputs, squeeze_axes]
squeeze_node.outputs = [gs.Variable('squeeze_fix_output', dtype=np.float32)]
node.attrs['else_branch'].nodes = [*node.attrs['else_branch'].nodes, squeeze_node]
node.attrs['else_branch'].outputs = squeeze_node.outputs
if node.name == '/model/roi_heads/box_pooler/level_poolers.2/If':
else_branch_graph = node.attrs['else_branch']
squeeze_axes = gs.Constant(
name='axes',
values=np.array([1]))
squeeze_node = gs.Node(
"Squeeze",
name="squeeze_fix")
squeeze_node.inputs = [*else_branch_graph.outputs, squeeze_axes]
squeeze_node.outputs = [gs.Variable('squeeze_fix_output', dtype=np.float32)]
node.attrs['else_branch'].nodes = [*node.attrs['else_branch'].nodes, squeeze_node]
node.attrs['else_branch'].outputs = squeeze_node.outputs
if node.name == '/model/roi_heads/box_pooler/level_poolers.3/If':
else_branch_graph = node.attrs['else_branch']
squeeze_axes = gs.Constant(
name='axes',
values=np.array([1]))
squeeze_node = gs.Node(
"Squeeze",
name="squeeze_fix")
squeeze_node.inputs = [*else_branch_graph.outputs, squeeze_axes]
squeeze_node.outputs = [gs.Variable('squeeze_fix_output', dtype=np.float32)]
node.attrs['else_branch'].nodes = [*node.attrs['else_branch'].nodes, squeeze_node]
node.attrs['else_branch'].outputs = squeeze_node.outputs
###### FIX: If_810
for node in graph.nodes:
if node.name == 'If_782':
sub_graph = node.attrs['else_branch']
for sub_node in sub_graph.nodes:
if sub_node.name == 'If_810':
else_branch_graph_810 = sub_node.attrs['else_branch']
squeeze_axes = gs.Constant(
name='axes',
values=np.array([1]))
squeeze_node = gs.Node(
"Squeeze",
name="squeeze_fix_if_810")
squeeze_node.inputs = [*else_branch_graph_810.outputs, squeeze_axes]
squeeze_node.outputs = [gs.Variable('squeeze_810_fix_output', dtype=np.int64)]
sub_node.attrs['else_branch'].nodes = [*sub_node.attrs['else_branch'].nodes, squeeze_node]
sub_node.attrs['else_branch'].outputs = squeeze_node.outputs
graph = graph.cleanup().toposort()
graph = graph.fold_constants().cleanup()
model = gs.export_onnx(graph)
model = onnx.shape_inference.infer_shapes(model)
graph = gs.import_onnx(model)
for node in graph.nodes:
if node.name == '/model/roi_heads/box_pooler/level_poolers.0/If':
print (node)
save_path = os.path.join(args.output, surgeoned_model)
onnx.save(model, save_path) |
sorry , closed by mistake after entering previous comment |
to be clear : issue is still open : I guess we are waiting for the ort team to check it is running with native TRT but not on ORT + TRT EP demonstrating the pb is on ONRT TRT EP. |
@datinje Thanks for providing the script that using onnx-graphsurgeon. Have you seen any inference failure after running [08/02/2023-19:15:46] [I] Starting inference
[08/02/2023-19:15:46] [E] Error[7]: [shapeMachine.cpp::executeContinuation::864] Error Code 7: Internal Error (If_1530_OutputLayer: dimensions not compatible for if-conditional outputs Condition '==' violated: 0 != 6. Instruction: CHECK_EQUAL 0 6.)
[08/02/2023-19:15:46] [E] Error occurred during inference Or would it be convenient to share |
My mistake , I did not realize trtexec also failed on a similar node issue on the demonstrator after all node surgeon -albeit towards the end. trtexec worked on my real project once itself surgeoned. |
Thanks for confirming. We await your new repro test case. If an onnx model runs on native TensorRT (trtexec), it should also run with OnnxRuntime TensorRT EP. If it doesn't , that means there's a bug in OnnxRuntime TensorRT EP. |
There are at least three issues:
|
Thanks a lot for your investigation.
|
Please find the small ReduceMax demonstrator model (thx to Nvidia Zhijin Li !) When run, the onnxruntime + TRT EP passes on the resultant onnx file ! (side note : There is an error reported by the onnx checker tool which should not occur since opset16 is used. (problem seems to be on the onnx checker tools side). However onnxruntime trt EP fails on the bigger faster-rcnn model from Detectron2 model zoo -above:
So I would like to understand why onnxruntime trt EP this time requires a ReduceMax API that requires a dim input as of opset18 ? Then I suspect that when onnxruntime EP code finds the opset18 ReduceMax API it switches to full opset18 and then when finding back the opset16 ReduceMax APi later in the onnx fle then it fails - by requesting the missing dim input. Can you confirm the above ?
|
maybe onnxruntime TRT EP can handles BOTH API in the same file ? |
Hi @datinje thank you for providing this context:
|
As for your question:
Starting from opset18, ReduceMax-18 promotes In your case, if in the opset18 block, According to your error msg:
The data input tensor might already be a 0-dimension (scalar), which couldn't be further dim-reduced in all axes (Acting this op as Identity op and set |
btw I saw you initiated an issue about reducemax op with pytorch-onnx converter. You might try the workaround there to similarly unify the op and see if that helps |
@datinje,
Actually this error message is coming from TRT parser not Onnxruntime. Onnxruntime did verify the graph and since it's opset 17 and the "axes" attribute of I also generated the small ReduceMax demonstrator mode, it passed at model check but failed at graph check as you mentioned:
I checked the opset for the model and it's 16,
|
@yf711 is investigating this and will let you know once we have some progress |
@chilo-ms and @yf711 That is also the case in my own faster-rcnn model made with Detectron2 . Looks like a pattern in Detectron2 models which uses both torchvision and Detectron2 modules and so likely in torch.onnx.export call generate 2 versions of the Reducemax depending on which module the subgraph comes from - would be nice to find which piece of code generate which ReduceMax version! However if I recap : onnxruntime and ORT CPU EP + CUDA EP seems to have a workaround since they do not complain about the 2 different ReduceMax APi : somehow they adapt to the API and generate the good underlying CPU or CUDA instruction whatever the ReduceMax node. As for TRT , you are saying onnxruntime and TRT EP does no check on the ReduceMax APi version but simply delegates to the TRT parser underneath so it has to be with TRT parser which being more strict chooses one API (ReduceMax-18) and fails on the other (ReduceMax-13 - I believe). But then , once I successfully passed the onnx-graphsurgeon on my faster-rcnn model which surgeons all if-then-else blocks to return tensor with same dimension , why is trtexec running fine on the surgeoned model while onnxruntime + TRT EP is not ? I assume trtexec is also using the same parser as ort+trt EP ? I am sorry I can't share my faster-rcnn model which passes trtexec once surgeoned due to IP so that you can try it. I think we are still missing something in the trt EP of onnxruntime, can you have a double look ? @yf711 : how do I set ReduceMax as noop_with_setting empty_axes as 1 ? I want to try . I indeed have if-then-else block with one branch being a scalar. Thanks a lot for all your time spent already : we are narrowing the issue. |
@datinje Onnxruntime and TRT EP both did verify the ReduceMax against its opset before calling TRT parser. (I've run the gdb to double check this)
|
What's the error message you saw when running ORT + TRT EP on your faster-rcnn model? Yes, I think ORT TRT and trtexec use same TRT parser.
When you run trtexec on that specific subgraph, you get the same error message.
(Note: TRT EP has a provider option "trt_dump_subgraphs". Once you enable it, you can find the TensorrtExecutionProvider_TRT_Subgraph.onnx on the disk and it's subgraph being feed to TRT parser.) Use gdb to trace TRT parser and you can locate the error is coming from here |
I think so but need to double check.
The reason that small ReduceMax model passes trtexec and ORT TRT is because it has the valid ReduceMax Node. |
Hi @datinje I found a workaround to adapt the d2 model to symbolic shape inference script and execute via ORT-TRT:
onnxsim onnx_output/faster_rcnn_fpn.onnx faster_rcnn_fpn_simplified.onnx
python /path/to/onnxruntime/onnxruntime/python/tools/symbolic_shape_infer.py --input=faster_rcnn_fpn_simplified.onnx --output=inferred_faster_rcnn_fpn_simplified.onnx --auto_merge
input name image
input shape [3, 'height', 'width']
input type tensor(float)
output name boxes
output shape ['NonZero_1028_o0__d1', 4]
output type tensor(float)
ort_outputs: [array([], shape=(0, 4), dtype=float32), array([], dtype=int64), array([], dtype=float32), array([480, 640], dtype=int64)]
ort_outputs number: 4
ort_boxes : []
ort scores : []
ort classes : []
detectron2 torch and onnx models results match!
[08/16 18:20:39 detectron2]: Success. Please see if this workaround also works on your own model. I will check on the op/args that have been simplified, which might block the symbolic shape infer script from processing. |
@yf711, the onnx_simplifier + symbolic_shape_infer.py is a great hint : it fixes the detectron2 model zoo faster-rcnn model execution on onnxruntime + TRT EP ! I wonder how this is working ? : does it fixes the model by unifying all the ReduceMax to a single API type (compatible with opset17) ? Now on my model , there is no more the reduceMax error (! great !) but I fall in another issue with ORT TRT EP = inference run fine on the same (simplified + inferred) onnx reformatted model with CPU and CUDA EPs. Note : I had to revert to using trt 8.5.3.1 for my python trials since 'pip install tensorrt' (taking 8.6.1) fails due to ssh errors despite using --trusted-host pypi.nvidia.com. This is weird though since https://github.com/onnx/onnx-tensorrt/blob/8.5-GA/docs/operators.md reports that squeeze operator is supported for TRT 8.5 GA Sorry for asking some hint again help on the above : you have been great so far. |
config for the above error : I will try if trt 8.6.0 installs |
same kernel squeeze issue with pip install tensorrt==8.6.0 and runtime TensorRT-8.6.1.6.Linux.x86_64-gnu.cuda-11.8.tar.gz |
Hi @datinje could you apply this command to your own model and see if issue still exists? The command excludes 35 optimizations that we don't need on d2 model, which is the least-optimized case. onnxsim onnx_output/faster_rcnn_fpn.onnx frcnn_split_expand_optimized.onnx --skip-optimization \
eliminate_nop_cast eliminate_nop_dropout eliminate_nop_flatten \
extract_constant_to_initializer eliminate_consecutive_idempotent_ops \
eliminate_if_with_const_cond eliminate_nop_monotone_argmax eliminate_nop_pad \
eliminate_nop_concat eliminate_shape_gather eliminate_slice_after_shape \
eliminate_nop_transpose fuse_add_bias_into_conv fuse_bn_into_conv \
fuse_consecutive_concats fuse_consecutive_log_softmax fuse_consecutive_reduce_unsqueeze \
fuse_consecutive_squeezes fuse_consecutive_transposes fuse_matmul_add_bias_into_gemm \
fuse_pad_into_conv fuse_pad_into_pool fuse_transpose_into_gemm fuse_concat_into_reshape \
eliminate_nop_reshape eliminate_nop_with_unit eliminate_common_subexpression \
fuse_qkv fuse_consecutive_unsqueezes eliminate_deadend eliminate_identity \
eliminate_shape_op fuse_consecutive_slices eliminate_unused_initializer \
eliminate_duplicate_initializer Onnxsim applies all 37 optimizations by default, and it turned out only 2 of which eventually made difference to d2 model and let it pass the symbolic_shape_infer and your testcase (the |
when I run onnxsim with the --skip-optimization above and then symbolic_shape_infer subsequently(which passes) , then my inference is still getting error = then, when I just run onnxsim with --skip-optimizations only, and run inference with the resulting simplified model (without using symbolic_shape_infer )then I still get error = |
I got the above results when using my python onnxruntime + trt EP inference code. Now when running the model obtained with onnsim with all above --skip-optimization AND symbolic_shape_infer in a C++ program , I am getting the same results with defaults setting (don't know which TRT optimisations are used with default settings). BUT if now I use the following C++ setting in the tensorrt_options: Here is my code extract : |
actually the Squeeze(13) kernel not found comes back with |
I think this issue might be in your model's subgraphs, and onnxsim can't process on subgraphs.
Would it be possible to share your squeeze node spec (including its input/output node)? |
Thanks for the explanations : i understand the reason why then it becomes so long. in my case TRT is no good. |
Thanks, @datinje. BTW, the error message "Failed to find kernel for Squeeze(13)" is coming from the onnxruntime's function where it duplicateds any initializer that is used by TRT EP nodes and CPU nodes. That's all we can tell by far. |
update : with the help of Nvidia , we found the cause of the ReduceMax issue and then applying onnxsim + symbolic_shape_infer.py on the model , I can now run onnxruntime with TensorRT EP on the my converted model. The problem is due to pytorch in torchvision module batched_nms() https://github.com/pytorch/vision/blob/v0.15.2/torchvision/ops/boxes.py#L89
The function torchvision function _batched_nms_coordinate_trick() in batched_nms() that I call in my faster-rcnn derived model is causing an if-the-else block where the if block is causing a different dimension tensor than the else block. see line in bold above <== this causes ReduceMax called twice in the resultant onnx subgraphs. boxes.max() causes a proper ReduceMax with a dimension in the attribute as per opset17 since there are boxes found whereas the line above -for the case where no boxes found- causes a ReduceMax with no attribute since the result is a constant scalar - cannot reduce a scalar. No if lines below are commented , then , after passing onnxsim and symbolic_shape_infer.py, I can execute onnxruntime with TRT EP if boxes.numel() == 0:
Note that with the workaround above alone , trt does not passes : it requires to infer the shapes which in turns requires to pass onnxsim to my model. This is only a workaround as I need to deal with the case wher the AI did not find any detection boxes in my image. I will try to see how to declare the plugin with onnxruntime + TRT EP and see if it fixes the problem. That closes really the reduceMax issue in my model with onnxruntime TRT EP. Forget about the Squeeze(13) kernel not found : likely an artefact of the onnsim + symbolic_shape_infer.py conversion. One last thing though . I can do the inference in python as = But I am getting an issue in C++ = Somehow , C++ is missing some treatment python is doing with TRT EP: any idea ? Here is the C++ code = Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR, instanceName.c_str()); Ort::SessionOptions sessionOptions; OrtTensorRTProviderOptions tensorrt_options{}; // load model // get overall shape of input from model // updates sizes : since input mimage size is variable (model was gnerated with dynamic dims meaning model dims are = -1) // set inputnames and outputnames as when onnxmodel was generated (alas don't trust inputTensorInfo as we get the names truncated!) // defines input are in CPU memory ... and assume it is in there : inputTensorValues // initialize onnxrt input tensor vector with input image tensor - tehers is only one input // do inference |
It's possible that it falls back to use CUDA EP when TRT EP fails to run when using Python. Could you help check this? The issue is still there for TRT. Regarding how to use TRT plugin with ORT TRT, please see the doc. Simply create a onnx custom node with name |
I finally successfully run my faster-rcnn model on Onnxruntime with TRT EP by fixing the TopK error =
with lines rational : # sort was faster than topk: in older pytorch release (fixed in PT2.0) see pytorch/pytorch#22812 In summary to make my detectron2 faster-rcnn-rpn works (note that most of the issues are similar if using torchvsion GeneralizedRCNN module), I had:
And as a result I am getting an execution which is SLOWER than with CUDA EP (because several operator are still fallbacked on CPU EP causing memory transfer between subnodes : for that I will file a new issue . but THIS issue can be closed. I hope this will benefit others. |
Hi @datinje , I am having the same issue. However, I could use TRT EP after running ONNX simplifier (onnxsim) followed by symbolic_shape_infer.py. But then when I tried to convert the ONNX file to TensorRT engine using trtexec I ended up with output dimensions of (0,) size. I am not sure how to solve this issue. I think even though the onnxsim + symbolic_shape_infer did a workaround for running the model with TRT EP, the underlying problem with "if then else" still remains which is causing this issue. Could you please tell me how you could solve the if-then-else block problem for the detectron2 fast rcnn model. ANy help would he highly appreciated. I am okay with any kind of solution with or without defeating the purpose of ONNX as my main goal is to build a TensorRT engine from the ONNX. Thanks |
here is how I treated the if-then-else block : you are right neither onnxsim nor symbolic_shape_infer can do anything : instead use onnx graph surgeon (see below) Note that according to Nvidia , this is quite a common case to adapt your model to TRT . hence Nvidia provides customer support for you to help run your model to their boad (happened to me) . So try contact your nvidia customer support. import onnx graph = gs.import_onnx(onnx.load('faster_rcnn_fpn.onnx')) check before surgeon (or use netron)print(graph)for node in graph.nodes:print(node)for node in graph.nodes: model = gs.export_onnx(graph) #check after surgeon (or use netron) print(graph)FIX: /model/roi_heads/pooler/level_poolers.0/Iffor node in graph.nodes:
graph = graph.cleanup().toposort() model = gs.export_onnx(graph) save_path = os.path.join(args.output, 'faster_rcnn_fpn-graph_surgeon.onnx') hope this helps. |
@datinje, thanks a lot! |
@datinje, I just tried. I think we are almost there. The error "IIfConditionalOutputLayer inputs must have the same shape. Shapes are [-1] and [-1,1]." has now become "Assertion failed: !isDynamic(shape) && "Cannot infer squeeze dimensions from a dynamic shape! Please re-export your model with the Squeeze axes input set." I have shared my original model at the following link. I would appreciate it if you could take a look. Below is the code that I tried to fix the issue: import onnx graph = gs.import_onnx(onnx.load('model.onnx')) for node in graph.nodes: model = gs.export_onnx(graph) for node in graph.nodes:
graph = graph.cleanup().toposort() model = gs.export_onnx(graph) onnx.save(model, 'model_fixed.onnx') |
my mistake : graph_surgeon is an Nvidia tool : it is part of the Tensorrt distribution |
@datinje, I am not sure what you mean by "my mistake : graph_surgeon is an Nvidia tool : it is part of the Tensorrt distribution". Would you be able to look into the model (uploaded to my one drive: https://1drv.ms/u/s!Ah7Pr6XgUyqEjCblsEk8S9G7ZiB-) and give me some directions? It would be extremely helpful. Thanks in advance. |
apologies : I did not notice your last update. You already succeeded importing to one onnx-graphsurgeon. Forget about the "my mistake.." comment.
Could be an if-then-else block that yields 2 different dimensions output that cannot be surgeoned because one is of dim 0. If the number of boxes found is zero, then you cannot squeeze one if block more. Sorry if I can only provides hints on the problem I had which can be very specific to my model. But tough to say more without pointing to the exact model code . |
@datinje , I finally could solve my issues. This link https://github.com/NVIDIA/TensorRT/tree/release/8.6/samples/python/detectron2 describes how to get around those with a sample mask rcnn model. Basically, I had to modify the the "create_onnx.py" sample script to create NMS node for EfficientNMS_TRT plugin and replace the output, create PyramidROIAlign_TRT plugin to replace the ROIAligns, fold the constants and modify the Reshape nodes of my onnx model. Once the onnx model is converted using the modified "create_onnx.py", I could generate the engine file without any issues. Thank you for sticking with me and providing me with your valuable feedback. |
Describe the issue
Onnxrunme tensort execution provider is crashing on loading model when parsing ReduceMax instruction
see more detailed below in ho wto reproduce it.
I used ONNX opset 17 and 16 with pytorch 2.0
Same issue with onnxruntime 1.13 to 1.15
using
TensorRT-8.6.1.6
cudnn 8.9.3
cuda 11.8
python 3.10
pytorch 2.0
same issue with C++ ort APIs
Note that because of #16883
I had to use a released package (1.13.1)
To reproduce
as traces show : the problem is with ReduceMax operator which is not well interpreetd by onnxruntime EP , not tensorrt (when I run the model to native tensorrt there is no problem - after converting the ONNX to trt format (note that I need a special version of nvidia graph-surgeon to convert to trt)
Note
the following code demonstate the error:
What exact command you run:
python3 export_model.py --output onnx_output --sample-image input.jpg
Full logs or other relevant observations:
Urgency
I am blocked with onnxrt and need to revert to tensorrt native APIs which defeats our portability strategy.
Platform
Linux
OS Version
SLES15 SP4
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.13.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
TensorRT
Execution Provider Library Version
TensorRT-8.6.1.6
The text was updated successfully, but these errors were encountered: