Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues with ONNX model ReduceMax API not supported by ONNXRT TensorRT EP (but ok with CPU and Cuda EP) #16886

Closed
datinje opened this issue Jul 27, 2023 · 42 comments
Assignees
Labels
ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider

Comments

@datinje
Copy link

datinje commented Jul 27, 2023

Describe the issue

Onnxrunme tensort execution provider is crashing on loading model when parsing ReduceMax instruction
see more detailed below in ho wto reproduce it.

I used ONNX opset 17 and 16 with pytorch 2.0
Same issue with onnxruntime 1.13 to 1.15

using
TensorRT-8.6.1.6
cudnn 8.9.3
cuda 11.8
python 3.10
pytorch 2.0

same issue with C++ ort APIs

Note that because of #16883
I had to use a released package (1.13.1)

To reproduce

  1. get a standard faster-rcnn model from Facebookresearch detectron2 model zoo in func setup()
  2. convert the torch model to onnx using function export_tracing()
  3. prepare a sample input with function get_sample_inputs() to be used for inferencing with torch and inferencing with onnxruntime tensorrt EP
  4. check the resultant onnx model with check_onnx_model() : we see that model is valid but graph is not (don't bother these is a bug in the onnx checker functions (!)
  5. do the inference with torch and onnxruntime and verify results match
  6. => here we get and error with tensorrt EP . while inference is ok with CPU EP and Cuda EP - results match with these.
    as traces show : the problem is with ReduceMax operator which is not well interpreetd by onnxruntime EP , not tensorrt (when I run the model to native tensorrt there is no problem - after converting the ONNX to trt format (note that I need a special version of nvidia graph-surgeon to convert to trt)

Note

the following code demonstate the error:

#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# this is an adaptation of detectron2/tools/deploy/export_model.py
# it does export of a faster-rcnn model to onnx and test it vs the original detectron2 model
# requires any RGB input image (jpg or png)
import argparse
import os
from typing import Dict, List, Tuple
import torch
from torch import Tensor, nn

import detectron2.data.transforms as T
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import build_detection_test_loader, detection_utils
from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format
from detectron2 import model_zoo
""" 
# cannot use detectron2 export lib since it depends on Caffe2 which is not provided anymore with pytorch dist
from detectron2.export import (
    STABLE_ONNX_OPSET_VERSION,
    TracingAdapter,
    dump_torchscript_IR,
    scripting_with_instances,
)
"""
# # use export lib stripped out from caffe2 (/detectron2/export/__init__.py)
from lib.export import (
    TracingAdapter,
    dump_torchscript_IR,
    scripting_with_instances,
)
from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.projects.point_rend import add_pointrend_config
from detectron2.structures import Boxes
from detectron2.utils.env import TORCH_VERSION
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger

import onnx
import onnxruntime as ort
import numpy as np
import cv2 as cv2

def setup_cfg(args):   
    cfg = get_cfg()
    
    #use detectron2 satndard faster rcnn

    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml") 
    cfg.MODEL.DEVICE = 'cuda'   

    # cuda context is initialized before creating dataloader, so we don't fork anymore
    cfg.DATALOADER.NUM_WORKERS = 0
    add_pointrend_config(cfg)
    cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml"))
    cfg.merge_from_list(args.opts)
    cfg.freeze()
        
    return cfg

# experimental. API not yet final
def export_tracing(torch_model, inputs):
    assert TORCH_VERSION >= (1, 8)
    image = inputs[0]["image"]
    inputs = [{"image": image}]  # remove other unused keys

    inference=None
    """
    if isinstance(torch_model, GeneralizedRCNN):

        def inference(model, inputs):
            # use do_postprocess=False so it returns ROI mask
            inst = model.inference(inputs, do_postprocess=False)[0]
            return [{"instances": inst}]

    else:
        inference = None  # assume that we just call the model directly
    """
    
    traceable_model = TracingAdapter(torch_model, inputs, inference)
    
    with PathManager.open(os.path.join(args.output, "faster_rcnn_fpn.onnx"), "wb") as f:
        torch.onnx.export(
              traceable_model, 
              (image,), 
              f, 
              do_constant_folding=True,
              export_params=True,
              input_names=["image"], # the model's input names
              output_names=["boxes", "labels", "scores", "image_dims"], # the model's output names
              dynamic_axes={
                "image"      : {1: "height", 2: "width"},
                "boxes"      : {0: "findings"}, # boxes is a tensor of shape [number of findings, 4] 
                "labels"     : {0: "findings"},
                "scores"     : {0: "findings"}
                },
              verbose=True, 
              opset_version=17) #issue is same with opset 16 and opset 18 is not validated for pytorch 2.0
              
    logger.info("Inputs schema: " + str(traceable_model.inputs_schema))
    logger.info("Outputs schema: " + str(traceable_model.outputs_schema))

    onnx_model_path = os.path.join(args.output, "faster_rcnn_fpn.onnx")
    onnx_model = onnx.load(onnx_model_path)
    
    return onnx_model


def get_sample_inputs(args):

    if args.sample_image is None:
        # get a first batch from dataset
        data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
        first_batch = next(iter(data_loader))
        return first_batch
    else:
        # get a sample data
        original_image = cv2.imread("./input.jpg")
        print ("original_image input shape :", original_image.shape)
         
        # Do same preprocessing as DefaultPredictor
        aug = T.ResizeShortestEdge(
            [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
        )
        image_with_different_size = aug.get_transform(original_image).apply_image(original_image)
        cv2.imwrite("./inputExpanded.jpg", image_with_different_size)
        
        image = original_image
        height, width = original_image.shape[:2]
        image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) # need chanel first for onnx
        print ("image input shape :", image.shape)
                      
        inputs = {"image": image, "height": height, "width": width}

        # Sample ready
        sample_inputs = [inputs]
        return sample_inputs

def check_onnx_model (onnx_model):
  # Check the model
  try:
    onnx.checker.check_model(onnx_model, full_check=True)
  except onnx.checker.ValidationError as e:
    print("The model is invalid: %s" % e)
  else:
    print("The model is valid!")
    
  # check the onnx graph
  try:
     graph = onnx_model.graph
     onnx.checker.check_graph(graph)
  except onnx.checker.ValidationError as e:
    print("The graph is invalid: %s" % e)
  else:
    print("The graph is valid!")
    
  input_shapes = [[d.dim_value for d in _input.type.tensor_type.shape.dim] for _input in onnx_model.graph.input]
  print ('onnx model input shapes', input_shapes)
   
  return None

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
  
   
def eval_onnx_model (torch_model, onnx_model, sample_inputs, args): 
  # get D2 results
  torch_model.eval()
  torch_outputs = torch_model(sample_inputs)
  print ('torch_outputs: ', torch_outputs)
  print ('torch size of outputs: ', len(torch_outputs))

  t_outputs_scores =   to_numpy(torch_outputs[0]['instances'].scores)
  print('d2_torch_scores: ', t_outputs_scores)
  t_outputs_boxes =   to_numpy(torch_outputs[0]['instances'].pred_boxes.tensor)
  print('d2_torch_boxes: ', t_outputs_boxes)
  t_outputs_classes =   to_numpy(torch_outputs[0]['instances'].pred_classes)
  print('d2_torch_classes: ', t_outputs_classes)
  print('')
  
  # get ONNXRT results
  onnx_model_path = os.path.join(args.output, "faster_rcnn_fpn.onnx")
  providers = [('TensorrtExecutionProvider')]
  #providers = [('CUDAExecutionProvider')] # works !
       
  sess_opt = ort.SessionOptions()
  sess = ort.InferenceSession(onnx_model_path, sess_options=sess_opt, providers=providers)
  
  input_name = sess.get_inputs()[0].name
  print("input name", input_name)
  input_shape = sess.get_inputs()[0].shape
  print("input shape", input_shape)
  input_type = sess.get_inputs()[0].type
  print("input type", input_type)

  output_name = sess.get_outputs()[0].name
  print("output name", output_name)
  output_shape = sess.get_outputs()[0].shape
  print("output shape", output_shape)
  output_type = sess.get_outputs()[0].type
  print("output type", output_type)
 
  image = sample_inputs[0]['image']
  np_image  = image.cpu().numpy()
  
  # compute ONNX Runtime output prediction
  ort_inputs = {sess.get_inputs()[0].name: np_image}
  ort_outputs = sess.run(None, ort_inputs)

  print ('ort_outputs: ', ort_outputs)
  print('ort_outputs number: ', len(ort_outputs))
  print('')
  
  boxes = ort_outputs[0]
  classes =  ort_outputs[1]
  scores = ort_outputs[2]
 
  print ('ort_boxes : ', boxes)
  print ('ort scores : ', scores)
  print ('ort classes : ', classes)
  print('')
  
  # eval torch and onnxrt outputs
  np.testing.assert_allclose(t_outputs_boxes, boxes, rtol=1e-03, atol=1e-05)
  np.testing.assert_allclose(t_outputs_scores, scores, rtol=1e-03, atol=1e-05)
  np.testing.assert_allclose(t_outputs_classes, classes, rtol=1e-03, atol=1e-05)
  print('detectron2 torch and onnx models results match!')
  print('')
  
  return None
  
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Export a model for deployment.")
    parser.add_argument("--sample-image", default=None, type=str, help="sample image for input")
    parser.add_argument("--output", help="output directory for the converted model")
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()
    logger = setup_logger()
    logger.info("Command line arguments: " + str(args))
    
    PathManager.mkdirs(args.output)

    cfg = setup_cfg(args)

    # create a torch model
    torch_model = build_model(cfg)
    DetectionCheckpointer(torch_model).resume_or_load(cfg.MODEL.WEIGHTS)
    torch_model.eval()

    # convert and save model
    sample_inputs = get_sample_inputs(args)
    onnx_model = export_tracing(torch_model, sample_inputs)
    
    check_onnx_model (onnx_model)
    
    eval_onnx_model(torch_model, onnx_model, sample_inputs, args)

    logger.info("Success.")

What exact command you run:
python3 export_model.py --output onnx_output --sample-image input.jpg

Full logs or other relevant observations:

[04/04 16:14:53 detectron2]: Command line arguments: Namespace(sample_image='input.jpg', output='onnx_output', opts=[])
original_image input shape : (480, 640, 3)
image input shape : torch.Size([3, 480, 640])

  %/model/ReduceMax_output_0 : Long(2, strides=[1], requires_grad=0, device=cpu) = **onnx::ReduceMax[axes=[0],** keepdims=0, onnx_name="/model/ReduceMax"](%/model/Concat_1_output_0), scope: lib.export.flatten.TracingAdapter::/detectron2.modeling.meta_arch.rcnn.GeneralizedRCNN::model # /usr/local/lib/python3.10/dist-packages/**detectron2**/structures/image_list.py:83:0


      %max_coordinate.3 : Float(device=cpu) = **onnx::ReduceMax[keepdims=0]**(%/model/roi_heads/Cast_9_output_0) # /usr/local/lib/python3.10/dist-packages/**torchvision**/ops/boxes.py:91:21


============= Diagnostic Run torch.onnx.export version 2.0.0+cu118 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

[04/04 16:15:02 detectron2]: Inputs schema: TupleSchema(schemas=[ListSchema(schemas=[DictSchema(schemas=[IdentitySchema()], sizes=[1], keys=['image'])], sizes=[1])], sizes=[1])
[04/04 16:15:02 detectron2]: Outputs schema: ListSchema(schemas=[DictSchema(schemas=[InstancesSchema(schemas=[TensorWrapSchema(class_name='detectron2.structures.Boxes'), IdentitySchema(), IdentitySchema()], sizes=[1, 1, 1], keys=['pred_boxes', 'pred_classes', 'scores'])], sizes=[4], keys=['instances'])], sizes=[4])
The model is valid!

The graph is invalid: Unrecognized attribute: axes for operator ReduceMax
==> Context: Bad node spec for node. Name: /model/ReduceMax OpType: ReduceMax
onnx model input shapes [[3, 0, 0]]

2023-04-04 16:37:52.173723690 [E:onnxruntime:Default, tensorrt_execution_provider.h:61 log] [2023-04-04 16:37:52   ERROR] **ReduceMax_1597: at least 1 dimensions are required for input.**
2**023-04-04 16:37:52.324418966 [E:onnxruntime:, inference_session.cc:1532 operator()] Exception during initialization: /onnxruntime_src/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:897 SubGraphCollection_t onnxruntime::TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t, int, int, const onnxruntime::GraphViewer&, bool*) const [ONNXRuntimeError] : 1 : FAIL : TensorRT input: /model/proposal_generator/GatherND_2_output_0 has no shape specified. Please run shape inference on the onnx model first. Details can be found in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs**

Traceback (most recent call last):
  File "/cad-engine/export_model.py", line 264, in <module>
    eval_onnx_model(torch_model, onnx_model, sample_inputs, args)
  File "/cad-engine/export_model.py", line 190, in eval_onnx_model
    sess = ort.InferenceSession(onnx_model_path, sess_options=sess_opt, providers=providers)
  File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 360, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 408, in _create_inference_session
    sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Exception duri

Urgency

I am blocked with onnxrt and need to revert to tensorrt native APIs which defeats our portability strategy.

Platform

Linux

OS Version

SLES15 SP4

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.13.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

TensorRT

Execution Provider Library Version

TensorRT-8.6.1.6

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider labels Jul 27, 2023
@jywu-msft
Copy link
Member

Thanks for providing these details.
In step 6 , you mention that the model works fine when converting to native TRT format but you needed to use a special version of nvidia graphsurgeon. can you provide more details on that and share the steps for that conversion?
tensorrt ep relies on a component from nvidia, onnx-tensorrt parser to help us convert from onnx to tensorrt network, so it would be helpful for us to know what graph transform graphsurgeon is doing.

@datinje
Copy link
Author

datinje commented Jul 30, 2023

You need to load nvidia onnx graphsurgeon for tensorrt
(tensor rt is more strict than onnx runtime for onnx model to run : in the faster-rcnn model from detectron2 model zoo , it does not like an if-then-else branch to return a tensor with different dimensions , so the script below just fixes this. After that you can run trtexec.

run as : graph_surgeon_faster_rcnn.py --onnx faster_rcnn_fpn.onnx --output onnx_output
then run : trtexec --onnx=faster_rcnn_fpn-graph_surgeon.onnx --verbose --saveEngine=faster_rcnn_fpn-graph_surgeon.trt

Before that , you need to install the following from nvidia

# INSTALL PYTHON MODULES polygraphy, tensorrt, onnx_graphsurgeon
python -m pip install colored polygraph --trusted-host pypi.ngc.nvidia.com --extra-index-url https://pypi.ngc.nvidia.com
# and build onnx-graphsurgeon since the pip installer did not work for me 
# pip install onnx_graphsurgeon>=0.3.21 --trusted-host pypi.ngc.nvidia.com  --extra-index-url=https://pypi.ngc.nvidia.com
git clone https://github.com/NVIDIA/TensorRT.git
cd TensorRT/tools/onnx-graphsurgeon && make build && python3 -m pip install dist/onnx_graphsurgeon-*-py2.py3-none-any.whl

here is the graph_surgeon_faster_rcnn.py (tested on python 3.10)
(I provide the reference to the nice person from nvidia who helped me)

#! /usr/bin/env python3
## ---------------------------------------------------------------------------
##
## File: graph_surgeon_faster_rcnn.py for detectron2 faster-rcnn model zoo
##
## Created by Zhijin Li
## E-mail:   <zhijinl@nvidia.com>
##
## Started on  Mon Jul 17 15:53:26 2023 Zhijin Li
## Last update Tue Jul 18 02:08:41 2023 Zhijin Li
## Modified    Sun Jul 30                2023 JC Datin
## ---------------------------------------------------------------------------
import onnx
import onnx_graphsurgeon as gs

import os
import urllib
import argparse
import numpy as np

from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.utils.logger import setup_logger

def parse_arguments():
  """
  Parse input command line arguments.

  Returns
  ----------
  Parsed argument object.

  """
  parser = argparse.ArgumentParser(
    description='graph surgeon an ONNX model for tensorRT conversion.')

  parser.add_argument(
    '--onnx',
    type=str,
    required=True,
    help='Path to ONNX model')

  parser.add_argument(
    '--output',
    type=str,
    default='./output',
    help='output directory for the graph surgeoned model')

  parser.add_argument(
    'opts',
    help='Modify config options using the command-line',
    default=None,
    nargs=argparse.REMAINDER,
  )

  args = parser.parse_args()
  return args

def sanitize(graph):
  """
  Sanitize the graph by cleaning any unconnected nodes, do a topological resort, and fold constant inputs values.
  When possible, run shape inference on the ONNX graph to determine tensor shapes.
  """

  for i in range(3):

    count_before = len(graph.nodes)
    graph.cleanup().toposort()

    try:
      for node in graph.nodes:
        for o in node.outputs:
          o.shape = None

      model = gs.export_onnx(graph)
      model = shape_inference.infer_shapes(model)
      graph = gs.import_onnx(model)

    except Exception as e:
      log.info("Shape inference could not be performed at this time:\n{}".format(e))

    try:
      graph.fold_constants(fold_shapes=True)
    except TypeError as e:
      log.error("This version of ONNX GraphSurgeon does not support folding shapes, please upgrade your "
                "onnx_graphsurgeon module. Error:\n{}".format(e))
      raise

    count_after = len(graph.nodes)
    if count_before == count_after:
      # No new folding occurred in this iteration, so we can stop for now.
      break


if __name__ == '__main__':

  logger = setup_logger()
  args = parse_arguments()

   # Load graph

  logger.info("loading ONNX model: {}".format(args.onnx))
  original_model = args.onnx
  graph = gs.import_onnx(onnx.load(original_model))
  surgeoned_model_path = original_model.rsplit('.', 1)[0]
  surgeoned_model_name = surgeoned_model_path.rsplit('/', 1)[-1] 
  surgeoned_model= surgeoned_model_name + "-graph_surgeon" + ".onnx"
  print(surgeoned_model_path)
  print(surgeoned_model_name)
  print(surgeoned_model)

  graph = graph.cleanup().toposort()
  graph = graph.fold_constants().cleanup()

  # print(graph)

  # for node in graph.nodes:
  #   print(node)

  for node in graph.nodes:
    for o in node.outputs:
      o.shape = None

  model = gs.export_onnx(graph)
  model = onnx.shape_inference.infer_shapes(model)
  graph = gs.import_onnx(model)

  # print(graph)

  ###### FIX: /model/roi_heads/pooler/level_poolers.0/If
  for node in graph.nodes:
    if node.name == '/model/roi_heads/box_pooler/level_poolers.0/If':

      else_branch_graph = node.attrs['else_branch']

      squeeze_axes = gs.Constant(
        name='axes',
        values=np.array([1]))
      squeeze_node = gs.Node(
        "Squeeze",
        name="squeeze_fix")
      squeeze_node.inputs = [*else_branch_graph.outputs, squeeze_axes]
      squeeze_node.outputs = [gs.Variable('squeeze_fix_output', dtype=np.float32)]

      node.attrs['else_branch'].nodes = [*node.attrs['else_branch'].nodes, squeeze_node]
      node.attrs['else_branch'].outputs = squeeze_node.outputs
      
    if node.name == '/model/roi_heads/box_pooler/level_poolers.1/If':

      else_branch_graph = node.attrs['else_branch']

      squeeze_axes = gs.Constant(
        name='axes',
        values=np.array([1]))
      squeeze_node = gs.Node(
        "Squeeze",
        name="squeeze_fix")
      squeeze_node.inputs = [*else_branch_graph.outputs, squeeze_axes]
      squeeze_node.outputs = [gs.Variable('squeeze_fix_output', dtype=np.float32)]

      node.attrs['else_branch'].nodes = [*node.attrs['else_branch'].nodes, squeeze_node]
      node.attrs['else_branch'].outputs = squeeze_node.outputs

    if node.name == '/model/roi_heads/box_pooler/level_poolers.2/If':

      else_branch_graph = node.attrs['else_branch']

      squeeze_axes = gs.Constant(
        name='axes',
        values=np.array([1]))
      squeeze_node = gs.Node(
        "Squeeze",
        name="squeeze_fix")
      squeeze_node.inputs = [*else_branch_graph.outputs, squeeze_axes]
      squeeze_node.outputs = [gs.Variable('squeeze_fix_output', dtype=np.float32)]

      node.attrs['else_branch'].nodes = [*node.attrs['else_branch'].nodes, squeeze_node]
      node.attrs['else_branch'].outputs = squeeze_node.outputs
      
    if node.name == '/model/roi_heads/box_pooler/level_poolers.3/If':

      else_branch_graph = node.attrs['else_branch']

      squeeze_axes = gs.Constant(
        name='axes',
        values=np.array([1]))
      squeeze_node = gs.Node(
        "Squeeze",
        name="squeeze_fix")
      squeeze_node.inputs = [*else_branch_graph.outputs, squeeze_axes]
      squeeze_node.outputs = [gs.Variable('squeeze_fix_output', dtype=np.float32)]

      node.attrs['else_branch'].nodes = [*node.attrs['else_branch'].nodes, squeeze_node]
      node.attrs['else_branch'].outputs = squeeze_node.outputs

  ###### FIX: If_810
  for node in graph.nodes:
    if node.name == 'If_782':

      sub_graph = node.attrs['else_branch']

      for sub_node in sub_graph.nodes:
        if sub_node.name == 'If_810':

          else_branch_graph_810 = sub_node.attrs['else_branch']

          squeeze_axes = gs.Constant(
            name='axes',
            values=np.array([1]))
          squeeze_node = gs.Node(
            "Squeeze",
            name="squeeze_fix_if_810")
          squeeze_node.inputs = [*else_branch_graph_810.outputs, squeeze_axes]
          squeeze_node.outputs = [gs.Variable('squeeze_810_fix_output', dtype=np.int64)]

          sub_node.attrs['else_branch'].nodes = [*sub_node.attrs['else_branch'].nodes, squeeze_node]
          sub_node.attrs['else_branch'].outputs = squeeze_node.outputs


  graph = graph.cleanup().toposort()
  graph = graph.fold_constants().cleanup()

  model = gs.export_onnx(graph)
  model = onnx.shape_inference.infer_shapes(model)
  graph = gs.import_onnx(model)

  for node in graph.nodes:
    if node.name == '/model/roi_heads/box_pooler/level_poolers.0/If':
       print (node)

  save_path = os.path.join(args.output, surgeoned_model)
  onnx.save(model, save_path)

@datinje datinje closed this as completed Jul 30, 2023
@datinje datinje reopened this Jul 30, 2023
@datinje
Copy link
Author

datinje commented Jul 30, 2023

sorry , closed by mistake after entering previous comment

@datinje
Copy link
Author

datinje commented Aug 2, 2023

to be clear : issue is still open : I guess we are waiting for the ort team to check it is running with native TRT but not on ORT + TRT EP demonstrating the pb is on ONRT TRT EP.

@yf711
Copy link
Contributor

yf711 commented Aug 2, 2023

@datinje Thanks for providing the script that using onnx-graphsurgeon.

Have you seen any inference failure after running trtexec --onnx=faster_rcnn_fpn-graph_surgeon.onnx --verbose --saveEngine=faster_rcnn_fpn-graph_surgeon.trt?
The native trt inference failed at:

[08/02/2023-19:15:46] [I] Starting inference
[08/02/2023-19:15:46] [E] Error[7]: [shapeMachine.cpp::executeContinuation::864] Error Code 7: Internal Error (If_1530_OutputLayer: dimensions not compatible for if-conditional outputs Condition '==' violated: 0 != 6. Instruction: CHECK_EQUAL 0 6.)
[08/02/2023-19:15:46] [E] Error occurred during inference

Or would it be convenient to share faster_rcnn_fpn-graph_surgeon.onnx that you had tested?

@datinje
Copy link
Author

datinje commented Aug 3, 2023

My mistake , I did not realize trtexec also failed on a similar node issue on the demonstrator after all node surgeon -albeit towards the end. trtexec worked on my real project once itself surgeoned.
If_1595_OutputLayer: dimensions not compatible for if-conditional outputs Condition '==' violated: 0 != 6. Instruction: CHECK_EQUAL 0 6.)
I have to find a way to surgeon this node : according to Netron one of the if node block is also missing a fix to get the same dimensions on all blocks.
Meanwhile : I will propose a much simpler reduceMax demonstrater.

@jywu-msft
Copy link
Member

My mistake , I did not realize trtexec also failed on a similar node issue on the demonstrator after all node surgeon -albeit towards the end. trtexec worked on my real project once itself surgeoned. If_1595_OutputLayer: dimensions not compatible for if-conditional outputs Condition '==' violated: 0 != 6. Instruction: CHECK_EQUAL 0 6.) I have to find a way to surgeon this node : according to Netron one of the if node block is also missing a fix to get the same dimensions on all blocks. Meanwhile : I will propose a much simpler reduceMax demonstrater.

Thanks for confirming. We await your new repro test case. If an onnx model runs on native TensorRT (trtexec), it should also run with OnnxRuntime TensorRT EP. If it doesn't , that means there's a bug in OnnxRuntime TensorRT EP.

@chilo-ms
Copy link
Contributor

chilo-ms commented Aug 10, 2023

There are at least three issues:

  1. The ReduceMax node converted from the export_model.py you shared specified the opset_version to be 17. However, from the converted onnx model, the ReduceMax node doens't have axes input which is required by opset 17 (It's optional starts from opset 18).
    That might be the reason you saw this:

    2023-04-04 16:37:52.173723690 [E:onnxruntime:Default, tensorrt_execution_provider.h:61 log] [2023-04-04 16:37:52   ERROR] **ReduceMax_1597: at least 1 dimensions are required for input.**
    

    I think this issue is related to pytorch converter. But you also mentioned "issue is same with opset 16 and opset 18 is not validated for pytorch 2.0", that's a bit strange, since the axes input of ReduceMax is optional for opset 18.

  2. The issue is related to some input tensor doesn't have shapes due to the main graph is being partitioned into TRT subgraphs and CUDA/CPU subgraphs where TRT requires all the graph input to have shape info. Typically, we ask user to run the shape inference script provided by us, but right now it has problmes running it. We are still investigating it.

    2023-04-04 16:37:52.324418966 [E:onnxruntime:, inference_session.cc:1532 operator()] Exception during initialization: /onnxruntime_src/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:897 SubGraphCollection_t onnxruntime::TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t, int, int, const onnxruntime::GraphViewer&, bool*) const [ONNXRuntimeError] : 1 : FAIL : TensorRT input: /model/proposal_generator/GatherND_2_output_0 has no shape specified. Please run shape inference on the onnx model first. Details can be found in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs**
    
  3. The issue is because of "different output shapes of the conditionals (if node)" and it's the constraint of TensorRT where both branches for a conditional should have same output shape.

    [08/02/2023-19:15:46] [E] Error[7]: [shapeMachine.cpp::executeContinuation::864] Error Code 7: Internal Error (If_1530_OutputLayer: dimensions not compatible for if-conditional outputs Condition '==' violated: 0 != 6. Instruction: CHECK_EQUAL 0 6.)
    

    we discussed with Nvidia, right now TRT doesn't support conditionals with different output shapes.
    so yes, you need to find a way to surgeon this if node.

