Skip to content

Commit

Permalink
[red knot] add Type::is_disjoint_from and intersection simplificati…
Browse files Browse the repository at this point in the history
…ons (#13775)

## Summary

- Add `Type::is_disjoint_from` as a way to test whether two types
overlap
- Add a first set of simplification rules for intersection types
  - `S & T = S` for `S <: T`
  - `S & ~T = Never` for `S <: T`
  - `~S & ~T = ~T` for `S <: T`
  - `A & ~B = A` for `A` disjoint from `B`
  - `A & B = Never` for `A` disjoint from `B`
  - `bool & ~Literal[bool] = Literal[!bool]`

resolves one item in #12694

## Open questions:

- Can we somehow leverage the (anti) symmetry between `positive` and
`negative` contributions? I could imagine that there would be a way if
we had `Type::Not(type)`/`Type::Negative(type)`, but with the
`positive`/`negative` architecture, I'm not sure. Note that there is a
certain duplication in the `add_positive`/`add_negative` functions (e.g.
`S & ~T = Never` is implemented twice), but other rules are actually not
perfectly symmetric: `S & T = S` vs `~S & ~T = ~T`.
- I'm not particularly proud of the way `add_positive`/`add_negative`
turned out. They are long imperative-style functions with some
mutability mixed in (`to_remove`). I'm happy to look into ways to
improve this code *if we decide to go with this approach* of
implementing a set of ad-hoc rules for simplification.
- ~~Is it useful to perform simplifications eagerly in
`add_positive`/`add_negative`? (@carljm)~~ This is what I did for now.

## Test Plan

- Unit tests for `Type::is_disjoint_from`
- Observe changes in Markdown-based tests
- Unit tests for `IntersectionBuilder::build()`

---------

Co-authored-by: Carl Meyer <[email protected]>
  • Loading branch information
sharkdp and carljm authored Oct 18, 2024
1 parent c93a7c7 commit 6964eef
Show file tree
Hide file tree
Showing 5 changed files with 626 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
x = None if flag else 1

if x is None:
# TODO the following should be simplified to 'None'
reveal_type(x) # revealed: None | Literal[1] & None
reveal_type(x) # revealed: None

reveal_type(x) # revealed: None | Literal[1]
```
Expand All @@ -22,8 +21,7 @@ x = A()
y = x if flag else None

if y is x:
# TODO the following should be simplified to 'A'
reveal_type(y) # revealed: A | None & A
reveal_type(y) # revealed: A

reveal_type(y) # revealed: A | None
```
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ x = True if flag else False
reveal_type(x) # revealed: bool

if x is not False:
# TODO the following should be `Literal[True]`
reveal_type(x) # revealed: bool & ~Literal[False]
reveal_type(x) # revealed: Literal[True]
```

## `is not` for non-singleton types
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,5 @@ match x:
case None:
y = x

# TODO intersection simplification: should be just Literal[0] | None
reveal_type(y) # revealed: Literal[0] | None | Literal[1] & None
reveal_type(y) # revealed: Literal[0] | None
```
246 changes: 244 additions & 2 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,150 @@ impl<'db> Type<'db> {
self == other
}

/// Return true if this type and `other` have no common elements.
///
/// Note: This function aims to have no false positives, but might return
/// wrong `false` answers in some cases.
pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool {
match (self, other) {
(Type::Never, _) | (_, Type::Never) => true,

(Type::Any, _) | (_, Type::Any) => false,
(Type::Unknown, _) | (_, Type::Unknown) => false,
(Type::Unbound, _) | (_, Type::Unbound) => false,
(Type::Todo, _) | (_, Type::Todo) => false,

(Type::Union(union), other) | (other, Type::Union(union)) => union
.elements(db)
.iter()
.all(|e| e.is_disjoint_from(db, other)),

(Type::Intersection(intersection), other)
| (other, Type::Intersection(intersection)) => {
if intersection
.positive(db)
.iter()
.any(|p| p.is_disjoint_from(db, other))
{
true
} else {
// TODO we can do better here. For example:
// X & ~Literal[1] is disjoint from Literal[1]
false
}
}

(
left @ (Type::None
| Type::BooleanLiteral(..)
| Type::IntLiteral(..)
| Type::StringLiteral(..)
| Type::BytesLiteral(..)
| Type::Function(..)
| Type::Module(..)
| Type::Class(..)),
right @ (Type::None
| Type::BooleanLiteral(..)
| Type::IntLiteral(..)
| Type::StringLiteral(..)
| Type::BytesLiteral(..)
| Type::Function(..)
| Type::Module(..)
| Type::Class(..)),
) => left != right,

(Type::None, Type::Instance(class_type)) | (Type::Instance(class_type), Type::None) => {
!matches!(
class_type.known(db),
Some(KnownClass::NoneType | KnownClass::Object)
)
}
(Type::None, _) | (_, Type::None) => true,

(Type::BooleanLiteral(..), Type::Instance(class_type))
| (Type::Instance(class_type), Type::BooleanLiteral(..)) => !matches!(
class_type.known(db),
Some(KnownClass::Bool | KnownClass::Int | KnownClass::Object)
),
(Type::BooleanLiteral(..), _) | (_, Type::BooleanLiteral(..)) => true,

(Type::IntLiteral(..), Type::Instance(class_type))
| (Type::Instance(class_type), Type::IntLiteral(..)) => !matches!(
class_type.known(db),
Some(KnownClass::Int | KnownClass::Object)
),
(Type::IntLiteral(..), _) | (_, Type::IntLiteral(..)) => true,

(Type::StringLiteral(..), Type::LiteralString)
| (Type::LiteralString, Type::StringLiteral(..)) => false,
(Type::StringLiteral(..), Type::Instance(class_type))
| (Type::Instance(class_type), Type::StringLiteral(..)) => !matches!(
class_type.known(db),
Some(KnownClass::Str | KnownClass::Object)
),
(Type::StringLiteral(..), _) | (_, Type::StringLiteral(..)) => true,

(Type::LiteralString, Type::LiteralString) => false,
(Type::LiteralString, Type::Instance(class_type))
| (Type::Instance(class_type), Type::LiteralString) => !matches!(
class_type.known(db),
Some(KnownClass::Str | KnownClass::Object)
),
(Type::LiteralString, _) | (_, Type::LiteralString) => true,

(Type::BytesLiteral(..), Type::Instance(class_type))
| (Type::Instance(class_type), Type::BytesLiteral(..)) => !matches!(
class_type.known(db),
Some(KnownClass::Bytes | KnownClass::Object)
),
(Type::BytesLiteral(..), _) | (_, Type::BytesLiteral(..)) => true,

(
Type::Function(..) | Type::Module(..) | Type::Class(..),
Type::Instance(class_type),
)
| (
Type::Instance(class_type),
Type::Function(..) | Type::Module(..) | Type::Class(..),
) => !class_type.is_known(db, KnownClass::Object),

(Type::Instance(..), Type::Instance(..)) => {
// TODO: once we have support for `final`, there might be some cases where
// we can determine that two types are disjoint. For non-final classes, we
// return false (multiple inheritance).

// TODO: is there anything specific to do for instances of KnownClass::Type?

false
}

(Type::Tuple(tuple), other) | (other, Type::Tuple(tuple)) => {
if let Type::Tuple(other_tuple) = other {
if tuple.len(db) == other_tuple.len(db) {
tuple
.elements(db)
.iter()
.zip(other_tuple.elements(db).iter())
.any(|(e1, e2)| e1.is_disjoint_from(db, *e2))
} else {
true
}
} else {
// We can not be sure if the tuple is disjoint from 'other' because:
// - 'other' might be the homogeneous arbitrary-length tuple type
// tuple[T, ...] (which we don't have support for yet); if all of
// our element types are not disjoint with T, this is not disjoint
// - 'other' might be a user subtype of tuple, which, if generic
// over the same or compatible *Ts, would overlap with tuple.
//
// TODO: add checks for the above cases once we support them

false
}
}
}
}

/// Return true if there is just a single inhabitant for this type.
///
/// Note: This function aims to have no false positives, but might return `false`
Expand Down Expand Up @@ -1558,8 +1702,8 @@ impl<'db> TupleType<'db> {
#[cfg(test)]
mod tests {
use super::{
builtins_symbol_ty, BytesLiteralType, StringLiteralType, Truthiness, TupleType, Type,
UnionType,
builtins_symbol_ty, BytesLiteralType, IntersectionBuilder, StringLiteralType, Truthiness,
TupleType, Type, UnionType,
};
use crate::db::tests::TestDb;
use crate::program::{Program, SearchPathSettings};
Expand Down Expand Up @@ -1603,6 +1747,7 @@ mod tests {
BytesLiteral(&'static str),
BuiltinInstance(&'static str),
Union(Vec<Ty>),
Intersection { pos: Vec<Ty>, neg: Vec<Ty> },
Tuple(Vec<Ty>),
}

Expand All @@ -1622,6 +1767,16 @@ mod tests {
Ty::Union(tys) => {
UnionType::from_elements(db, tys.into_iter().map(|ty| ty.into_type(db)))
}
Ty::Intersection { pos, neg } => {
let mut builder = IntersectionBuilder::new(db);
for p in pos {
builder = builder.add_positive(p.into_type(db));
}
for n in neg {
builder = builder.add_negative(n.into_type(db));
}
builder.build()
}
Ty::Tuple(tys) => {
let elements: Box<_> = tys.into_iter().map(|ty| ty.into_type(db)).collect();
Type::Tuple(TupleType::new(db, elements))
Expand Down Expand Up @@ -1697,6 +1852,93 @@ mod tests {
assert!(from.into_type(&db).is_equivalent_to(&db, to.into_type(&db)));
}

#[test_case(Ty::Never, Ty::Never)]
#[test_case(Ty::Never, Ty::None)]
#[test_case(Ty::Never, Ty::BuiltinInstance("int"))]
#[test_case(Ty::None, Ty::BoolLiteral(true))]
#[test_case(Ty::None, Ty::IntLiteral(1))]
#[test_case(Ty::None, Ty::StringLiteral("test"))]
#[test_case(Ty::None, Ty::BytesLiteral("test"))]
#[test_case(Ty::None, Ty::LiteralString)]
#[test_case(Ty::None, Ty::BuiltinInstance("int"))]
#[test_case(Ty::None, Ty::Tuple(vec![Ty::None]))]
#[test_case(Ty::BoolLiteral(true), Ty::BoolLiteral(false))]
#[test_case(Ty::BoolLiteral(true), Ty::Tuple(vec![Ty::None]))]
#[test_case(Ty::BoolLiteral(true), Ty::IntLiteral(1))]
#[test_case(Ty::BoolLiteral(false), Ty::IntLiteral(0))]
#[test_case(Ty::IntLiteral(1), Ty::IntLiteral(2))]
#[test_case(Ty::IntLiteral(1), Ty::Tuple(vec![Ty::None]))]
#[test_case(Ty::StringLiteral("a"), Ty::StringLiteral("b"))]
#[test_case(Ty::StringLiteral("a"), Ty::Tuple(vec![Ty::None]))]
#[test_case(Ty::LiteralString, Ty::BytesLiteral("a"))]
#[test_case(Ty::BytesLiteral("a"), Ty::BytesLiteral("b"))]
#[test_case(Ty::BytesLiteral("a"), Ty::Tuple(vec![Ty::None]))]
#[test_case(Ty::BytesLiteral("a"), Ty::StringLiteral("a"))]
#[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::IntLiteral(3))]
#[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Union(vec![Ty::IntLiteral(3), Ty::IntLiteral(4)]))]
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int"), Ty::IntLiteral(1)], neg: vec![]}, Ty::IntLiteral(2))]
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(1)]), Ty::Tuple(vec![Ty::IntLiteral(2)]))]
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Tuple(vec![Ty::IntLiteral(1)]))]
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Tuple(vec![Ty::IntLiteral(1), Ty::IntLiteral(3)]))]
fn is_disjoint_from(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
let b = b.into_type(&db);

assert!(a.is_disjoint_from(&db, b));
assert!(b.is_disjoint_from(&db, a));
}

#[test_case(Ty::Any, Ty::BuiltinInstance("int"))]
#[test_case(Ty::None, Ty::None)]
#[test_case(Ty::None, Ty::BuiltinInstance("object"))]
#[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinInstance("int"))]
#[test_case(Ty::BuiltinInstance("str"), Ty::LiteralString)]
#[test_case(Ty::BoolLiteral(true), Ty::BoolLiteral(true))]
#[test_case(Ty::BoolLiteral(false), Ty::BoolLiteral(false))]
#[test_case(Ty::BoolLiteral(true), Ty::BuiltinInstance("bool"))]
#[test_case(Ty::BoolLiteral(true), Ty::BuiltinInstance("int"))]
#[test_case(Ty::IntLiteral(1), Ty::IntLiteral(1))]
#[test_case(Ty::StringLiteral("a"), Ty::StringLiteral("a"))]
#[test_case(Ty::StringLiteral("a"), Ty::LiteralString)]
#[test_case(Ty::StringLiteral("a"), Ty::BuiltinInstance("str"))]
#[test_case(Ty::LiteralString, Ty::LiteralString)]
#[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::IntLiteral(2))]
#[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Union(vec![Ty::IntLiteral(2), Ty::IntLiteral(3)]))]
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int"), Ty::IntLiteral(2)], neg: vec![]}, Ty::IntLiteral(2))]
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Tuple(vec![Ty::IntLiteral(1), Ty::BuiltinInstance("int")]))]
fn is_not_disjoint_from(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
let b = b.into_type(&db);

assert!(!a.is_disjoint_from(&db, b));
assert!(!b.is_disjoint_from(&db, a));
}

#[test]
fn is_disjoint_from_union_of_class_types() {
let mut db = setup_db();
db.write_dedented(
"/src/module.py",
"
class A: ...
class B: ...
x = A if flag else B
",
)
.unwrap();
let module = ruff_db::files::system_path_to_file(&db, "/src/module.py").unwrap();

let type_a = super::global_symbol_ty(&db, module, "A");
let type_x = super::global_symbol_ty(&db, module, "x");

assert!(matches!(type_a, Type::Class(_)));
assert!(matches!(type_x, Type::Union(_)));

assert!(!type_a.is_disjoint_from(&db, type_x));
}

#[test_case(Ty::None)]
#[test_case(Ty::BoolLiteral(true))]
#[test_case(Ty::BoolLiteral(false))]
Expand Down
Loading

0 comments on commit 6964eef

Please sign in to comment.