Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] 1.14RC1 Tensorrt Regression #14484

Closed
seddonm1 opened this issue Jan 30, 2023 · 29 comments · Fixed by #14719
Closed

[Performance] 1.14RC1 Tensorrt Regression #14484

seddonm1 opened this issue Jan 30, 2023 · 29 comments · Fixed by #14719
Assignees
Labels
ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider regression issues that demonstrate a regression in ORT functionality and need to be addressed immediately

Comments

@seddonm1
Copy link

Describe the issue

I have been testing 1.14.0RC1 and am seeing quite a significant performance regression vs 1.13.1 using C api. You can see the use of the GPU is lower (both wattage and volatile ram) suggesting some bottlenecking has been introduced.

onnxruntime 1.14.0

real    2m48.415s
user    5m15.012s
sys     0m4.494s

|===============================+======================+======================|
|   0  NVIDIA A2           Off  | 00000000:51:00.0 Off |                    0 |
|  0%   65C    P0    57W /  60W |   1242MiB / 15356MiB |     90%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

onnxruntime 1.13.1

real    2m30.369s
user    4m46.815s
sys     0m5.421s

|===============================+======================+======================|
|   0  NVIDIA A2           Off  | 00000000:51:00.0 Off |                    0 |
|  0%   63C    P0    60W /  60W |   1010MiB / 15356MiB |     97%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

To reproduce

This process is doing three things in parallel:

  • copying data from cpu to gpu
  • executing a preprocessing model (resizing images)
  • executing a ml model against those images

Urgency

No response

Platform

Linux

OS Version

Ubuntu 20.04.5LTS

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

rel-1.14.0

ONNX Runtime API

C

Architecture

X64

Execution Provider

TensorRT

Execution Provider Library Version

CUDA 12.0

Model File

No response

Is this a quantized model?

No

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider labels Jan 30, 2023
@ytaous ytaous added regression issues that demonstrate a regression in ORT functionality and need to be addressed immediately release:1.14 labels Jan 31, 2023
@RyanUnderhill
Copy link
Member

If possible, can you give us steps to reproduce the performance you're seeing?

@seddonm1
Copy link
Author

seddonm1 commented Feb 1, 2023

Hi, I am sorry but I cannot provide the models and I am executing using the C api via Rust bindings so a bit difficult to share.

I have been running profiling (which I can provide) to understand where the bottleneck is and have found some interesting things.

As I mentioned above I have three onnxsessions running in parallel (overlapping) that:

  1. model copies image data to gpu (io_binding cpu in, cuda out)
  2. model resizes and pads image (io_binding cuda in, cuda out)
  3. model executes ml inference against that (io_binding cuda in, cpu out)

Individually each model's 50th percentile is faster on 1.14.0 vs 1.13.1 yet some sort of scheduling delay is causing the end-to-end time to increase. The speedup of ~2% can be measured for 1.14.0 vs 1.13.1 when executing a single model.

median model_run time in nanoseconds per model/version

|   | 1.13.1   | 1.14.0   | improvement |
|---|----------|----------|-------------|
| 1 |   645000 |   602000 |          7% |
| 2 |  2665000 |  1320000 |        102% |
| 3 | 12530000 | 12359000 |          1% |

FYI SQL query for calculation (skipping initial 100 runs to allow for warmup):

WITH p AS (SELECT dur, NTILE(10) OVER (ORDER BY dur) AS percentile
           FROM slice
           WHERE name = 'model_run'
           AND id > 100)
SELECT percentile, MAX(dur) as dur
FROM p
GROUP BY percentile;

I will continue investigating.

@seddonm1
Copy link
Author

seddonm1 commented Feb 2, 2023

Another update.

I have run NVidia NSight profiling while running 1.13.1 and 1.14.0.

The biggest change I can see is that the Memcpy HtoD is running in parallel on 1.13.1 and sequentially on 1.14.0. Also that Memcpy is running on an arbitrary stream on 1.13.1 the Default stream in 1.14.0. I have timed the two parts that were previously running in parallel (the six red parts and the green bit) and it sums up to ~2ms which aligns perfectly with my initial findings.

1.13.1
1 13 1

1.14.0
1 14 0

@jywu-msft
Copy link
Member

I wanted to confirm all the variables here and ask some follow-up questions.
ORT version changed from 1.13.1 -> 1.14.0
TensorRT version changed from 8.4 -> 8.5.x , correct?

I assume the application code to run both ORT versions is exactly the same? (no change in api usage?)
Can the c api code be shared (for setting up session, cuda stream, io bindings etc.)
Does the memcpy HtoD in the profile correspond to Model 1's iobinding cpu in, cuda out?
I'm trying to better understand how the memcpy could have been parallelized with compute in 1.13 case. model 2 needs to wait for the HtoD to complete.
lastly, is there any chance you can provide a minimal repro since you can't share the models?

