Explorar el Código

Add missing type hints to `synapse.logging.context` (#11556)

tags/v1.50.0rc1
Sean Quah hace 2 años
committed by GitHub
padre
commit
0147b3de20
No se encontró ninguna clave conocida en la base de datos para esta firma ID de clave GPG: 4AEE18F83AFDEB23
Se han modificado 13 ficheros con 215 adiciones y 122 borrados
  1. +1
    -0
      changelog.d/11556.misc
  2. +3
    -0
      mypy.ini
  3. +5
    -4
      stubs/txredisapi.pyi
  4. +4
    -5
      synapse/federation/federation_server.py
  5. +11
    -8
      synapse/handlers/federation.py
  6. +19
    -14
      synapse/handlers/initial_sync.py
  7. +6
    -7
      synapse/handlers/message.py
  8. +5
    -2
      synapse/http/federation/matrix_federation_agent.py
  9. +103
    -46
      synapse/logging/context.py
  10. +56
    -1
      synapse/util/async_helpers.py
  11. +1
    -0
      synapse/util/caches/cached_call.py
  12. +1
    -0
      synapse/util/file_consumer.py
  13. +0
    -35
      tests/util/test_logcontext.py

+ 1
- 0
changelog.d/11556.misc Ver fichero

@@ -0,0 +1 @@
Add missing type hints to `synapse.logging.context`.

+ 3
- 0
mypy.ini Ver fichero

@@ -167,6 +167,9 @@ disallow_untyped_defs = True
[mypy-synapse.http.server]
disallow_untyped_defs = True

[mypy-synapse.logging.context]
disallow_untyped_defs = True

[mypy-synapse.metrics.*]
disallow_untyped_defs = True



+ 5
- 4
stubs/txredisapi.pyi Ver fichero

@@ -17,11 +17,12 @@
from typing import Any, List, Optional, Type, Union

from twisted.internet import protocol
from twisted.internet.defer import Deferred

class RedisProtocol(protocol.Protocol):
def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
async def set(
def ping(self) -> "Deferred[None]": ...
def set(
self,
key: str,
value: Any,
@@ -29,8 +30,8 @@ class RedisProtocol(protocol.Protocol):
pexpire: Optional[int] = None,
only_if_not_exists: bool = False,
only_if_exists: bool = False,
) -> None: ...
async def get(self, key: str) -> Any: ...
) -> "Deferred[None]": ...
def get(self, key: str) -> "Deferred[Any]": ...

class SubscriberProtocol(RedisProtocol):
def __init__(self, *args, **kwargs): ...


+ 4
- 5
synapse/federation/federation_server.py Ver fichero

@@ -30,7 +30,6 @@ from typing import (

from prometheus_client import Counter, Gauge, Histogram

from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure

@@ -67,7 +66,7 @@ from synapse.replication.http.federation import (
from synapse.storage.databases.main.lock import Lock
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_server_name

@@ -360,13 +359,13 @@ class FederationServer(FederationBase):
# want to block things like to device messages from reaching clients
# behind the potentially expensive handling of PDUs.
pdu_results, _ = await make_deferred_yieldable(
defer.gatherResults(
[
gather_results(
(
run_in_background(
self._handle_pdus_in_txn, origin, transaction, request_time
),
run_in_background(self._handle_edus_in_txn, origin, transaction),
],
),
consumeErrors=True,
).addErrback(unwrapFirstError)
)


+ 11
- 8
synapse/handlers/federation.py Ver fichero

@@ -360,31 +360,34 @@ class FederationHandler:

logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
states = await make_deferred_yieldable(
states_list = await make_deferred_yieldable(
defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
)
)

# dict[str, dict[tuple, str]], a map from event_id to state map of
# event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
# A map from event_id to state map of event_ids.
state_ids: Dict[str, StateMap[str]] = dict(
zip(event_ids, [s.state for s in states_list])
)

state_map = await self.store.get_events(
[e_id for ids in states.values() for e_id in ids.values()],
[e_id for ids in state_ids.values() for e_id in ids.values()],
get_prev_content=False,
)
states = {

# A map from event_id to state map of events.
state_events: Dict[str, StateMap[EventBase]] = {
key: {
k: state_map[e_id]
for k, e_id in state_dict.items()
if e_id in state_map
}
for key, state_dict in states.items()
for key, state_dict in state_ids.items()
}

for e_id in event_ids:
likely_extremeties_domains = get_domains_from_state(states[e_id])
likely_extremeties_domains = get_domains_from_state(state_events[e_id])

success = await try_backfill(
[


+ 19
- 14
synapse/handlers/initial_sync.py Ver fichero

@@ -13,21 +13,27 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, List, Optional, Tuple

from twisted.internet import defer
from typing import TYPE_CHECKING, List, Optional, Tuple, cast

from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.receipts import ReceiptEventSource
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
from synapse.types import (
JsonDict,
Requester,
RoomStreamToken,
StateMap,
StreamToken,
UserID,
)
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute
from synapse.util.async_helpers import concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client

@@ -190,14 +196,13 @@ class InitialSyncHandler:
)
deferred_room_state = run_in_background(
self.state_store.get_state_for_events, [event.event_id]
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
).addCallback(
lambda states: cast(StateMap[EventBase], states[event.event_id])
)

(messages, token), current_state = await make_deferred_yieldable(
defer.gatherResults(
[
gather_results(
(
run_in_background(
self.store.get_recent_events_for_room,
event.room_id,
@@ -205,7 +210,7 @@ class InitialSyncHandler:
end_token=room_end_token,
),
deferred_room_state,
]
)
)
).addErrback(unwrapFirstError)

@@ -454,8 +459,8 @@ class InitialSyncHandler:
return receipts

presence, receipts, (messages, token) = await make_deferred_yieldable(
defer.gatherResults(
[
gather_results(
(
run_in_background(get_presence),
run_in_background(get_receipts),
run_in_background(
@@ -464,7 +469,7 @@ class InitialSyncHandler:
limit=limit,
end_token=now_token.room_key,
),
],
),
consumeErrors=True,
).addErrback(unwrapFirstError)
)


+ 6
- 7
synapse/handlers/message.py Ver fichero

@@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple

from canonicaljson import encode_canonical_json

from twisted.internet import defer
from twisted.internet.interfaces import IDelayedCall

from synapse import event_auth
@@ -57,7 +56,7 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
from synapse.util import json_decoder, json_encoder, log_failure
from synapse.util.async_helpers import Linearizer, unwrapFirstError
from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
@@ -1168,9 +1167,9 @@ class EventCreationHandler:

# We now persist the event (and update the cache in parallel, since we
# don't want to block on it).
result = await make_deferred_yieldable(
defer.gatherResults(
[
result, _ = await make_deferred_yieldable(
gather_results(
(
run_in_background(
self._persist_event,
requester=requester,
@@ -1182,12 +1181,12 @@ class EventCreationHandler:
run_in_background(
self.cache_joined_hosts_for_event, event, context
).addErrback(log_failure, "cache_joined_hosts_for_event failed"),
],
),
consumeErrors=True,
)
).addErrback(unwrapFirstError)

return result[0]
return result

async def _persist_event(
self,


+ 5
- 2
synapse/http/federation/matrix_federation_agent.py Ver fichero

@@ -25,6 +25,7 @@ from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import (
IProtocol,
IProtocolFactory,
IReactorCore,
IStreamClientEndpoint,
@@ -309,12 +310,14 @@ class MatrixHostnameEndpoint:

self._srv_resolver = srv_resolver

def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
def connect(
self, protocol_factory: IProtocolFactory
) -> "defer.Deferred[IProtocol]":
"""Implements IStreamClientEndpoint interface"""

return run_in_background(self._do_connect, protocol_factory)

async def _do_connect(self, protocol_factory: IProtocolFactory) -> None:
async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
first_exception = None

server_list = await self._resolve_server()


+ 103
- 46
synapse/logging/context.py Ver fichero

@@ -22,20 +22,33 @@ them.

See doc/log_contexts.rst for details on how this works.
"""
import inspect
import logging
import threading
import typing
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Optional,
Tuple,
Type,
TypeVar,
Union,
overload,
)

import attr
from typing_extensions import Literal

from twisted.internet import defer, threads
from twisted.python.threadpool import ThreadPool

if TYPE_CHECKING:
from synapse.logging.scopecontextmanager import _LogContextScope
from synapse.types import ISynapseReactor

logger = logging.getLogger(__name__)

@@ -66,7 +79,7 @@ except Exception:


# a hook which can be set during testing to assert that we aren't abusing logcontexts.
def logcontext_error(msg: str):
def logcontext_error(msg: str) -> None:
logger.warning(msg)


@@ -223,22 +236,19 @@ class _Sentinel:
def __str__(self) -> str:
return "sentinel"

def copy_to(self, record):
pass

def start(self, rusage: "Optional[resource.struct_rusage]"):
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
pass

def stop(self, rusage: "Optional[resource.struct_rusage]"):
def stop(self, rusage: "Optional[resource.struct_rusage]") -> None:
pass

def add_database_transaction(self, duration_sec):
def add_database_transaction(self, duration_sec: float) -> None:
pass

def add_database_scheduled(self, sched_sec):
def add_database_scheduled(self, sched_sec: float) -> None:
pass

def record_event_fetch(self, event_count):
def record_event_fetch(self, event_count: int) -> None:
pass

def __bool__(self) -> Literal[False]:
@@ -379,7 +389,12 @@ class LoggingContext:
)
return self

def __exit__(self, type, value, traceback) -> None:
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Restore the logging context in thread local storage to the state it
was before this context was entered.
Returns:
@@ -399,17 +414,6 @@ class LoggingContext:
# recorded against the correct metrics.
self.finished = True

def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or
another LoggingContext
"""

# we track the current request
record.request = self.request

# we also track the current scope:
record.scope = self.scope

def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""
Record that this logcontext is currently running.
@@ -626,7 +630,12 @@ class PreserveLoggingContext:
def __enter__(self) -> None:
self._old_context = set_current_context(self._new_context)

def __exit__(self, type, value, traceback) -> None:
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
context = set_current_context(self._old_context)

if context != self._new_context:
@@ -711,16 +720,61 @@ def nested_logging_context(suffix: str) -> LoggingContext:
)


def preserve_fn(f):
R = TypeVar("R")


@overload
def preserve_fn( # type: ignore[misc]
f: Callable[..., Awaitable[R]],
) -> Callable[..., "defer.Deferred[R]"]:
# The `type: ignore[misc]` above suppresses
# "Overloaded function signatures 1 and 2 overlap with incompatible return types"
...


@overload
def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]:
...


def preserve_fn(
f: Union[
Callable[..., R],
Callable[..., Awaitable[R]],
]
) -> Callable[..., "defer.Deferred[R]"]:
"""Function decorator which wraps the function with run_in_background"""

def g(*args, **kwargs):
def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]":
return run_in_background(f, *args, **kwargs)

return g


def run_in_background(f, *args, **kwargs) -> defer.Deferred:
@overload
def run_in_background( # type: ignore[misc]
f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any
) -> "defer.Deferred[R]":
# The `type: ignore[misc]` above suppresses
# "Overloaded function signatures 1 and 2 overlap with incompatible return types"
...


@overload
def run_in_background(
f: Callable[..., R], *args: Any, **kwargs: Any
) -> "defer.Deferred[R]":
...


def run_in_background(
f: Union[
Callable[..., R],
Callable[..., Awaitable[R]],
],
*args: Any,
**kwargs: Any,
) -> "defer.Deferred[R]":
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
deferred returned by the function completes.
@@ -751,6 +805,10 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
# At this point we should have a Deferred, if not then f was a synchronous
# function, wrap it in a Deferred for consistency.
if not isinstance(res, defer.Deferred):
# `res` is not a `Deferred` and not a `Coroutine`.
# There are no other types of `Awaitable`s we expect to encounter in Synapse.
assert not isinstance(res, Awaitable)

return defer.succeed(res)

if res.called and not res.paused:
@@ -778,13 +836,14 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
return res


def make_deferred_yieldable(deferred):
"""Given a deferred (or coroutine), make it follow the Synapse logcontext
rules:
T = TypeVar("T")


If the deferred has completed (or is not actually a Deferred), essentially
does nothing (just returns another completed deferred with the
result/failure).
def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
"""Given a deferred, make it follow the Synapse logcontext rules:

If the deferred has completed, essentially does nothing (just returns another
completed deferred with the result/failure).

If the deferred has not yet completed, resets the logcontext before
returning a deferred. Then, when the deferred completes, restores the
@@ -792,16 +851,6 @@ def make_deferred_yieldable(deferred):

(This is more-or-less the opposite operation to run_in_background.)
"""
if inspect.isawaitable(deferred):
# If we're given a coroutine we convert it to a deferred so that we
# run it and find out if it immediately finishes, it it does then we
# don't need to fiddle with log contexts at all and can return
# immediately.
deferred = defer.ensureDeferred(deferred)

if not isinstance(deferred, defer.Deferred):
return deferred

if deferred.called and not deferred.paused:
# it looks like this deferred is ready to run any callbacks we give it
# immediately. We may as well optimise out the logcontext faffery.
@@ -823,7 +872,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
return result


def defer_to_thread(reactor, f, *args, **kwargs):
def defer_to_thread(
reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any
) -> "defer.Deferred[R]":
"""
Calls the function `f` using a thread from the reactor's default threadpool and
returns the result as a Deferred.
@@ -855,7 +906,13 @@ def defer_to_thread(reactor, f, *args, **kwargs):
return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)


def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
def defer_to_threadpool(
reactor: "ISynapseReactor",
threadpool: ThreadPool,
f: Callable[..., R],
*args: Any,
**kwargs: Any,
) -> "defer.Deferred[R]":
"""
A wrapper for twisted.internet.threads.deferToThreadpool, which handles
logcontexts correctly.
@@ -897,7 +954,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context

def g():
def g() -> R:
with LoggingContext(str(curr_context), parent_context=parent_context):
return f(*args, **kwargs)



+ 56
- 1
synapse/util/async_helpers.py Ver fichero

@@ -30,9 +30,11 @@ from typing import (
Iterator,
Optional,
Set,
Tuple,
TypeVar,
Union,
cast,
overload,
)

import attr
@@ -234,6 +236,59 @@ def yieldable_gather_results(
).addErrback(unwrapFirstError)


T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")


@overload
def gather_results(
deferredList: Tuple[()], consumeErrors: bool = ...
) -> "defer.Deferred[Tuple[()]]":
...


@overload
def gather_results(
deferredList: Tuple["defer.Deferred[T1]"],
consumeErrors: bool = ...,
) -> "defer.Deferred[Tuple[T1]]":
...


@overload
def gather_results(
deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"],
consumeErrors: bool = ...,
) -> "defer.Deferred[Tuple[T1, T2]]":
...


@overload
def gather_results(
deferredList: Tuple[
"defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]"
],
consumeErrors: bool = ...,
) -> "defer.Deferred[Tuple[T1, T2, T3]]":
...


def gather_results( # type: ignore[misc]
deferredList: Tuple["defer.Deferred[T1]", ...],
consumeErrors: bool = False,
) -> "defer.Deferred[Tuple[T1, ...]]":
"""Combines a tuple of `Deferred`s into a single `Deferred`.

Wraps `defer.gatherResults` to provide type annotations that support heterogenous
lists of `Deferred`s.
"""
# The `type: ignore[misc]` above suppresses
# "Overloaded function implementation cannot produce return type of signature 1/2/3"
deferred = defer.gatherResults(deferredList, consumeErrors=consumeErrors)
return deferred.addCallback(tuple)


@attr.s(slots=True)
class _LinearizerEntry:
# The number of things executing.
@@ -352,7 +407,7 @@ class Linearizer:

logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)

new_defer = make_deferred_yieldable(defer.Deferred())
new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred())
entry.deferreds[new_defer] = 1

def cb(_r: None) -> "defer.Deferred[None]":


+ 1
- 0
synapse/util/caches/cached_call.py Ver fichero

@@ -76,6 +76,7 @@ class CachedCall(Generic[TV]):

# Fire off the callable now if this is our first time
if not self._deferred:
assert self._callable is not None
self._deferred = run_in_background(self._callable)

# we will never need the callable again, so make sure it can be GCed


+ 1
- 0
synapse/util/file_consumer.py Ver fichero

@@ -142,6 +142,7 @@ class BackgroundFileConsumer:

def wait(self) -> "Deferred[None]":
"""Returns a deferred that resolves when finished writing to file"""
assert self._finished_deferred is not None
return make_deferred_yieldable(self._finished_deferred)

def _resume_paused_producer(self) -> None:


+ 0
- 35
tests/util/test_logcontext.py Ver fichero

@@ -152,46 +152,11 @@ class LoggingContextTestCase(unittest.TestCase):
# now it should be restored
self._check_test_key("one")

@defer.inlineCallbacks
def test_make_deferred_yieldable_on_non_deferred(self):
"""Check that make_deferred_yieldable does the right thing when its
argument isn't actually a deferred"""

with LoggingContext("one"):
d1 = make_deferred_yieldable("bum")
self._check_test_key("one")

r = yield d1
self.assertEqual(r, "bum")
self._check_test_key("one")

def test_nested_logging_context(self):
with LoggingContext("foo"):
nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.name, "foo-bar")

@defer.inlineCallbacks
def test_make_deferred_yieldable_with_await(self):
# an async function which returns an incomplete coroutine, but doesn't
# follow the synapse rules.

async def blocking_function():
d = defer.Deferred()
reactor.callLater(0, d.callback, None)
await d

sentinel_context = current_context()

with LoggingContext("one"):
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)

yield d1

# now it should be restored
self._check_test_key("one")


# a function which returns a deferred which has been "called", but
# which had a function which returned another incomplete deferred on


Cargando…
Cancelar
Guardar