@datinje
Copy link
Author

datinje commented Aug 12, 2023

Thanks a lot for your investigation.
3. Yes I am working to fix the if then else branch to have the same output dimension. Actually this is fixed on my company model using Nvidia collaboration on the onnx- graphsurgeon tailored for the faster-rcnn we use.

  1. Then pytorch onnx- converter has truly a problem since I am getting the problem for both opset16 and 17.
  2. Could you tell me how to use your shape inference script when ready.
  3. Still working on a simpler model for exercising reducemax. But I am in vacation and have no access to my computer now. Will be back on August 16.

@datinje
Copy link
Author

datinje commented Aug 13, 2023

Please find the small ReduceMax demonstrator model (thx to Nvidia Zhijin Li !)

When run, the onnxruntime + TRT EP passes on the resultant onnx file !
Similarly trtexec also passes.
Means that 1. pytorch generates the right API (no dim input required) , 2. onnxruntime TRT EP recognizes the ReduceMax API with no dim input (as of opset16). So does trtexec.

(side note : There is an error reported by the onnx checker tool which should not occur since opset16 is used. (problem seems to be on the onnx checker tools side).

However onnxruntime trt EP fails on the bigger faster-rcnn model from Detectron2 model zoo -above:

2023-08-08 15:39:49.852404881 [E:onnxruntime:Default, tensorrt_execution_provider.h:73 log] [2023-08-08 15:39:49   ERROR] ReduceMax_477: at least 1 dimensions are required for input.

So I would like to understand why onnxruntime trt EP this time requires a ReduceMax API that requires a dim input as of opset18 ?
Could it be because pytorch when generating the onnx file from the detectron2 model is generating at least 2 different ONNX blocks (this is what I found out when looking at the onnx file) : one block with opset16 ReduceMax API (no dim input parameter) and one with opset18 API (dim input parameter)?

Then I suspect that when onnxruntime EP code finds the opset18 ReduceMax API it switches to full opset18 and then when finding back the opset16 ReduceMax APi later in the onnx fle then it fails - by requesting the missing dim input.

Can you confirm the above ?
(if so then we have to fix the detectron2 model or fix the pytorch onnx exporter)

Here is the script full code (was made by Nvidia).

#! /usr/bin/env python3
## ---------------------------------------------------------------------------
##
## File: test_onnx_reducemax.py for project: GEHC Mammo
##
## Created by Zhijin Li
## E-mail:   <zhijinl@nvidia.com>
##
## Started on  Wed Jun 21 15:29:18 2023 Zhijin Li
## Last update Thu Jun 29 01:25:10 2023 Zhijin Li
## ---------------------------------------------------------------------------
import os
import torch
import numpy as np
import onnx
import onnxruntime as ort

ONNX_SAVE_PATH = './output/test-reducemax.onnx'

class ReduceMaxTest(torch.nn.Module):

  def forward(self, x):
    """
    X: list of 2D matrices.
    """
    # result = x.max(0).values
    result = torch.max(x, dim=0).values
    print('--> result:', result)

    return result

if __name__ == '__main__':

  if not os.path.exists('./output'):
    os.makedirs('./output', exist_ok=True)

  print('\n---------- ONNX VERSION: {}'.format(onnx.__version__))
  print('---------- ONNXRUNTIME VERSION: {}\n'.format(ort.__version__))

  # MUST set opset version to 16.
  # This will generate ReduceMax operator with `axes` as attribute.
  # Pytorch 2.0 currently does not support opset >= 17.
  # starting from opset 18, ReduceMax's attribute signature changes
  # by moving `axes` from attribute to input.
  #
  torch.onnx.export(
    ReduceMaxTest().eval(),
    torch.rand(64, 64),
    ONNX_SAVE_PATH,
    input_names=['inp'],
    output_names=['out'],
    dynamic_axes={
      'inp': [0,1]
    },
    verbose=True,
    opset_version=16
  )

  try:
    onnx_model = onnx.load(ONNX_SAVE_PATH)
    onnx.checker.check_model(onnx_model, full_check=True)
    print('---------- ONNX model check successful!')
  except Exception:
    print('\n!!!!! --> ONNX MODEL CHECK FAILED!\n')

  print(onnx.helper.printable_graph(onnx_model.graph))

  # Graph check should work for onnx 1.11 & 1.12, which supports
  # opset versions up to v16 and v17 respectively.
  # Starting from onnx version 1.13, which supports opset v18,
  # graph check will not work. This is probably a bug in onnx.
  #
  try:
    graph = onnx_model.graph
    onnx.checker.check_graph(graph)
    print('---------- ONNX graph check successful!')
  except Exception:
    print('\n!!!!! --> ONNX GRAPH CHECK FAILED!\n')

  # Run with TRT exec provider.
  providers = [(
    'TensorrtExecutionProvider',
    {
      'trt_fp16_enable': True,
      'trt_engine_cache_enable': True,
      'trt_engine_cache_path': './output/trt_engine/'
    })]
  # providers = [('CUDAExecutionProvider')]

  sess_opts = ort.SessionOptions()
  sess = ort.InferenceSession(ONNX_SAVE_PATH, sess_options=sess_opts, providers=providers)

  input = sess.get_inputs()[0]
  print('--> input name:', input.name)

  output = sess.get_outputs()[0]
  print('--> output name:', output.name)

  ort_inputs = {input.name: np.random.randn(57, 89).astype(np.float32)}
  ort_outputs = sess.run(None, ort_inputs)[0]

  print('--> ort_outputs value: ', ort_outputs)
  print('--> ort_outputs shape: ', ort_outputs.shape)
  print('\n---------- Runtime Execution Success.\n')

with regards,
JCD

@datinje
Copy link
Author

datinje commented Aug 13, 2023

maybe onnxruntime TRT EP can handles BOTH API in the same file ?

@yf711
Copy link
Contributor

yf711 commented Aug 14, 2023

Hi @datinje thank you for providing this context:

...detectron2 model is generating at least 2 different ONNX blocks (this is what I found out when looking at the onnx file) : one block with opset16 ReduceMax API (no dim input parameter) and one with opset18 API (dim input parameter)

  • Could you point out which block is opset16 and which block is opset18?
    I just found that TensorRT 8.6 supports operators up to Opset 17, so running model with opset18/16 mixed apis might be problematic to TensorRT

  • If you can confirm that in your frcnn detectron2 model, opset18 api is being used(which assigns axes as input), then the frcnn detectron2 model/model converter need to be fixed to make sure only opset16 is being used in model.

@yf711
Copy link
Contributor

yf711 commented Aug 14, 2023

As for your question:

...why onnxruntime trt EP this time requires a ReduceMax API that requires a dim input as of opset18 ?

Starting from opset18, ReduceMax-18 promotes axes from a static attribute to a dynamic input. And there's a newly introduced attribute noop_with_empty_axes which is correlated with axes.

In your case, if in the opset18 block, axes as an input wasn't specified, and noop_with_empty_axes is in default 0, then all dim should be reduced.

According to your error msg:

ERROR] ReduceMax_477: at least 1 dimensions are required for input.