@seddonm1
Copy link
Author

seddonm1 commented Feb 2, 2023

Hi.
Correct all the code is exactly the same except 1.13.1 to1.14.0 (built from rel-1.14.0 branch yesterday). We have been on TensorRt 8.5 for quite a while.

I will push my updated Rust bindings soon and can link here.

I am executing this per frame of a video so the parallelism seen must be for the next frame. For clarity I have three independent sessions running in parallel in their own threads passing the ortvalue of the iobinding to the next model. We are doing nothing to explicitly set up any cuda streams. It looks to me like the cuda option do_copy_in_default_stream is active now?

I guess I can try to repro a minimal example but quite hard.

@jywu-msft
Copy link
Member

I don't believe there any changes in TensorRT Execution Provider between 1.13 and 1.14 that could have resulted in this behavior.
And given that there's no application or TensorRT/CUDA version changes,
#13495 is a possible suspect
Can you test with f4cd35f which is the commit before that PR was merged to help us narrow it down?

@seddonm1
Copy link
Author

seddonm1 commented Feb 4, 2023

Hi.
Thank you for your help so far and suggestion for debug.

I have built both f4cd35f and a81faee which is the commit directly before and after the commit of #13495.

I think these results conclusively identify #13495 as the cause of the performance regression.

BEFORE #13495:

f4cd35f9b1301f54d65a3e59c525b92e85bf384e 2m32.615s
|===============================+======================+======================|
|   0  NVIDIA A2           Off  | 00000000:51:00.0 Off |                    0 |
|  0%   60C    P0    60W /  60W |   1050MiB / 15356MiB |     97%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

f4cd35f9b1301f54d65a3e59c525b92e85bf384e

AFTER #13495:

a81faee41ef2344de448caecb0f42a34fdc9ead7 2m43.328s
|===============================+======================+======================|
|   0  NVIDIA A2           Off  | 00000000:51:00.0 Off |                    0 |
|  0%   56C    P0    58W /  60W |   1106MiB / 15356MiB |     95%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

a81faee41ef2344de448caecb0f42a34fdc9ead7

For clarity these are the versions involved:

NVIDIA-SMI 525.60.13    Driver Version: 525.60.13    CUDA Version: 12.0
libnvinfer8/unknown,now 8.5.2-1+cuda11.8 amd64 [installed,upgradable to: 8.5.3-1+cuda11.8]
libnvinfer-plugin8/unknown,now 8.5.2-1+cuda11.8 amd64 [installed,upgradable to: 8.5.3-1+cuda11.8]

@jywu-msft
Copy link
Member

Hi. Thank you for your help so far and suggestion for debug.

I have built both f4cd35f and a81faee which is the commit directly before and after the commit of #13495.

I think these results conclusively identify #13495 as the cause of the performance regression.

BEFORE #13495:

f4cd35f9b1301f54d65a3e59c525b92e85bf384e 2m32.615s
|===============================+======================+======================|
|   0  NVIDIA A2           Off  | 00000000:51:00.0 Off |                    0 |
|  0%   60C    P0    60W /  60W |   1050MiB / 15356MiB |     97%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
f4cd35f9b1301f54d65a3e59c525b92e85bf384e

AFTER #13495:

a81faee41ef2344de448caecb0f42a34fdc9ead7 2m43.328s
|===============================+======================+======================|
|   0  NVIDIA A2           Off  | 00000000:51:00.0 Off |                    0 |
|  0%   56C    P0    58W /  60W |   1106MiB / 15356MiB |     95%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
a81faee41ef2344de448caecb0f42a34fdc9ead7

For clarity these are the versions involved:

NVIDIA-SMI 525.60.13    Driver Version: 525.60.13    CUDA Version: 12.0
libnvinfer8/unknown,now 8.5.2-1+cuda11.8 amd64 [installed,upgradable to: 8.5.3-1+cuda11.8]
libnvinfer-plugin8/unknown,now 8.5.2-1+cuda11.8 amd64 [installed,upgradable to: 8.5.3-1+cuda11.8]

Thanks a lot for running these experiments and helping narrow it down!
I will loop in other devs to this thread.

@jywu-msft
Copy link
Member

+@souptc FYI

@souptc
Copy link
Member

souptc commented Feb 6, 2023

hi @seddonm1, could you please share how did you copy the input to GPU? based on my understanding, in onnxruntime, we never put the input memory copy on separate stream. not suer is there any advance feature that we missed during the refactoring.

@souptc
Copy link
Member

souptc commented Feb 6, 2023

