@@ -0,0 +1 @@ | |||
Use `ParamSpec` to refine type hints. |
@@ -1563,7 +1563,7 @@ url_preview = ["lxml"] | |||
[metadata] | |||
lock-version = "1.1" | |||
python-versions = "^3.7.1" | |||
content-hash = "eebc9e1d720e2e866f5fddda98ce83d858949a6fdbe30c7e5aef4cf9d17be498" | |||
content-hash = "d39d5ac5d51c014581186b7691999b861058b569084c525523baf70b77f292b1" | |||
[metadata.files] | |||
attrs = [ | |||
@@ -143,7 +143,9 @@ netaddr = ">=0.7.18" | |||
Jinja2 = ">=3.0" | |||
bleach = ">=1.4.3" | |||
# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0. | |||
typing-extensions = ">=3.10.0" | |||
# Additionally we need https://github.com/python/typing/pull/817 to allow types to be | |||
# generic over ParamSpecs. | |||
typing-extensions = ">=3.10.0.1" | |||
# We enforce that we have a `cryptography` version that bundles an `openssl` | |||
# with the latest security patches. | |||
cryptography = ">=3.4.7" | |||
@@ -38,6 +38,7 @@ from typing import ( | |||
from cryptography.utils import CryptographyDeprecationWarning | |||
from matrix_common.versionstring import get_distribution_version_string | |||
from typing_extensions import ParamSpec | |||
import twisted | |||
from twisted.internet import defer, error, reactor as _reactor | |||
@@ -81,11 +82,12 @@ logger = logging.getLogger(__name__) | |||
# list of tuples of function, args list, kwargs dict | |||
_sighup_callbacks: List[ | |||
Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]] | |||
Tuple[Callable[..., None], Tuple[object, ...], Dict[str, object]] | |||
] = [] | |||
P = ParamSpec("P") | |||
def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> None: | |||
def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: | |||
""" | |||
Register a function to be called when a SIGHUP occurs. | |||
@@ -93,7 +95,9 @@ def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> Non | |||
func: Function to be called when sent a SIGHUP signal. | |||
*args, **kwargs: args and kwargs to be passed to the target function. | |||
""" | |||
_sighup_callbacks.append((func, args, kwargs)) | |||
# This type-ignore should be redundant once we use a mypy release with | |||
# https://github.com/python/mypy/pull/12668. | |||
_sighup_callbacks.append((func, args, kwargs)) # type: ignore[arg-type] | |||
def start_worker_reactor( | |||
@@ -214,7 +218,9 @@ def redirect_stdio_to_logs() -> None: | |||
print("Redirected stdout/stderr to logs") | |||
def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> None: | |||
def register_start( | |||
cb: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs | |||
) -> None: | |||
"""Register a callback with the reactor, to be called once it is running | |||
This can be used to initialise parts of the system which require an asynchronous | |||
@@ -22,9 +22,12 @@ from typing import ( | |||
List, | |||
Optional, | |||
Set, | |||
TypeVar, | |||
Union, | |||
) | |||
from typing_extensions import ParamSpec | |||
from synapse.api.presence import UserPresenceState | |||
from synapse.util.async_helpers import maybe_awaitable | |||
@@ -40,6 +43,10 @@ GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]] | |||
logger = logging.getLogger(__name__) | |||
P = ParamSpec("P") | |||
R = TypeVar("R") | |||
def load_legacy_presence_router(hs: "HomeServer") -> None: | |||
"""Wrapper that loads a presence router module configured using the old | |||
configuration, and registers the hooks they implement. | |||
@@ -63,13 +70,15 @@ def load_legacy_presence_router(hs: "HomeServer") -> None: | |||
# All methods that the module provides should be async, but this wasn't enforced | |||
# in the old module system, so we wrap them if needed | |||
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: | |||
def async_wrapper( | |||
f: Optional[Callable[P, R]] | |||
) -> Optional[Callable[P, Awaitable[R]]]: | |||
# f might be None if the callback isn't implemented by the module. In this | |||
# case we don't want to register a callback at all so we return None. | |||
if f is None: | |||
return None | |||
def run(*args: Any, **kwargs: Any) -> Awaitable: | |||
def run(*args: P.args, **kwargs: P.kwargs) -> Awaitable[R]: | |||
# Assertion required because mypy can't prove we won't change `f` | |||
# back to `None`. See | |||
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions | |||
@@ -80,7 +89,7 @@ def load_legacy_presence_router(hs: "HomeServer") -> None: | |||
return run | |||
# Register the hooks through the module API. | |||
hooks = { | |||
hooks: Dict[str, Optional[Callable[..., Any]]] = { | |||
hook: async_wrapper(getattr(presence_router, hook, None)) | |||
for hook in presence_router_methods | |||
} | |||
@@ -30,6 +30,7 @@ from typing import ( | |||
import attr | |||
import jinja2 | |||
from typing_extensions import ParamSpec | |||
from twisted.internet import defer | |||
from twisted.web.resource import Resource | |||
@@ -129,6 +130,7 @@ if TYPE_CHECKING: | |||
T = TypeVar("T") | |||
P = ParamSpec("P") | |||
""" | |||
This package defines the 'stable' API which can be used by extension modules which | |||
@@ -799,9 +801,9 @@ class ModuleApi: | |||
def run_db_interaction( | |||
self, | |||
desc: str, | |||
func: Callable[..., T], | |||
*args: Any, | |||
**kwargs: Any, | |||
func: Callable[P, T], | |||
*args: P.args, | |||
**kwargs: P.kwargs, | |||
) -> "defer.Deferred[T]": | |||
"""Run a function with a database connection | |||
@@ -817,8 +819,9 @@ class ModuleApi: | |||
Returns: | |||
Deferred[object]: result of func | |||
""" | |||
# type-ignore: See https://github.com/python/mypy/issues/8862 | |||
return defer.ensureDeferred( | |||
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) | |||
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type] | |||
) | |||
def complete_sso_login( | |||
@@ -1296,9 +1299,9 @@ class ModuleApi: | |||
async def defer_to_thread( | |||
self, | |||
f: Callable[..., T], | |||
*args: Any, | |||
**kwargs: Any, | |||
f: Callable[P, T], | |||
*args: P.args, | |||
**kwargs: P.kwargs, | |||
) -> T: | |||
"""Runs the given function in a separate thread from Synapse's thread pool. | |||
@@ -15,8 +15,6 @@ | |||
import logging | |||
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple | |||
from twisted.web.server import Request | |||
from synapse.api.constants import Membership | |||
from synapse.api.errors import SynapseError | |||
from synapse.http.server import HttpServer | |||
@@ -97,7 +95,7 @@ class KnockRoomAliasServlet(RestServlet): | |||
return 200, {"room_id": room_id} | |||
def on_PUT( | |||
self, request: Request, room_identifier: str, txn_id: str | |||
self, request: SynapseRequest, room_identifier: str, txn_id: str | |||
) -> Awaitable[Tuple[int, JsonDict]]: | |||
set_tag("txn_id", txn_id) | |||
@@ -15,7 +15,9 @@ | |||
"""This module contains logic for storing HTTP PUT transactions. This is used | |||
to ensure idempotency when performing PUTs using the REST API.""" | |||
import logging | |||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Tuple | |||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple | |||
from typing_extensions import ParamSpec | |||
from twisted.python.failure import Failure | |||
from twisted.web.server import Request | |||
@@ -32,6 +34,9 @@ logger = logging.getLogger(__name__) | |||
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins | |||
P = ParamSpec("P") | |||
class HttpTransactionCache: | |||
def __init__(self, hs: "HomeServer"): | |||
self.hs = hs | |||
@@ -65,9 +70,9 @@ class HttpTransactionCache: | |||
def fetch_or_execute_request( | |||
self, | |||
request: Request, | |||
fn: Callable[..., Awaitable[Tuple[int, JsonDict]]], | |||
*args: Any, | |||
**kwargs: Any, | |||
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], | |||
*args: P.args, | |||
**kwargs: P.kwargs, | |||
) -> Awaitable[Tuple[int, JsonDict]]: | |||
"""A helper function for fetch_or_execute which extracts | |||
a transaction key from the given request. | |||
@@ -82,9 +87,9 @@ class HttpTransactionCache: | |||
def fetch_or_execute( | |||
self, | |||
txn_key: str, | |||
fn: Callable[..., Awaitable[Tuple[int, JsonDict]]], | |||
*args: Any, | |||
**kwargs: Any, | |||
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], | |||
*args: P.args, | |||
**kwargs: P.kwargs, | |||
) -> Awaitable[Tuple[int, JsonDict]]: | |||
"""Fetches the response for this transaction, or executes the given function | |||
to produce a response for this transaction. | |||
@@ -192,7 +192,7 @@ class LoggingDatabaseConnection: | |||
# The type of entry which goes on our after_callbacks and exception_callbacks lists. | |||
_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]] | |||
_CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]] | |||
P = ParamSpec("P") | |||
R = TypeVar("R") | |||
@@ -239,7 +239,9 @@ class LoggingTransaction: | |||
self.after_callbacks = after_callbacks | |||
self.exception_callbacks = exception_callbacks | |||
def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any): | |||
def call_after( | |||
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs | |||
) -> None: | |||
"""Call the given callback on the main twisted thread after the transaction has | |||
finished. | |||
@@ -256,11 +258,12 @@ class LoggingTransaction: | |||
# LoggingTransaction isn't expecting there to be any callbacks; assert that | |||
# is not the case. | |||
assert self.after_callbacks is not None | |||
self.after_callbacks.append((callback, args, kwargs)) | |||
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 | |||
self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] | |||
def call_on_exception( | |||
self, callback: Callable[..., object], *args: Any, **kwargs: Any | |||
): | |||
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs | |||
) -> None: | |||
"""Call the given callback on the main twisted thread after the transaction has | |||
failed. | |||
@@ -274,7 +277,8 @@ class LoggingTransaction: | |||
# LoggingTransaction isn't expecting there to be any callbacks; assert that | |||
# is not the case. | |||
assert self.exception_callbacks is not None | |||
self.exception_callbacks.append((callback, args, kwargs)) | |||
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 | |||
self.exception_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] | |||
def fetchone(self) -> Optional[Tuple]: | |||
return self.txn.fetchone() | |||
@@ -549,9 +553,9 @@ class DatabasePool: | |||
desc: str, | |||
after_callbacks: List[_CallbackListEntry], | |||
exception_callbacks: List[_CallbackListEntry], | |||
func: Callable[..., R], | |||
*args: Any, | |||
**kwargs: Any, | |||
func: Callable[Concatenate[LoggingTransaction, P], R], | |||
*args: P.args, | |||
**kwargs: P.kwargs, | |||
) -> R: | |||
"""Start a new database transaction with the given connection. | |||
@@ -581,7 +585,10 @@ class DatabasePool: | |||
# will fail if we have to repeat the transaction. | |||
# For now, we just log an error, and hope that it works on the first attempt. | |||
# TODO: raise an exception. | |||
for i, arg in enumerate(args): | |||
# Type-ignore Mypy doesn't yet consider ParamSpec.args to be iterable; see | |||
# https://github.com/python/mypy/pull/12668 | |||
for i, arg in enumerate(args): # type: ignore[arg-type, var-annotated] | |||
if inspect.isgenerator(arg): | |||
logger.error( | |||
"Programming error: generator passed to new_transaction as " | |||
@@ -589,7 +596,9 @@ class DatabasePool: | |||
i, | |||
func, | |||
) | |||
for name, val in kwargs.items(): | |||
# Type-ignore Mypy doesn't yet consider ParamSpec.args to be a mapping; see | |||
# https://github.com/python/mypy/pull/12668 | |||
for name, val in kwargs.items(): # type: ignore[attr-defined] | |||
if inspect.isgenerator(val): | |||
logger.error( | |||
"Programming error: generator passed to new_transaction as " | |||
@@ -1648,8 +1648,12 @@ class PersistEventsStore: | |||
txn.call_after(prefill) | |||
def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: | |||
# Invalidate the caches for the redacted event, note that these caches | |||
# are also cleared as part of event replication in _invalidate_caches_for_event. | |||
"""Invalidate the caches for the redacted event. | |||
Note that these caches are also cleared as part of event replication in | |||
_invalidate_caches_for_event. | |||
""" | |||
assert event.redacts is not None | |||
txn.call_after(self.store._invalidate_get_event_cache, event.redacts) | |||
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,)) | |||
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,)) | |||
@@ -42,7 +42,7 @@ from typing import ( | |||
) | |||
import attr | |||
from typing_extensions import AsyncContextManager, Literal | |||
from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec | |||
from twisted.internet import defer | |||
from twisted.internet.defer import CancelledError | |||
@@ -237,9 +237,16 @@ async def concurrently_execute( | |||
) | |||
P = ParamSpec("P") | |||
R = TypeVar("R") | |||
async def yieldable_gather_results( | |||
func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any | |||
) -> List[T]: | |||
func: Callable[Concatenate[T, P], Awaitable[R]], | |||
iter: Iterable[T], | |||
*args: P.args, | |||
**kwargs: P.kwargs, | |||
) -> List[R]: | |||
"""Executes the function with each argument concurrently. | |||
Args: | |||
@@ -255,7 +262,15 @@ async def yieldable_gather_results( | |||
try: | |||
return await make_deferred_yieldable( | |||
defer.gatherResults( | |||
[run_in_background(func, item, *args, **kwargs) for item in iter], | |||
# type-ignore: mypy reports two errors: | |||
# error: Argument 1 to "run_in_background" has incompatible type | |||
# "Callable[[T, **P], Awaitable[R]]"; expected | |||
# "Callable[[T, **P], Awaitable[R]]" [arg-type] | |||
# error: Argument 2 to "run_in_background" has incompatible type | |||
# "T"; expected "[T, **P.args]" [arg-type] | |||
# The former looks like a mypy bug, and the latter looks like a | |||
# false positive. | |||
[run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type] | |||
consumeErrors=True, | |||
) | |||
) | |||
@@ -577,9 +592,6 @@ class ReadWriteLock: | |||
return _ctx_manager() | |||
R = TypeVar("R") | |||
def timeout_deferred( | |||
deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime | |||
) -> "defer.Deferred[_T]": | |||
@@ -12,7 +12,19 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import Any, Callable, Dict, List | |||
from typing import ( | |||
Any, | |||
Awaitable, | |||
Callable, | |||
Dict, | |||
Generic, | |||
List, | |||
Optional, | |||
TypeVar, | |||
Union, | |||
) | |||
from typing_extensions import ParamSpec | |||
from twisted.internet import defer | |||
@@ -75,7 +87,11 @@ class Distributor: | |||
run_as_background_process(name, self.signals[name].fire, *args, **kwargs) | |||
class Signal: | |||
P = ParamSpec("P") | |||
R = TypeVar("R") | |||
class Signal(Generic[P]): | |||
"""A Signal is a dispatch point that stores a list of callables as | |||
observers of it. | |||
@@ -87,16 +103,16 @@ class Signal: | |||
def __init__(self, name: str): | |||
self.name: str = name | |||
self.observers: List[Callable] = [] | |||
self.observers: List[Callable[P, Any]] = [] | |||
def observe(self, observer: Callable) -> None: | |||
def observe(self, observer: Callable[P, Any]) -> None: | |||
"""Adds a new callable to the observer list which will be invoked by | |||
the 'fire' method. | |||
Each observer callable may return a Deferred.""" | |||
self.observers.append(observer) | |||
def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]": | |||
def fire(self, *args: P.args, **kwargs: P.kwargs) -> "defer.Deferred[List[Any]]": | |||
"""Invokes every callable in the observer list, passing in the args and | |||
kwargs. Exceptions thrown by observers are logged but ignored. It is | |||
not an error to fire a signal with no observers. | |||
@@ -104,7 +120,7 @@ class Signal: | |||
Returns a Deferred that will complete when all the observers have | |||
completed.""" | |||
async def do(observer: Callable[..., Any]) -> Any: | |||
async def do(observer: Callable[P, Union[R, Awaitable[R]]]) -> Optional[R]: | |||
try: | |||
return await maybe_awaitable(observer(*args, **kwargs)) | |||
except Exception as e: | |||
@@ -114,6 +130,7 @@ class Signal: | |||
observer, | |||
e, | |||
) | |||
return None | |||
deferreds = [run_in_background(do, o) for o in self.observers] | |||
@@ -15,10 +15,10 @@ | |||
import logging | |||
from functools import wraps | |||
from types import TracebackType | |||
from typing import Any, Callable, Optional, Type, TypeVar, cast | |||
from typing import Awaitable, Callable, Optional, Type, TypeVar | |||
from prometheus_client import Counter | |||
from typing_extensions import Protocol | |||
from typing_extensions import Concatenate, ParamSpec, Protocol | |||
from synapse.logging.context import ( | |||
ContextResourceUsage, | |||
@@ -72,16 +72,21 @@ in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge( | |||
) | |||
T = TypeVar("T", bound=Callable[..., Any]) | |||
P = ParamSpec("P") | |||
R = TypeVar("R") | |||
class HasClock(Protocol): | |||
clock: Clock | |||
def measure_func(name: Optional[str] = None) -> Callable[[T], T]: | |||
""" | |||
Used to decorate an async function with a `Measure` context manager. | |||
def measure_func( | |||
name: Optional[str] = None, | |||
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: | |||
"""Decorate an async method with a `Measure` context manager. | |||
The Measure is created using `self.clock`; it should only be used to decorate | |||
methods in classes defining an instance-level `clock` attribute. | |||
Usage: | |||
@@ -97,18 +102,24 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]: | |||
""" | |||
def wrapper(func: T) -> T: | |||
def wrapper( | |||
func: Callable[Concatenate[HasClock, P], Awaitable[R]] | |||
) -> Callable[P, Awaitable[R]]: | |||
block_name = func.__name__ if name is None else name | |||
@wraps(func) | |||
async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> Any: | |||
async def measured_func(self: HasClock, *args: P.args, **kwargs: P.kwargs) -> R: | |||
with Measure(self.clock, block_name): | |||
r = await func(self, *args, **kwargs) | |||
return r | |||
return cast(T, measured_func) | |||
# There are some shenanigans here, because we're decorating a method but | |||
# explicitly making use of the `self` parameter. The key thing here is that the | |||
# return type within the return type for `measure_func` itself describes how the | |||
# decorated function will be called. | |||
return measured_func # type: ignore[return-value] | |||
return wrapper | |||
return wrapper # type: ignore[return-value] | |||
class Measure: | |||
@@ -16,6 +16,8 @@ import functools | |||
import sys | |||
from typing import Any, Callable, Generator, List, TypeVar, cast | |||
from typing_extensions import ParamSpec | |||
from twisted.internet import defer | |||
from twisted.internet.defer import Deferred | |||
from twisted.python.failure import Failure | |||
@@ -25,6 +27,7 @@ _already_patched = False | |||
T = TypeVar("T") | |||
P = ParamSpec("P") | |||
def do_patch() -> None: | |||
@@ -41,13 +44,13 @@ def do_patch() -> None: | |||
return | |||
def new_inline_callbacks( | |||
f: Callable[..., Generator["Deferred[object]", object, T]] | |||
) -> Callable[..., "Deferred[T]"]: | |||
f: Callable[P, Generator["Deferred[object]", object, T]] | |||
) -> Callable[P, "Deferred[T]"]: | |||
@functools.wraps(f) | |||
def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]": | |||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> "Deferred[T]": | |||
start_context = current_context() | |||
changes: List[str] = [] | |||
orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks( | |||
orig: Callable[P, "Deferred[T]"] = orig_inline_callbacks( | |||
_check_yield_points(f, changes) | |||
) | |||
@@ -115,7 +118,7 @@ def do_patch() -> None: | |||
def _check_yield_points( | |||
f: Callable[..., Generator["Deferred[object]", object, T]], | |||
f: Callable[P, Generator["Deferred[object]", object, T]], | |||
changes: List[str], | |||
) -> Callable: | |||
"""Wraps a generator that is about to be passed to defer.inlineCallbacks | |||
@@ -138,7 +141,7 @@ def _check_yield_points( | |||
@functools.wraps(f) | |||
def check_yield_points_inner( | |||
*args: Any, **kwargs: Any | |||
*args: P.args, **kwargs: P.kwargs | |||
) -> Generator["Deferred[object]", object, T]: | |||
gen = f(*args, **kwargs) | |||