From 6964eef369b0a19e8fbc758e22ab8b756b319381 Mon Sep 17 00:00:00 2001 From: David Peter Date: Fri, 18 Oct 2024 23:34:43 +0200 Subject: [PATCH] [red knot] add `Type::is_disjoint_from` and intersection simplifications (#13775) ## Summary - Add `Type::is_disjoint_from` as a way to test whether two types overlap - Add a first set of simplification rules for intersection types - `S & T = S` for `S <: T` - `S & ~T = Never` for `S <: T` - `~S & ~T = ~T` for `S <: T` - `A & ~B = A` for `A` disjoint from `B` - `A & B = Never` for `A` disjoint from `B` - `bool & ~Literal[bool] = Literal[!bool]` resolves one item in #12694 ## Open questions: - Can we somehow leverage the (anti) symmetry between `positive` and `negative` contributions? I could imagine that there would be a way if we had `Type::Not(type)`/`Type::Negative(type)`, but with the `positive`/`negative` architecture, I'm not sure. Note that there is a certain duplication in the `add_positive`/`add_negative` functions (e.g. `S & ~T = Never` is implemented twice), but other rules are actually not perfectly symmetric: `S & T = S` vs `~S & ~T = ~T`. - I'm not particularly proud of the way `add_positive`/`add_negative` turned out. They are long imperative-style functions with some mutability mixed in (`to_remove`). I'm happy to look into ways to improve this code *if we decide to go with this approach* of implementing a set of ad-hoc rules for simplification. - ~~Is it useful to perform simplifications eagerly in `add_positive`/`add_negative`? (@carljm)~~ This is what I did for now. ## Test Plan - Unit tests for `Type::is_disjoint_from` - Observe changes in Markdown-based tests - Unit tests for `IntersectionBuilder::build()` --------- Co-authored-by: Carl Meyer --- .../mdtest/narrow/conditionals_is.md | 6 +- .../mdtest/narrow/conditionals_is_not.md | 3 +- .../resources/mdtest/narrow/match.md | 3 +- crates/red_knot_python_semantic/src/types.rs | 246 +++++++++- .../src/types/builder.rs | 431 +++++++++++++++--- 5 files changed, 626 insertions(+), 63 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is.md index 1f51771dc035c..d215be4bc2995 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is.md @@ -6,8 +6,7 @@ x = None if flag else 1 if x is None: - # TODO the following should be simplified to 'None' - reveal_type(x) # revealed: None | Literal[1] & None + reveal_type(x) # revealed: None reveal_type(x) # revealed: None | Literal[1] ``` @@ -22,8 +21,7 @@ x = A() y = x if flag else None if y is x: - # TODO the following should be simplified to 'A' - reveal_type(y) # revealed: A | None & A + reveal_type(y) # revealed: A reveal_type(y) # revealed: A | None ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is_not.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is_not.md index b1c75d053c1d3..dc094096a94a5 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is_not.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is_not.md @@ -20,8 +20,7 @@ x = True if flag else False reveal_type(x) # revealed: bool if x is not False: - # TODO the following should be `Literal[True]` - reveal_type(x) # revealed: bool & ~Literal[False] + reveal_type(x) # revealed: Literal[True] ``` ## `is not` for non-singleton types diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md index b3218d2c6ed1b..0c8ea0e363cbe 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md @@ -12,6 +12,5 @@ match x: case None: y = x -# TODO intersection simplification: should be just Literal[0] | None -reveal_type(y) # revealed: Literal[0] | None | Literal[1] & None +reveal_type(y) # revealed: Literal[0] | None ``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 070bcbe127029..8f5412ef4373f 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -471,6 +471,150 @@ impl<'db> Type<'db> { self == other } + /// Return true if this type and `other` have no common elements. + /// + /// Note: This function aims to have no false positives, but might return + /// wrong `false` answers in some cases. + pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool { + match (self, other) { + (Type::Never, _) | (_, Type::Never) => true, + + (Type::Any, _) | (_, Type::Any) => false, + (Type::Unknown, _) | (_, Type::Unknown) => false, + (Type::Unbound, _) | (_, Type::Unbound) => false, + (Type::Todo, _) | (_, Type::Todo) => false, + + (Type::Union(union), other) | (other, Type::Union(union)) => union + .elements(db) + .iter() + .all(|e| e.is_disjoint_from(db, other)), + + (Type::Intersection(intersection), other) + | (other, Type::Intersection(intersection)) => { + if intersection + .positive(db) + .iter() + .any(|p| p.is_disjoint_from(db, other)) + { + true + } else { + // TODO we can do better here. For example: + // X & ~Literal[1] is disjoint from Literal[1] + false + } + } + + ( + left @ (Type::None + | Type::BooleanLiteral(..) + | Type::IntLiteral(..) + | Type::StringLiteral(..) + | Type::BytesLiteral(..) + | Type::Function(..) + | Type::Module(..) + | Type::Class(..)), + right @ (Type::None + | Type::BooleanLiteral(..) + | Type::IntLiteral(..) + | Type::StringLiteral(..) + | Type::BytesLiteral(..) + | Type::Function(..) + | Type::Module(..) + | Type::Class(..)), + ) => left != right, + + (Type::None, Type::Instance(class_type)) | (Type::Instance(class_type), Type::None) => { + !matches!( + class_type.known(db), + Some(KnownClass::NoneType | KnownClass::Object) + ) + } + (Type::None, _) | (_, Type::None) => true, + + (Type::BooleanLiteral(..), Type::Instance(class_type)) + | (Type::Instance(class_type), Type::BooleanLiteral(..)) => !matches!( + class_type.known(db), + Some(KnownClass::Bool | KnownClass::Int | KnownClass::Object) + ), + (Type::BooleanLiteral(..), _) | (_, Type::BooleanLiteral(..)) => true, + + (Type::IntLiteral(..), Type::Instance(class_type)) + | (Type::Instance(class_type), Type::IntLiteral(..)) => !matches!( + class_type.known(db), + Some(KnownClass::Int | KnownClass::Object) + ), + (Type::IntLiteral(..), _) | (_, Type::IntLiteral(..)) => true, + + (Type::StringLiteral(..), Type::LiteralString) + | (Type::LiteralString, Type::StringLiteral(..)) => false, + (Type::StringLiteral(..), Type::Instance(class_type)) + | (Type::Instance(class_type), Type::StringLiteral(..)) => !matches!( + class_type.known(db), + Some(KnownClass::Str | KnownClass::Object) + ), + (Type::StringLiteral(..), _) | (_, Type::StringLiteral(..)) => true, + + (Type::LiteralString, Type::LiteralString) => false, + (Type::LiteralString, Type::Instance(class_type)) + | (Type::Instance(class_type), Type::LiteralString) => !matches!( + class_type.known(db), + Some(KnownClass::Str | KnownClass::Object) + ), + (Type::LiteralString, _) | (_, Type::LiteralString) => true, + + (Type::BytesLiteral(..), Type::Instance(class_type)) + | (Type::Instance(class_type), Type::BytesLiteral(..)) => !matches!( + class_type.known(db), + Some(KnownClass::Bytes | KnownClass::Object) + ), + (Type::BytesLiteral(..), _) | (_, Type::BytesLiteral(..)) => true, + + ( + Type::Function(..) | Type::Module(..) | Type::Class(..), + Type::Instance(class_type), + ) + | ( + Type::Instance(class_type), + Type::Function(..) | Type::Module(..) | Type::Class(..), + ) => !class_type.is_known(db, KnownClass::Object), + + (Type::Instance(..), Type::Instance(..)) => { + // TODO: once we have support for `final`, there might be some cases where + // we can determine that two types are disjoint. For non-final classes, we + // return false (multiple inheritance). + + // TODO: is there anything specific to do for instances of KnownClass::Type? + + false + } + + (Type::Tuple(tuple), other) | (other, Type::Tuple(tuple)) => { + if let Type::Tuple(other_tuple) = other { + if tuple.len(db) == other_tuple.len(db) { + tuple + .elements(db) + .iter() + .zip(other_tuple.elements(db).iter()) + .any(|(e1, e2)| e1.is_disjoint_from(db, *e2)) + } else { + true + } + } else { + // We can not be sure if the tuple is disjoint from 'other' because: + // - 'other' might be the homogeneous arbitrary-length tuple type + // tuple[T, ...] (which we don't have support for yet); if all of + // our element types are not disjoint with T, this is not disjoint + // - 'other' might be a user subtype of tuple, which, if generic + // over the same or compatible *Ts, would overlap with tuple. + // + // TODO: add checks for the above cases once we support them + + false + } + } + } + } + /// Return true if there is just a single inhabitant for this type. /// /// Note: This function aims to have no false positives, but might return `false` @@ -1558,8 +1702,8 @@ impl<'db> TupleType<'db> { #[cfg(test)] mod tests { use super::{ - builtins_symbol_ty, BytesLiteralType, StringLiteralType, Truthiness, TupleType, Type, - UnionType, + builtins_symbol_ty, BytesLiteralType, IntersectionBuilder, StringLiteralType, Truthiness, + TupleType, Type, UnionType, }; use crate::db::tests::TestDb; use crate::program::{Program, SearchPathSettings}; @@ -1603,6 +1747,7 @@ mod tests { BytesLiteral(&'static str), BuiltinInstance(&'static str), Union(Vec), + Intersection { pos: Vec, neg: Vec }, Tuple(Vec), } @@ -1622,6 +1767,16 @@ mod tests { Ty::Union(tys) => { UnionType::from_elements(db, tys.into_iter().map(|ty| ty.into_type(db))) } + Ty::Intersection { pos, neg } => { + let mut builder = IntersectionBuilder::new(db); + for p in pos { + builder = builder.add_positive(p.into_type(db)); + } + for n in neg { + builder = builder.add_negative(n.into_type(db)); + } + builder.build() + } Ty::Tuple(tys) => { let elements: Box<_> = tys.into_iter().map(|ty| ty.into_type(db)).collect(); Type::Tuple(TupleType::new(db, elements)) @@ -1697,6 +1852,93 @@ mod tests { assert!(from.into_type(&db).is_equivalent_to(&db, to.into_type(&db))); } + #[test_case(Ty::Never, Ty::Never)] + #[test_case(Ty::Never, Ty::None)] + #[test_case(Ty::Never, Ty::BuiltinInstance("int"))] + #[test_case(Ty::None, Ty::BoolLiteral(true))] + #[test_case(Ty::None, Ty::IntLiteral(1))] + #[test_case(Ty::None, Ty::StringLiteral("test"))] + #[test_case(Ty::None, Ty::BytesLiteral("test"))] + #[test_case(Ty::None, Ty::LiteralString)] + #[test_case(Ty::None, Ty::BuiltinInstance("int"))] + #[test_case(Ty::None, Ty::Tuple(vec![Ty::None]))] + #[test_case(Ty::BoolLiteral(true), Ty::BoolLiteral(false))] + #[test_case(Ty::BoolLiteral(true), Ty::Tuple(vec![Ty::None]))] + #[test_case(Ty::BoolLiteral(true), Ty::IntLiteral(1))] + #[test_case(Ty::BoolLiteral(false), Ty::IntLiteral(0))] + #[test_case(Ty::IntLiteral(1), Ty::IntLiteral(2))] + #[test_case(Ty::IntLiteral(1), Ty::Tuple(vec![Ty::None]))] + #[test_case(Ty::StringLiteral("a"), Ty::StringLiteral("b"))] + #[test_case(Ty::StringLiteral("a"), Ty::Tuple(vec![Ty::None]))] + #[test_case(Ty::LiteralString, Ty::BytesLiteral("a"))] + #[test_case(Ty::BytesLiteral("a"), Ty::BytesLiteral("b"))] + #[test_case(Ty::BytesLiteral("a"), Ty::Tuple(vec![Ty::None]))] + #[test_case(Ty::BytesLiteral("a"), Ty::StringLiteral("a"))] + #[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::IntLiteral(3))] + #[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Union(vec![Ty::IntLiteral(3), Ty::IntLiteral(4)]))] + #[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int"), Ty::IntLiteral(1)], neg: vec![]}, Ty::IntLiteral(2))] + #[test_case(Ty::Tuple(vec![Ty::IntLiteral(1)]), Ty::Tuple(vec![Ty::IntLiteral(2)]))] + #[test_case(Ty::Tuple(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Tuple(vec![Ty::IntLiteral(1)]))] + #[test_case(Ty::Tuple(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Tuple(vec![Ty::IntLiteral(1), Ty::IntLiteral(3)]))] + fn is_disjoint_from(a: Ty, b: Ty) { + let db = setup_db(); + let a = a.into_type(&db); + let b = b.into_type(&db); + + assert!(a.is_disjoint_from(&db, b)); + assert!(b.is_disjoint_from(&db, a)); + } + + #[test_case(Ty::Any, Ty::BuiltinInstance("int"))] + #[test_case(Ty::None, Ty::None)] + #[test_case(Ty::None, Ty::BuiltinInstance("object"))] + #[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinInstance("int"))] + #[test_case(Ty::BuiltinInstance("str"), Ty::LiteralString)] + #[test_case(Ty::BoolLiteral(true), Ty::BoolLiteral(true))] + #[test_case(Ty::BoolLiteral(false), Ty::BoolLiteral(false))] + #[test_case(Ty::BoolLiteral(true), Ty::BuiltinInstance("bool"))] + #[test_case(Ty::BoolLiteral(true), Ty::BuiltinInstance("int"))] + #[test_case(Ty::IntLiteral(1), Ty::IntLiteral(1))] + #[test_case(Ty::StringLiteral("a"), Ty::StringLiteral("a"))] + #[test_case(Ty::StringLiteral("a"), Ty::LiteralString)] + #[test_case(Ty::StringLiteral("a"), Ty::BuiltinInstance("str"))] + #[test_case(Ty::LiteralString, Ty::LiteralString)] + #[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::IntLiteral(2))] + #[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Union(vec![Ty::IntLiteral(2), Ty::IntLiteral(3)]))] + #[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int"), Ty::IntLiteral(2)], neg: vec![]}, Ty::IntLiteral(2))] + #[test_case(Ty::Tuple(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Tuple(vec![Ty::IntLiteral(1), Ty::BuiltinInstance("int")]))] + fn is_not_disjoint_from(a: Ty, b: Ty) { + let db = setup_db(); + let a = a.into_type(&db); + let b = b.into_type(&db); + + assert!(!a.is_disjoint_from(&db, b)); + assert!(!b.is_disjoint_from(&db, a)); + } + + #[test] + fn is_disjoint_from_union_of_class_types() { + let mut db = setup_db(); + db.write_dedented( + "/src/module.py", + " + class A: ... + class B: ... + x = A if flag else B + ", + ) + .unwrap(); + let module = ruff_db::files::system_path_to_file(&db, "/src/module.py").unwrap(); + + let type_a = super::global_symbol_ty(&db, module, "A"); + let type_x = super::global_symbol_ty(&db, module, "x"); + + assert!(matches!(type_a, Type::Class(_))); + assert!(matches!(type_x, Type::Union(_))); + + assert!(!type_a.is_disjoint_from(&db, type_x)); + } + #[test_case(Ty::None)] #[test_case(Ty::BoolLiteral(true))] #[test_case(Ty::BoolLiteral(false))] diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 013e6988cccbc..4cf17ba4f6ddc 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -216,74 +216,140 @@ impl<'db> InnerIntersectionBuilder<'db> { } /// Adds a positive type to this intersection. - fn add_positive(&mut self, db: &'db dyn Db, ty: Type<'db>) { - // TODO `Any`/`Unknown`/`Todo` actually should not self-cancel - match ty { - Type::Intersection(inter) => { - let pos = inter.positive(db); - let neg = inter.negative(db); - self.positive.extend(pos.difference(&self.negative)); - self.negative.extend(neg.difference(&self.positive)); - self.positive.retain(|elem| !neg.contains(elem)); - self.negative.retain(|elem| !pos.contains(elem)); + fn add_positive(&mut self, db: &'db dyn Db, new_positive: Type<'db>) { + if let Type::Intersection(other) = new_positive { + for pos in other.positive(db) { + self.add_positive(db, *pos); } - _ => { - if !self.negative.remove(&ty) { - self.positive.insert(ty); - }; + for neg in other.negative(db) { + self.add_negative(db, *neg); + } + } else { + // ~Literal[True] & bool = Literal[False] + if let Type::Instance(class_type) = new_positive { + if class_type.is_known(db, KnownClass::Bool) { + if let Some(&Type::BooleanLiteral(value)) = self + .negative + .iter() + .find(|element| matches!(element, Type::BooleanLiteral(..))) + { + *self = Self::new(); + self.positive.insert(Type::BooleanLiteral(!value)); + return; + } + } } + + let mut to_remove = SmallVec::<[usize; 1]>::new(); + for (index, existing_positive) in self.positive.iter().enumerate() { + // S & T = S if S <: T + if existing_positive.is_subtype_of(db, new_positive) { + return; + } + // same rule, reverse order + if new_positive.is_subtype_of(db, *existing_positive) { + to_remove.push(index); + } + // A & B = Never if A and B are disjoint + if new_positive.is_disjoint_from(db, *existing_positive) { + *self = Self::new(); + return; + } + } + for index in to_remove.iter().rev() { + self.positive.swap_remove_index(*index); + } + + let mut to_remove = SmallVec::<[usize; 1]>::new(); + for (index, existing_negative) in self.negative.iter().enumerate() { + // S & ~T = Never if S <: T + if new_positive.is_subtype_of(db, *existing_negative) { + *self = Self::new(); + return; + } + // A & ~B = A if A and B are disjoint + if existing_negative.is_disjoint_from(db, new_positive) { + to_remove.push(index); + } + } + for index in to_remove.iter().rev() { + self.negative.swap_remove_index(*index); + } + + self.positive.insert(new_positive); } } /// Adds a negative type to this intersection. - fn add_negative(&mut self, db: &'db dyn Db, ty: Type<'db>) { - // TODO `Any`/`Unknown`/`Todo` actually should not self-cancel - match ty { - Type::Intersection(intersection) => { - let pos = intersection.negative(db); - let neg = intersection.positive(db); - self.positive.extend(pos.difference(&self.negative)); - self.negative.extend(neg.difference(&self.positive)); - self.positive.retain(|elem| !neg.contains(elem)); - self.negative.retain(|elem| !pos.contains(elem)); + fn add_negative(&mut self, db: &'db dyn Db, new_negative: Type<'db>) { + match new_negative { + Type::Intersection(inter) => { + for pos in inter.positive(db) { + self.add_negative(db, *pos); + } + for neg in inter.negative(db) { + self.add_positive(db, *neg); + } } - Type::Never => {} Type::Unbound => {} - _ => { - if !self.positive.remove(&ty) { - self.negative.insert(ty); - }; + ty @ (Type::Any | Type::Unknown | Type::Todo) => { + // Adding any of these types to the negative side of an intersection + // is equivalent to adding it to the positive side. We do this to + // simplify the representation. + self.positive.insert(ty); } - } - } + // ~Literal[True] & bool = Literal[False] + Type::BooleanLiteral(bool) + if self + .positive + .iter() + .any(|pos| *pos == KnownClass::Bool.to_instance(db)) => + { + *self = Self::new(); + self.positive.insert(Type::BooleanLiteral(!bool)); + } + _ => { + let mut to_remove = SmallVec::<[usize; 1]>::new(); + for (index, existing_negative) in self.negative.iter().enumerate() { + // ~S & ~T = ~T if S <: T + if existing_negative.is_subtype_of(db, new_negative) { + to_remove.push(index); + } + // same rule, reverse order + if new_negative.is_subtype_of(db, *existing_negative) { + return; + } + } + for index in to_remove.iter().rev() { + self.negative.swap_remove_index(*index); + } - fn simplify(&mut self) { - // TODO this should be generalized based on subtyping, for now we just handle a few cases + for existing_positive in &self.positive { + // S & ~T = Never if S <: T + if existing_positive.is_subtype_of(db, new_negative) { + *self = Self::new(); + return; + } + // A & ~B = A if A and B are disjoint + if existing_positive.is_disjoint_from(db, new_negative) { + return; + } + } - // Never is a subtype of all types - if self.positive.contains(&Type::Never) { - self.positive.retain(Type::is_never); - self.negative.clear(); + self.negative.insert(new_negative); + } } + } + fn simplify_unbound(&mut self) { if self.positive.contains(&Type::Unbound) { self.positive.retain(Type::is_unbound); self.negative.clear(); } - - // None intersects only with object - for pos in &self.positive { - if let Type::Instance(_) = pos { - // could be `object` type - } else { - self.negative.remove(&Type::None); - break; - } - } } fn build(mut self, db: &'db dyn Db) -> Type<'db> { - self.simplify(); + self.simplify_unbound(); match (self.positive.len(), self.negative.len()) { (0, 0) => Type::Never, (1, 0) => self.positive[0], @@ -302,9 +368,10 @@ mod tests { use crate::db::tests::TestDb; use crate::program::{Program, SearchPathSettings}; use crate::python_version::PythonVersion; - use crate::types::{KnownClass, UnionBuilder}; + use crate::types::{KnownClass, StringLiteralType, UnionBuilder}; use crate::ProgramSettings; use ruff_db::system::{DbWithTestSystem, SystemPathBuf}; + use test_case::test_case; fn setup_db() -> TestDb { let db = TestDb::new(); @@ -473,7 +540,7 @@ mod tests { .expect_intersection(); assert_eq!(intersection.pos_vec(&db), &[t2, ta]); - assert_eq!(intersection.neg_vec(&db), &[t1]); + assert_eq!(intersection.neg_vec(&db), &[]); } #[test] @@ -481,7 +548,7 @@ mod tests { let db = setup_db(); let ta = Type::Any; let t1 = Type::IntLiteral(1); - let t2 = Type::IntLiteral(2); + let t2 = KnownClass::Int.to_instance(&db); let i0 = IntersectionBuilder::new(&db) .add_positive(ta) .add_negative(t1) @@ -492,8 +559,8 @@ mod tests { .build() .expect_intersection(); - assert_eq!(intersection.pos_vec(&db), &[t2, t1]); - assert_eq!(intersection.neg_vec(&db), &[ta]); + assert_eq!(intersection.pos_vec(&db), &[ta, t1]); + assert_eq!(intersection.neg_vec(&db), &[]); } #[test] @@ -574,11 +641,269 @@ mod tests { #[test] fn build_intersection_simplify_negative_none() { let db = setup_db(); + let ty = IntersectionBuilder::new(&db) .add_negative(Type::None) .add_positive(Type::IntLiteral(1)) .build(); + assert_eq!(ty, Type::IntLiteral(1)); + let ty = IntersectionBuilder::new(&db) + .add_positive(Type::IntLiteral(1)) + .add_negative(Type::None) + .build(); assert_eq!(ty, Type::IntLiteral(1)); } + + #[test] + fn build_intersection_simplify_positive_type_and_positive_subtype() { + let db = setup_db(); + + let t = KnownClass::Str.to_instance(&db); + let s = Type::LiteralString; + + let ty = IntersectionBuilder::new(&db) + .add_positive(t) + .add_positive(s) + .build(); + assert_eq!(ty, s); + + let ty = IntersectionBuilder::new(&db) + .add_positive(s) + .add_positive(t) + .build(); + assert_eq!(ty, s); + + let literal = Type::StringLiteral(StringLiteralType::new(&db, "a")); + let expected = IntersectionBuilder::new(&db) + .add_positive(s) + .add_negative(literal) + .build(); + + let ty = IntersectionBuilder::new(&db) + .add_positive(t) + .add_negative(literal) + .add_positive(s) + .build(); + assert_eq!(ty, expected); + + let ty = IntersectionBuilder::new(&db) + .add_positive(s) + .add_negative(literal) + .add_positive(t) + .build(); + assert_eq!(ty, expected); + } + + #[test] + fn build_intersection_simplify_negative_type_and_negative_subtype() { + let db = setup_db(); + + let t = KnownClass::Str.to_instance(&db); + let s = Type::LiteralString; + + let expected = IntersectionBuilder::new(&db).add_negative(t).build(); + + let ty = IntersectionBuilder::new(&db) + .add_negative(t) + .add_negative(s) + .build(); + assert_eq!(ty, expected); + + let ty = IntersectionBuilder::new(&db) + .add_negative(s) + .add_negative(t) + .build(); + assert_eq!(ty, expected); + + let object = KnownClass::Object.to_instance(&db); + let expected = IntersectionBuilder::new(&db) + .add_negative(t) + .add_positive(object) + .build(); + + let ty = IntersectionBuilder::new(&db) + .add_negative(t) + .add_positive(object) + .add_negative(s) + .build(); + assert_eq!(ty, expected); + } + + #[test] + fn build_intersection_simplify_negative_type_and_multiple_negative_subtypes() { + let db = setup_db(); + + let s1 = Type::IntLiteral(1); + let s2 = Type::IntLiteral(2); + let t = KnownClass::Int.to_instance(&db); + + let expected = IntersectionBuilder::new(&db).add_negative(t).build(); + + let ty = IntersectionBuilder::new(&db) + .add_negative(s1) + .add_negative(s2) + .add_negative(t) + .build(); + assert_eq!(ty, expected); + } + + #[test] + fn build_intersection_simplify_negative_type_and_positive_subtype() { + let db = setup_db(); + + let t = KnownClass::Str.to_instance(&db); + let s = Type::LiteralString; + + let ty = IntersectionBuilder::new(&db) + .add_negative(t) + .add_positive(s) + .build(); + assert_eq!(ty, Type::Never); + + let ty = IntersectionBuilder::new(&db) + .add_positive(s) + .add_negative(t) + .build(); + assert_eq!(ty, Type::Never); + + // This should also work in the presence of additional contributions: + let ty = IntersectionBuilder::new(&db) + .add_positive(KnownClass::Object.to_instance(&db)) + .add_negative(t) + .add_positive(s) + .build(); + assert_eq!(ty, Type::Never); + + let ty = IntersectionBuilder::new(&db) + .add_positive(s) + .add_negative(Type::StringLiteral(StringLiteralType::new(&db, "a"))) + .add_negative(t) + .build(); + assert_eq!(ty, Type::Never); + } + + #[test] + fn build_intersection_simplify_disjoint_positive_types() { + let db = setup_db(); + + let t1 = Type::IntLiteral(1); + let t2 = Type::None; + + let ty = IntersectionBuilder::new(&db) + .add_positive(t1) + .add_positive(t2) + .build(); + assert_eq!(ty, Type::Never); + + // If there are any negative contributions, they should + // be removed too. + let ty = IntersectionBuilder::new(&db) + .add_positive(KnownClass::Str.to_instance(&db)) + .add_negative(Type::LiteralString) + .add_positive(t2) + .build(); + assert_eq!(ty, Type::Never); + } + + #[test] + fn build_intersection_simplify_disjoint_positive_and_negative_types() { + let db = setup_db(); + + let t_p = KnownClass::Int.to_instance(&db); + let t_n = Type::StringLiteral(StringLiteralType::new(&db, "t_n")); + + let ty = IntersectionBuilder::new(&db) + .add_positive(t_p) + .add_negative(t_n) + .build(); + assert_eq!(ty, t_p); + + let ty = IntersectionBuilder::new(&db) + .add_negative(t_n) + .add_positive(t_p) + .build(); + assert_eq!(ty, t_p); + + let int_literal = Type::IntLiteral(1); + let expected = IntersectionBuilder::new(&db) + .add_positive(t_p) + .add_negative(int_literal) + .build(); + + let ty = IntersectionBuilder::new(&db) + .add_positive(t_p) + .add_negative(int_literal) + .add_negative(t_n) + .build(); + assert_eq!(ty, expected); + + let ty = IntersectionBuilder::new(&db) + .add_negative(t_n) + .add_negative(int_literal) + .add_positive(t_p) + .build(); + assert_eq!(ty, expected); + } + + #[test_case(true)] + #[test_case(false)] + fn build_intersection_simplify_split_bool(bool_value: bool) { + let db = setup_db(); + + let t_bool = KnownClass::Bool.to_instance(&db); + let t_bool_literal = Type::BooleanLiteral(bool_value); + + // We add t_object in various orders (in first or second position) in + // the tests below to ensure that the boolean simplification eliminates + // everything from the intersection, not just `bool`. + let t_object = KnownClass::Object.to_instance(&db); + + let ty = IntersectionBuilder::new(&db) + .add_positive(t_object) + .add_positive(t_bool) + .add_negative(t_bool_literal) + .build(); + assert_eq!(ty, Type::BooleanLiteral(!bool_value)); + + let ty = IntersectionBuilder::new(&db) + .add_positive(t_bool) + .add_positive(t_object) + .add_negative(t_bool_literal) + .build(); + assert_eq!(ty, Type::BooleanLiteral(!bool_value)); + + let ty = IntersectionBuilder::new(&db) + .add_positive(t_object) + .add_negative(t_bool_literal) + .add_positive(t_bool) + .build(); + assert_eq!(ty, Type::BooleanLiteral(!bool_value)); + + let ty = IntersectionBuilder::new(&db) + .add_negative(t_bool_literal) + .add_positive(t_object) + .add_positive(t_bool) + .build(); + assert_eq!(ty, Type::BooleanLiteral(!bool_value)); + } + + #[test_case(Type::Any)] + #[test_case(Type::Unknown)] + #[test_case(Type::Todo)] + fn build_intersection_t_and_negative_t_does_not_simplify(ty: Type) { + let db = setup_db(); + + let result = IntersectionBuilder::new(&db) + .add_positive(ty) + .add_negative(ty) + .build(); + assert_eq!(result, ty); + + let result = IntersectionBuilder::new(&db) + .add_negative(ty) + .add_positive(ty) + .build(); + assert_eq!(result, ty); + } }