The following modules now pass `disallow_untyped_defs`: * synapse.util.caches.cached_call * synapse.util.caches.lrucache * synapse.util.caches.response_cache * synapse.util.caches.stream_change_cache * synapse.util.caches.ttlcache pass * synapse.util.daemonize * synapse.util.patch_inline_callbacks pass `no-untyped-defs` * synapse.util.versionstring Additional typing in synapse.util.metrics. Didn't get this to pass `no-untyped-defs`, think I'll need to watch #10847tags/v1.45.0rc1
@@ -0,0 +1 @@ | |||
Improve type hinting in `synapse.util`. |
@@ -102,9 +102,27 @@ disallow_untyped_defs = True | |||
[mypy-synapse.util.batching_queue] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.caches.cached_call] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.caches.dictionary_cache] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.caches.lrucache] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.caches.response_cache] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.caches.stream_change_cache] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.caches.ttl_cache] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.daemonize] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.file_consumer] | |||
disallow_untyped_defs = True | |||
@@ -141,6 +159,9 @@ disallow_untyped_defs = True | |||
[mypy-synapse.util.msisdn] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.patch_inline_callbacks] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.ratelimitutils] | |||
disallow_untyped_defs = True | |||
@@ -162,6 +183,9 @@ disallow_untyped_defs = True | |||
[mypy-synapse.util.wheel_timer] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.versionstring] | |||
disallow_untyped_defs = True | |||
[mypy-tests.handlers.test_user_directory] | |||
disallow_untyped_defs = True | |||
@@ -85,7 +85,7 @@ class CachedCall(Generic[TV]): | |||
# result in the deferred, since `awaiting` a deferred destroys its result. | |||
# (Also, if it's a Failure, GCing the deferred would log a critical error | |||
# about unhandled Failures) | |||
def got_result(r): | |||
def got_result(r: Union[TV, Failure]) -> None: | |||
self._result = r | |||
self._deferred.addBoth(got_result) | |||
@@ -31,6 +31,7 @@ from prometheus_client import Gauge | |||
from twisted.internet import defer | |||
from twisted.python import failure | |||
from twisted.python.failure import Failure | |||
from synapse.util.async_helpers import ObservableDeferred | |||
from synapse.util.caches.lrucache import LruCache | |||
@@ -112,7 +113,7 @@ class DeferredCache(Generic[KT, VT]): | |||
self.thread: Optional[threading.Thread] = None | |||
@property | |||
def max_entries(self): | |||
def max_entries(self) -> int: | |||
return self.cache.max_size | |||
def check_thread(self) -> None: | |||
@@ -258,7 +259,7 @@ class DeferredCache(Generic[KT, VT]): | |||
return False | |||
def cb(result) -> None: | |||
def cb(result: VT) -> None: | |||
if compare_and_pop(): | |||
self.cache.set(key, result, entry.callbacks) | |||
else: | |||
@@ -270,7 +271,7 @@ class DeferredCache(Generic[KT, VT]): | |||
# not have been. Either way, let's double-check now. | |||
entry.invalidate() | |||
def eb(_fail) -> None: | |||
def eb(_fail: Failure) -> None: | |||
compare_and_pop() | |||
entry.invalidate() | |||
@@ -284,11 +285,11 @@ class DeferredCache(Generic[KT, VT]): | |||
def prefill( | |||
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None | |||
): | |||
) -> None: | |||
callbacks = [callback] if callback else [] | |||
self.cache.set(key, value, callbacks=callbacks) | |||
def invalidate(self, key): | |||
def invalidate(self, key) -> None: | |||
"""Delete a key, or tree of entries | |||
If the cache is backed by a regular dict, then "key" must be of | |||
@@ -52,7 +52,7 @@ logger = logging.getLogger(__name__) | |||
try: | |||
from pympler.asizeof import Asizer | |||
def _get_size_of(val: Any, *, recurse=True) -> int: | |||
def _get_size_of(val: Any, *, recurse: bool = True) -> int: | |||
"""Get an estimate of the size in bytes of the object. | |||
Args: | |||
@@ -71,7 +71,7 @@ try: | |||
except ImportError: | |||
def _get_size_of(val: Any, *, recurse=True) -> int: | |||
def _get_size_of(val: Any, *, recurse: bool = True) -> int: | |||
return 0 | |||
@@ -85,15 +85,6 @@ VT = TypeVar("VT") | |||
# a general type var, distinct from either KT or VT | |||
T = TypeVar("T") | |||
def enumerate_leaves(node, depth): | |||
if depth == 0: | |||
yield node | |||
else: | |||
for n in node.values(): | |||
yield from enumerate_leaves(n, depth - 1) | |||
P = TypeVar("P") | |||
@@ -102,7 +93,7 @@ class _TimedListNode(ListNode[P]): | |||
__slots__ = ["last_access_ts_secs"] | |||
def update_last_access(self, clock: Clock): | |||
def update_last_access(self, clock: Clock) -> None: | |||
self.last_access_ts_secs = int(clock.time()) | |||
@@ -115,7 +106,7 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node() | |||
@wrap_as_background_process("LruCache._expire_old_entries") | |||
async def _expire_old_entries(clock: Clock, expiry_seconds: int): | |||
async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None: | |||
"""Walks the global cache list to find cache entries that haven't been | |||
accessed in the given number of seconds. | |||
""" | |||
@@ -163,7 +154,7 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int): | |||
logger.info("Dropped %d items from caches", i) | |||
def setup_expire_lru_cache_entries(hs: "HomeServer"): | |||
def setup_expire_lru_cache_entries(hs: "HomeServer") -> None: | |||
"""Start a background job that expires all cache entries if they have not | |||
been accessed for the given number of seconds. | |||
""" | |||
@@ -183,7 +174,7 @@ def setup_expire_lru_cache_entries(hs: "HomeServer"): | |||
) | |||
class _Node: | |||
class _Node(Generic[KT, VT]): | |||
__slots__ = [ | |||
"_list_node", | |||
"_global_list_node", | |||
@@ -197,8 +188,8 @@ class _Node: | |||
def __init__( | |||
self, | |||
root: "ListNode[_Node]", | |||
key, | |||
value, | |||
key: KT, | |||
value: VT, | |||
cache: "weakref.ReferenceType[LruCache]", | |||
clock: Clock, | |||
callbacks: Collection[Callable[[], None]] = (), | |||
@@ -409,7 +400,7 @@ class LruCache(Generic[KT, VT]): | |||
def synchronized(f: FT) -> FT: | |||
@wraps(f) | |||
def inner(*args, **kwargs): | |||
def inner(*args: Any, **kwargs: Any) -> Any: | |||
with lock: | |||
return f(*args, **kwargs) | |||
@@ -418,17 +409,19 @@ class LruCache(Generic[KT, VT]): | |||
cached_cache_len = [0] | |||
if size_callback is not None: | |||
def cache_len(): | |||
def cache_len() -> int: | |||
return cached_cache_len[0] | |||
else: | |||
def cache_len(): | |||
def cache_len() -> int: | |||
return len(cache) | |||
self.len = synchronized(cache_len) | |||
def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()): | |||
def add_node( | |||
key: KT, value: VT, callbacks: Collection[Callable[[], None]] = () | |||
) -> None: | |||
node = _Node( | |||
list_root, | |||
key, | |||
@@ -446,7 +439,7 @@ class LruCache(Generic[KT, VT]): | |||
if caches.TRACK_MEMORY_USAGE and metrics: | |||
metrics.inc_memory_usage(node.memory) | |||
def move_node_to_front(node: _Node): | |||
def move_node_to_front(node: _Node) -> None: | |||
node.move_to_front(real_clock, list_root) | |||
def delete_node(node: _Node) -> int: | |||
@@ -488,7 +481,7 @@ class LruCache(Generic[KT, VT]): | |||
default: Optional[T] = None, | |||
callbacks: Collection[Callable[[], None]] = (), | |||
update_metrics: bool = True, | |||
): | |||
) -> Union[None, T, VT]: | |||
node = cache.get(key, None) | |||
if node is not None: | |||
move_node_to_front(node) | |||
@@ -502,7 +495,9 @@ class LruCache(Generic[KT, VT]): | |||
return default | |||
@synchronized | |||
def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()): | |||
def cache_set( | |||
key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = () | |||
) -> None: | |||
node = cache.get(key, None) | |||
if node is not None: | |||
# We sometimes store large objects, e.g. dicts, which cause | |||
@@ -547,7 +542,7 @@ class LruCache(Generic[KT, VT]): | |||
... | |||
@synchronized | |||
def cache_pop(key: KT, default: Optional[T] = None): | |||
def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]: | |||
node = cache.get(key, None) | |||
if node: | |||
delete_node(node) | |||
@@ -612,25 +607,25 @@ class LruCache(Generic[KT, VT]): | |||
self.contains = cache_contains | |||
self.clear = cache_clear | |||
def __getitem__(self, key): | |||
def __getitem__(self, key: KT) -> VT: | |||
result = self.get(key, self.sentinel) | |||
if result is self.sentinel: | |||
raise KeyError() | |||
else: | |||
return result | |||
return cast(VT, result) | |||
def __setitem__(self, key, value): | |||
def __setitem__(self, key: KT, value: VT) -> None: | |||
self.set(key, value) | |||
def __delitem__(self, key, value): | |||
def __delitem__(self, key: KT, value: VT) -> None: | |||
result = self.pop(key, self.sentinel) | |||
if result is self.sentinel: | |||
raise KeyError() | |||
def __len__(self): | |||
def __len__(self) -> int: | |||
return self.len() | |||
def __contains__(self, key): | |||
def __contains__(self, key: KT) -> bool: | |||
return self.contains(key) | |||
def set_cache_factor(self, factor: float) -> bool: | |||
@@ -104,8 +104,8 @@ class ResponseCache(Generic[KV]): | |||
return None | |||
def _set( | |||
self, context: ResponseCacheContext[KV], deferred: defer.Deferred | |||
) -> defer.Deferred: | |||
self, context: ResponseCacheContext[KV], deferred: "defer.Deferred[RV]" | |||
) -> "defer.Deferred[RV]": | |||
"""Set the entry for the given key to the given deferred. | |||
*deferred* should run its callbacks in the sentinel logcontext (ie, | |||
@@ -126,7 +126,7 @@ class ResponseCache(Generic[KV]): | |||
key = context.cache_key | |||
self.pending_result_cache[key] = result | |||
def on_complete(r): | |||
def on_complete(r: RV) -> RV: | |||
# if this cache has a non-zero timeout, and the callback has not cleared | |||
# the should_cache bit, we leave it in the cache for now and schedule | |||
# its removal later. | |||
@@ -40,10 +40,10 @@ class StreamChangeCache: | |||
self, | |||
name: str, | |||
current_stream_pos: int, | |||
max_size=10000, | |||
max_size: int = 10000, | |||
prefilled_cache: Optional[Mapping[EntityType, int]] = None, | |||
): | |||
self._original_max_size = max_size | |||
) -> None: | |||
self._original_max_size: int = max_size | |||
self._max_size = math.floor(max_size) | |||
self._entity_to_key: Dict[EntityType, int] = {} | |||
@@ -159,12 +159,12 @@ class TTLCache(Generic[KT, VT]): | |||
del self._expiry_list[0] | |||
@attr.s(frozen=True, slots=True) | |||
class _CacheEntry: | |||
@attr.s(frozen=True, slots=True, auto_attribs=True) | |||
class _CacheEntry: # Should be Generic[KT, VT]. See python-attrs/attrs#313 | |||
"""TTLCache entry""" | |||
# expiry_time is the first attribute, so that entries are sorted by expiry. | |||
expiry_time = attr.ib(type=float) | |||
ttl = attr.ib(type=float) | |||
key = attr.ib() | |||
value = attr.ib() | |||
expiry_time: float | |||
ttl: float | |||
key: Any # should be KT | |||
value: Any # should be VT |
@@ -19,6 +19,8 @@ import logging | |||
import os | |||
import signal | |||
import sys | |||
from types import FrameType, TracebackType | |||
from typing import NoReturn, Type | |||
def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None: | |||
@@ -97,7 +99,9 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - | |||
# (we don't normally expect reactor.run to raise any exceptions, but this will | |||
# also catch any other uncaught exceptions before we get that far.) | |||
def excepthook(type_, value, traceback): | |||
def excepthook( | |||
type_: Type[BaseException], value: BaseException, traceback: TracebackType | |||
) -> None: | |||
logger.critical("Unhanded exception", exc_info=(type_, value, traceback)) | |||
sys.excepthook = excepthook | |||
@@ -119,7 +123,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - | |||
sys.exit(1) | |||
# write a log line on SIGTERM. | |||
def sigterm(signum, frame): | |||
def sigterm(signum: signal.Signals, frame: FrameType) -> NoReturn: | |||
logger.warning("Caught signal %s. Stopping daemon." % signum) | |||
sys.exit(0) | |||
@@ -14,9 +14,11 @@ | |||
import logging | |||
from functools import wraps | |||
from typing import Any, Callable, Optional, TypeVar, cast | |||
from types import TracebackType | |||
from typing import Any, Callable, Optional, Type, TypeVar, cast | |||
from prometheus_client import Counter | |||
from typing_extensions import Protocol | |||
from synapse.logging.context import ( | |||
ContextResourceUsage, | |||
@@ -24,6 +26,7 @@ from synapse.logging.context import ( | |||
current_context, | |||
) | |||
from synapse.metrics import InFlightGauge | |||
from synapse.util import Clock | |||
logger = logging.getLogger(__name__) | |||
@@ -64,6 +67,10 @@ in_flight = InFlightGauge( | |||
T = TypeVar("T", bound=Callable[..., Any]) | |||
class HasClock(Protocol): | |||
clock: Clock | |||
def measure_func(name: Optional[str] = None) -> Callable[[T], T]: | |||
""" | |||
Used to decorate an async function with a `Measure` context manager. | |||
@@ -86,7 +93,7 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]: | |||
block_name = func.__name__ if name is None else name | |||
@wraps(func) | |||
async def measured_func(self, *args, **kwargs): | |||
async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> Any: | |||
with Measure(self.clock, block_name): | |||
r = await func(self, *args, **kwargs) | |||
return r | |||
@@ -104,10 +111,10 @@ class Measure: | |||
"start", | |||
] | |||
def __init__(self, clock, name: str): | |||
def __init__(self, clock: Clock, name: str) -> None: | |||
""" | |||
Args: | |||
clock: A n object with a "time()" method, which returns the current | |||
clock: An object with a "time()" method, which returns the current | |||
time in seconds. | |||
name: The name of the metric to report. | |||
""" | |||
@@ -124,7 +131,7 @@ class Measure: | |||
assert isinstance(curr_context, LoggingContext) | |||
parent_context = curr_context | |||
self._logging_context = LoggingContext(str(curr_context), parent_context) | |||
self.start: Optional[int] = None | |||
self.start: Optional[float] = None | |||
def __enter__(self) -> "Measure": | |||
if self.start is not None: | |||
@@ -138,7 +145,12 @@ class Measure: | |||
return self | |||
def __exit__(self, exc_type, exc_val, exc_tb): | |||
def __exit__( | |||
self, | |||
exc_type: Optional[Type[BaseException]], | |||
exc_val: Optional[BaseException], | |||
exc_tb: Optional[TracebackType], | |||
) -> None: | |||
if self.start is None: | |||
raise RuntimeError("Measure() block exited without being entered") | |||
@@ -168,8 +180,9 @@ class Measure: | |||
""" | |||
return self._logging_context.get_resource_usage() | |||
def _update_in_flight(self, metrics): | |||
def _update_in_flight(self, metrics) -> None: | |||
"""Gets called when processing in flight metrics""" | |||
assert self.start is not None | |||
duration = self.clock.time() - self.start | |||
metrics.real_time_max = max(metrics.real_time_max, duration) | |||
@@ -14,7 +14,7 @@ | |||
import functools | |||
import sys | |||
from typing import Any, Callable, List | |||
from typing import Any, Callable, Generator, List, TypeVar | |||
from twisted.internet import defer | |||
from twisted.internet.defer import Deferred | |||
@@ -24,6 +24,9 @@ from twisted.python.failure import Failure | |||
_already_patched = False | |||
T = TypeVar("T") | |||
def do_patch() -> None: | |||
""" | |||
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit | |||
@@ -37,15 +40,19 @@ def do_patch() -> None: | |||
if _already_patched: | |||
return | |||
def new_inline_callbacks(f): | |||
def new_inline_callbacks( | |||
f: Callable[..., Generator["Deferred[object]", object, T]] | |||
) -> Callable[..., "Deferred[T]"]: | |||
@functools.wraps(f) | |||
def wrapped(*args, **kwargs): | |||
def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]": | |||
start_context = current_context() | |||
changes: List[str] = [] | |||
orig = orig_inline_callbacks(_check_yield_points(f, changes)) | |||
orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks( | |||
_check_yield_points(f, changes) | |||
) | |||
try: | |||
res = orig(*args, **kwargs) | |||
res: "Deferred[T]" = orig(*args, **kwargs) | |||
except Exception: | |||
if current_context() != start_context: | |||
for err in changes: | |||
@@ -84,7 +91,7 @@ def do_patch() -> None: | |||
print(err, file=sys.stderr) | |||
raise Exception(err) | |||
def check_ctx(r): | |||
def check_ctx(r: T) -> T: | |||
if current_context() != start_context: | |||
for err in changes: | |||
print(err, file=sys.stderr) | |||
@@ -107,7 +114,10 @@ def do_patch() -> None: | |||
_already_patched = True | |||
def _check_yield_points(f: Callable, changes: List[str]) -> Callable: | |||
def _check_yield_points( | |||
f: Callable[..., Generator["Deferred[object]", object, T]], | |||
changes: List[str], | |||
) -> Callable: | |||
"""Wraps a generator that is about to be passed to defer.inlineCallbacks | |||
checking that after every yield the log contexts are correct. | |||
@@ -127,7 +137,9 @@ def _check_yield_points(f: Callable, changes: List[str]) -> Callable: | |||
from synapse.logging.context import current_context | |||
@functools.wraps(f) | |||
def check_yield_points_inner(*args, **kwargs): | |||
def check_yield_points_inner( | |||
*args: Any, **kwargs: Any | |||
) -> Generator["Deferred[object]", object, T]: | |||
gen = f(*args, **kwargs) | |||
last_yield_line_no = gen.gi_frame.f_lineno | |||
@@ -15,14 +15,18 @@ | |||
import logging | |||
import os | |||
import subprocess | |||
from types import ModuleType | |||
from typing import Dict | |||
logger = logging.getLogger(__name__) | |||
version_cache: Dict[ModuleType, str] = {} | |||
def get_version_string(module) -> str: | |||
def get_version_string(module: ModuleType) -> str: | |||
"""Given a module calculate a git-aware version string for it. | |||
If called on a module not in a git checkout will return `__verison__`. | |||
If called on a module not in a git checkout will return `__version__`. | |||
Args: | |||
module (module) | |||
@@ -31,11 +35,13 @@ def get_version_string(module) -> str: | |||
str | |||
""" | |||
cached_version = getattr(module, "_synapse_version_string_cache", None) | |||
if cached_version: | |||
cached_version = version_cache.get(module) | |||
if cached_version is not None: | |||
return cached_version | |||
version_string = module.__version__ | |||
# We want this to fail loudly with an AttributeError. Type-ignore this so | |||
# mypy only considers the happy path. | |||
version_string = module.__version__ # type: ignore[attr-defined] | |||
try: | |||
null = open(os.devnull, "w") | |||
@@ -97,10 +103,15 @@ def get_version_string(module) -> str: | |||
s for s in (git_branch, git_tag, git_commit, git_dirty) if s | |||
) | |||
version_string = "%s (%s)" % (module.__version__, git_version) | |||
version_string = "%s (%s)" % ( | |||
# If the __version__ attribute doesn't exist, we'll have failed | |||
# loudly above. | |||
module.__version__, # type: ignore[attr-defined] | |||
git_version, | |||
) | |||
except Exception as e: | |||
logger.info("Failed to check for git repository: %s", e) | |||
module._synapse_version_string_cache = version_string | |||
version_cache[module] = version_string | |||
return version_string |