Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pynvjitlink as a dependency #14763

Merged
merged 19 commits into from
Jan 19, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions python/cudf/cudf/utils/_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from functools import lru_cache

from numba import config as numba_config
from pynvjitlink.patch import (
patch_numba_linker as patch_numba_linker_pynvjitlink,
)


# Use an lru_cache with a single value to allow a delayed import of
Expand Down Expand Up @@ -135,7 +132,9 @@ def _setup_numba():
if driver_version < (12, 0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the comment above to mention pynvjitlink and the corresponding role of that package? This comment:

    # ptxcompiler is a requirement for cuda 11.x packages but not
    # cuda 12.x packages. However its version checking machinery
    # is still necessary. If a user happens to have ptxcompiler
    # in a cuda 12 environment, it's use for the purposes of
    # checking the driver and runtime versions is harmless

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@brandon-b-miller I would generally advocate reviewing this entire file and any other files that relate to ptxcompiler/pynvjitlink to make sure things are named sensibly, etc. in a way that will support both CUDA 11 and CUDA 12+. I want the code comments and docs to reflect the implemented design going forward.

Keep in mind that we don't want to name things "CUDA 12" in the code if we can avoid it if it is likely that later versions will act in the same way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about something like 7dbf9f2 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a CUDA 12.x environment, ptxcompiler provides version checking, but not MVC directly

Is this true? We don't use ptxcompiler in CUDA 12 environments. No environment should have both ptxcompiler and pynvjitlink installed at the same time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's technically _ptxcompiler.py in this case - our slimmed down, vendored version of the few functions we need.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooooo. But I don't know how to distinguish ptxcompiler the package (only used when on CUDA 11) from _ptxcompiler.py the internal helper file (always active) from the text of this comment. Documenting that kind of thing clearly is what I want to achieve before merging this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some reworking in e8a90b9

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much clearer! Thanks for iterating on this.

patch_numba_linker_cuda_11()
else:
patch_numba_linker_pynvjitlink()
from pynvjitlink.patch import patch_numba_linker

patch_numba_linker()


def _get_cuda_version_from_ptx_file(path):
Expand Down
Loading