from the profiling you shared, it seems in ORT-1.13, the MemCpy HtoD is put on a stream with other ort kernels. I'd like to confirm following thigs:

  1. in the ort-1.13 profiling, is the stream with MemCpy HtoD associated with your pre-processing model (copy in, cuda out)
  2. if the answer to 1) is yes, for the pre-processing model, which EP you registered for it, Cuda EP or TensorRT EP?
  3. you are not using any ORT minimal build, right?
  4. you are not using ORT training build, right?

@seddonm1
Copy link
Author

seddonm1 commented Feb 6, 2023

Hi @souptc.

  1. The copy from cpu to gpu happens via the TensorrtProvider using io_binding with a BindInput from a CPU OrtValue and an output Cuda OrtValue. I am leaving all stream allocation to the OnnxRuntime.
  2. See below.
  3. TensorRT
  4. I am using this process to build ORT: https://gist.github.com/seddonm1/5dc956b136a98b8d46d03ca205252b4a
  5. See point 3.

I have exported more detailed views of the Nsight profile showing the individual threads.

~1.13.1 (f4cd35f) build:

f4cd35f9b1301f54d65a3e59c525b92e85bf384e

The streams are:

  1. GPU Hardware Video decoder decodes and copies data to CPU - Default stream 7. You can see the memcpy in red.
  2. model copies image data to gpu (io_binding cpu in, cuda out) - Stream 15. You can see the big memcpy in green followed by a small kernel execution.
  3. model resizes and pads image (io_binding cuda in, cuda out) - Stream 13. You can see the preprocessing steps executed using Tensorrt.
  4. model executes ml inference against that (io_binding cuda in, cpu out) - Stream 14. You can see the many steps of the CNN.

~1.14.0 (a81faee)

a81faee41ef2344de448caecb0f42a34fdc9ead7

In this one the streams are:

  1. GPU Hardware Video decoder decodes and copies data to CPU - Default stream 7. You can see the memcpy in red copying the data from the GPU to the CPU (seen at around +111ms).
  2. model copies image data to gpu (io_binding cpu in, cuda out) - Default stream 7. You can see the memcpy but maybe there is now contention due to the use of the default stream?
  3. model resizes and pads image (io_binding cuda in, cuda out) - Stream 55. You can see the preprocessing steps executed using Tensorrt.
  4. model executes ml inference against that (io_binding cuda in, cpu out) - Stream 90. You can see the many steps of the CNN.

I guess the big questions are:

  • why is the model execution in step 2 happening on the default thread now?
  • is there something special about the default thread that stops it executing in parallel?

@souptc
Copy link
Member

souptc commented Feb 7, 2023

the model execution in step 2 (memcpy HtoD + a small kernels) is not on default stream, you can see in your ort 1.14 profiling, the small kernel is executed on stream 16. It is just the memcpy been put on default stream.

From the implementation, we only put the graph input copy on default stream for two cases:

  1. there is no stream created on this device. see here.
  2. it is a minimal build. see here.

but i don't see how could your case following into any of those two categories. We will try to reproduce what happened with iobinding, but could you share what you did in that small kernel?

@seddonm1
Copy link
Author

seddonm1 commented Feb 7, 2023

Thanks @souptc - you are better at interpreting these Nsight reports than I am.

The operation that you are referring to on stream 16 after the memcpy is a Scale operation.

Step 2 does three things:

  • copy a tensor input from host to device (automatically handled by ORT)
  • Cast uint8 -> fp32
  • Mul operator with 0.0039 (normalization) which is the 'Scale' you can see executed using Tensorrt.

a81faee41ef2344de448caecb0f42a34fdc9ead7-zoom

@souptc
Copy link
Member

souptc commented Feb 7, 2023

thanks @seddonm1 . Could you help me for two more experiments:

  1. start a process, just run step2 (memcpy HtoD + cast + Mul), whether the MemCpy happened on default stream?
  2. if yes, could you share how you create inference session and iobinding for this small model? if you can share that model will be great.

@seddonm1
Copy link
Author

seddonm1 commented Feb 7, 2023

@souptc.

Yes, you can see the issue happens with just this one model.

step2

I am using the C API via Rust bindings but not doing anything exciting. I can post my bindings tomorrow but I am creating an OrtEnvironment with global inter/intra op threads 16, then an OrtSession with execution_mode parallel, graph_optimization all, mem_pattern true, disable_per_session_threads.

Do you have an email I can send you the model? I will need to verify with my employer before I can send. In the meantime I can easily add some cout statements to the https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/utils.cc file to help debug?

@souptc
Copy link
Member

souptc commented Feb 7, 2023

sure, could you mail to chenta@microsoft.com? not only model, i am more interested in how you configure the session and iobinding, probably some usage i previously missed. do you mind to share that script together?