The data input tensor might already be a 0-dimension (scalar), which couldn't be further dim-reduced in all axes (Acting this op as Identity op and set noop_with_empty_axes as 1 could avoid the dim-reducing operation, but it's not clear if that's expected in your frcnn detectron2 model)

@yf711
Copy link
Contributor

yf711 commented Aug 14, 2023

btw I saw you initiated an issue about reducemax op with pytorch-onnx converter. You might try the workaround there to similarly unify the op and see if that helps

@chilo-ms
Copy link
Contributor

chilo-ms commented Aug 15, 2023

@datinje,
I took a closer look and found the root cause of the ReduceMax issue.
The detectron2 model is a two-level control flow ops graph where the "If_1530" if node's "else_branch" subgraph contains the "ReduceMax_1532" ReduceMax node that doesn't have the "axes" attribute or "axes" input.
However, the other ReduceMax node in the main graph has "axes" attribute.
That's why when you run ORT TRT, you saw similar error message:

2023-08-15 00:05:52.706385374 [E:onnxruntime:Default, tensorrt_execution_provider.h:75 log] [2023-08-15 00:05:52   ERROR] ReduceMax_1532: at least 1 dimensions are required for input.

Actually this error message is coming from TRT parser not Onnxruntime. Onnxruntime did verify the graph and since it's opset 17 and the "axes" attribute of ReduceMax is optional, so it won't complain. But later when TRT EP is calling the TRT parser which has more strict check and found out that it can't get the "axes" attribute for that ReduceMax, so it failed.

