Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update custom Triton kernel documentation and examples #20883

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from

Conversation

Numeri
Copy link

@Numeri Numeri commented May 31, 2024

Description

Updates and improves on the existing documentation for using custom Triton kernels as operators in ONNX Runtime. Also has a small fix in the Python script for compiling Triton kernels, and allows compiling kernels outside of the docker build step.

Motivation and Context

I wanted to write my own ORT operators using Triton, but the only examples in the codebase are for adding a kernel to an existing TunableOp, with code spread over a large number of files.

This is my best attempt at making a more minimal example, with documentation of each step needed to do this independently.

Note

Calling the operator defined here currently triggers a CUDA error:

2024-05-31 15:58:30.844121139 [E:onnxruntime:Default, cuda_call.cc:118 CudaCall] CUDA failure 700: an illegal memory access was encountered ; GPU=0 ; hostname=998ab211f19f ; file=/code/onnxruntime/core/providers/cuda/gpu_data_transfer.cc ; line=73 ; expr=cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, static_cast<cudaStream_t>(stream.GetHandle())); 
2024-05-31 15:58:30.844186034 [E:onnxruntime:Default, cuda_call.cc:118 CudaCall] CUDA failure 700: an illegal memory access was encountered ; GPU=0 ; hostname=998ab211f19f ; file=/code/onnxruntime/core/providers/cuda/cuda_execution_provider.cc ; line=446 ; expr=cudaStreamSynchronize(static_cast<cudaStream_t>(stream_)); 

I'm opening this as a draft PR until I can fix this – I'm also hoping I can get help fixing this, as I assume it's just a small mistake with how I'm passing in the CUDA stream to the kernel. If this isn't welcome, I can close this PR until I get it fixed.

@tianleiwu
Copy link
Contributor

tianleiwu commented May 31, 2024

Even though this approach works when you build from source and run in current machine, the binary might not be able to run in another GPU.

If we want to make the binary to support different GPUs, we need either invoke it in python (so the kernel is compiled by trion in the machine running the kernel, but it will have dependency on triton and python) or use triton ahead-of-time (AoT) compiler (need compile cubins for GPUs with different compute capability).

We have an example (The SparseAttention operator) using AoT compiler: https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1

@Numeri
Copy link
Author

Numeri commented Jun 2, 2024

@tianleiwu Hi! Thanks for responding. I'm aware that this technique only works for the machine/GPU architecture you compile it on – that's also a shortcoming of the existing SoftmaxTriton kernel provided as an example in the existing documentation. That's also the reason that these features are correctly disabled unless the flag --use_triton_kernel is provided.

There are plenty of people (myself included) who only need their Triton kernels to run on a certain GPU, and who would love to be able to quickly add Triton kernels as ORT operators. The code for the sparse attention op is clever, but really not very approachable or maintainable for someone who just wants to get a Triton kernel working quickly.

@Numeri
Copy link
Author

Numeri commented Jun 2, 2024

Even though this approach works when you build from source and run in current machine

I do understand if this isn't accepted in the ORT codebase because of this, but maybe then we could work together on a better way to do it.

