Skip to content

Commit

Permalink
red-knot: infer multiplication for strings and integers (astral-sh#13117
Browse files Browse the repository at this point in the history
)

## Summary

The resulting type when multiplying a string literal by an integer
literal is one of two types:

- `StringLiteral`, in the case where it is a reasonably small resulting
string (arbitrarily bounded here to 4096 bytes, roughly a page on many
operating systems), including the fully expanded string.
- `LiteralString`, matching Pyright etc., for strings larger than that.

Additionally:

- Switch to using `Box<str>` instead of `String` for the internal value
of `StringLiteral`, saving some non-trivial byte overhead (and keeping
the total number of allocations the same).
- Be clearer and more accurate about which types we ought to defer to in
`StringLiteral` and `LiteralString` member lookup.

## Test Plan

Added a test case covering multiplication times integers: positive,
negative, zero, and in and out of bounds.

---------

Co-authored-by: Alex Waygood <alex.waygood@gmail.com>
Co-authored-by: Carl Meyer <carl@astral.sh>
  • Loading branch information
3 people authored Aug 27, 2024
1 parent 96b42b0 commit aba1802
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 8 deletions.
13 changes: 11 additions & 2 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ pub enum Type<'db> {
BooleanLiteral(bool),
/// A string literal
StringLiteral(StringLiteralType<'db>),
/// A string known to originate only from literal values, but whose value is not known (unlike
/// `StringLiteral` above).
LiteralString,
/// A bytes literal
BytesLiteral(BytesLiteralType<'db>),
// TODO protocols, callable types, overloads, generics, type vars
Expand Down Expand Up @@ -281,7 +284,13 @@ impl<'db> Type<'db> {
}
Type::BooleanLiteral(_) => Type::Unknown,
Type::StringLiteral(_) => {
// TODO defer to Type::Instance(<str from typeshed>).member
// TODO defer to `typing.LiteralString`/`builtins.str` methods
// from typeshed's stubs
Type::Unknown
}
Type::LiteralString => {
// TODO defer to `typing.LiteralString`/`builtins.str` methods
// from typeshed's stubs
Type::Unknown
}
Type::BytesLiteral(_) => {
Expand Down Expand Up @@ -387,7 +396,7 @@ pub struct IntersectionType<'db> {
#[salsa::interned]
pub struct StringLiteralType<'db> {
#[return_ref]
value: String,
value: Box<str>,
}

#[salsa::interned]
Expand Down
1 change: 1 addition & 0 deletions crates/red_knot_python_semantic/src/types/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ impl Display for DisplayType<'_> {
r#"Literal["{}"]"#,
string.value(self.db).replace('"', r#"\""#)
),
Type::LiteralString => write!(f, "LiteralString"),
Type::BytesLiteral(bytes) => {
let escape =
AsciiEscape::with_preferred_quote(bytes.value(self.db).as_ref(), Quote::Double);
Expand Down
82 changes: 76 additions & 6 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,12 @@ struct TypeInferenceBuilder<'db> {
}

impl<'db> TypeInferenceBuilder<'db> {
/// How big a string do we build before bailing?
///
/// This is a fairly arbitrary number. It should be *far* more than enough
/// for most use cases, but we can reevaluate it later if useful.
const MAX_STRING_LITERAL_SIZE: usize = 4096;

/// Creates a new builder for inferring types in a region.
pub(super) fn new(
db: &'db dyn Db,
Expand Down Expand Up @@ -1259,12 +1265,16 @@ impl<'db> TypeInferenceBuilder<'db> {
Type::BooleanLiteral(*value)
}

#[allow(clippy::unused_self)]
fn infer_string_literal_expression(&mut self, literal: &ast::ExprStringLiteral) -> Type<'db> {
Type::StringLiteral(StringLiteralType::new(self.db, literal.value.to_string()))
let value = if literal.value.len() <= Self::MAX_STRING_LITERAL_SIZE {
literal.value.to_str().into()
} else {
Box::default()
};

Type::StringLiteral(StringLiteralType::new(self.db, value))
}

#[allow(clippy::unused_self)]
fn infer_bytes_literal_expression(&mut self, literal: &ast::ExprBytesLiteral) -> Type<'db> {
// TODO: ignoring r/R prefixes for now, should normalize bytes values
Type::BytesLiteral(BytesLiteralType::new(
Expand Down Expand Up @@ -1787,11 +1797,30 @@ impl<'db> TypeInferenceBuilder<'db> {

(Type::StringLiteral(lhs), Type::StringLiteral(rhs), ast::Operator::Add) => {
Type::StringLiteral(StringLiteralType::new(self.db, {
let lhs_value = lhs.value(self.db);
let rhs_value = rhs.value(self.db);
lhs_value.clone() + rhs_value
let lhs_value = lhs.value(self.db).to_string();
let rhs_value = rhs.value(self.db).as_ref();
(lhs_value + rhs_value).into()
}))
}

(Type::StringLiteral(s), Type::IntLiteral(n), ast::Operator::Mult)
| (Type::IntLiteral(n), Type::StringLiteral(s), ast::Operator::Mult) => {
if n < 1 {
Type::StringLiteral(StringLiteralType::new(self.db, Box::default()))
} else if let Ok(n) = usize::try_from(n) {
if n.checked_mul(s.value(self.db).len())
.is_some_and(|new_length| new_length <= Self::MAX_STRING_LITERAL_SIZE)
{
let new_literal = s.value(self.db).repeat(n);
Type::StringLiteral(StringLiteralType::new(self.db, new_literal.into()))
} else {
Type::LiteralString
}
} else {
Type::LiteralString
}
}

_ => Type::Unknown, // TODO
}
}
Expand Down Expand Up @@ -1951,6 +1980,7 @@ enum ModuleNameResolutionError {

#[cfg(test)]
mod tests {

use anyhow::Context;

use ruff_db::files::{system_path_to_file, File};
Expand All @@ -1969,6 +1999,8 @@ mod tests {
use crate::types::{global_symbol_ty_by_name, infer_definition_types, symbol_ty_by_name, Type};
use crate::{HasTy, ProgramSettings, SemanticModel};

use super::TypeInferenceBuilder;

fn setup_db() -> TestDb {
let db = TestDb::new();

Expand Down Expand Up @@ -2378,6 +2410,44 @@ mod tests {
Ok(())
}

#[test]
fn multiplied_string() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
&format!(
r#"
w = 2 * "hello"
x = "goodbye" * 3
y = "a" * {y}
z = {z} * "b"
a = 0 * "hello"
b = -3 * "hello"
"#,
y = TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE,
z = TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE + 1
),
)?;

assert_public_ty(&db, "src/a.py", "w", r#"Literal["hellohello"]"#);
assert_public_ty(&db, "src/a.py", "x", r#"Literal["goodbyegoodbyegoodbye"]"#);
assert_public_ty(
&db,
"src/a.py",
"y",
&format!(
r#"Literal["{}"]"#,
"a".repeat(TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE)
),
);
assert_public_ty(&db, "src/a.py", "z", "LiteralString");
assert_public_ty(&db, "src/a.py", "a", r#"Literal[""]"#);
assert_public_ty(&db, "src/a.py", "b", r#"Literal[""]"#);

Ok(())
}

#[test]
fn bytes_type() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down

0 comments on commit aba1802

Please sign in to comment.