I also generated the small ReduceMax demonstrator mode, it passed at model check but failed at graph check as you mentioned:

 onnx_model = onnx.load(ONNX_SAVE_PATH)
 onnx.checker.check_model(onnx_model, full_check=True)
 ... 
 graph = onnx_model.graph
 onnx.checker.check_graph(graph)

I checked the opset for the model and it's 16,
but there is no opset version found in the graph, so I guessed onnx uses the latest opset version 18 by default, so that's why it failed with:

Unrecognized attribute: axes for operator ReduceMax

@chilo-ms
Copy link
Contributor

Could you tell me how to use your shape inference script when ready.

@yf711 is investigating this and will let you know once we have some progress

@datinje
Copy link
Author

datinje commented Aug 16, 2023

@chilo-ms and @yf711
Yes , you confirmed that the detectron2 faster-rcnn model from the D2 model zoo is having 2 types of ReduceMax in the generated ONNX model. That is what I noted when looking at the verbose output of the torch.onnx.export method I used.

That is also the case in my own faster-rcnn model made with Detectron2 . Looks like a pattern in Detectron2 models which uses both torchvision and Detectron2 modules and so likely in torch.onnx.export call generate 2 versions of the Reducemax depending on which module the subgraph comes from - would be nice to find which piece of code generate which ReduceMax version!

However if I recap : onnxruntime and ORT CPU EP + CUDA EP seems to have a workaround since they do not complain about the 2 different ReduceMax APi : somehow they adapt to the API and generate the good underlying CPU or CUDA instruction whatever the ReduceMax node.

As for TRT , you are saying onnxruntime and TRT EP does no check on the ReduceMax APi version but simply delegates to the TRT parser underneath so it has to be with TRT parser which being more strict chooses one API (ReduceMax-18) and fails on the other (ReduceMax-13 - I believe).

But then , once I successfully passed the onnx-graphsurgeon on my faster-rcnn model which surgeons all if-then-else blocks to return tensor with same dimension , why is trtexec running fine on the surgeoned model while onnxruntime + TRT EP is not ? I assume trtexec is also using the same parser as ort+trt EP ?

I am sorry I can't share my faster-rcnn model which passes trtexec once surgeoned due to IP so that you can try it.
I can only share the D2 model zoo faster-rcnn as the problem demonstrator but unfortunately as we have seen it can't be easily surgeoned and so does not pass trtexec .
Since the nvidia small ReduceMax model both passes trtexec AND onnxruntime+ TRT EP , it really has to do with dual ReduceMax API use in the same model, but I don't think it has to do with trt parser.

I think we are still missing something in the trt EP of onnxruntime, can you have a double look ?

@yf711 : how do I set ReduceMax as noop_with_setting empty_axes as 1 ? I want to try . I indeed have if-then-else block with one branch being a scalar.

Thanks a lot for all your time spent already : we are narrowing the issue.

@chilo-ms
Copy link
Contributor

chilo-ms commented Aug 16, 2023

As for TRT , you are saying onnxruntime and TRT EP does no check on the ReduceMax APi version but simply delegates to the TRT parser underneath so it has to be with TRT parser which being more strict chooses one API (ReduceMax-18) and fails on the other (ReduceMax-13 - I believe).

@datinje
In the DetectronV2 model, the opset version is 17 and I didn't see opset 18 in the model. So, the model will be interpreted as opset 17 by Onnxruntime, TRT EP and TRT parser.