@@ -0,0 +1,77 @@
#include "contrib_ops/cuda/my_triton_kernel.h"

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
@@ -0,0 +1,23 @@
#pragma once

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
@@ -24,6 +24,8 @@ using namespace onnxruntime::common;
namespace onnxruntime {
namespace contrib {
namespace cuda {
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MyTritonKernel);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use #ifdef USE_TRITON_KERNEL to conditional compile the triton operator.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@tianleiwu
Copy link
Contributor

I do understand if this isn't accepted in the ORT codebase because of this, but maybe then we could work together on a better way to do it.

It's great that you can contribute a simple example and update the documentation.

It may be helpful to add an example of vector-add or layer norm (only need forward for inference).

@Numeri
Copy link
Author

Numeri commented Jun 4, 2024

@microsoft-github-policy-service agree company="Lilt"

@Numeri
Copy link
Author

Numeri commented Jun 4, 2024

@tianleiwu I think switching to a somewhat more interesting kernel is a good idea – I just reduced it to this minimal working example to try debug this: #20885 (also mentioned above).

Do you have any idea why my call to LaunchTritonKernel is causing memory access errors?

@Numeri
Copy link
Author

Numeri commented Jun 26, 2024

Sorry to ping this – is there anyone that I could ask about this?

@tianleiwu
Copy link
Contributor

tianleiwu commented Jun 26, 2024

@Numeri, I do not have idea why LaunchTritonKernel is causing memory access errors.

You can do some debugging, like starts with tensor of one element, and add some printf before calling triton kernel and also some print inside triton kernel to ensure that memory pointers and other parameters are correctly passed.

@xiaoyu1215
Copy link

xiaoyu1215 commented Sep 14, 2024

@Numeri Hi, I also implemented a triton kernel in onnxruntime basis your guide.
And i also meet your issue about cuda memory, but my error infomation is different from yours.

2024-09-14 11:08:23.259158746 [E:onnxruntime:Default, cuda_call.cc:118 CudaCall] CUDA failure 716: misaligned address ; GPU=0 ; hostname=a100-5 ; file=/onnxruntime-1.19.0/onnxruntime/core/providers/cuda/gpu_data_transfer.cc ; line=73 ; expr=cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, static_cast<cudaStream_t>(stream.GetHandle()));
2024-09-14 11:08:23.259366382 [E:onnxruntime:Default, cuda_call.cc:118 CudaCall] CUDA failure 716: misaligned address ; GPU=0 ; hostname=a100-5 ; file=/onnxruntime-1.19.0/onnxruntime/core/providers/cuda/cuda_execution_provider.cc ; line=446 ; expr=cudaStreamSynchronize(static_cast<cudaStream_t>(stream_));

Sometimes I encounter the error message above, sometimes the program can be executed properly.(It seems that finding an error is a sporadic) [Maybe my input size is too small.]
Do you have any method to debug/solve this issue?

For divide ==========================================================
When the program can be executed properly, I also encountered a very strange problem. I wonder if you have encountered it before or have any good solution.

The input I passed to the triton kernel was 0, which was not the input I expected. So the output of triton kernel alse is 0.
I'm not sure if i missed something when i integrated the code or others reasons. Could you give me some advice to debug this problem.

the log like:

$ ./onnxruntime_test_all  --gtest_filter=VecAddTritonKernelContribOpTest*  # this is test case of my implemented
Note: Google Test filter = VecAddTritonKernelContribOpTest*
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from VecAddTritonKernelContribOpTest
[ RUN      ] VecAddTritonKernelContribOpTest.float_plus_float_float
pid (0, 0, 0) idx (   0) [debug] x -: 0.000000
pid (0, 0, 0) idx (   1) [debug] x -: 0.000000
pid (0, 0, 0) idx (   2) [debug] x -: 0.000000
pid (0, 0, 0) idx (   3) [debug] x -: 0.000000
pid (0, 0, 0) idx (   4) [debug] x -: 0.000000
pid (0, 0, 0) idx (   5) [debug] x -: 0.000000
pid (0, 0, 0) idx (   6) [debug] x -: 0.000000
pid (0, 0, 0) idx (   7) [debug] x -: 0.000000
pid (0, 0, 0) idx (   8) [debug] x -: 0.000000
pid (0, 0, 0) idx (   9) [debug] x -: 0.000000
......

/onnxruntime-1.19.0/onnxruntime/test/providers/checkers.cc:393: Failure
The difference between cur_expected[i] and cur_actual[i] is 4, which exceeds tolerance, where
cur_expected[i] evaluates to 4,
cur_actual[i] evaluates to 0, and
tolerance evaluates to 0.0014000000664964318.
i:0
Google Test trace:
/onnxruntime-1.19.0/onnxruntime/test/providers/checkers.cc:569: provider type: CUDAExecutionProvider
/onnxruntime-1.19.0/onnxruntime/test/providers/base_tester.cc:832: registered execution providers: CUDAExecutionProvider
Stack trace:
  0x5588bee3570f: onnxruntime::test::(anonymous namespace)::InternalNumericalCheck<>()
  0x5588bee32ece: onnxruntime::test::(anonymous namespace)::TensorCheck<>::operator()()
  0x5588bee38501: onnxruntime::utils::mltype_dispatcher_internal::CallableDispatchableHelper::Invoke<>()
  0x5588bee36b9c: onnxruntime::utils::MLTypeCallDispatcher<>::InvokeWithLeadingTemplateArgs<>()
  0x5588bee35d80: onnxruntime::utils::MLTypeCallDispatcher<>::Invoke<>()
  0x5588bee34031: onnxruntime::test::Check<>()
  0x5588bee35f32: onnxruntime::test::CheckDispatch<>()
  0x5588bee345d0: onnxruntime::test::CheckOrtValuesAreEqual()
  0x5588bee2ace8: onnxruntime::test::BaseTester::ExecuteModel<>()
  0x5588bee2507d: onnxruntime::test::BaseTester::ExecuteModelForEps()
  0x5588bee23834: onnxruntime::test::BaseTester::RunWithConfig()
  0x5588bee22306: onnxruntime::test::BaseTester::Run()
  0x5588bee22194: onnxruntime::test::BaseTester::Run()
  0x5588bedec24d: onnxruntime::test::VecAddTritonKernelContribOpTest_float_plus_float_float_Test::TestBody()
  0x5588c15a6bf2: testing::internal::HandleSehExceptionsInMethodIfSupported<>()
  0x5588c159fcba: testing::internal::HandleExceptionsInMethodIfSupported<>()
  0x5588c1583bcc: testing::Test::Run()
  0x5588c1584610: testing::TestInfo::Run()
... Google Test internal frames ...

This is the test script I implemented.
Location is: onnxruntime/test/contrib_ops/test_vec_add_triton.cc

#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "core/util/math.h"
namespace onnxruntime {
namespace test {
TEST(VecAddTritonKernelContribOpTest, float_plus_float_float) {
  OpTester test("VecAddTritonKernel", 1, kMSDomain);
 
  std::vector<float> X1 = {2.0f, 2.0f};
  std::vector<float> X2 = {2.0f, 2.0f};
  std::vector<float> output = {4.0f, 4.0f};  

  test.AddInput<float>("X", {2}, X1);
  test.AddInput<float>("Y", {2}, X2);
  test.AddOutput<float>("output", {2}, output);

  test.Run();
}

}  // namespace test
}  // namespace onnxruntime

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants