diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index b415266d38886..ce84a55e6dee0 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -14,7 +14,7 @@ use crate::stdlib::{ builtins_symbol_ty, types_symbol_ty, typeshed_symbol_ty, typing_extensions_symbol_ty, }; use crate::types::narrow::narrowing_constraint; -use crate::{Db, FxOrderSet}; +use crate::{Db, FxOrderSet, Module}; pub(crate) use self::builder::{IntersectionBuilder, UnionBuilder}; pub(crate) use self::diagnostic::TypeCheckDiagnostics; @@ -385,14 +385,6 @@ impl<'db> Type<'db> { } } - pub fn builtin_str_instance(db: &'db dyn Db) -> Self { - builtins_symbol_ty(db, "str").to_instance(db) - } - - pub fn builtin_int_instance(db: &'db dyn Db) -> Self { - builtins_symbol_ty(db, "int").to_instance(db) - } - pub fn is_stdlib_symbol(&self, db: &'db dyn Db, module_name: &str, name: &str) -> bool { match self { Type::Class(class) => class.is_stdlib_symbol(db, module_name, name), @@ -423,19 +415,17 @@ impl<'db> Type<'db> { (_, Type::Unknown | Type::Any | Type::Todo) => false, (Type::Never, _) => true, (_, Type::Never) => false, - (Type::IntLiteral(_), Type::Instance(class)) - if class.is_stdlib_symbol(db, "builtins", "int") => - { + (Type::IntLiteral(_), Type::Instance(class)) if class.is_known(db, KnownClass::Int) => { true } (Type::StringLiteral(_), Type::LiteralString) => true, (Type::StringLiteral(_) | Type::LiteralString, Type::Instance(class)) - if class.is_stdlib_symbol(db, "builtins", "str") => + if class.is_known(db, KnownClass::Str) => { true } (Type::BytesLiteral(_), Type::Instance(class)) - if class.is_stdlib_symbol(db, "builtins", "bytes") => + if class.is_known(db, KnownClass::Bytes) => { true } @@ -443,8 +433,8 @@ impl<'db> Type<'db> { .elements(db) .iter() .any(|&elem_ty| ty.is_subtype_of(db, elem_ty)), - (_, Type::Instance(class)) if class.is_stdlib_symbol(db, "builtins", "object") => true, - (Type::Instance(class), _) if class.is_stdlib_symbol(db, "builtins", "object") => false, + (_, Type::Instance(class)) if class.is_known(db, KnownClass::Object) => true, + (Type::Instance(class), _) if class.is_known(db, KnownClass::Object) => false, // TODO _ => false, } @@ -600,9 +590,9 @@ impl<'db> Type<'db> { fn call(self, db: &'db dyn Db, arg_types: &[Type<'db>]) -> CallOutcome<'db> { match self { // TODO validate typed call arguments vs callable signature - Type::Function(function_type) => match function_type.kind(db) { - FunctionKind::Ordinary => CallOutcome::callable(function_type.return_type(db)), - FunctionKind::RevealType => CallOutcome::revealed( + Type::Function(function_type) => match function_type.known(db) { + None => CallOutcome::callable(function_type.return_type(db)), + Some(KnownFunction::RevealType) => CallOutcome::revealed( function_type.return_type(db), *arg_types.first().unwrap_or(&Type::Unknown), ), @@ -610,16 +600,15 @@ impl<'db> Type<'db> { // TODO annotated return type on `__new__` or metaclass `__call__` Type::Class(class) => { - // If the class is the builtin-bool class (for example `bool(1)`), we try to return - // the specific truthiness value of the input arg, `Literal[True]` for the example above. - let is_bool = class.is_stdlib_symbol(db, "builtins", "bool"); - CallOutcome::callable(if is_bool { - arg_types + CallOutcome::callable(match class.known(db) { + // If the class is the builtin-bool class (for example `bool(1)`), we try to + // return the specific truthiness value of the input arg, `Literal[True]` for + // the example above. + Some(KnownClass::Bool) => arg_types .first() .map(|arg| arg.bool(db).into_type(db)) - .unwrap_or(Type::BooleanLiteral(false)) - } else { - Type::Instance(class) + .unwrap_or(Type::BooleanLiteral(false)), + _ => Type::Instance(class), }) } @@ -714,7 +703,7 @@ impl<'db> Type<'db> { let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__"); dunder_get_item_method - .call(db, &[self, builtins_symbol_ty(db, "int").to_instance(db)]) + .call(db, &[self, KnownClass::Int.to_instance(db)]) .return_ty(db) .map(|element_ty| IterationOutcome::Iterable { element_ty }) .unwrap_or(IterationOutcome::NotIterable { @@ -758,17 +747,17 @@ impl<'db> Type<'db> { Type::Never => Type::Never, Type::Instance(class) => Type::Class(*class), Type::Union(union) => union.map(db, |ty| ty.to_meta_type(db)), - Type::BooleanLiteral(_) => builtins_symbol_ty(db, "bool"), - Type::BytesLiteral(_) => builtins_symbol_ty(db, "bytes"), - Type::IntLiteral(_) => builtins_symbol_ty(db, "int"), - Type::Function(_) => types_symbol_ty(db, "FunctionType"), - Type::Module(_) => types_symbol_ty(db, "ModuleType"), - Type::Tuple(_) => builtins_symbol_ty(db, "tuple"), - Type::None => typeshed_symbol_ty(db, "NoneType"), + Type::BooleanLiteral(_) => KnownClass::Bool.to_class(db), + Type::BytesLiteral(_) => KnownClass::Bytes.to_class(db), + Type::IntLiteral(_) => KnownClass::Int.to_class(db), + Type::Function(_) => KnownClass::FunctionType.to_class(db), + Type::Module(_) => KnownClass::ModuleType.to_class(db), + Type::Tuple(_) => KnownClass::Tuple.to_class(db), + Type::None => KnownClass::NoneType.to_class(db), // TODO not accurate if there's a custom metaclass... - Type::Class(_) => builtins_symbol_ty(db, "type"), + Type::Class(_) => KnownClass::Type.to_class(db), // TODO can we do better here? `type[LiteralString]`? - Type::StringLiteral(_) | Type::LiteralString => builtins_symbol_ty(db, "str"), + Type::StringLiteral(_) | Type::LiteralString => KnownClass::Str.to_class(db), // TODO: `type[Any]`? Type::Any => Type::Todo, // TODO: `type[Unknown]`? @@ -790,7 +779,7 @@ impl<'db> Type<'db> { Type::IntLiteral(_) | Type::BooleanLiteral(_) => self.repr(db), Type::StringLiteral(_) | Type::LiteralString => *self, // TODO: handle more complex types - _ => Type::builtin_str_instance(db), + _ => KnownClass::Str.to_instance(db), } } @@ -813,7 +802,7 @@ impl<'db> Type<'db> { })), Type::LiteralString => Type::LiteralString, // TODO: handle more complex types - _ => Type::builtin_str_instance(db), + _ => KnownClass::Str.to_instance(db), } } } @@ -824,6 +813,133 @@ impl<'db> From<&Type<'db>> for Type<'db> { } } +/// Non-exhaustive enumeration of known classes (e.g. `builtins.int`, `typing.Any`, ...) to allow +/// for easier syntax when interacting with very common classes. +/// +/// Feel free to expand this enum if you ever find yourself using the same class in multiple +/// places. +/// Note: good candidates are any classes in `[crate::stdlib::CoreStdlibModule]` +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum KnownClass { + // To figure out where an stdlib symbol is defined, you can go into `crates/red_knot_vendored` + // and grep for the symbol name in any `.pyi` file. + + // Builtins + Bool, + Object, + Bytes, + Type, + Int, + Float, + Str, + List, + Tuple, + Set, + Dict, + // Types + ModuleType, + FunctionType, + // Typeshed + NoneType, // Part of `types` for Python >= 3.10 +} + +impl<'db> KnownClass { + pub const fn as_str(&self) -> &'static str { + match self { + Self::Bool => "bool", + Self::Object => "object", + Self::Bytes => "bytes", + Self::Tuple => "tuple", + Self::Int => "int", + Self::Float => "float", + Self::Str => "str", + Self::Set => "set", + Self::Dict => "dict", + Self::List => "list", + Self::Type => "type", + Self::ModuleType => "ModuleType", + Self::FunctionType => "FunctionType", + Self::NoneType => "NoneType", + } + } + + pub fn to_instance(&self, db: &'db dyn Db) -> Type<'db> { + self.to_class(db).to_instance(db) + } + + pub fn to_class(&self, db: &'db dyn Db) -> Type<'db> { + match self { + Self::Bool + | Self::Object + | Self::Bytes + | Self::Type + | Self::Int + | Self::Float + | Self::Str + | Self::List + | Self::Tuple + | Self::Set + | Self::Dict => builtins_symbol_ty(db, self.as_str()), + Self::ModuleType | Self::FunctionType => types_symbol_ty(db, self.as_str()), + Self::NoneType => typeshed_symbol_ty(db, self.as_str()), + } + } + + pub fn maybe_from_module(module: &Module, class_name: &str) -> Option { + let candidate = Self::from_name(class_name)?; + if candidate.check_module(module) { + Some(candidate) + } else { + None + } + } + + fn from_name(name: &str) -> Option { + // Note: if this becomes hard to maintain (as rust can't ensure at compile time that all + // variants of `Self` are covered), we might use a macro (in-house or dependency) + // See: https://stackoverflow.com/q/39070244 + match name { + "bool" => Some(Self::Bool), + "object" => Some(Self::Object), + "bytes" => Some(Self::Bytes), + "tuple" => Some(Self::Tuple), + "type" => Some(Self::Type), + "int" => Some(Self::Int), + "float" => Some(Self::Float), + "str" => Some(Self::Str), + "set" => Some(Self::Set), + "dict" => Some(Self::Dict), + "list" => Some(Self::List), + "NoneType" => Some(Self::NoneType), + "ModuleType" => Some(Self::ModuleType), + "FunctionType" => Some(Self::FunctionType), + _ => None, + } + } + + /// Private method checking if known class can be defined in the given module. + fn check_module(self, module: &Module) -> bool { + if !module.search_path().is_standard_library() { + return false; + } + match self { + Self::Bool + | Self::Object + | Self::Bytes + | Self::Type + | Self::Int + | Self::Float + | Self::Str + | Self::List + | Self::Tuple + | Self::Set + | Self::Dict => module.name() == "builtins", + Self::ModuleType | Self::FunctionType => module.name() == "types", + Self::NoneType => matches!(module.name().as_str(), "_typeshed" | "types"), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq)] enum CallOutcome<'db> { Callable { @@ -1128,7 +1244,7 @@ impl Truthiness { match self { Self::AlwaysTrue => Type::BooleanLiteral(true), Self::AlwaysFalse => Type::BooleanLiteral(false), - Self::Ambiguous => builtins_symbol_ty(db, "bool").to_instance(db), + Self::Ambiguous => KnownClass::Bool.to_instance(db), } } } @@ -1150,7 +1266,7 @@ pub struct FunctionType<'db> { pub name: ast::name::Name, /// Is this a function that we special-case somehow? If so, which one? - kind: FunctionKind, + known: Option, definition: Definition<'db>, @@ -1202,11 +1318,10 @@ impl<'db> FunctionType<'db> { } } -#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Hash)] -pub enum FunctionKind { - /// Just a normal function for which we have no particular special casing - #[default] - Ordinary, +/// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might +/// have special behavior. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum KnownFunction { /// `builtins.reveal_type`, `typing.reveal_type` or `typing_extensions.reveal_type` RevealType, } @@ -1220,9 +1335,18 @@ pub struct ClassType<'db> { definition: Definition<'db>, body_scope: ScopeId<'db>, + + known: Option, } impl<'db> ClassType<'db> { + pub fn is_known(self, db: &'db dyn Db, known_class: KnownClass) -> bool { + match self.known(db) { + Some(known) => known == known_class, + None => false, + } + } + /// Return true if this class is a standard library type with given module name and name. pub(crate) fn is_stdlib_symbol(self, db: &'db dyn Db, module_name: &str, name: &str) -> bool { name == self.name(db) diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 1ba6bc72c4cb6..3164e9e3592d6 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -25,10 +25,12 @@ //! * No type in an intersection can be a supertype of any other type in the intersection (just //! eliminate the supertype from the intersection). //! * An intersection containing two non-overlapping types should simplify to [`Type::Never`]. -use crate::types::{builtins_symbol_ty, IntersectionType, Type, UnionType}; +use crate::types::{IntersectionType, Type, UnionType}; use crate::{Db, FxOrderSet}; use smallvec::SmallVec; +use super::KnownClass; + pub(crate) struct UnionBuilder<'db> { elements: Vec>, db: &'db dyn Db, @@ -64,7 +66,7 @@ impl<'db> UnionBuilder<'db> { let mut to_remove = SmallVec::<[usize; 2]>::new(); for (index, element) in self.elements.iter().enumerate() { if Some(*element) == bool_pair { - to_add = builtins_symbol_ty(self.db, "bool"); + to_add = KnownClass::Bool.to_class(self.db); to_remove.push(index); // The type we are adding is a BooleanLiteral, which doesn't have any // subtypes. And we just found that the union already contained our @@ -300,7 +302,7 @@ mod tests { use crate::db::tests::TestDb; use crate::program::{Program, SearchPathSettings}; use crate::python_version::PythonVersion; - use crate::types::{builtins_symbol_ty, UnionBuilder}; + use crate::types::{KnownClass, UnionBuilder}; use crate::ProgramSettings; use ruff_db::system::{DbWithTestSystem, SystemPathBuf}; @@ -360,7 +362,7 @@ mod tests { #[test] fn build_union_bool() { let db = setup_db(); - let bool_ty = builtins_symbol_ty(&db, "bool"); + let bool_ty = KnownClass::Bool.to_class(&db); let t0 = Type::BooleanLiteral(true); let t1 = Type::BooleanLiteral(true); @@ -389,7 +391,7 @@ mod tests { #[test] fn build_union_simplify_subtype() { let db = setup_db(); - let t0 = Type::builtin_str_instance(&db); + let t0 = KnownClass::Str.to_instance(&db); let t1 = Type::LiteralString; let u0 = UnionType::from_elements(&db, [t0, t1]); let u1 = UnionType::from_elements(&db, [t1, t0]); @@ -401,7 +403,7 @@ mod tests { #[test] fn build_union_no_simplify_unknown() { let db = setup_db(); - let t0 = Type::builtin_str_instance(&db); + let t0 = KnownClass::Str.to_instance(&db); let t1 = Type::Unknown; let u0 = UnionType::from_elements(&db, [t0, t1]); let u1 = UnionType::from_elements(&db, [t1, t0]); @@ -413,9 +415,9 @@ mod tests { #[test] fn build_union_subsume_multiple() { let db = setup_db(); - let str_ty = Type::builtin_str_instance(&db); - let int_ty = Type::builtin_int_instance(&db); - let object_ty = builtins_symbol_ty(&db, "object").to_instance(&db); + let str_ty = KnownClass::Str.to_instance(&db); + let int_ty = KnownClass::Int.to_instance(&db); + let object_ty = KnownClass::Object.to_instance(&db); let unknown_ty = Type::Unknown; let u0 = UnionType::from_elements(&db, [str_ty, unknown_ty, int_ty, object_ty]); diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 690f73a696d87..eef0e9d16dcdd 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -51,11 +51,13 @@ use crate::stdlib::builtins_module_scope; use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics}; use crate::types::{ bindings_ty, builtins_symbol_ty, declarations_ty, global_symbol_ty, symbol_ty, - typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionKind, FunctionType, + typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionType, KnownFunction, StringLiteralType, Truthiness, TupleType, Type, TypeArrayDisplay, UnionType, }; use crate::Db; +use super::KnownClass; + /// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. /// Use when checking a scope, or needing to provide a type for an arbitrary expression in the /// scope. @@ -518,8 +520,8 @@ impl<'db> TypeInferenceBuilder<'db> { match left { Type::IntLiteral(_) => {} Type::Instance(cls) - if cls.is_stdlib_symbol(self.db, "builtins", "float") - || cls.is_stdlib_symbol(self.db, "builtins", "int") => {} + if cls.is_known(self.db, KnownClass::Float) + || cls.is_known(self.db, KnownClass::Int) => {} _ => return, }; @@ -749,8 +751,10 @@ impl<'db> TypeInferenceBuilder<'db> { } let function_kind = match &**name { - "reveal_type" if definition.is_typing_definition(self.db) => FunctionKind::RevealType, - _ => FunctionKind::Ordinary, + "reveal_type" if definition.is_typing_definition(self.db) => { + Some(KnownFunction::RevealType) + } + _ => None, }; let function_ty = Type::Function(FunctionType::new( self.db, @@ -861,11 +865,15 @@ impl<'db> TypeInferenceBuilder<'db> { .node_scope(NodeWithScopeRef::Class(class)) .to_scope_id(self.db, self.file); + let maybe_known_class = file_to_module(self.db, body_scope.file(self.db)) + .as_ref() + .and_then(|module| KnownClass::maybe_from_module(module, name.as_str())); let class_ty = Type::Class(ClassType::new( self.db, name.id.clone(), definition, body_scope, + maybe_known_class, )); self.add_declaration_with_binding(class.into(), definition, class_ty, class_ty); @@ -1708,8 +1716,8 @@ impl<'db> TypeInferenceBuilder<'db> { ast::Number::Int(n) => n .as_i64() .map(Type::IntLiteral) - .unwrap_or_else(|| Type::builtin_int_instance(self.db)), - ast::Number::Float(_) => builtins_symbol_ty(self.db, "float").to_instance(self.db), + .unwrap_or_else(|| KnownClass::Int.to_instance(self.db)), + ast::Number::Float(_) => KnownClass::Float.to_instance(self.db), ast::Number::Complex { .. } => { builtins_symbol_ty(self.db, "complex").to_instance(self.db) } @@ -1826,7 +1834,7 @@ impl<'db> TypeInferenceBuilder<'db> { } // TODO generic - builtins_symbol_ty(self.db, "list").to_instance(self.db) + KnownClass::List.to_instance(self.db) } fn infer_set_expression(&mut self, set: &ast::ExprSet) -> Type<'db> { @@ -1837,7 +1845,7 @@ impl<'db> TypeInferenceBuilder<'db> { } // TODO generic - builtins_symbol_ty(self.db, "set").to_instance(self.db) + KnownClass::Set.to_instance(self.db) } fn infer_dict_expression(&mut self, dict: &ast::ExprDict) -> Type<'db> { @@ -1849,7 +1857,7 @@ impl<'db> TypeInferenceBuilder<'db> { } // TODO generic - builtins_symbol_ty(self.db, "dict").to_instance(self.db) + KnownClass::Dict.to_instance(self.db) } /// Infer the type of the `iter` expression of the first comprehension. @@ -2347,31 +2355,31 @@ impl<'db> TypeInferenceBuilder<'db> { (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Add) => n .checked_add(m) .map(Type::IntLiteral) - .unwrap_or_else(|| Type::builtin_int_instance(self.db)), + .unwrap_or_else(|| KnownClass::Int.to_instance(self.db)), (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Sub) => n .checked_sub(m) .map(Type::IntLiteral) - .unwrap_or_else(|| Type::builtin_int_instance(self.db)), + .unwrap_or_else(|| KnownClass::Int.to_instance(self.db)), (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mult) => n .checked_mul(m) .map(Type::IntLiteral) - .unwrap_or_else(|| Type::builtin_int_instance(self.db)), + .unwrap_or_else(|| KnownClass::Int.to_instance(self.db)), (Type::IntLiteral(_), Type::IntLiteral(_), ast::Operator::Div) => { - builtins_symbol_ty(self.db, "float").to_instance(self.db) + KnownClass::Float.to_instance(self.db) } (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::FloorDiv) => n .checked_div(m) .map(Type::IntLiteral) - .unwrap_or_else(|| Type::builtin_int_instance(self.db)), + .unwrap_or_else(|| KnownClass::Int.to_instance(self.db)), (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mod) => n .checked_rem(m) .map(Type::IntLiteral) - .unwrap_or_else(|| Type::builtin_int_instance(self.db)), + .unwrap_or_else(|| KnownClass::Int.to_instance(self.db)), (Type::BytesLiteral(lhs), Type::BytesLiteral(rhs), ast::Operator::Add) => { Type::BytesLiteral(BytesLiteralType::new( @@ -2581,10 +2589,10 @@ impl<'db> TypeInferenceBuilder<'db> { ast::CmpOp::In | ast::CmpOp::NotIn => None, }, (Type::IntLiteral(_), Type::Instance(_)) => { - self.infer_binary_type_comparison(Type::builtin_int_instance(self.db), op, right) + self.infer_binary_type_comparison(KnownClass::Int.to_instance(self.db), op, right) } (Type::Instance(_), Type::IntLiteral(_)) => { - self.infer_binary_type_comparison(left, op, Type::builtin_int_instance(self.db)) + self.infer_binary_type_comparison(left, op, KnownClass::Int.to_instance(self.db)) } // Booleans are coded as integers (False = 0, True = 1) (Type::IntLiteral(n), Type::BooleanLiteral(b)) => self.infer_binary_type_comparison( @@ -3124,7 +3132,7 @@ impl StringPartsCollector { fn ty(self, db: &dyn Db) -> Type { if self.expression { - Type::builtin_str_instance(db) + KnownClass::Str.to_instance(db) } else if let Some(concatenated) = self.concatenated { Type::StringLiteral(StringLiteralType::new(db, concatenated.into_boxed_str())) } else {