Skip to content

Commit

Permalink
cuda module: fully type annotate
Browse files Browse the repository at this point in the history
Special notes:
- _nvcc_arch_flags is always called with exact arguments, no need for
  default values
- min_driver_version has its args annotation loosened because it has to
  fit the constraints of the module interface?
  • Loading branch information
eli-schwartz committed Feb 13, 2024
1 parent 5899daf commit 65ee397
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
48 changes: 25 additions & 23 deletions mesonbuild/modules/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,39 @@

from __future__ import annotations

import typing as T
import re
import typing as T

from ..mesonlib import version_compare
from ..mesonlib import listify, version_compare
from ..compilers.cuda import CudaCompiler
from ..interpreter.type_checking import NoneType

from . import NewExtensionModule, ModuleInfo

from ..interpreterbase import (
ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, typed_pos_args,
ContainerTypeInfo, InvalidArguments, KwargInfo, noKwargs, typed_kwargs, typed_pos_args,
)

if T.TYPE_CHECKING:
from typing_extensions import TypedDict

from . import ModuleState
from ..interpreter import Interpreter
from ..interpreterbase import TYPE_var

class ArchFlagsKwargs(TypedDict):
detected: T.Optional[T.List[str]]

AutoArch = T.Union[str, T.List[str]]


DETECTED_KW: KwargInfo[T.Union[None, T.List[str]]] = KwargInfo('detected', (ContainerTypeInfo(list, str), NoneType), listify=True)

class CudaModule(NewExtensionModule):

INFO = ModuleInfo('CUDA', '0.50.0', unstable=True)

def __init__(self, *args, **kwargs):
def __init__(self, interp: Interpreter):
super().__init__()
self.methods.update({
"min_driver_version": self.min_driver_version,
Expand All @@ -41,7 +45,7 @@ def __init__(self, *args, **kwargs):

@noKwargs
def min_driver_version(self, state: 'ModuleState',
args: T.Tuple[str],
args: T.List[TYPE_var],
kwargs: T.Dict[str, T.Any]) -> str:
argerror = InvalidArguments('min_driver_version must have exactly one positional argument: ' +
'a CUDA Toolkit version string. Beware that, since CUDA 11.0, ' +
Expand Down Expand Up @@ -113,41 +117,39 @@ def nvcc_arch_readable(self, state: 'ModuleState',
return ret

@staticmethod
def _break_arch_string(s):
def _break_arch_string(s: str) -> T.List[str]:
s = re.sub('[ \t\r\n,;]+', ';', s)
s = s.strip(';').split(';')
return s
return s.strip(';').split(';')

@staticmethod
def _detected_cc_from_compiler(c) -> T.List[str]:
def _detected_cc_from_compiler(c: T.Union[str, CudaCompiler]) -> T.List[str]:
if isinstance(c, CudaCompiler):
return [c.detected_cc]
return []

def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]], kwargs: ArchFlagsKwargs):
def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]],
kwargs: ArchFlagsKwargs) -> T.Tuple[str, AutoArch, T.List[str]]:

compiler = args[0]
if isinstance(compiler, CudaCompiler):
cuda_version = compiler.version
else:
cuda_version = compiler

arch_list = args[1]
arch_list = [self._break_arch_string(a) for a in arch_list]
arch_list = flatten(arch_list)
arch_list: AutoArch = args[1]
arch_list = listify([self._break_arch_string(a) for a in arch_list])
if len(arch_list) > 1 and not set(arch_list).isdisjoint({'All', 'Common', 'Auto'}):
raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')
arch_list = arch_list[0] if len(arch_list) == 1 else arch_list

detected = kwargs['detected'] if kwargs['detected'] is not None else self._detected_cc_from_compiler(compiler)
detected = [self._break_arch_string(a) for a in detected]
detected = flatten(detected)
detected = [x for a in detected for x in self._break_arch_string(a)]
if not set(detected).isdisjoint({'All', 'Common', 'Auto'}):
raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')

return cuda_version, arch_list, detected

def _filter_cuda_arch_list(self, cuda_arch_list, lo: str, hi: T.Optional[str], saturate: str) -> T.List[str]:
def _filter_cuda_arch_list(self, cuda_arch_list: T.List[str], lo: str, hi: T.Optional[str], saturate: str) -> T.List[str]:
"""
Filter CUDA arch list (no codenames) for >= low and < hi architecture
bounds, and deduplicate.
Expand All @@ -165,7 +167,7 @@ def _filter_cuda_arch_list(self, cuda_arch_list, lo: str, hi: T.Optional[str], s
filtered_cuda_arch_list.append(arch)
return filtered_cuda_arch_list

def _nvcc_arch_flags(self, cuda_version, cuda_arch_list='Auto', detected=''):
def _nvcc_arch_flags(self, cuda_version: str, cuda_arch_list: AutoArch, detected: T.List[str]) -> T.Tuple[T.List[str], T.List[str]]:
"""
Using the CUDA Toolkit version and the target architectures, compute
the NVCC architecture flags.
Expand Down Expand Up @@ -288,11 +290,11 @@ def _nvcc_arch_flags(self, cuda_version, cuda_arch_list='Auto', detected=''):

cuda_arch_list = sorted(x for x in set(cuda_arch_list) if x)

cuda_arch_bin = []
cuda_arch_ptx = []
cuda_arch_bin: T.List[str] = []
cuda_arch_ptx: T.List[str] = []
for arch_name in cuda_arch_list:
arch_bin = []
arch_ptx = []
arch_bin: T.Optional[T.List[str]]
arch_ptx: T.Optional[T.List[str]]
add_ptx = arch_name.endswith('+PTX')
if add_ptx:
arch_name = arch_name[:-len('+PTX')]
Expand Down Expand Up @@ -371,5 +373,5 @@ def _nvcc_arch_flags(self, cuda_version, cuda_arch_list='Auto', detected=''):

return nvcc_flags, nvcc_archs_readable

def initialize(*args, **kwargs):
return CudaModule(*args, **kwargs)
def initialize(interp: Interpreter) -> CudaModule:
return CudaModule(interp)
1 change: 1 addition & 0 deletions run_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
'mesonbuild/mlog.py',
'mesonbuild/msubprojects.py',
'mesonbuild/modules/__init__.py',
'mesonbuild/modules/cuda.py',
'mesonbuild/modules/external_project.py',
'mesonbuild/modules/fs.py',
'mesonbuild/modules/gnome.py',
Expand Down

0 comments on commit 65ee397

Please sign in to comment.