Skip to content

Commit

Permalink
[red-knot] more ergonomic and efficient handling of known builtin cla…
Browse files Browse the repository at this point in the history
…sses (astral-sh#13615)
  • Loading branch information
Slyces authored Oct 5, 2024
1 parent 7c5a7d9 commit 1c2cafc
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 74 deletions.
216 changes: 170 additions & 46 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::stdlib::{
builtins_symbol_ty, types_symbol_ty, typeshed_symbol_ty, typing_extensions_symbol_ty,
};
use crate::types::narrow::narrowing_constraint;
use crate::{Db, FxOrderSet};
use crate::{Db, FxOrderSet, Module};

pub(crate) use self::builder::{IntersectionBuilder, UnionBuilder};
pub(crate) use self::diagnostic::TypeCheckDiagnostics;
Expand Down Expand Up @@ -385,14 +385,6 @@ impl<'db> Type<'db> {
}
}

pub fn builtin_str_instance(db: &'db dyn Db) -> Self {
builtins_symbol_ty(db, "str").to_instance(db)
}

pub fn builtin_int_instance(db: &'db dyn Db) -> Self {
builtins_symbol_ty(db, "int").to_instance(db)
}

pub fn is_stdlib_symbol(&self, db: &'db dyn Db, module_name: &str, name: &str) -> bool {
match self {
Type::Class(class) => class.is_stdlib_symbol(db, module_name, name),
Expand Down Expand Up @@ -423,28 +415,26 @@ impl<'db> Type<'db> {
(_, Type::Unknown | Type::Any | Type::Todo) => false,
(Type::Never, _) => true,
(_, Type::Never) => false,
(Type::IntLiteral(_), Type::Instance(class))
if class.is_stdlib_symbol(db, "builtins", "int") =>
{
(Type::IntLiteral(_), Type::Instance(class)) if class.is_known(db, KnownClass::Int) => {
true
}
(Type::StringLiteral(_), Type::LiteralString) => true,
(Type::StringLiteral(_) | Type::LiteralString, Type::Instance(class))
if class.is_stdlib_symbol(db, "builtins", "str") =>
if class.is_known(db, KnownClass::Str) =>
{
true
}
(Type::BytesLiteral(_), Type::Instance(class))
if class.is_stdlib_symbol(db, "builtins", "bytes") =>
if class.is_known(db, KnownClass::Bytes) =>
{
true
}
(ty, Type::Union(union)) => union
.elements(db)
.iter()
.any(|&elem_ty| ty.is_subtype_of(db, elem_ty)),
(_, Type::Instance(class)) if class.is_stdlib_symbol(db, "builtins", "object") => true,
(Type::Instance(class), _) if class.is_stdlib_symbol(db, "builtins", "object") => false,
(_, Type::Instance(class)) if class.is_known(db, KnownClass::Object) => true,
(Type::Instance(class), _) if class.is_known(db, KnownClass::Object) => false,
// TODO
_ => false,
}
Expand Down Expand Up @@ -600,26 +590,25 @@ impl<'db> Type<'db> {
fn call(self, db: &'db dyn Db, arg_types: &[Type<'db>]) -> CallOutcome<'db> {
match self {
// TODO validate typed call arguments vs callable signature
Type::Function(function_type) => match function_type.kind(db) {
FunctionKind::Ordinary => CallOutcome::callable(function_type.return_type(db)),
FunctionKind::RevealType => CallOutcome::revealed(
Type::Function(function_type) => match function_type.known(db) {
None => CallOutcome::callable(function_type.return_type(db)),
Some(KnownFunction::RevealType) => CallOutcome::revealed(
function_type.return_type(db),
*arg_types.first().unwrap_or(&Type::Unknown),
),
},

// TODO annotated return type on `__new__` or metaclass `__call__`
Type::Class(class) => {
// If the class is the builtin-bool class (for example `bool(1)`), we try to return
// the specific truthiness value of the input arg, `Literal[True]` for the example above.
let is_bool = class.is_stdlib_symbol(db, "builtins", "bool");
CallOutcome::callable(if is_bool {
arg_types
CallOutcome::callable(match class.known(db) {
// If the class is the builtin-bool class (for example `bool(1)`), we try to
// return the specific truthiness value of the input arg, `Literal[True]` for
// the example above.
Some(KnownClass::Bool) => arg_types
.first()
.map(|arg| arg.bool(db).into_type(db))
.unwrap_or(Type::BooleanLiteral(false))
} else {
Type::Instance(class)
.unwrap_or(Type::BooleanLiteral(false)),
_ => Type::Instance(class),
})
}

Expand Down Expand Up @@ -714,7 +703,7 @@ impl<'db> Type<'db> {
let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__");

dunder_get_item_method
.call(db, &[self, builtins_symbol_ty(db, "int").to_instance(db)])
.call(db, &[self, KnownClass::Int.to_instance(db)])
.return_ty(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
Expand Down Expand Up @@ -758,17 +747,17 @@ impl<'db> Type<'db> {
Type::Never => Type::Never,
Type::Instance(class) => Type::Class(*class),
Type::Union(union) => union.map(db, |ty| ty.to_meta_type(db)),
Type::BooleanLiteral(_) => builtins_symbol_ty(db, "bool"),
Type::BytesLiteral(_) => builtins_symbol_ty(db, "bytes"),
Type::IntLiteral(_) => builtins_symbol_ty(db, "int"),
Type::Function(_) => types_symbol_ty(db, "FunctionType"),
Type::Module(_) => types_symbol_ty(db, "ModuleType"),
Type::Tuple(_) => builtins_symbol_ty(db, "tuple"),
Type::None => typeshed_symbol_ty(db, "NoneType"),
Type::BooleanLiteral(_) => KnownClass::Bool.to_class(db),
Type::BytesLiteral(_) => KnownClass::Bytes.to_class(db),
Type::IntLiteral(_) => KnownClass::Int.to_class(db),
Type::Function(_) => KnownClass::FunctionType.to_class(db),
Type::Module(_) => KnownClass::ModuleType.to_class(db),
Type::Tuple(_) => KnownClass::Tuple.to_class(db),
Type::None => KnownClass::NoneType.to_class(db),
// TODO not accurate if there's a custom metaclass...
Type::Class(_) => builtins_symbol_ty(db, "type"),
Type::Class(_) => KnownClass::Type.to_class(db),
// TODO can we do better here? `type[LiteralString]`?
Type::StringLiteral(_) | Type::LiteralString => builtins_symbol_ty(db, "str"),
Type::StringLiteral(_) | Type::LiteralString => KnownClass::Str.to_class(db),
// TODO: `type[Any]`?
Type::Any => Type::Todo,
// TODO: `type[Unknown]`?
Expand All @@ -790,7 +779,7 @@ impl<'db> Type<'db> {
Type::IntLiteral(_) | Type::BooleanLiteral(_) => self.repr(db),
Type::StringLiteral(_) | Type::LiteralString => *self,
// TODO: handle more complex types
_ => Type::builtin_str_instance(db),
_ => KnownClass::Str.to_instance(db),
}
}

Expand All @@ -813,7 +802,7 @@ impl<'db> Type<'db> {
})),
Type::LiteralString => Type::LiteralString,
// TODO: handle more complex types
_ => Type::builtin_str_instance(db),
_ => KnownClass::Str.to_instance(db),
}
}
}
Expand All @@ -824,6 +813,133 @@ impl<'db> From<&Type<'db>> for Type<'db> {
}
}

/// Non-exhaustive enumeration of known classes (e.g. `builtins.int`, `typing.Any`, ...) to allow
/// for easier syntax when interacting with very common classes.
///
/// Feel free to expand this enum if you ever find yourself using the same class in multiple
/// places.
/// Note: good candidates are any classes in `[crate::stdlib::CoreStdlibModule]`
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum KnownClass {
// To figure out where an stdlib symbol is defined, you can go into `crates/red_knot_vendored`
// and grep for the symbol name in any `.pyi` file.

// Builtins
Bool,
Object,
Bytes,
Type,
Int,
Float,
Str,
List,
Tuple,
Set,
Dict,
// Types
ModuleType,
FunctionType,
// Typeshed
NoneType, // Part of `types` for Python >= 3.10
}

impl<'db> KnownClass {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Bool => "bool",
Self::Object => "object",
Self::Bytes => "bytes",
Self::Tuple => "tuple",
Self::Int => "int",
Self::Float => "float",
Self::Str => "str",
Self::Set => "set",
Self::Dict => "dict",
Self::List => "list",
Self::Type => "type",
Self::ModuleType => "ModuleType",
Self::FunctionType => "FunctionType",
Self::NoneType => "NoneType",
}
}

pub fn to_instance(&self, db: &'db dyn Db) -> Type<'db> {
self.to_class(db).to_instance(db)
}

pub fn to_class(&self, db: &'db dyn Db) -> Type<'db> {
match self {
Self::Bool
| Self::Object
| Self::Bytes
| Self::Type
| Self::Int
| Self::Float
| Self::Str
| Self::List
| Self::Tuple
| Self::Set
| Self::Dict => builtins_symbol_ty(db, self.as_str()),
Self::ModuleType | Self::FunctionType => types_symbol_ty(db, self.as_str()),
Self::NoneType => typeshed_symbol_ty(db, self.as_str()),
}
}

pub fn maybe_from_module(module: &Module, class_name: &str) -> Option<Self> {
let candidate = Self::from_name(class_name)?;
if candidate.check_module(module) {
Some(candidate)
} else {
None
}
}

fn from_name(name: &str) -> Option<Self> {
// Note: if this becomes hard to maintain (as rust can't ensure at compile time that all
// variants of `Self` are covered), we might use a macro (in-house or dependency)
// See: https://stackoverflow.com/q/39070244
match name {
"bool" => Some(Self::Bool),
"object" => Some(Self::Object),
"bytes" => Some(Self::Bytes),
"tuple" => Some(Self::Tuple),
"type" => Some(Self::Type),
"int" => Some(Self::Int),
"float" => Some(Self::Float),
"str" => Some(Self::Str),
"set" => Some(Self::Set),
"dict" => Some(Self::Dict),
"list" => Some(Self::List),
"NoneType" => Some(Self::NoneType),
"ModuleType" => Some(Self::ModuleType),
"FunctionType" => Some(Self::FunctionType),
_ => None,
}
}

/// Private method checking if known class can be defined in the given module.
fn check_module(self, module: &Module) -> bool {
if !module.search_path().is_standard_library() {
return false;
}
match self {
Self::Bool
| Self::Object
| Self::Bytes
| Self::Type
| Self::Int
| Self::Float
| Self::Str
| Self::List
| Self::Tuple
| Self::Set
| Self::Dict => module.name() == "builtins",
Self::ModuleType | Self::FunctionType => module.name() == "types",
Self::NoneType => matches!(module.name().as_str(), "_typeshed" | "types"),
}
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
enum CallOutcome<'db> {
Callable {
Expand Down Expand Up @@ -1128,7 +1244,7 @@ impl Truthiness {
match self {
Self::AlwaysTrue => Type::BooleanLiteral(true),
Self::AlwaysFalse => Type::BooleanLiteral(false),
Self::Ambiguous => builtins_symbol_ty(db, "bool").to_instance(db),
Self::Ambiguous => KnownClass::Bool.to_instance(db),
}
}
}
Expand All @@ -1150,7 +1266,7 @@ pub struct FunctionType<'db> {
pub name: ast::name::Name,

/// Is this a function that we special-case somehow? If so, which one?
kind: FunctionKind,
known: Option<KnownFunction>,

definition: Definition<'db>,

Expand Down Expand Up @@ -1202,11 +1318,10 @@ impl<'db> FunctionType<'db> {
}
}

#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Hash)]
pub enum FunctionKind {
/// Just a normal function for which we have no particular special casing
#[default]
Ordinary,
/// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might
/// have special behavior.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum KnownFunction {
/// `builtins.reveal_type`, `typing.reveal_type` or `typing_extensions.reveal_type`
RevealType,
}
Expand All @@ -1220,9 +1335,18 @@ pub struct ClassType<'db> {
definition: Definition<'db>,

body_scope: ScopeId<'db>,

known: Option<KnownClass>,
}

impl<'db> ClassType<'db> {
pub fn is_known(self, db: &'db dyn Db, known_class: KnownClass) -> bool {
match self.known(db) {
Some(known) => known == known_class,
None => false,
}
}

/// Return true if this class is a standard library type with given module name and name.
pub(crate) fn is_stdlib_symbol(self, db: &'db dyn Db, module_name: &str, name: &str) -> bool {
name == self.name(db)
Expand Down
Loading

0 comments on commit 1c2cafc

Please sign in to comment.