Onnxruntime and TRT EP both did verify the ReduceMax against its opset before calling TRT parser. (I've run the gdb to double check this)
They will call Graph::VerifyNodeAndOpMatch() inside ORT for all the nodes as well as their subgraphs' nodes. You can see following two places of verifying the node:

  • As you may see this function will call checker::check_node and OpSchema::Verify from ONNX checker and it checks whatever the node's attribute comes with the current node meaning the problematic ReduceMax node has only "keepdims" attribute and it's a valid attribute and it passes the check.

  • Also, Graph::VerifyNodeAndOpMatch() checks the current node against the schema of the node with specific opset, here is opset 17, so the "axes" is marked as not required. So the problematic ReduceMax node has only "keepdims" attribute passes the check as well.

@chilo-ms
Copy link
Contributor

chilo-ms commented Aug 16, 2023

But then , once I successfully passed the onnx-graphsurgeon on my faster-rcnn model which surgeons all if-then-else blocks to return tensor with same dimension , why is trtexec running fine on the surgeoned model while onnxruntime + TRT EP is not ? I assume trtexec is also using the same parser as ort+trt EP ?

What's the error message you saw when running ORT + TRT EP on your faster-rcnn model?

Yes, I think ORT TRT and trtexec use same TRT parser.
For The Detectron2 model, the ORT TRT failed on one of the subgraph which contains the "invalid" ReduceMax node when calling the TRT parser.

2023-08-15 00:05:52.706385374 [E:onnxruntime:Default, tensorrt_execution_provider.h:75 log] [2023-08-15 00:05:52   ERROR] ReduceMax_1532: at least 1 dimensions are required for input. 

When you run trtexec on that specific subgraph, you get the same error message.

&&&& RUNNING TensorRT.trtexec [TensorRT v8601] # /home/azureuser/tensorrt/TensorRT-8.6.1.6/bin/trtexec --onnx=TensorrtExecutionProvider_TRT_Subgraph.onnx
...
[08/16/2023-16:43:55] [E] [TRT] ModelImporter.cpp:773: input: "/model/proposal_generator/GatherND_output_0"
output: "max_coordinate"
name: "ReduceMax_1532"
op_type: "ReduceMax"
attribute {
  name: "keepdims"
  i: 0
  type: INT
}
doc_string: "  File \"/home/yifan/miniconda3/envs/detectron2/lib/python3.10/site-packages/torchvision/ops/boxes.py\", line 
91\n    if boxes.numel() == 0:\n        return torch.empty((0,), dtype=torch.int64, device=boxes.device)\n    max_coordinate = 
boxes.max()\n                     ~~~~~~~~~ <--- HERE\n    offsets = idxs.to(boxes) * (max_coordinate + 
torch.tensor(1).to(boxes))\n    boxes_for_nms = boxes + offsets[:, None]\n"

[08/16/2023-16:43:55] [E] [TRT] ModelImporter.cpp:774: --- End node ---
[08/16/2023-16:43:55] [E] [TRT] ModelImporter.cpp:777: ERROR: ModelImporter.cpp:195 In function parseGraph:
[6] Invalid Node - ReduceMax_1532
ReduceMax_1532: at least 1 dimensions are required for input.
[08/16/2023-16:43:55] [E] Failed to parse onnx file

(Note: TRT EP has a provider option "trt_dump_subgraphs". Once you enable it, you can find the TensorrtExecutionProvider_TRT_Subgraph.onnx on the disk and it's subgraph being feed to TRT parser.)

Use gdb to trace TRT parser and you can locate the error is coming from here
https://github.com/onnx/onnx-tensorrt/blob/main/ModelImporter.cpp#L191
once the reduceTensor() is finished
https://github.com/onnx/onnx-tensorrt/blob/main/onnx2trt_utils.cpp#L1533

@chilo-ms
Copy link
Contributor

However if I recap : onnxruntime and ORT CPU EP + CUDA EP seems to have a workaround since they do not complain about the 2 different ReduceMax APi : somehow they adapt to the API and generate the good underlying CPU or CUDA instruction whatever the ReduceMax node.

I think so but need to double check.

Since the nvidia small ReduceMax model both passes trtexec AND onnxruntime+ TRT EP , it really has to do with dual ReduceMax API use in the same model, but I don't think it has to do with trt parser.

The reason that small ReduceMax model passes trtexec and ORT TRT is because it has the valid ReduceMax Node.
But for the Detectron2 model, it has the invalid ReduceMax and can't pass the TRT parser verification either from trtexec or ORT TRT.
If you want to use TRT, we need to fix this invalid ReduceMax node issue.

@yf711
Copy link
Contributor

yf711 commented Aug 17, 2023

Hi @datinje I found a workaround to adapt the d2 model to symbolic shape inference script and execute via ORT-TRT:

  1. Simplify the model using onnx-simplifier:
onnxsim onnx_output/faster_rcnn_fpn.onnx faster_rcnn_fpn_simplified.onnx
  1. Execute symbolic shape inference script:
python /path/to/onnxruntime/onnxruntime/python/tools/symbolic_shape_infer.py --input=faster_rcnn_fpn_simplified.onnx --output=inferred_faster_rcnn_fpn_simplified.onnx --auto_merge
  1. Use inferred_faster_rcnn_fpn_simplified.onnx as input of your eval_onnx_model() . Here's log I had:
input name image
input shape [3, 'height', 'width']
input type tensor(float)
output name boxes
output shape ['NonZero_1028_o0__d1', 4]
output type tensor(float)

ort_outputs:  [array([], shape=(0, 4), dtype=float32), array([], dtype=int64), array([], dtype=float32), array([480, 640], dtype=int64)]
ort_outputs number:  4

ort_boxes :  []
ort scores :  []
ort classes :  []

detectron2 torch and onnx models results match!

[08/16 18:20:39 detectron2]: Success.

Please see if this workaround also works on your own model.

I will check on the op/args that have been simplified, which might block the symbolic shape infer script from processing.

@datinje
Copy link
Author

datinje commented Aug 20, 2023

@yf711, the onnx_simplifier + symbolic_shape_infer.py is a great hint : it fixes the detectron2 model zoo faster-rcnn model execution on onnxruntime + TRT EP !
I got the same matches between pytorch inference and onrt + trt EP like you.
Thanks a lot !

I wonder how this is working ? : does it fixes the model by unifying all the ReduceMax to a single API type (compatible with opset17) ?

Now on my model , there is no more the reduceMax error (! great !) but I fall in another issue with ORT TRT EP =
2023-08-20 12:26:14.971402070 [E:onnxruntime:, inference_session.cc:1644 operator()] Exception during initialization: /onnxruntime_src/onnxruntime/core/optimizer/transformer_memcpy.cc:374 bool onnxruntime::TransformerMemcpyImpl::ProcessInitializers(const onnxruntime::KernelRegistryManager&, const InitializedTensorSet&) status.IsOK() was false. Failed to find kernel for Squeeze(13) (node /model/roi_heads/box_pooler/level_poolers.0/Squeeze). Kernel not found

inference run fine on the same (simplified + inferred) onnx reformatted model with CPU and CUDA EPs.

Note : I had to revert to using trt 8.5.3.1 for my python trials since 'pip install tensorrt' (taking 8.6.1) fails due to ssh errors despite using --trusted-host pypi.nvidia.com. This is weird though since https://github.com/onnx/onnx-tensorrt/blob/8.5-GA/docs/operators.md reports that squeeze operator is supported for TRT 8.5 GA

Sorry for asking some hint again help on the above : you have been great so far.
Should I close the case as I may come a different type of issue ?

@datinje
Copy link
Author

datinje commented Aug 20, 2023

config for the above error :
using nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 docker image
pip install onnx onnxruntime-gpu #onnx (1.14) and onnxruntime (1.15.1)
pip install tensorrt==8.5.3.1
insatlle trt runtime with TensorRT-8.5.3.1.Linux.x86_64-gnu.cuda-11.8.cudnn8.6.tar.gz
pip install onnx-simplifier # (0.4.33)

I will try if trt 8.6.0 installs

@datinje
Copy link
Author

datinje commented Aug 20, 2023

same kernel squeeze issue with pip install tensorrt==8.6.0 and runtime TensorRT-8.6.1.6.Linux.x86_64-gnu.cuda-11.8.tar.gz

@yf711
Copy link
Contributor

yf711 commented Aug 21, 2023

Hi @datinje could you apply this command to your own model and see if issue still exists? The command excludes 35 optimizations that we don't need on d2 model, which is the least-optimized case.

onnxsim onnx_output/faster_rcnn_fpn.onnx frcnn_split_expand_optimized.onnx --skip-optimization \
eliminate_nop_cast eliminate_nop_dropout eliminate_nop_flatten \
extract_constant_to_initializer eliminate_consecutive_idempotent_ops \
eliminate_if_with_const_cond eliminate_nop_monotone_argmax eliminate_nop_pad \
eliminate_nop_concat eliminate_shape_gather eliminate_slice_after_shape \
eliminate_nop_transpose fuse_add_bias_into_conv fuse_bn_into_conv \
fuse_consecutive_concats fuse_consecutive_log_softmax fuse_consecutive_reduce_unsqueeze \
fuse_consecutive_squeezes fuse_consecutive_transposes fuse_matmul_add_bias_into_gemm \
fuse_pad_into_conv fuse_pad_into_pool fuse_transpose_into_gemm fuse_concat_into_reshape \
eliminate_nop_reshape eliminate_nop_with_unit eliminate_common_subexpression \
fuse_qkv fuse_consecutive_unsqueezes eliminate_deadend eliminate_identity \
eliminate_shape_op fuse_consecutive_slices eliminate_unused_initializer \
eliminate_duplicate_initializer

Onnxsim applies all 37 optimizations by default, and it turned out only 2 of which eventually made difference to d2 model and let it pass the symbolic_shape_infer and your testcase (the eliminate_nop_split and eliminate_nop_expand optimization);

@datinje
Copy link
Author

datinje commented Aug 22, 2023

when I run onnxsim with the --skip-optimization above and then symbolic_shape_infer subsequently(which passes) , then my inference is still getting error =
2023-08-22 07:13:34.988129481 [E:onnxruntime:, inference_session.cc:1644 operator()] Exception during initialization: /onnxruntime_src/onnxruntime/core/optimizer/transformer_memcpy.cc:374 bool onnxruntime::TransformerMemcpyImpl::ProcessInitializers(const onnxruntime::KernelRegistryManager&, const InitializedTensorSet&) status.IsOK() was false. Failed to find kernel for Squeeze(13) (node /model/roi_heads/box_pooler/level_poolers.0/Squeeze). Kernel not found

then, when I just run onnxsim with --skip-optimizations only, and run inference with the resulting simplified model (without using symbolic_shape_infer )then I still get error =
2023-08-22 07:07:57.990154544 [E:onnxruntime:Default, tensorrt_execution_provider.h:73 log] [2023-08-22 07:07:57 ERROR] ReduceMax_505: at least 1 dimensions are required for input.

@datinje
Copy link
Author

datinje commented Aug 22, 2023

I got the above results when using my python onnxruntime + trt EP inference code.

Now when running the model obtained with onnsim with all above --skip-optimization AND symbolic_shape_infer in a C++ program , I am getting the same results with defaults setting (don't know which TRT optimisations are used with default settings).
Failed to find kernel for Squeeze(13)

BUT if now I use the following C++ setting in the tensorrt_options:
tensorrt_options.trt_min_subgraph_size = 5;
THEN inference with onnxruntime tensort EP passes BUT 7x slower than with the CUDA EP

Here is my code extract :
if (useTensoRT) {
std::cout << "\t Inference Execution Provider: TensorRT" << std::endl;
OrtTensorRTProviderOptions tensorrt_options{};
tensorrt_options.trt_max_partition_iterations = 10; // default . must be positive integer
tensorrt_options.trt_min_subgraph_size = 5; // default . must be positive integer
tensorrt_options.trt_max_workspace_size = 4294967296; // from default 1GB - to 4GB
sessionOptions.AppendExecutionProvider_TensorRT(tensorrt_options);
}
else {
std::cout << "\t Inference Execution Provider: CUDA" << std::endl;
OrtCUDAProviderOptions cuda_options{};
sessionOptions.AppendExecutionProvider_CUDA(cuda_options); // 7x faster than with TRT EP
}
Any hints ?

@datinje
Copy link
Author

datinje commented Aug 22, 2023

actually the Squeeze(13) kernel not found comes back with
tensorrt_options.trt_min_subgraph_size = 1; (that must be the default)

@yf711
Copy link
Contributor

yf711 commented Aug 22, 2023

I think this issue might be in your model's subgraphs, and onnxsim can't process on subgraphs.

trt_min_subgraph_size specifies the min number of nodes a subgraph should have to be considered for optimization by TensorRT EP.
Setting this as 5 might fall back most of your subgraphs with CUDA EP+additional conversion overheads, which could make it much slower than pure CUDA EP.
(Maybe try setting this as 2?)

Would it be possible to share your squeeze node spec (including its input/output node)?
As we can't debug on your model and d2 model doesn't have this issue, I wonder if you could share a minimal repro model?

@datinje
Copy link
Author

datinje commented Aug 23, 2023

Thanks for the explanations : i understand the reason why then it becomes so long. in my case TRT is no good.
I need to work on a repro model and for that find the squeeze node details. That will take me time until I feed this report again.
Or we should rewrite the whole model in pure pytorch to avoid any Detectron2 dependencies given qll the problems it yields to deploy on onnx.
That is what I proposed in detectron2 github facebookresearch/detectron2#4896
I want to thank you @chilo-ms and @yf711 for your time and advices : they will not be lost.

@chilo-ms
Copy link
Contributor

chilo-ms commented Aug 23, 2023

Thanks, @datinje.
We will wait for you to share the repro model as well as the squeeze node so that we can further debug this.

BTW, the error message "Failed to find kernel for Squeeze(13)" is coming from the onnxruntime's function where it duplicateds any initializer that is used by TRT EP nodes and CPU nodes. That's all we can tell by far.

@datinje
Copy link
Author

datinje commented Aug 27, 2023

update : with the help of Nvidia , we found the cause of the ReduceMax issue and then applying onnxsim + symbolic_shape_infer.py on the model , I can now run onnxruntime with TensorRT EP on the my converted model.

The problem is due to pytorch in torchvision module batched_nms() https://github.com/pytorch/vision/blob/v0.15.2/torchvision/ops/boxes.py#L89

if boxes.numel() == 0:
    return torch.empty((0,), dtype=torch.int64, device=boxes.device) # <=== this is the problem ...
**max_coordinate = boxes.max()**  # <=== ... with this

The function torchvision function _batched_nms_coordinate_trick() in batched_nms() that I call in my faster-rcnn derived model is causing an if-the-else block where the if block is causing a different dimension tensor than the else block. see line in bold above

<== this causes ReduceMax called twice in the resultant onnx subgraphs. boxes.max() causes a proper ReduceMax with a dimension in the attribute as per opset17 since there are boxes found whereas the line above -for the case where no boxes found- causes a ReduceMax with no attribute since the result is a constant scalar - cannot reduce a scalar.
@yf711 DID see the problem above , but even setting noop_with_empty_axes as 1 as he suggested , I could not resolve the problem.

No if lines below are commented , then , after passing onnxsim and symbolic_shape_infer.py, I can execute onnxruntime with TRT EP

if boxes.numel() == 0:

    #return torch.empty((0,), dtype=torch.int64, device=boxes.device) 

Note that with the workaround above alone , trt does not passes : it requires to infer the shapes which in turns requires to pass onnxsim to my model.

This is only a workaround as I need to deal with the case wher the AI did not find any detection boxes in my image.
The idea of Nvidia (Zhijin , the nice guy that is helping me) is to use Nvidia TRT nmsPlugin https://github.com/NVIDIA/TensorRT/tree/release/8.6/plugin/nmsPlugin
The reason is that the NVIDIA TRT nmsPlugin treats the nms if-then-else blocks (particularly the case where no detection boxes found) smarter than torchvision.

I will try to see how to declare the plugin with onnxruntime + TRT EP and see if it fixes the problem.

That closes really the reduceMax issue in my model with onnxruntime TRT EP. Forget about the Squeeze(13) kernel not found : likely an artefact of the onnsim + symbolic_shape_infer.py conversion.

One last thing though . I can do the inference in python as =
corrected_onnx_model = "inferred_simplified_model.onnx"
onnx_model_path = os.path.join(args.output, corrected_onnx_model)
providers = ['TensorrtExecutionProvider']
sess_opt = ort.SessionOptions()
sess = ort.InferenceSession(onnx_model_path, sess_options=sess_opt, providers=providers)
image = sample_inputs[0]['image']
np_image = image.cpu().numpy()
ort_inputs = {sess.get_inputs()[0].name: np_image}
ort_outputs = sess.run(None, ort_inputs)

But I am getting an issue in C++ =
2023-08-27 16:23:24.441734536 [E:onnxruntime:cad-engine-inference, tensorrt_execution_provider.h:75 log] [2023-08-27 16:23:24 ERROR] 4: [graphShapeAnalyzer.cpp::processCheck::862] Error Code 4: Internal Error (/model/proposal_generator/TopK: K exceeds the maximum value allowed (3840).)
2023-08-27 16:23:24.441805954 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running TRTKernel_graph_torch_jit_10911703694145108358_0 node. Name:'TensorrtExecutionProvider_TRTKernel_graph_torch_jit_10911703694145108358_0_0' Status Message: TensorRT EP Failed to Build Engine.

Somehow , C++ is missing some treatment python is doing with TRT EP: any idea ?

Here is the C++ code =
std::string modelFilepath{"models/inferred_model_simplified.onnx"};
std::string instanceName{"myModel"};

Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR, instanceName.c_str());

Ort::SessionOptions sessionOptions;
sessionOptions.SetIntraOpNumThreads(1);

OrtTensorRTProviderOptions tensorrt_options{};
sessionOptions.AppendExecutionProvider_TensorRT(tensorrt_options);

// load model
Ort::Session session(env, modelFilepath.c_str(), sessionOptions);

// get overall shape of input from model
std::vector<int64_t> inputDims = inputTensorInfo.GetShape();

// updates sizes : since input mimage size is variable (model was gnerated with dynamic dims meaning model dims are = -1)
inputDims.at(0) = rows;
inputDims.at(1) = cols;

// set inputnames and outputnames as when onnxmodel was generated (alas don't trust inputTensorInfo as we get the names truncated!)
std::vector<const char*> inputNames{(const char*)"image"};
std::vector<const char*> outputNames{(const char*)"boxes", (const char*)"labels", (const char*)"scores", (const char*)"image_dims"};

// defines input are in CPU memory ... and assume it is in there : inputTensorValues
Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);

// initialize onnxrt input tensor vector with input image tensor - tehers is only one input
std::vectorOrt::Value inputTensors;
inputTensors.emplace_back(Ort::Value::CreateTensor(memoryInfo, inputTensorValues.data(),inputTensorSize,inputDims.data(),inputDims.size()));

// do inference
auto outputTensors = session.Run(Ort::RunOptions{nullptr},
inputNames.data(), inputTensors.data(), 1 /Number of inputs/,
outputNames.data(), 4 /Number of outputs/);

@chilo-ms
Copy link
Contributor

chilo-ms commented Aug 28, 2023

Somehow , C++ is missing some treatment python is doing with TRT EP: any idea ?

It's possible that it falls back to use CUDA EP when TRT EP fails to run when using Python. Could you help check this? The issue is still there for TRT.

Regarding how to use TRT plugin with ORT TRT, please see the doc. Simply create a onnx custom node with name NMS_TRT or NMSDynamic_TRT with domain name trt.plugins.
Note: Please use C++ to run the model with TRT plugins, if you use python to run it, you might encounter the error that the onnx custom node is not registered. We are working on a fix for Python.

@datinje
Copy link
Author

datinje commented Sep 6, 2023

I finally successfully run my faster-rcnn model on Onnxruntime with TRT EP by fixing the TopK error =
in detectron2 v0.6 (the last official) in file dist-packages/detectron2/modeling/proposal_generator/proposal_utils.py
replace the lines 76,79c76,79

    # topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
    logits_i, idx = logits_i.sort(descending=True, dim=1)
    topk_scores_i = logits_i.narrow(1, 0, num_proposals_i)
    topk_idx = idx.narrow(1, 0, num_proposals_i)

with lines
< topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
< #logits_i, idx = logits_i.sort(descending=True, dim=1)
< #topk_scores_i = logits_i.narrow(1, 0, num_proposals_i)
< #topk_idx = idx.narrow(1, 0, num_proposals_i)

rational : # sort was faster than topk: in older pytorch release (fixed in PT2.0) see pytorch/pytorch#22812
Then just nned to set K value to max value supported by TRT in my Detectron 2 model , such as :
cfg.MODEL.RPN.PRE_NMS_TOPK_TEST = 3840

In summary to make my detectron2 faster-rcnn-rpn works (note that most of the issues are similar if using torchvsion GeneralizedRCNN module), I had:

  1. modify torchvision standard detection operator BatchedNMS (Non maximum Suppressor) to make it compatible with TensorRT : one of its lower instruction (max() when used in a if-then-else block translates in ONNX in a ReduceMax operator call with one block without a dimension (cannot do a max() on a list with Zero detections). And this is rejected by TRT (native or ORT EP) while accepted by ORT CUDA EP. I use a workaround : remove the if block to always use max() on the list of detections . This does not work if list size is null (no detection) . A clean solution is to use TRT ONNX BatchedNMSPlugin. But I won't as this makes the ONNX model TensorRt dependent (defeast the purpose of ONNX being neutral) - still looking for a solution - likely by pytroch/torchvision.

  2. fix elements which shape is not explicit for TRT (issue happens in faster-rcnn-rpn whether made with torchvision or Detectron2 ) . For that I used 2 levels of ONNX converters (thx @yf711 and @chilo-ms !)
    a. First converter is an OSS called onnxsim : https://github.com/daquexian/onnx-simplifier that reduces the complexity of the Model.
    b. Second converter is the shape infer provided by onnxruntiime tools : /usr/local/lib/python3.10/dist-packages/onnxruntime/tools/symbolic_shape_infer.py
    Note that Second converter happens not to work without onnx model simplification (onnxsim)

  3. Finally fixing TopK operator usage in my model : set the K to max 3480 (see above)

And as a result I am getting an execution which is SLOWER than with CUDA EP (because several operator are still fallbacked on CPU EP causing memory transfer between subnodes : for that I will file a new issue . but THIS issue can be closed. I hope this will benefit others.

@caruofc
Copy link

caruofc commented Sep 19, 2023

Hi @datinje , I am having the same issue. However, I could use TRT EP after running ONNX simplifier (onnxsim) followed by symbolic_shape_infer.py. But then when I tried to convert the ONNX file to TensorRT engine using trtexec I ended up with output dimensions of (0,) size. I am not sure how to solve this issue. I think even though the onnxsim + symbolic_shape_infer did a workaround for running the model with TRT EP, the underlying problem with "if then else" still remains which is causing this issue. Could you please tell me how you could solve the if-then-else block problem for the detectron2 fast rcnn model. ANy help would he highly appreciated. I am okay with any kind of solution with or without defeating the purpose of ONNX as my main goal is to build a TensorRT engine from the ONNX. Thanks

@datinje
Copy link
Author

datinje commented Sep 20, 2023

here is how I treated the if-then-else block : you are right neither onnxsim nor symbolic_shape_infer can do anything : instead use onnx graph surgeon (see below)

Note that according to Nvidia , this is quite a common case to adapt your model to TRT . hence Nvidia provides customer support for you to help run your model to their boad (happened to me) . So try contact your nvidia customer support.

import onnx
import onnx_graphsurgeon as gs

graph = gs.import_onnx(onnx.load('faster_rcnn_fpn.onnx'))
graph = graph.cleanup().toposort()
graph = graph.fold_constants().cleanup()

check before surgeon (or use netron)

print(graph)

for node in graph.nodes:

print(node)

for node in graph.nodes:
for o in node.outputs:
o.shape = None

model = gs.export_onnx(graph)
model = onnx.shape_inference.infer_shapes(model)
graph = gs.import_onnx(model)

#check after surgeon (or use netron)

print(graph)

FIX: /model/roi_heads/pooler/level_poolers.0/If

for node in graph.nodes:
# adapt to your own case
if node.name == '/model/roi_heads/pooler/level_poolers.0/If':

  else_branch_graph = node.attrs['else_branch']

  squeeze_axes = gs.Constant(
    name='axes',
    values=np.array([1]))
  squeeze_node = gs.Node(
    "Squeeze",
    name="squeeze_fix")
  squeeze_node.inputs = [*else_branch_graph.outputs, squeeze_axes]
  squeeze_node.outputs = [gs.Variable('squeeze_fix_output', dtype=np.float32)]

  node.attrs['else_branch'].nodes = [*node.attrs['else_branch'].nodes, squeeze_node]
  node.attrs['else_branch'].outputs = squeeze_node.outputs

graph = graph.cleanup().toposort()
graph = graph.fold_constants().cleanup()

model = gs.export_onnx(graph)
model = onnx.shape_inference.infer_shapes(model)
graph = gs.import_onnx(model)

save_path = os.path.join(args.output, 'faster_rcnn_fpn-graph_surgeon.onnx')
onnx.save(model, save_path)

hope this helps.
be careful that you cannot squeeze a scaler (that was my problem originally with reduceMax above so you have to change your model design)

@caruofc
Copy link

caruofc commented Sep 21, 2023

@datinje, thanks a lot!
I am going to try this and let you know.

@caruofc
Copy link

caruofc commented Sep 21, 2023

@datinje, I just tried. I think we are almost there. The error

"IIfConditionalOutputLayer inputs must have the same shape. Shapes are [-1] and [-1,1]."

has now become

"Assertion failed: !isDynamic(shape) && "Cannot infer squeeze dimensions from a dynamic shape! Please re-export your model with the Squeeze axes input set."

I have shared my original model at the following link. I would appreciate it if you could take a look.
https://1drv.ms/u/s!Ah7Pr6XgUyqEjCblsEk8S9G7ZiB-

Below is the code that I tried to fix the issue:

import onnx
import onnx_graphsurgeon as gs
import numpy as np

graph = gs.import_onnx(onnx.load('model.onnx'))
graph = graph.cleanup().toposort()
graph = graph.fold_constants().cleanup()

for node in graph.nodes:
for o in node.outputs:
o.shape = None

model = gs.export_onnx(graph)
model = onnx.shape_inference.infer_shapes(model)
graph = gs.import_onnx(model)

for node in graph.nodes:
# adapt to your own case
if node.name == '/roi_heads/box_pooler/level_poolers.0/If':
else_branch_graph = node.attrs['else_branch']
squeeze_axes = gs.Constant(
name='axes',
values=np.array([1]))
squeeze_node = gs.Node(
"Squeeze",
name="squeeze_fix")
squeeze_node.inputs = [*else_branch_graph.outputs, squeeze_axes]
squeeze_node.outputs = [gs.Variable('squeeze_fix_output', dtype=np.float32)]

    node.attrs['else_branch'].nodes = [*node.attrs['else_branch'].nodes, squeeze_node]
    node.attrs['else_branch'].outputs = squeeze_node.outputs

graph = graph.cleanup().toposort()
graph = graph.fold_constants().cleanup()

model = gs.export_onnx(graph)
model = onnx.shape_inference.infer_shapes(model)
graph = gs.import_onnx(model)

onnx.save(model, 'model_fixed.onnx')

@datinje
Copy link
Author

datinje commented Sep 21, 2023

my mistake : graph_surgeon is an Nvidia tool : it is part of the Tensorrt distribution

@caruofc
Copy link

caruofc commented Sep 22, 2023

@datinje, I am not sure what you mean by "my mistake : graph_surgeon is an Nvidia tool : it is part of the Tensorrt distribution". Would you be able to look into the model (uploaded to my one drive: https://1drv.ms/u/s!Ah7Pr6XgUyqEjCblsEk8S9G7ZiB-) and give me some directions? It would be extremely helpful. Thanks in advance.

@datinje
Copy link
Author

datinje commented Sep 25, 2023

apologies : I did not notice your last update. You already succeeded importing to one onnx-graphsurgeon. Forget about the "my mistake.." comment.

  1. Stupid question, maybe , but did you adapt the surgeon code for your model .
    For example : my surgeon code used
    "if node.name == '/roi_heads/box_pooler/level_poolers.0/If':"
    => you have to adapt your if-then-else block in your model .

  2. got your model . can you share a small python that loads it and inferences it with onnxruntime on a sample input for me to try ? I don't know how to run TRT natively on python . then tough to point model code that yields the error with trtexec.
    Can you locate where in your pytocrh model the TRT is failing ?

Could be an if-then-else block that yields 2 different dimensions output that cannot be surgeoned because one is of dim 0.
For example I had the problem in my model with the code=
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
max_coordinate = boxes.max()

If the number of boxes found is zero, then you cannot squeeze one if block more.
Instead, to work around the problem, I commented the case " if boxes.numel() == 0:" in the model (had to pacth torchvision code!)
Then I could run the model on trt. (the final fix would require to use TRt NMS plugin that deals well with cases of 0 box findings)

Sorry if I can only provides hints on the problem I had which can be very specific to my model. But tough to say more without pointing to the exact model code .

@caruofc
Copy link

caruofc commented Sep 28, 2023

@datinje , I finally could solve my issues. This link https://github.com/NVIDIA/TensorRT/tree/release/8.6/samples/python/detectron2 describes how to get around those with a sample mask rcnn model. Basically, I had to modify the the "create_onnx.py" sample script to create NMS node for EfficientNMS_TRT plugin and replace the output, create PyramidROIAlign_TRT plugin to replace the ROIAligns, fold the constants and modify the Reshape nodes of my onnx model. Once the onnx model is converted using the modified "create_onnx.py", I could generate the engine file without any issues.

Thank you for sticking with me and providing me with your valuable feedback.
Appreciate it a lot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider
Projects
None yet
Development

No branches or pull requests

5 participants