Skip to content

perf: try to cache inner contexts of overloads #19408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@
"builtins.memoryview",
}

POISON_KEY: Final = (-1,)


class TooManyUnions(Exception):
"""Indicates that we need to stop splitting unions in an attempt
Expand Down Expand Up @@ -356,7 +358,12 @@ def __init__(

self._arg_infer_context_cache = None

self.overload_stack_depth = 0
self._args_cache: dict[tuple[int, ...], list[Type]] = {}

def reset(self) -> None:
assert self.overload_stack_depth == 0
assert not self._args_cache
self.resolved_type = {}

def visit_name_expr(self, e: NameExpr) -> Type:
Expand Down Expand Up @@ -1613,9 +1620,10 @@ def check_call(
object_type,
)
elif isinstance(callee, Overloaded):
return self.check_overload_call(
callee, args, arg_kinds, arg_names, callable_name, object_type, context
)
with self.overload_context():
return self.check_overload_call(
callee, args, arg_kinds, arg_names, callable_name, object_type, context
)
elif isinstance(callee, AnyType) or not self.chk.in_checked_function():
return self.check_any_type_call(args, callee)
elif isinstance(callee, UnionType):
Expand Down Expand Up @@ -1674,6 +1682,14 @@ def check_call(
else:
return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error)

@contextmanager
def overload_context(self) -> Iterator[None]:
self.overload_stack_depth += 1
yield
self.overload_stack_depth -= 1
if self.overload_stack_depth == 0:
self._args_cache.clear()

def check_callable_call(
self,
callee: CallableType,
Expand Down Expand Up @@ -1937,6 +1953,17 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]
In short, we basically recurse on each argument without considering
in what context the argument was called.
"""
# We can only use this hack locally while checking a single nested overloaded
# call. This saves a lot of rechecking, but is not generally safe. Cache is
# pruned upon leaving the outermost overload.
can_cache = (
self.overload_stack_depth > 0
and POISON_KEY not in self._args_cache
and not any(isinstance(t, TempNode) for t in args)
)
key = tuple(map(id, args))
if can_cache and key in self._args_cache:
return self._args_cache[key]
res: list[Type] = []

for arg in args:
Expand All @@ -1945,6 +1972,8 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]
res.append(NoneType())
else:
res.append(arg_type)
if can_cache:
self._args_cache[key] = res
return res

def infer_more_unions_for_recursive_type(self, type_context: Type) -> bool:
Expand Down Expand Up @@ -2917,17 +2946,16 @@ def infer_overload_return_type(

for typ in plausible_targets:
assert self.msg is self.chk.msg
with self.msg.filter_errors() as w:
with self.chk.local_type_map() as m:
ret_type, infer_type = self.check_call(
callee=typ,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
with self.msg.filter_errors() as w, self.chk.local_type_map() as m:
ret_type, infer_type = self.check_call(
callee=typ,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
is_match = not w.has_new_errors()
if is_match:
# Return early if possible; otherwise record info, so we can
Expand Down Expand Up @@ -3474,6 +3502,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
return self.strfrm_checker.check_str_interpolation(e.left, e.right)
if isinstance(e.left, StrExpr):
return self.strfrm_checker.check_str_interpolation(e.left, e.right)

left_type = self.accept(e.left)

proper_left_type = get_proper_type(left_type)
Expand Down Expand Up @@ -5401,6 +5430,9 @@ def find_typeddict_context(

def visit_lambda_expr(self, e: LambdaExpr) -> Type:
"""Type check lambda expression."""
if self.overload_stack_depth > 0:
# Poison cache when we encounter lambdas - it isn't safe to cache their types.
self._args_cache[POISON_KEY] = []
self.chk.check_default_args(e, body_is_trivial=False)
inferred_type, type_override = self.infer_lambda_type_using_context(e)
if not inferred_type:
Expand Down