Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix typing for @cached wrapped functions #8240

Merged
merged 7 commits into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8240.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix type hints for functions decorated with `@cached`.
3 changes: 2 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[mypy]
namespace_packages = True
plugins = mypy_zope:plugin
plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
follow_imports = silent
check_untyped_defs = True
show_error_codes = True
Expand Down Expand Up @@ -51,6 +51,7 @@ files =
synapse/storage/util,
synapse/streams,
synapse/types.py,
synapse/util/caches/descriptors.py,
synapse/util/caches/stream_change_cache.py,
synapse/util/metrics.py,
tests/replication,
Expand Down
85 changes: 85 additions & 0 deletions scripts-dev/mypy_synapse_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This is a mypy plugin for Synpase to deal with some of the funky typing that
can crop up, e.g the cache descriptors.
"""

from typing import Callable, Optional

from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self
from mypy.types import CallableType


class SynapsePlugin(Plugin):
def get_method_signature_hook(
self, fullname: str
) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith(
"synapse.util.caches.descriptors._CachedFunction.__call__"
):
return cached_function_method_signature
return None


def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
"""Fixes the `_CachedFunction.__call__` signature to be correct.

It already has *almost* the correct signature, except:

1. the `self` argument needs to be marked as "bound"; and
2. any `cache_context` argument should be removed.
"""

# First we mark this as a bound function signature.
signature = bind_self(ctx.default_signature)

# Secondly, we remove any "cache_context" args.
#
# Note: We should be only doing this if `cache_context=True` is set, but if
# it isn't then the code will raise an exception when its called anyway, so
# its not the end of the world.
Comment on lines +52 to +54
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This vaguely feels like how overload works, but probably isn't worth the trouble figuring out here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, its entirely doable I think by threading through the cache_context literal, but yeah, it's not worth it.

context_arg_index = None
for idx, name in enumerate(signature.arg_names):
if name == "cache_context":
context_arg_index = idx
break

if context_arg_index:
arg_types = list(signature.arg_types)
arg_types.pop(context_arg_index)

arg_names = list(signature.arg_names)
arg_names.pop(context_arg_index)

arg_kinds = list(signature.arg_kinds)
arg_kinds.pop(context_arg_index)

signature = signature.copy_modified(
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
)

return signature


def plugin(version: str):
# This is the entry point of the plugin, and let's us deal with the fact
# that the mypy plugin interface is *not* stable by looking at the version
# string.
#
# However, since we pin the version of mypy Synapse uses in CI, we don't
# really care.
return SynapsePlugin
10 changes: 5 additions & 5 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,11 @@ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
if not prevs - seen:
return

latest = await self.store.get_latest_event_ids_in_room(room_id)
latest_list = await self.store.get_latest_event_ids_in_room(room_id)

# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
latest = set(latest)
latest = set(latest_list)
latest |= seen

logger.info(
Expand Down Expand Up @@ -781,7 +781,7 @@ async def _process_received_pdu(
# keys across all devices.
current_keys = [
key
for device in cached_devices
for device in cached_devices.values()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might have just been a plain bug, though I don't see it getting hit on matrix.org :/

for key in device.get("keys", {}).get("keys", {}).values()
]

Expand Down Expand Up @@ -2119,8 +2119,8 @@ async def _check_for_soft_fail(
if backfilled or event.internal_metadata.is_outlier():
return

extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids)
extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids_list)
prev_event_ids = set(event.prev_event_ids())

if extrem_ids == prev_event_ids:
Expand Down
42 changes: 28 additions & 14 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
import inspect
import logging
import threading
from typing import Any, Tuple, Union, cast
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
from weakref import WeakValueDictionary

from prometheus_client import Gauge
from typing_extensions import Protocol

from twisted.internet import defer

Expand All @@ -38,17 +37,22 @@

CacheKey = Union[Tuple, Any]

F = TypeVar("F", bound=Callable[..., Any])

class _CachedFunction(Protocol):

class _CachedFunction(Generic[F]):
invalidate = None # type: Any
invalidate_all = None # type: Any
invalidate_many = None # type: Any
prefill = None # type: Any
cache = None # type: Any
num_args = None # type: Any

def __name__(self):
...
__name__ = None # type: str

# Note: This function signature is actually fiddled with by the synapse mypy
# plugin to a) make it a bound method, and b) remove any `cache_context` arg.
__call__ = None # type: F


cache_pending_metric = Gauge(
Expand Down Expand Up @@ -123,7 +127,7 @@ def __init__(

self.name = name
self.keylen = keylen
self.thread = None
self.thread = None # type: Optional[threading.Thread]
self.metrics = register_cache(
"cache",
name,
Expand Down Expand Up @@ -662,9 +666,13 @@ def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheCon


def cached(
max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
):
return lambda orig: CacheDescriptor(
max_entries: int = 1000,
num_args: Optional[int] = None,
tree: bool = False,
cache_context: bool = False,
iterable: bool = False,
) -> Callable[[F], _CachedFunction[F]]:
func = lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
Expand All @@ -673,8 +681,12 @@ def cached(
iterable=iterable,
)

return cast(Callable[[F], _CachedFunction[F]], func)

def cachedList(cached_method_name, list_name, num_args=None):

def cachedList(
cached_method_name: str, list_name: str, num_args: Optional[int] = None
) -> Callable[[F], _CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.

Used to do batch lookups for an already created cache. A single argument
Expand All @@ -684,11 +696,11 @@ def cachedList(cached_method_name, list_name, num_args=None):
cache.

Args:
cached_method_name (str): The name of the single-item lookup method.
cached_method_name: The name of the single-item lookup method.
This is only used to find the cache to use.
list_name (str): The name of the argument that is the list to use to
list_name: The name of the argument that is the list to use to
do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache
num_args: Number of arguments to use as the key in the cache
(including list_name). Defaults to all named parameters.

Example:
Expand All @@ -702,9 +714,11 @@ def do_something(self, first_arg):
def batch_do_something(self, first_arg, second_args):
...
"""
return lambda orig: CacheListDescriptor(
func = lambda orig: CacheListDescriptor(
orig,
cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
)

return cast(Callable[[F], _CachedFunction[F]], func)