if you want to debug it locally, i would suggest to check here, what is the stream and the "stream->handle" you got underline. if it is on default stream, you will get a NULL here.

@seddonm1
Copy link
Author

seddonm1 commented Feb 8, 2023

Hi @souptc I have sent you the ONNX model to your email address and aim to do some local debugging today.

@noujaimc
Copy link

Any updates ? We have the same problem.

@seddonm1
Copy link
Author

@noujaimc I have been working with @souptc and @jslhcl to reproduce the issue. It looks like an issue relating to io_binding and they have created a PR a few minutes ago that hopefully will fix the issue.

Thank you very much @souptc and @jslhcl for your excellent support.

@seddonm1
Copy link
Author

Hi,
I have now run benchmarking on three releases. 1.13.1 official release, 1.14.0 official release and 1.14.+ where this PR was compiled: #14719

| version | commit                                   | runtime  | regression |
|---------|------------------------------------------|----------|------------|
| 1.13.1  | b353e0b41d588605958b03f9a223d10a2fbeb514 | 00:02:28 |            |
| 1.14.0  | 6ccaeddefa65ccac402a47fa4d9cad8229794bb2 | 00:02:43 |        10% |
| 1.14.x  | e4b9d54d81442b4f36634fb8379d2cef1b506769 | 00:02:39 |         7% |

My results show that the #14719 PR does help reduce the regression but runtime vs baseline 1.13.1 is +7% with this PR and +10% without the PR applied - so still a regression.

Are these results inline with your testing @noujaimc ?

@souptc @jslhcl should we continue investigating?

@jslhcl
Copy link
Contributor

jslhcl commented Feb 20, 2023

@seddonm1 Thanks for the testing result. Is the memcpy op still launching on the default stream?

@jslhcl
Copy link
Contributor

jslhcl commented Feb 20, 2023

Are you testing on the small model shared with us or the original model? @seddonm1

jslhcl added a commit that referenced this issue Feb 20, 2023
### Description
Create new stream for data copy for IOBidning input scenario



### Motivation and Context
Previously in bindInput(), a nullptr Stream is passed to copy data cross
device. This caused the default stream is used thus hurt the
performance.
This PR is to fix #14484

---------

Co-authored-by: Lei Cao <leca@microsoft.com>
@seddonm1
Copy link
Author

@jslhcl
These benchmarks were running the original model from the start of this thread (i.e. 3 ONNX models in parallel). I think #14719 has solved part of the problem but there is still a clear regression from 1.13.1 (ignoring that 1.14.x should perform better than 1.13.1). I don't think this issue should be closed?

Here is and updated profile that does show the memcpy HtoD is no longer on the default stream (#14719):
1 14 x

@jslhcl
Copy link
Contributor

jslhcl commented Feb 21, 2023

Thanks for the result. I didn't close the issue manually, it is automatically closed by the PR. Let me talk to Cheng tomorrow on this issue.

@jslhcl
Copy link
Contributor

jslhcl commented Feb 21, 2023

@seddonm1 do you mind sharing the nsys profiling reports on both feature branch and main branch? Thank you very much!

PatriceVignola pushed a commit that referenced this issue Feb 22, 2023
### Description
Create new stream for data copy for IOBidning input scenario



### Motivation and Context
Previously in bindInput(), a nullptr Stream is passed to copy data cross
device. This caused the default stream is used thus hurt the
performance.
This PR is to fix #14484

---------

Co-authored-by: Lei Cao <leca@microsoft.com>
PatriceVignola pushed a commit that referenced this issue Feb 22, 2023
### Description
Create new stream for data copy for IOBidning input scenario



### Motivation and Context
Previously in bindInput(), a nullptr Stream is passed to copy data cross
device. This caused the default stream is used thus hurt the
performance.
This PR is to fix #14484

---------

Co-authored-by: Lei Cao <leca@microsoft.com>
@seddonm1
Copy link
Author

@jslhcl I have sent the profiles for 1.13.1, 1.14.0 and 1.14.x to you via email.

If anyone reading is able to provide a reproducible example using public models that would really help.

@jslhcl
Copy link
Contributor

jslhcl commented Apr 4, 2023

It should be fixed in the latest main branch code, that the created stream will work concurrently with the default stream without any explicit synchronization

@seddonm1
Copy link
Author

seddonm1 commented Apr 4, 2023

Hi,
I have tested 1.13.1 vs 1.14.1 with this fix applied and can confirm runtime is now exactly the same for my test job. I think this probably justifies a 1.14.2 release as a ~5-10% regression is significant.

Thanks everyone 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider regression issues that demonstrate a regression in ORT functionality and need to be addressed immediately
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants