Skip to content

Commit

Permalink
Optimize TypeIndirectionVisitor
Browse files Browse the repository at this point in the history
This was a performance bottleneck when type checking torch. It used
to perform lots of set unions and hash value calculations on
mypy type objects, which are both pretty expensive. Now we mostly
rely on set contains and set add operations with strings, which are
much faster. We also avoid constructing many temporary objects.

Speeds up type checking torch by about 3%. Also appears to speed up
self check by about 2%.
  • Loading branch information
JukkaL committed Dec 16, 2024
1 parent be87d3d commit 0b04f3a
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 64 deletions.
140 changes: 78 additions & 62 deletions mypy/indirection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Iterable, Set
from typing import Iterable

import mypy.types as types
from mypy.types import TypeVisitor
Expand All @@ -17,105 +17,121 @@ def extract_module_names(type_name: str | None) -> list[str]:
return []


class TypeIndirectionVisitor(TypeVisitor[Set[str]]):
class TypeIndirectionVisitor(TypeVisitor[None]):
"""Returns all module references within a particular type."""

def __init__(self) -> None:
self.cache: dict[types.Type, set[str]] = {}
# Module references are collected here
self.modules: set[str] = set()
# User to avoid infinite recursion with recursive type aliases
self.seen_aliases: set[types.TypeAliasType] = set()
# The following two are used to avoid redundant work
self.seen_fullnames: set[str] = set()
self.seen_module_names: set[str] = set()

def find_modules(self, typs: Iterable[types.Type]) -> set[str]:
self.seen_aliases.clear()
return self._visit(typs)

def _visit(self, typ_or_typs: types.Type | Iterable[types.Type]) -> set[str]:
self.modules = set()
self.seen_fullnames = set()
self.seen_module_names = set()
self.seen_aliases = set()
self._visit(typs)
return self.modules

def _visit(self, typ_or_typs: types.Type | Iterable[types.Type]) -> None:
typs = [typ_or_typs] if isinstance(typ_or_typs, types.Type) else typ_or_typs
output: set[str] = set()
for typ in typs:
if isinstance(typ, types.TypeAliasType):
# Avoid infinite recursion for recursive type aliases.
if typ in self.seen_aliases:
continue
self.seen_aliases.add(typ)
if typ in self.cache:
modules = self.cache[typ]
else:
modules = typ.accept(self)
self.cache[typ] = set(modules)
output.update(modules)
return output
typ.accept(self)

def _visit_module_name(self, module_name: str) -> None:
if module_name not in self.seen_module_names:
self.modules.update(split_module_names(module_name))
self.seen_module_names.add(module_name)

def visit_unbound_type(self, t: types.UnboundType) -> set[str]:
return self._visit(t.args)
def visit_unbound_type(self, t: types.UnboundType) -> None:
self._visit(t.args)

def visit_any(self, t: types.AnyType) -> set[str]:
return set()
def visit_any(self, t: types.AnyType) -> None:
pass

def visit_none_type(self, t: types.NoneType) -> set[str]:
return set()
def visit_none_type(self, t: types.NoneType) -> None:
pass

def visit_uninhabited_type(self, t: types.UninhabitedType) -> set[str]:
return set()
def visit_uninhabited_type(self, t: types.UninhabitedType) -> None:
pass

def visit_erased_type(self, t: types.ErasedType) -> set[str]:
return set()
def visit_erased_type(self, t: types.ErasedType) -> None:
pass

def visit_deleted_type(self, t: types.DeletedType) -> set[str]:
return set()
def visit_deleted_type(self, t: types.DeletedType) -> None:
pass

def visit_type_var(self, t: types.TypeVarType) -> set[str]:
return self._visit(t.values) | self._visit(t.upper_bound) | self._visit(t.default)
def visit_type_var(self, t: types.TypeVarType) -> None:
self._visit(t.values)
self._visit(t.upper_bound)
self._visit(t.default)

def visit_param_spec(self, t: types.ParamSpecType) -> set[str]:
return self._visit(t.upper_bound) | self._visit(t.default)
def visit_param_spec(self, t: types.ParamSpecType) -> None:
self._visit(t.upper_bound)
self._visit(t.default)

