Skip to content

Commit

Permalink
cuda module: use typed_pos_args for most methods
Browse files Browse the repository at this point in the history
The min_driver_version function has an extensive, informative custom
error message, so leave that in place.

The other two functions didn't have much information there, and it's
fairly evident that the cuda compiler itself is the best thing to have
here. Moreover, there was some fairly gnarly code to validate the
allowed values, which we can greatly simplify by uplifting the
typechecking parts to the dedicated decorators that are both really good
at it, and have nicely formatted error messages complete with reference
to the problematic functions.
  • Loading branch information
eli-schwartz committed Feb 13, 2024
1 parent 1b15176 commit 5899daf
Showing 1 changed file with 11 additions and 16 deletions.
27 changes: 11 additions & 16 deletions mesonbuild/modules/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
from . import NewExtensionModule, ModuleInfo

from ..interpreterbase import (
ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs,
ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, typed_pos_args,
)

if T.TYPE_CHECKING:
from typing_extensions import TypedDict

from . import ModuleState
from ..compilers import Compiler

class ArchFlagsKwargs(TypedDict):
detected: T.Optional[T.List[str]]
Expand Down Expand Up @@ -95,17 +94,19 @@ def min_driver_version(self, state: 'ModuleState',

return driver_version

@typed_pos_args('cuda.nvcc_arch_flags', (str, CudaCompiler), varargs=str)
@typed_kwargs('cuda.nvcc_arch_flags', DETECTED_KW)
def nvcc_arch_flags(self, state: 'ModuleState',
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
args: T.Tuple[T.Union[CudaCompiler, str], T.List[str]],
kwargs: ArchFlagsKwargs) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[0]
return ret

@typed_pos_args('cuda.nvcc_arch_readable', (str, CudaCompiler), varargs=str)
@typed_kwargs('cuda.nvcc_arch_readable', DETECTED_KW)
def nvcc_arch_readable(self, state: 'ModuleState',
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
args: T.Tuple[T.Union[CudaCompiler, str], T.List[str]],
kwargs: ArchFlagsKwargs) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[1]
Expand All @@ -123,21 +124,15 @@ def _detected_cc_from_compiler(c) -> T.List[str]:
return [c.detected_cc]
return []

def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs):
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]], kwargs: ArchFlagsKwargs):

if len(args) < 1:
raise argerror
compiler = args[0]
if isinstance(compiler, CudaCompiler):
cuda_version = compiler.version
else:
compiler = args[0]
if isinstance(compiler, CudaCompiler):
cuda_version = compiler.version
elif isinstance(compiler, str):
cuda_version = compiler
else:
raise argerror
cuda_version = compiler

arch_list = [] if len(args) <= 1 else flatten(args[1:])
arch_list = args[1]
arch_list = [self._break_arch_string(a) for a in arch_list]
arch_list = flatten(arch_list)
if len(arch_list) > 1 and not set(arch_list).isdisjoint({'All', 'Common', 'Auto'}):
Expand Down

0 comments on commit 5899daf

Please sign in to comment.