@@ -0,0 +1 @@ | |||
Add missing type hints to `synapse.logging.context`. |
@@ -167,6 +167,9 @@ disallow_untyped_defs = True | |||
[mypy-synapse.http.server] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.logging.context] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.metrics.*] | |||
disallow_untyped_defs = True | |||
@@ -17,11 +17,12 @@ | |||
from typing import Any, List, Optional, Type, Union | |||
from twisted.internet import protocol | |||
from twisted.internet.defer import Deferred | |||
class RedisProtocol(protocol.Protocol): | |||
def publish(self, channel: str, message: bytes): ... | |||
async def ping(self) -> None: ... | |||
async def set( | |||
def ping(self) -> "Deferred[None]": ... | |||
def set( | |||
self, | |||
key: str, | |||
value: Any, | |||
@@ -29,8 +30,8 @@ class RedisProtocol(protocol.Protocol): | |||
pexpire: Optional[int] = None, | |||
only_if_not_exists: bool = False, | |||
only_if_exists: bool = False, | |||
) -> None: ... | |||
async def get(self, key: str) -> Any: ... | |||
) -> "Deferred[None]": ... | |||
def get(self, key: str) -> "Deferred[Any]": ... | |||
class SubscriberProtocol(RedisProtocol): | |||
def __init__(self, *args, **kwargs): ... | |||
@@ -30,7 +30,6 @@ from typing import ( | |||
from prometheus_client import Counter, Gauge, Histogram | |||
from twisted.internet import defer | |||
from twisted.internet.abstract import isIPAddress | |||
from twisted.python import failure | |||
@@ -67,7 +66,7 @@ from synapse.replication.http.federation import ( | |||
from synapse.storage.databases.main.lock import Lock | |||
from synapse.types import JsonDict, get_domain_from_id | |||
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError | |||
from synapse.util.async_helpers import Linearizer, concurrently_execute | |||
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results | |||
from synapse.util.caches.response_cache import ResponseCache | |||
from synapse.util.stringutils import parse_server_name | |||
@@ -360,13 +359,13 @@ class FederationServer(FederationBase): | |||
# want to block things like to device messages from reaching clients | |||
# behind the potentially expensive handling of PDUs. | |||
pdu_results, _ = await make_deferred_yieldable( | |||
defer.gatherResults( | |||
[ | |||
gather_results( | |||
( | |||
run_in_background( | |||
self._handle_pdus_in_txn, origin, transaction, request_time | |||
), | |||
run_in_background(self._handle_edus_in_txn, origin, transaction), | |||
], | |||
), | |||
consumeErrors=True, | |||
).addErrback(unwrapFirstError) | |||
) | |||
@@ -360,31 +360,34 @@ class FederationHandler: | |||
logger.debug("calling resolve_state_groups in _maybe_backfill") | |||
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) | |||
states = await make_deferred_yieldable( | |||
states_list = await make_deferred_yieldable( | |||
defer.gatherResults( | |||
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True | |||
) | |||
) | |||
# dict[str, dict[tuple, str]], a map from event_id to state map of | |||
# event_ids. | |||
states = dict(zip(event_ids, [s.state for s in states])) | |||
# A map from event_id to state map of event_ids. | |||
state_ids: Dict[str, StateMap[str]] = dict( | |||
zip(event_ids, [s.state for s in states_list]) | |||
) | |||
state_map = await self.store.get_events( | |||
[e_id for ids in states.values() for e_id in ids.values()], | |||
[e_id for ids in state_ids.values() for e_id in ids.values()], | |||
get_prev_content=False, | |||
) | |||
states = { | |||
# A map from event_id to state map of events. | |||
state_events: Dict[str, StateMap[EventBase]] = { | |||
key: { | |||
k: state_map[e_id] | |||
for k, e_id in state_dict.items() | |||
if e_id in state_map | |||
} | |||
for key, state_dict in states.items() | |||
for key, state_dict in state_ids.items() | |||
} | |||
for e_id in event_ids: | |||
likely_extremeties_domains = get_domains_from_state(states[e_id]) | |||
likely_extremeties_domains = get_domains_from_state(state_events[e_id]) | |||
success = await try_backfill( | |||
[ | |||
@@ -13,21 +13,27 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING, List, Optional, Tuple | |||
from twisted.internet import defer | |||
from typing import TYPE_CHECKING, List, Optional, Tuple, cast | |||
from synapse.api.constants import EduTypes, EventTypes, Membership | |||
from synapse.api.errors import SynapseError | |||
from synapse.events import EventBase | |||
from synapse.events.validator import EventValidator | |||
from synapse.handlers.presence import format_user_presence_state | |||
from synapse.handlers.receipts import ReceiptEventSource | |||
from synapse.logging.context import make_deferred_yieldable, run_in_background | |||
from synapse.storage.roommember import RoomsForUser | |||
from synapse.streams.config import PaginationConfig | |||
from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID | |||
from synapse.types import ( | |||
JsonDict, | |||
Requester, | |||
RoomStreamToken, | |||
StateMap, | |||
StreamToken, | |||
UserID, | |||
) | |||
from synapse.util import unwrapFirstError | |||
from synapse.util.async_helpers import concurrently_execute | |||
from synapse.util.async_helpers import concurrently_execute, gather_results | |||
from synapse.util.caches.response_cache import ResponseCache | |||
from synapse.visibility import filter_events_for_client | |||
@@ -190,14 +196,13 @@ class InitialSyncHandler: | |||
) | |||
deferred_room_state = run_in_background( | |||
self.state_store.get_state_for_events, [event.event_id] | |||
) | |||
deferred_room_state.addCallback( | |||
lambda states: states[event.event_id] | |||
).addCallback( | |||
lambda states: cast(StateMap[EventBase], states[event.event_id]) | |||
) | |||
(messages, token), current_state = await make_deferred_yieldable( | |||
defer.gatherResults( | |||
[ | |||
gather_results( | |||
( | |||
run_in_background( | |||
self.store.get_recent_events_for_room, | |||
event.room_id, | |||
@@ -205,7 +210,7 @@ class InitialSyncHandler: | |||
end_token=room_end_token, | |||
), | |||
deferred_room_state, | |||
] | |||
) | |||
) | |||
).addErrback(unwrapFirstError) | |||
@@ -454,8 +459,8 @@ class InitialSyncHandler: | |||
return receipts | |||
presence, receipts, (messages, token) = await make_deferred_yieldable( | |||
defer.gatherResults( | |||
[ | |||
gather_results( | |||
( | |||
run_in_background(get_presence), | |||
run_in_background(get_receipts), | |||
run_in_background( | |||
@@ -464,7 +469,7 @@ class InitialSyncHandler: | |||
limit=limit, | |||
end_token=now_token.room_key, | |||
), | |||
], | |||
), | |||
consumeErrors=True, | |||
).addErrback(unwrapFirstError) | |||
) | |||
@@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple | |||
from canonicaljson import encode_canonical_json | |||
from twisted.internet import defer | |||
from twisted.internet.interfaces import IDelayedCall | |||
from synapse import event_auth | |||
@@ -57,7 +56,7 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour | |||
from synapse.storage.state import StateFilter | |||
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester | |||
from synapse.util import json_decoder, json_encoder, log_failure | |||
from synapse.util.async_helpers import Linearizer, unwrapFirstError | |||
from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError | |||
from synapse.util.caches.expiringcache import ExpiringCache | |||
from synapse.util.metrics import measure_func | |||
from synapse.visibility import filter_events_for_client | |||
@@ -1168,9 +1167,9 @@ class EventCreationHandler: | |||
# We now persist the event (and update the cache in parallel, since we | |||
# don't want to block on it). | |||
result = await make_deferred_yieldable( | |||
defer.gatherResults( | |||
[ | |||
result, _ = await make_deferred_yieldable( | |||
gather_results( | |||
( | |||
run_in_background( | |||
self._persist_event, | |||
requester=requester, | |||
@@ -1182,12 +1181,12 @@ class EventCreationHandler: | |||
run_in_background( | |||
self.cache_joined_hosts_for_event, event, context | |||
).addErrback(log_failure, "cache_joined_hosts_for_event failed"), | |||
], | |||
), | |||
consumeErrors=True, | |||
) | |||
).addErrback(unwrapFirstError) | |||
return result[0] | |||
return result | |||
async def _persist_event( | |||
self, | |||
@@ -25,6 +25,7 @@ from zope.interface import implementer | |||
from twisted.internet import defer | |||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS | |||
from twisted.internet.interfaces import ( | |||
IProtocol, | |||
IProtocolFactory, | |||
IReactorCore, | |||
IStreamClientEndpoint, | |||
@@ -309,12 +310,14 @@ class MatrixHostnameEndpoint: | |||
self._srv_resolver = srv_resolver | |||
def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred: | |||
def connect( | |||
self, protocol_factory: IProtocolFactory | |||
) -> "defer.Deferred[IProtocol]": | |||
"""Implements IStreamClientEndpoint interface""" | |||
return run_in_background(self._do_connect, protocol_factory) | |||
async def _do_connect(self, protocol_factory: IProtocolFactory) -> None: | |||
async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol: | |||
first_exception = None | |||
server_list = await self._resolve_server() | |||
@@ -22,20 +22,33 @@ them. | |||
See doc/log_contexts.rst for details on how this works. | |||
""" | |||
import inspect | |||
import logging | |||
import threading | |||
import typing | |||
import warnings | |||
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union | |||
from types import TracebackType | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Any, | |||
Awaitable, | |||
Callable, | |||
Optional, | |||
Tuple, | |||
Type, | |||
TypeVar, | |||
Union, | |||
overload, | |||
) | |||
import attr | |||
from typing_extensions import Literal | |||
from twisted.internet import defer, threads | |||
from twisted.python.threadpool import ThreadPool | |||
if TYPE_CHECKING: | |||
from synapse.logging.scopecontextmanager import _LogContextScope | |||
from synapse.types import ISynapseReactor | |||
logger = logging.getLogger(__name__) | |||
@@ -66,7 +79,7 @@ except Exception: | |||
# a hook which can be set during testing to assert that we aren't abusing logcontexts. | |||
def logcontext_error(msg: str): | |||
def logcontext_error(msg: str) -> None: | |||
logger.warning(msg) | |||
@@ -223,22 +236,19 @@ class _Sentinel: | |||
def __str__(self) -> str: | |||
return "sentinel" | |||
def copy_to(self, record): | |||
pass | |||
def start(self, rusage: "Optional[resource.struct_rusage]"): | |||
def start(self, rusage: "Optional[resource.struct_rusage]") -> None: | |||
pass | |||
def stop(self, rusage: "Optional[resource.struct_rusage]"): | |||
def stop(self, rusage: "Optional[resource.struct_rusage]") -> None: | |||
pass | |||
def add_database_transaction(self, duration_sec): | |||
def add_database_transaction(self, duration_sec: float) -> None: | |||
pass | |||
def add_database_scheduled(self, sched_sec): | |||
def add_database_scheduled(self, sched_sec: float) -> None: | |||
pass | |||
def record_event_fetch(self, event_count): | |||
def record_event_fetch(self, event_count: int) -> None: | |||
pass | |||
def __bool__(self) -> Literal[False]: | |||
@@ -379,7 +389,12 @@ class LoggingContext: | |||
) | |||
return self | |||
def __exit__(self, type, value, traceback) -> None: | |||
def __exit__( | |||
self, | |||
type: Optional[Type[BaseException]], | |||
value: Optional[BaseException], | |||
traceback: Optional[TracebackType], | |||
) -> None: | |||
"""Restore the logging context in thread local storage to the state it | |||
was before this context was entered. | |||
Returns: | |||
@@ -399,17 +414,6 @@ class LoggingContext: | |||
# recorded against the correct metrics. | |||
self.finished = True | |||
def copy_to(self, record) -> None: | |||
"""Copy logging fields from this context to a log record or | |||
another LoggingContext | |||
""" | |||
# we track the current request | |||
record.request = self.request | |||
# we also track the current scope: | |||
record.scope = self.scope | |||
def start(self, rusage: "Optional[resource.struct_rusage]") -> None: | |||
""" | |||
Record that this logcontext is currently running. | |||
@@ -626,7 +630,12 @@ class PreserveLoggingContext: | |||
def __enter__(self) -> None: | |||
self._old_context = set_current_context(self._new_context) | |||
def __exit__(self, type, value, traceback) -> None: | |||
def __exit__( | |||
self, | |||
type: Optional[Type[BaseException]], | |||
value: Optional[BaseException], | |||
traceback: Optional[TracebackType], | |||
) -> None: | |||
context = set_current_context(self._old_context) | |||
if context != self._new_context: | |||
@@ -711,16 +720,61 @@ def nested_logging_context(suffix: str) -> LoggingContext: | |||
) | |||
def preserve_fn(f): | |||
R = TypeVar("R") | |||
@overload | |||
def preserve_fn( # type: ignore[misc] | |||
f: Callable[..., Awaitable[R]], | |||
) -> Callable[..., "defer.Deferred[R]"]: | |||
# The `type: ignore[misc]` above suppresses | |||
# "Overloaded function signatures 1 and 2 overlap with incompatible return types" | |||
... | |||
@overload | |||
def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]: | |||
... | |||
def preserve_fn( | |||
f: Union[ | |||
Callable[..., R], | |||
Callable[..., Awaitable[R]], | |||
] | |||
) -> Callable[..., "defer.Deferred[R]"]: | |||
"""Function decorator which wraps the function with run_in_background""" | |||
def g(*args, **kwargs): | |||
def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]": | |||
return run_in_background(f, *args, **kwargs) | |||
return g | |||
def run_in_background(f, *args, **kwargs) -> defer.Deferred: | |||
@overload | |||
def run_in_background( # type: ignore[misc] | |||
f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any | |||
) -> "defer.Deferred[R]": | |||
# The `type: ignore[misc]` above suppresses | |||
# "Overloaded function signatures 1 and 2 overlap with incompatible return types" | |||
... | |||
@overload | |||
def run_in_background( | |||
f: Callable[..., R], *args: Any, **kwargs: Any | |||
) -> "defer.Deferred[R]": | |||
... | |||
def run_in_background( | |||
f: Union[ | |||
Callable[..., R], | |||
Callable[..., Awaitable[R]], | |||
], | |||
*args: Any, | |||
**kwargs: Any, | |||
) -> "defer.Deferred[R]": | |||
"""Calls a function, ensuring that the current context is restored after | |||
return from the function, and that the sentinel context is set once the | |||
deferred returned by the function completes. | |||
@@ -751,6 +805,10 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred: | |||
# At this point we should have a Deferred, if not then f was a synchronous | |||
# function, wrap it in a Deferred for consistency. | |||
if not isinstance(res, defer.Deferred): | |||
# `res` is not a `Deferred` and not a `Coroutine`. | |||
# There are no other types of `Awaitable`s we expect to encounter in Synapse. | |||
assert not isinstance(res, Awaitable) | |||
return defer.succeed(res) | |||
if res.called and not res.paused: | |||
@@ -778,13 +836,14 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred: | |||
return res | |||
def make_deferred_yieldable(deferred): | |||
"""Given a deferred (or coroutine), make it follow the Synapse logcontext | |||
rules: | |||
T = TypeVar("T") | |||
If the deferred has completed (or is not actually a Deferred), essentially | |||
does nothing (just returns another completed deferred with the | |||
result/failure). | |||
def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": | |||
"""Given a deferred, make it follow the Synapse logcontext rules: | |||
If the deferred has completed, essentially does nothing (just returns another | |||
completed deferred with the result/failure). | |||
If the deferred has not yet completed, resets the logcontext before | |||
returning a deferred. Then, when the deferred completes, restores the | |||
@@ -792,16 +851,6 @@ def make_deferred_yieldable(deferred): | |||
(This is more-or-less the opposite operation to run_in_background.) | |||
""" | |||
if inspect.isawaitable(deferred): | |||
# If we're given a coroutine we convert it to a deferred so that we | |||
# run it and find out if it immediately finishes, it it does then we | |||
# don't need to fiddle with log contexts at all and can return | |||
# immediately. | |||
deferred = defer.ensureDeferred(deferred) | |||
if not isinstance(deferred, defer.Deferred): | |||
return deferred | |||
if deferred.called and not deferred.paused: | |||
# it looks like this deferred is ready to run any callbacks we give it | |||
# immediately. We may as well optimise out the logcontext faffery. | |||
@@ -823,7 +872,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: | |||
return result | |||
def defer_to_thread(reactor, f, *args, **kwargs): | |||
def defer_to_thread( | |||
reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any | |||
) -> "defer.Deferred[R]": | |||
""" | |||
Calls the function `f` using a thread from the reactor's default threadpool and | |||
returns the result as a Deferred. | |||
@@ -855,7 +906,13 @@ def defer_to_thread(reactor, f, *args, **kwargs): | |||
return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) | |||
def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): | |||
def defer_to_threadpool( | |||
reactor: "ISynapseReactor", | |||
threadpool: ThreadPool, | |||
f: Callable[..., R], | |||
*args: Any, | |||
**kwargs: Any, | |||
) -> "defer.Deferred[R]": | |||
""" | |||
A wrapper for twisted.internet.threads.deferToThreadpool, which handles | |||
logcontexts correctly. | |||
@@ -897,7 +954,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): | |||
assert isinstance(curr_context, LoggingContext) | |||
parent_context = curr_context | |||
def g(): | |||
def g() -> R: | |||
with LoggingContext(str(curr_context), parent_context=parent_context): | |||
return f(*args, **kwargs) | |||
@@ -30,9 +30,11 @@ from typing import ( | |||
Iterator, | |||
Optional, | |||
Set, | |||
Tuple, | |||
TypeVar, | |||
Union, | |||
cast, | |||
overload, | |||
) | |||
import attr | |||
@@ -234,6 +236,59 @@ def yieldable_gather_results( | |||
).addErrback(unwrapFirstError) | |||
T1 = TypeVar("T1") | |||
T2 = TypeVar("T2") | |||
T3 = TypeVar("T3") | |||
@overload | |||
def gather_results( | |||
deferredList: Tuple[()], consumeErrors: bool = ... | |||
) -> "defer.Deferred[Tuple[()]]": | |||
... | |||
@overload | |||
def gather_results( | |||
deferredList: Tuple["defer.Deferred[T1]"], | |||
consumeErrors: bool = ..., | |||
) -> "defer.Deferred[Tuple[T1]]": | |||
... | |||
@overload | |||
def gather_results( | |||
deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"], | |||
consumeErrors: bool = ..., | |||
) -> "defer.Deferred[Tuple[T1, T2]]": | |||
... | |||
@overload | |||
def gather_results( | |||
deferredList: Tuple[ | |||
"defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]" | |||
], | |||
consumeErrors: bool = ..., | |||
) -> "defer.Deferred[Tuple[T1, T2, T3]]": | |||
... | |||
def gather_results( # type: ignore[misc] | |||
deferredList: Tuple["defer.Deferred[T1]", ...], | |||
consumeErrors: bool = False, | |||
) -> "defer.Deferred[Tuple[T1, ...]]": | |||
"""Combines a tuple of `Deferred`s into a single `Deferred`. | |||
Wraps `defer.gatherResults` to provide type annotations that support heterogenous | |||
lists of `Deferred`s. | |||
""" | |||
# The `type: ignore[misc]` above suppresses | |||
# "Overloaded function implementation cannot produce return type of signature 1/2/3" | |||
deferred = defer.gatherResults(deferredList, consumeErrors=consumeErrors) | |||
return deferred.addCallback(tuple) | |||
@attr.s(slots=True) | |||
class _LinearizerEntry: | |||
# The number of things executing. | |||
@@ -352,7 +407,7 @@ class Linearizer: | |||
logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) | |||
new_defer = make_deferred_yieldable(defer.Deferred()) | |||
new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred()) | |||
entry.deferreds[new_defer] = 1 | |||
def cb(_r: None) -> "defer.Deferred[None]": | |||
@@ -76,6 +76,7 @@ class CachedCall(Generic[TV]): | |||
# Fire off the callable now if this is our first time | |||
if not self._deferred: | |||
assert self._callable is not None | |||
self._deferred = run_in_background(self._callable) | |||
# we will never need the callable again, so make sure it can be GCed | |||
@@ -142,6 +142,7 @@ class BackgroundFileConsumer: | |||
def wait(self) -> "Deferred[None]": | |||
"""Returns a deferred that resolves when finished writing to file""" | |||
assert self._finished_deferred is not None | |||
return make_deferred_yieldable(self._finished_deferred) | |||
def _resume_paused_producer(self) -> None: | |||
@@ -152,46 +152,11 @@ class LoggingContextTestCase(unittest.TestCase): | |||
# now it should be restored | |||
self._check_test_key("one") | |||
@defer.inlineCallbacks | |||
def test_make_deferred_yieldable_on_non_deferred(self): | |||
"""Check that make_deferred_yieldable does the right thing when its | |||
argument isn't actually a deferred""" | |||
with LoggingContext("one"): | |||
d1 = make_deferred_yieldable("bum") | |||
self._check_test_key("one") | |||
r = yield d1 | |||
self.assertEqual(r, "bum") | |||
self._check_test_key("one") | |||
def test_nested_logging_context(self): | |||
with LoggingContext("foo"): | |||
nested_context = nested_logging_context(suffix="bar") | |||
self.assertEqual(nested_context.name, "foo-bar") | |||
@defer.inlineCallbacks | |||
def test_make_deferred_yieldable_with_await(self): | |||
# an async function which returns an incomplete coroutine, but doesn't | |||
# follow the synapse rules. | |||
async def blocking_function(): | |||
d = defer.Deferred() | |||
reactor.callLater(0, d.callback, None) | |||
await d | |||
sentinel_context = current_context() | |||
with LoggingContext("one"): | |||
d1 = make_deferred_yieldable(blocking_function()) | |||
# make sure that the context was reset by make_deferred_yieldable | |||
self.assertIs(current_context(), sentinel_context) | |||
yield d1 | |||
# now it should be restored | |||
self._check_test_key("one") | |||
# a function which returns a deferred which has been "called", but | |||
# which had a function which returned another incomplete deferred on | |||