Skip to content

Commit

Permalink
[red-knot] feat: add StringLiteral and LiteralString comparison (#…
Browse files Browse the repository at this point in the history
…13634)

## Summary

Implements string literal comparisons and fallbacks to `str` instance
for `LiteralString`.
Completes an item in #13618

## Test Plan

- Adds a dedicated test with non exhaustive cases

---------

Co-authored-by: Alex Waygood <[email protected]>
  • Loading branch information
Slyces and AlexWaygood authored Oct 5, 2024
1 parent f120517 commit 8108f83
Showing 1 changed file with 83 additions and 5 deletions.
88 changes: 83 additions & 5 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2535,9 +2535,7 @@ impl<'db> TypeInferenceBuilder<'db> {
ast::CmpOp::In
| ast::CmpOp::NotIn
| ast::CmpOp::Is
| ast::CmpOp::IsNot => {
builtins_symbol_ty(self.db, "bool").to_instance(self.db)
}
| ast::CmpOp::IsNot => KnownClass::Bool.to_instance(self.db),
// Other operators can return arbitrary types
_ => Type::Unknown,
}
Expand Down Expand Up @@ -2573,14 +2571,14 @@ impl<'db> TypeInferenceBuilder<'db> {
ast::CmpOp::GtE => Some(Type::BooleanLiteral(n >= m)),
ast::CmpOp::Is => {
if n == m {
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
Some(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(false))
}
}
ast::CmpOp::IsNot => {
if n == m {
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
Some(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(true))
}
Expand All @@ -2594,6 +2592,7 @@ impl<'db> TypeInferenceBuilder<'db> {
(Type::Instance(_), Type::IntLiteral(_)) => {
self.infer_binary_type_comparison(left, op, KnownClass::Int.to_instance(self.db))
}

// Booleans are coded as integers (False = 0, True = 1)
(Type::IntLiteral(n), Type::BooleanLiteral(b)) => self.infer_binary_type_comparison(
Type::IntLiteral(n),
Expand All @@ -2611,6 +2610,49 @@ impl<'db> TypeInferenceBuilder<'db> {
op,
Type::IntLiteral(i64::from(b)),
),

(Type::StringLiteral(salsa_s1), Type::StringLiteral(salsa_s2)) => {
let s1 = salsa_s1.value(self.db);
let s2 = salsa_s2.value(self.db);
match op {
ast::CmpOp::Eq => Some(Type::BooleanLiteral(s1 == s2)),
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(s1 != s2)),
ast::CmpOp::Lt => Some(Type::BooleanLiteral(s1 < s2)),
ast::CmpOp::LtE => Some(Type::BooleanLiteral(s1 <= s2)),
ast::CmpOp::Gt => Some(Type::BooleanLiteral(s1 > s2)),
ast::CmpOp::GtE => Some(Type::BooleanLiteral(s1 >= s2)),
ast::CmpOp::In => Some(Type::BooleanLiteral(s2.contains(s1.as_ref()))),
ast::CmpOp::NotIn => Some(Type::BooleanLiteral(!s2.contains(s1.as_ref()))),
ast::CmpOp::Is => {
if s1 == s2 {
Some(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(false))
}
}
ast::CmpOp::IsNot => {
if s1 == s2 {
Some(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(true))
}
}
}
}
(Type::StringLiteral(_), _) => {
self.infer_binary_type_comparison(KnownClass::Str.to_instance(self.db), op, right)
}
(_, Type::StringLiteral(_)) => {
self.infer_binary_type_comparison(left, op, KnownClass::Str.to_instance(self.db))
}

(Type::LiteralString, _) => {
self.infer_binary_type_comparison(KnownClass::Str.to_instance(self.db), op, right)
}
(_, Type::LiteralString) => {
self.infer_binary_type_comparison(left, op, KnownClass::Str.to_instance(self.db))
}

// Lookup the rich comparison `__dunder__` methods on instances
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
ast::CmpOp::Lt => {
Expand Down Expand Up @@ -4110,6 +4152,42 @@ mod tests {
Ok(())
}

#[test]
fn comparison_string_literals() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
r#"
def str_instance() -> str: ...
a = "abc" == "abc"
b = "ab_cd" <= "ab_ce"
c = "abc" in "ab cd"
d = "" not in "hello"
e = "--" is "--"
f = "A" is "B"
g = "--" is not "--"
h = "A" is not "B"
i = str_instance() < "..."
j = "ab" < "ab_cd"
"#,
)?;

assert_public_ty(&db, "src/a.py", "a", "Literal[True]");
assert_public_ty(&db, "src/a.py", "b", "Literal[True]");
assert_public_ty(&db, "src/a.py", "c", "Literal[False]");
assert_public_ty(&db, "src/a.py", "d", "Literal[False]");
assert_public_ty(&db, "src/a.py", "e", "bool");
assert_public_ty(&db, "src/a.py", "f", "Literal[False]");
assert_public_ty(&db, "src/a.py", "g", "bool");
assert_public_ty(&db, "src/a.py", "h", "Literal[True]");
assert_public_ty(&db, "src/a.py", "i", "bool");
// Very cornercase test ensuring we're not comparing the interned salsa symbols, which
// compare by order of declaration
assert_public_ty(&db, "src/a.py", "j", "Literal[True]");

Ok(())
}

#[test]
fn comparison_unsupported_operators() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down

0 comments on commit 8108f83

Please sign in to comment.