From 8108f8381056641543c0255ea5b54d650d1c5332 Mon Sep 17 00:00:00 2001 From: Simon Date: Sat, 5 Oct 2024 21:22:30 +0200 Subject: [PATCH] [red-knot] feat: add `StringLiteral` and `LiteralString` comparison (#13634) ## Summary Implements string literal comparisons and fallbacks to `str` instance for `LiteralString`. Completes an item in #13618 ## Test Plan - Adds a dedicated test with non exhaustive cases --------- Co-authored-by: Alex Waygood --- .../src/types/infer.rs | 88 +++++++++++++++++-- 1 file changed, 83 insertions(+), 5 deletions(-) diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 9d0583c260c6c..3de456dded065 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2535,9 +2535,7 @@ impl<'db> TypeInferenceBuilder<'db> { ast::CmpOp::In | ast::CmpOp::NotIn | ast::CmpOp::Is - | ast::CmpOp::IsNot => { - builtins_symbol_ty(self.db, "bool").to_instance(self.db) - } + | ast::CmpOp::IsNot => KnownClass::Bool.to_instance(self.db), // Other operators can return arbitrary types _ => Type::Unknown, } @@ -2573,14 +2571,14 @@ impl<'db> TypeInferenceBuilder<'db> { ast::CmpOp::GtE => Some(Type::BooleanLiteral(n >= m)), ast::CmpOp::Is => { if n == m { - Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db)) + Some(KnownClass::Bool.to_instance(self.db)) } else { Some(Type::BooleanLiteral(false)) } } ast::CmpOp::IsNot => { if n == m { - Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db)) + Some(KnownClass::Bool.to_instance(self.db)) } else { Some(Type::BooleanLiteral(true)) } @@ -2594,6 +2592,7 @@ impl<'db> TypeInferenceBuilder<'db> { (Type::Instance(_), Type::IntLiteral(_)) => { self.infer_binary_type_comparison(left, op, KnownClass::Int.to_instance(self.db)) } + // Booleans are coded as integers (False = 0, True = 1) (Type::IntLiteral(n), Type::BooleanLiteral(b)) => self.infer_binary_type_comparison( Type::IntLiteral(n), @@ -2611,6 +2610,49 @@ impl<'db> TypeInferenceBuilder<'db> { op, Type::IntLiteral(i64::from(b)), ), + + (Type::StringLiteral(salsa_s1), Type::StringLiteral(salsa_s2)) => { + let s1 = salsa_s1.value(self.db); + let s2 = salsa_s2.value(self.db); + match op { + ast::CmpOp::Eq => Some(Type::BooleanLiteral(s1 == s2)), + ast::CmpOp::NotEq => Some(Type::BooleanLiteral(s1 != s2)), + ast::CmpOp::Lt => Some(Type::BooleanLiteral(s1 < s2)), + ast::CmpOp::LtE => Some(Type::BooleanLiteral(s1 <= s2)), + ast::CmpOp::Gt => Some(Type::BooleanLiteral(s1 > s2)), + ast::CmpOp::GtE => Some(Type::BooleanLiteral(s1 >= s2)), + ast::CmpOp::In => Some(Type::BooleanLiteral(s2.contains(s1.as_ref()))), + ast::CmpOp::NotIn => Some(Type::BooleanLiteral(!s2.contains(s1.as_ref()))), + ast::CmpOp::Is => { + if s1 == s2 { + Some(KnownClass::Bool.to_instance(self.db)) + } else { + Some(Type::BooleanLiteral(false)) + } + } + ast::CmpOp::IsNot => { + if s1 == s2 { + Some(KnownClass::Bool.to_instance(self.db)) + } else { + Some(Type::BooleanLiteral(true)) + } + } + } + } + (Type::StringLiteral(_), _) => { + self.infer_binary_type_comparison(KnownClass::Str.to_instance(self.db), op, right) + } + (_, Type::StringLiteral(_)) => { + self.infer_binary_type_comparison(left, op, KnownClass::Str.to_instance(self.db)) + } + + (Type::LiteralString, _) => { + self.infer_binary_type_comparison(KnownClass::Str.to_instance(self.db), op, right) + } + (_, Type::LiteralString) => { + self.infer_binary_type_comparison(left, op, KnownClass::Str.to_instance(self.db)) + } + // Lookup the rich comparison `__dunder__` methods on instances (Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op { ast::CmpOp::Lt => { @@ -4110,6 +4152,42 @@ mod tests { Ok(()) } + #[test] + fn comparison_string_literals() -> anyhow::Result<()> { + let mut db = setup_db(); + db.write_dedented( + "src/a.py", + r#" + def str_instance() -> str: ... + a = "abc" == "abc" + b = "ab_cd" <= "ab_ce" + c = "abc" in "ab cd" + d = "" not in "hello" + e = "--" is "--" + f = "A" is "B" + g = "--" is not "--" + h = "A" is not "B" + i = str_instance() < "..." + j = "ab" < "ab_cd" + "#, + )?; + + assert_public_ty(&db, "src/a.py", "a", "Literal[True]"); + assert_public_ty(&db, "src/a.py", "b", "Literal[True]"); + assert_public_ty(&db, "src/a.py", "c", "Literal[False]"); + assert_public_ty(&db, "src/a.py", "d", "Literal[False]"); + assert_public_ty(&db, "src/a.py", "e", "bool"); + assert_public_ty(&db, "src/a.py", "f", "Literal[False]"); + assert_public_ty(&db, "src/a.py", "g", "bool"); + assert_public_ty(&db, "src/a.py", "h", "Literal[True]"); + assert_public_ty(&db, "src/a.py", "i", "bool"); + // Very cornercase test ensuring we're not comparing the interned salsa symbols, which + // compare by order of declaration + assert_public_ty(&db, "src/a.py", "j", "Literal[True]"); + + Ok(()) + } + #[test] fn comparison_unsupported_operators() -> anyhow::Result<()> { let mut db = setup_db();