def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> set[str]:
return self._visit(t.upper_bound) | self._visit(t.default)
def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> None:
self._visit(t.upper_bound)
self._visit(t.default)

def visit_unpack_type(self, t: types.UnpackType) -> set[str]:
return t.type.accept(self)
def visit_unpack_type(self, t: types.UnpackType) -> None:
t.type.accept(self)

def visit_parameters(self, t: types.Parameters) -> set[str]:
return self._visit(t.arg_types)
def visit_parameters(self, t: types.Parameters) -> None:
self._visit(t.arg_types)

def visit_instance(self, t: types.Instance) -> set[str]:
out = self._visit(t.args)
def visit_instance(self, t: types.Instance) -> None:
self._visit(t.args)
if t.type:
# Uses of a class depend on everything in the MRO,
# as changes to classes in the MRO can add types to methods,
# change property types, change the MRO itself, etc.
for s in t.type.mro:
out.update(split_module_names(s.module_name))
self._visit_module_name(s.module_name)
if t.type.metaclass_type is not None:
out.update(split_module_names(t.type.metaclass_type.type.module_name))
return out
self._visit_module_name(t.type.metaclass_type.type.module_name)

def visit_callable_type(self, t: types.CallableType) -> set[str]:
out = self._visit(t.arg_types) | self._visit(t.ret_type)
def visit_callable_type(self, t: types.CallableType) -> None:
self._visit(t.arg_types)
self._visit(t.ret_type)
if t.definition is not None:
out.update(extract_module_names(t.definition.fullname))
return out
fullname = t.definition.fullname
if fullname not in self.seen_fullnames:
self.modules.update(extract_module_names(t.definition.fullname))
self.seen_fullnames.add(fullname)

def visit_overloaded(self, t: types.Overloaded) -> set[str]:
return self._visit(t.items) | self._visit(t.fallback)
def visit_overloaded(self, t: types.Overloaded) -> None:
self._visit(t.items)
self._visit(t.fallback)

def visit_tuple_type(self, t: types.TupleType) -> set[str]:
return self._visit(t.items) | self._visit(t.partial_fallback)
def visit_tuple_type(self, t: types.TupleType) -> None:
self._visit(t.items)
self._visit(t.partial_fallback)

def visit_typeddict_type(self, t: types.TypedDictType) -> set[str]:
return self._visit(t.items.values()) | self._visit(t.fallback)
def visit_typeddict_type(self, t: types.TypedDictType) -> None:
self._visit(t.items.values())
self._visit(t.fallback)

def visit_literal_type(self, t: types.LiteralType) -> set[str]:
return self._visit(t.fallback)
def visit_literal_type(self, t: types.LiteralType) -> None:
self._visit(t.fallback)

def visit_union_type(self, t: types.UnionType) -> set[str]:
return self._visit(t.items)
def visit_union_type(self, t: types.UnionType) -> None:
self._visit(t.items)

def visit_partial_type(self, t: types.PartialType) -> set[str]:
return set()
def visit_partial_type(self, t: types.PartialType) -> None:
pass

def visit_type_type(self, t: types.TypeType) -> set[str]:
return self._visit(t.item)
def visit_type_type(self, t: types.TypeType) -> None:
self._visit(t.item)

def visit_type_alias_type(self, t: types.TypeAliasType) -> set[str]:
return self._visit(types.get_proper_type(t))
def visit_type_alias_type(self, t: types.TypeAliasType) -> None:
self._visit(types.get_proper_type(t))
6 changes: 4 additions & 2 deletions mypy/test/testtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,14 @@ def test_recursive_nested_in_non_recursive(self) -> None:
def test_indirection_no_infinite_recursion(self) -> None:
A, _ = self.fx.def_alias_1(self.fx.a)
visitor = TypeIndirectionVisitor()
modules = A.accept(visitor)
A.accept(visitor)
modules = visitor.modules
assert modules == {"__main__", "builtins"}

A, _ = self.fx.def_alias_2(self.fx.a)
visitor = TypeIndirectionVisitor()
modules = A.accept(visitor)
A.accept(visitor)
modules = visitor.modules
assert modules == {"__main__", "builtins"}


Expand Down

0 comments on commit 0b04f3a

Please sign in to comment.