@@ -0,0 +1 @@ | |||
Add type annotations to the synapse.util package. |
@@ -74,17 +74,7 @@ files = | |||
synapse/storage/util, | |||
synapse/streams, | |||
synapse/types.py, | |||
synapse/util/async_helpers.py, | |||
synapse/util/caches, | |||
synapse/util/daemonize.py, | |||
synapse/util/hash.py, | |||
synapse/util/iterutils.py, | |||
synapse/util/linked_list.py, | |||
synapse/util/metrics.py, | |||
synapse/util/macaroons.py, | |||
synapse/util/module_loader.py, | |||
synapse/util/msisdn.py, | |||
synapse/util/stringutils.py, | |||
synapse/util, | |||
synapse/visibility.py, | |||
tests/replication, | |||
tests/test_event_auth.py, | |||
@@ -102,6 +92,69 @@ files = | |||
[mypy-synapse.rest.client.*] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.batching_queue] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.caches.dictionary_cache] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.file_consumer] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.frozenutils] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.hash] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.httpresourcetree] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.iterutils] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.linked_list] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.logcontext] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.logformatter] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.macaroons] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.manhole] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.module_loader] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.msisdn] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.ratelimitutils] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.retryutils] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.rlimit] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.stringutils] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.templates] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.threepids] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.wheel_timer] | |||
disallow_untyped_defs = True | |||
[mypy-pymacaroons.*] | |||
ignore_missing_imports = True | |||
@@ -73,4 +73,4 @@ class RedisFactory(protocol.ReconnectingClientFactory): | |||
def buildProtocol(self, addr) -> RedisProtocol: ... | |||
class SubscriberFactory(RedisFactory): | |||
def __init__(self): ... | |||
def __init__(self) -> None: ... |
@@ -46,7 +46,7 @@ class Ratelimiter: | |||
# * How many times an action has occurred since a point in time | |||
# * The point in time | |||
# * The rate_hz of this particular entry. This can vary per request | |||
self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict() | |||
self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict() | |||
async def can_do_action( | |||
self, | |||
@@ -56,7 +56,7 @@ class Ratelimiter: | |||
burst_count: Optional[int] = None, | |||
update: bool = True, | |||
n_actions: int = 1, | |||
_time_now_s: Optional[int] = None, | |||
_time_now_s: Optional[float] = None, | |||
) -> Tuple[bool, float]: | |||
"""Can the entity (e.g. user or IP address) perform the action? | |||
@@ -160,7 +160,7 @@ class Ratelimiter: | |||
return allowed, time_allowed | |||
def _prune_message_counts(self, time_now_s: int): | |||
def _prune_message_counts(self, time_now_s: float): | |||
"""Remove message count entries that have not exceeded their defined | |||
rate_hz limit | |||
@@ -188,7 +188,7 @@ class Ratelimiter: | |||
burst_count: Optional[int] = None, | |||
update: bool = True, | |||
n_actions: int = 1, | |||
_time_now_s: Optional[int] = None, | |||
_time_now_s: Optional[float] = None, | |||
): | |||
"""Checks if an action can be performed. If not, raises a LimitExceededError | |||
@@ -14,6 +14,8 @@ | |||
from typing import Dict, Optional | |||
import attr | |||
from ._base import Config | |||
@@ -29,18 +31,13 @@ class RateLimitConfig: | |||
self.burst_count = int(config.get("burst_count", defaults["burst_count"])) | |||
@attr.s(auto_attribs=True) | |||
class FederationRateLimitConfig: | |||
_items_and_default = { | |||
"window_size": 1000, | |||
"sleep_limit": 10, | |||
"sleep_delay": 500, | |||
"reject_limit": 50, | |||
"concurrent": 3, | |||
} | |||
def __init__(self, **kwargs): | |||
for i in self._items_and_default.keys(): | |||
setattr(self, i, kwargs.get(i) or self._items_and_default[i]) | |||
window_size: int = 1000 | |||
sleep_limit: int = 10 | |||
sleep_delay: int = 500 | |||
reject_limit: int = 50 | |||
concurrent: int = 3 | |||
class RatelimitConfig(Config): | |||
@@ -69,11 +66,15 @@ class RatelimitConfig(Config): | |||
else: | |||
self.rc_federation = FederationRateLimitConfig( | |||
**{ | |||
"window_size": config.get("federation_rc_window_size"), | |||
"sleep_limit": config.get("federation_rc_sleep_limit"), | |||
"sleep_delay": config.get("federation_rc_sleep_delay"), | |||
"reject_limit": config.get("federation_rc_reject_limit"), | |||
"concurrent": config.get("federation_rc_concurrent"), | |||
k: v | |||
for k, v in { | |||
"window_size": config.get("federation_rc_window_size"), | |||
"sleep_limit": config.get("federation_rc_sleep_limit"), | |||
"sleep_delay": config.get("federation_rc_sleep_delay"), | |||
"reject_limit": config.get("federation_rc_reject_limit"), | |||
"concurrent": config.get("federation_rc_concurrent"), | |||
}.items() | |||
if v is not None | |||
} | |||
) | |||
@@ -22,6 +22,7 @@ from prometheus_client import Counter | |||
from typing_extensions import Literal | |||
from twisted.internet import defer | |||
from twisted.internet.interfaces import IDelayedCall | |||
import synapse.metrics | |||
from synapse.api.presence import UserPresenceState | |||
@@ -284,7 +285,9 @@ class FederationSender(AbstractFederationSender): | |||
) | |||
# wake up destinations that have outstanding PDUs to be caught up | |||
self._catchup_after_startup_timer = self.clock.call_later( | |||
self._catchup_after_startup_timer: Optional[ | |||
IDelayedCall | |||
] = self.clock.call_later( | |||
CATCH_UP_STARTUP_DELAY_SEC, | |||
run_as_background_process, | |||
"wake_destinations_needing_catchup", | |||
@@ -406,7 +409,7 @@ class FederationSender(AbstractFederationSender): | |||
now = self.clock.time_msec() | |||
ts = await self.store.get_received_ts(event.event_id) | |||
assert ts is not None | |||
synapse.metrics.event_processing_lag_by_event.labels( | |||
"federation_sender" | |||
).observe((now - ts) / 1000) | |||
@@ -435,6 +438,7 @@ class FederationSender(AbstractFederationSender): | |||
if events: | |||
now = self.clock.time_msec() | |||
ts = await self.store.get_received_ts(events[-1].event_id) | |||
assert ts is not None | |||
synapse.metrics.event_processing_lag.labels( | |||
"federation_sender" | |||
@@ -398,6 +398,7 @@ class AccountValidityHandler: | |||
""" | |||
now = self.clock.time_msec() | |||
if expiration_ts is None: | |||
assert self._account_validity_period is not None | |||
expiration_ts = now + self._account_validity_period | |||
await self.store.set_account_validity_for_user( | |||
@@ -131,6 +131,8 @@ class ApplicationServicesHandler: | |||
now = self.clock.time_msec() | |||
ts = await self.store.get_received_ts(event.event_id) | |||
assert ts is not None | |||
synapse.metrics.event_processing_lag_by_event.labels( | |||
"appservice_sender" | |||
).observe((now - ts) / 1000) | |||
@@ -166,6 +168,7 @@ class ApplicationServicesHandler: | |||
if events: | |||
now = self.clock.time_msec() | |||
ts = await self.store.get_received_ts(events[-1].event_id) | |||
assert ts is not None | |||
synapse.metrics.event_processing_lag.labels( | |||
"appservice_sender" | |||
@@ -28,6 +28,7 @@ from bisect import bisect | |||
from contextlib import contextmanager | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Any, | |||
Callable, | |||
Collection, | |||
Dict, | |||
@@ -615,7 +616,7 @@ class PresenceHandler(BasePresenceHandler): | |||
super().__init__(hs) | |||
self.hs = hs | |||
self.server_name = hs.hostname | |||
self.wheel_timer = WheelTimer() | |||
self.wheel_timer: WheelTimer[str] = WheelTimer() | |||
self.notifier = hs.get_notifier() | |||
self._presence_enabled = hs.config.use_presence | |||
@@ -924,7 +925,7 @@ class PresenceHandler(BasePresenceHandler): | |||
prev_state = await self.current_state_for_user(user_id) | |||
new_fields = {"last_active_ts": self.clock.time_msec()} | |||
new_fields: Dict[str, Any] = {"last_active_ts": self.clock.time_msec()} | |||
if prev_state.state == PresenceState.UNAVAILABLE: | |||
new_fields["state"] = PresenceState.ONLINE | |||
@@ -73,7 +73,7 @@ class FollowerTypingHandler: | |||
self._room_typing: Dict[str, Set[str]] = {} | |||
self._member_last_federation_poke: Dict[RoomMember, int] = {} | |||
self.wheel_timer = WheelTimer(bucket_size=5000) | |||
self.wheel_timer: WheelTimer[RoomMember] = WheelTimer(bucket_size=5000) | |||
self._latest_room_serial = 0 | |||
self.clock.looping_call(self._handle_timeouts, 5000) | |||
@@ -330,11 +330,11 @@ class UsernameAvailabilityRestServlet(RestServlet): | |||
# Artificially delay requests if rate > sleep_limit/window_size | |||
sleep_limit=1, | |||
# Amount of artificial delay to apply | |||
sleep_msec=1000, | |||
sleep_delay=1000, | |||
# Error with 429 if more than reject_limit requests are queued | |||
reject_limit=1, | |||
# Allow 1 request at a time | |||
concurrent_requests=1, | |||
concurrent=1, | |||
), | |||
) | |||
@@ -763,7 +763,10 @@ class RegisterRestServlet(RestServlet): | |||
Returns: | |||
dictionary for response from /register | |||
""" | |||
result = {"user_id": user_id, "home_server": self.hs.hostname} | |||
result: JsonDict = { | |||
"user_id": user_id, | |||
"home_server": self.hs.hostname, | |||
} | |||
if not params.get("inhibit_login", False): | |||
device_id = params.get("device_id") | |||
initial_display_name = params.get("initial_device_display_name") | |||
@@ -814,7 +817,7 @@ class RegisterRestServlet(RestServlet): | |||
user_id, device_id, initial_display_name, is_guest=True | |||
) | |||
result = { | |||
result: JsonDict = { | |||
"user_id": user_id, | |||
"device_id": device_id, | |||
"access_token": access_token, | |||
@@ -52,7 +52,7 @@ class NewUserConsentResource(DirectServeHtmlResource): | |||
yield hs.config.sso.sso_template_dir | |||
yield hs.config.sso.default_template_dir | |||
self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) | |||
self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config) | |||
async def _async_render_GET(self, request: Request) -> None: | |||
try: | |||
@@ -80,7 +80,7 @@ class AccountDetailsResource(DirectServeHtmlResource): | |||
yield hs.config.sso.sso_template_dir | |||
yield hs.config.sso.default_template_dir | |||
self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) | |||
self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config) | |||
async def _async_render_GET(self, request: Request) -> None: | |||
try: | |||
@@ -1091,6 +1091,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
delta equal to 10% of the validity period. | |||
""" | |||
now_ms = self._clock.time_msec() | |||
assert self._account_validity_period is not None | |||
expiration_ts = now_ms + self._account_validity_period | |||
if use_delta: | |||
@@ -38,6 +38,7 @@ from twisted.internet.interfaces import ( | |||
IReactorCore, | |||
IReactorPluggableNameResolver, | |||
IReactorTCP, | |||
IReactorThreads, | |||
IReactorTime, | |||
) | |||
@@ -63,7 +64,12 @@ JsonDict = Dict[str, Any] | |||
# Note that this seems to require inheriting *directly* from Interface in order | |||
# for mypy-zope to realize it is an interface. | |||
class ISynapseReactor( | |||
IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface | |||
IReactorTCP, | |||
IReactorPluggableNameResolver, | |||
IReactorTime, | |||
IReactorCore, | |||
IReactorThreads, | |||
Interface, | |||
): | |||
"""The interfaces necessary for Synapse to function.""" | |||
@@ -15,27 +15,35 @@ | |||
import json | |||
import logging | |||
import re | |||
from typing import Pattern | |||
import typing | |||
from typing import Any, Callable, Dict, Generator, Pattern | |||
import attr | |||
from frozendict import frozendict | |||
from twisted.internet import defer, task | |||
from twisted.internet.defer import Deferred | |||
from twisted.internet.interfaces import IDelayedCall, IReactorTime | |||
from twisted.internet.task import LoopingCall | |||
from twisted.python.failure import Failure | |||
from synapse.logging import context | |||
if typing.TYPE_CHECKING: | |||
pass | |||
logger = logging.getLogger(__name__) | |||
_WILDCARD_RUN = re.compile(r"([\?\*]+)") | |||
def _reject_invalid_json(val): | |||
def _reject_invalid_json(val: Any) -> None: | |||
"""Do not allow Infinity, -Infinity, or NaN values in JSON.""" | |||
raise ValueError("Invalid JSON value: '%s'" % val) | |||
def _handle_frozendict(obj): | |||
def _handle_frozendict(obj: Any) -> Dict[Any, Any]: | |||
"""Helper for json_encoder. Makes frozendicts serializable by returning | |||
the underlying dict | |||
""" | |||
@@ -60,10 +68,10 @@ json_encoder = json.JSONEncoder( | |||
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) | |||
def unwrapFirstError(failure): | |||
def unwrapFirstError(failure: Failure) -> Failure: | |||
# defer.gatherResults and DeferredLists wrap failures. | |||
failure.trap(defer.FirstError) | |||
return failure.value.subFailure | |||
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations | |||
@attr.s(slots=True) | |||
@@ -75,25 +83,25 @@ class Clock: | |||
reactor: The Twisted reactor to use. | |||
""" | |||
_reactor = attr.ib() | |||
_reactor: IReactorTime = attr.ib() | |||
@defer.inlineCallbacks | |||
def sleep(self, seconds): | |||
d = defer.Deferred() | |||
@defer.inlineCallbacks # type: ignore[arg-type] # Issue in Twisted's type annotations | |||
def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]": | |||
d: defer.Deferred[float] = defer.Deferred() | |||
with context.PreserveLoggingContext(): | |||
self._reactor.callLater(seconds, d.callback, seconds) | |||
res = yield d | |||
return res | |||
def time(self): | |||
def time(self) -> float: | |||
"""Returns the current system time in seconds since epoch.""" | |||
return self._reactor.seconds() | |||
def time_msec(self): | |||
def time_msec(self) -> int: | |||
"""Returns the current system time in milliseconds since epoch.""" | |||
return int(self.time() * 1000) | |||
def looping_call(self, f, msec, *args, **kwargs): | |||
def looping_call(self, f: Callable, msec: float, *args, **kwargs) -> LoopingCall: | |||
"""Call a function repeatedly. | |||
Waits `msec` initially before calling `f` for the first time. | |||
@@ -102,8 +110,8 @@ class Clock: | |||
other than trivial, you probably want to wrap it in run_as_background_process. | |||
Args: | |||
f(function): The function to call repeatedly. | |||
msec(float): How long to wait between calls in milliseconds. | |||
f: The function to call repeatedly. | |||
msec: How long to wait between calls in milliseconds. | |||
*args: Postional arguments to pass to function. | |||
**kwargs: Key arguments to pass to function. | |||
""" | |||
@@ -113,7 +121,7 @@ class Clock: | |||
d.addErrback(log_failure, "Looping call died", consumeErrors=False) | |||
return call | |||
def call_later(self, delay, callback, *args, **kwargs): | |||
def call_later(self, delay, callback, *args, **kwargs) -> IDelayedCall: | |||
"""Call something later | |||
Note that the function will be called with no logcontext, so if it is anything | |||
@@ -133,7 +141,7 @@ class Clock: | |||
with context.PreserveLoggingContext(): | |||
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) | |||
def cancel_call_later(self, timer, ignore_errs=False): | |||
def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None: | |||
try: | |||
timer.cancel() | |||
except Exception: | |||
@@ -37,6 +37,7 @@ import attr | |||
from typing_extensions import ContextManager | |||
from twisted.internet import defer | |||
from twisted.internet.base import ReactorBase | |||
from twisted.internet.defer import CancelledError | |||
from twisted.internet.interfaces import IReactorTime | |||
from twisted.python import failure | |||
@@ -268,6 +269,7 @@ class Linearizer: | |||
if not clock: | |||
from twisted.internet import reactor | |||
assert isinstance(reactor, ReactorBase) | |||
clock = Clock(reactor) | |||
self._clock = clock | |||
self.max_count = max_count | |||
@@ -411,7 +413,7 @@ class ReadWriteLock: | |||
# writers and readers have been resolved. The new writer replaces the latest | |||
# writer. | |||
def __init__(self): | |||
def __init__(self) -> None: | |||
# Latest readers queued | |||
self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {} | |||
@@ -503,7 +505,7 @@ def timeout_deferred( | |||
timed_out = [False] | |||
def time_it_out(): | |||
def time_it_out() -> None: | |||
timed_out[0] = True | |||
try: | |||
@@ -550,19 +552,21 @@ def timeout_deferred( | |||
return new_d | |||
# This class can't be generic because it uses slots with attrs. | |||
# See: https://github.com/python-attrs/attrs/issues/313 | |||
@attr.s(slots=True, frozen=True) | |||
class DoneAwaitable: | |||
class DoneAwaitable: # should be: Generic[R] | |||
"""Simple awaitable that returns the provided value.""" | |||
value = attr.ib() | |||
value = attr.ib(type=Any) # should be: R | |||
def __await__(self): | |||
return self | |||
def __iter__(self): | |||
def __iter__(self) -> "DoneAwaitable": | |||
return self | |||
def __next__(self): | |||
def __next__(self) -> None: | |||
raise StopIteration(self.value) | |||
@@ -122,7 +122,7 @@ class BatchingQueue(Generic[V, R]): | |||
# First we create a defer and add it and the value to the list of | |||
# pending items. | |||
d = defer.Deferred() | |||
d: defer.Deferred[R] = defer.Deferred() | |||
self._next_values.setdefault(key, []).append((value, d)) | |||
# If we're not currently processing the key fire off a background | |||
@@ -64,32 +64,32 @@ class CacheMetric: | |||
evicted_size = attr.ib(default=0) | |||
memory_usage = attr.ib(default=None) | |||
def inc_hits(self): | |||
def inc_hits(self) -> None: | |||
self.hits += 1 | |||
def inc_misses(self): | |||
def inc_misses(self) -> None: | |||
self.misses += 1 | |||
def inc_evictions(self, size=1): | |||
def inc_evictions(self, size: int = 1) -> None: | |||
self.evicted_size += size | |||
def inc_memory_usage(self, memory: int): | |||
def inc_memory_usage(self, memory: int) -> None: | |||
if self.memory_usage is None: | |||
self.memory_usage = 0 | |||
self.memory_usage += memory | |||
def dec_memory_usage(self, memory: int): | |||
def dec_memory_usage(self, memory: int) -> None: | |||
self.memory_usage -= memory | |||
def clear_memory_usage(self): | |||
def clear_memory_usage(self) -> None: | |||
if self.memory_usage is not None: | |||
self.memory_usage = 0 | |||
def describe(self): | |||
return [] | |||
def collect(self): | |||
def collect(self) -> None: | |||
try: | |||
if self._cache_type == "response_cache": | |||
response_cache_size.labels(self._cache_name).set(len(self._cache)) | |||
@@ -93,7 +93,7 @@ class DeferredCache(Generic[KT, VT]): | |||
TreeCache, "MutableMapping[KT, CacheEntry]" | |||
] = cache_type() | |||
def metrics_cb(): | |||
def metrics_cb() -> None: | |||
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) | |||
# cache is used for completed results and maps to the result itself, rather than | |||
@@ -113,7 +113,7 @@ class DeferredCache(Generic[KT, VT]): | |||
def max_entries(self): | |||
return self.cache.max_size | |||
def check_thread(self): | |||
def check_thread(self) -> None: | |||
expected_thread = self.thread | |||
if expected_thread is None: | |||
self.thread = threading.current_thread() | |||
@@ -235,7 +235,7 @@ class DeferredCache(Generic[KT, VT]): | |||
self._pending_deferred_cache[key] = entry | |||
def compare_and_pop(): | |||
def compare_and_pop() -> bool: | |||
"""Check if our entry is still the one in _pending_deferred_cache, and | |||
if so, pop it. | |||
@@ -256,7 +256,7 @@ class DeferredCache(Generic[KT, VT]): | |||
return False | |||
def cb(result): | |||
def cb(result) -> None: | |||
if compare_and_pop(): | |||
self.cache.set(key, result, entry.callbacks) | |||
else: | |||
@@ -268,7 +268,7 @@ class DeferredCache(Generic[KT, VT]): | |||
# not have been. Either way, let's double-check now. | |||
entry.invalidate() | |||
def eb(_fail): | |||
def eb(_fail) -> None: | |||
compare_and_pop() | |||
entry.invalidate() | |||
@@ -314,7 +314,7 @@ class DeferredCache(Generic[KT, VT]): | |||
for entry in iterate_tree_cache_entry(entry): | |||
entry.invalidate() | |||
def invalidate_all(self): | |||
def invalidate_all(self) -> None: | |||
self.check_thread() | |||
self.cache.clear() | |||
for entry in self._pending_deferred_cache.values(): | |||
@@ -332,7 +332,7 @@ class CacheEntry: | |||
self.callbacks = set(callbacks) | |||
self.invalidated = False | |||
def invalidate(self): | |||
def invalidate(self) -> None: | |||
if not self.invalidated: | |||
self.invalidated = True | |||
for callback in self.callbacks: | |||
@@ -27,10 +27,14 @@ logger = logging.getLogger(__name__) | |||
KT = TypeVar("KT") | |||
# The type of the dictionary keys. | |||
DKT = TypeVar("DKT") | |||
# The type of the dictionary values. | |||
DV = TypeVar("DV") | |||
# This class can't be generic because it uses slots with attrs. | |||
# See: https://github.com/python-attrs/attrs/issues/313 | |||
@attr.s(slots=True) | |||
class DictionaryEntry: | |||
class DictionaryEntry: # should be: Generic[DKT, DV]. | |||
"""Returned when getting an entry from the cache | |||
Attributes: | |||
@@ -43,10 +47,10 @@ class DictionaryEntry: | |||
""" | |||
full = attr.ib(type=bool) | |||
known_absent = attr.ib() | |||
value = attr.ib() | |||
known_absent = attr.ib(type=Set[Any]) # should be: Set[DKT] | |||
value = attr.ib(type=Dict[Any, Any]) # should be: Dict[DKT, DV] | |||
def __len__(self): | |||
def __len__(self) -> int: | |||
return len(self.value) | |||
@@ -56,7 +60,7 @@ class _Sentinel(enum.Enum): | |||
sentinel = object() | |||
class DictionaryCache(Generic[KT, DKT]): | |||
class DictionaryCache(Generic[KT, DKT, DV]): | |||
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e. | |||
fetching a subset of dictionary keys for a particular key. | |||
""" | |||
@@ -87,7 +91,7 @@ class DictionaryCache(Generic[KT, DKT]): | |||
Args: | |||
key | |||
dict_key: If given a set of keys then return only those keys | |||
dict_keys: If given a set of keys then return only those keys | |||
that exist in the cache. | |||
Returns: | |||
@@ -125,7 +129,7 @@ class DictionaryCache(Generic[KT, DKT]): | |||
self, | |||
sequence: int, | |||
key: KT, | |||
value: Dict[DKT, Any], | |||
value: Dict[DKT, DV], | |||
fetched_keys: Optional[Set[DKT]] = None, | |||
) -> None: | |||
"""Updates the entry in the cache | |||
@@ -151,15 +155,15 @@ class DictionaryCache(Generic[KT, DKT]): | |||
self._update_or_insert(key, value, fetched_keys) | |||
def _update_or_insert( | |||
self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT] | |||
self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT] | |||
) -> None: | |||
# We pop and reinsert as we need to tell the cache the size may have | |||
# changed | |||
entry = self.cache.pop(key, DictionaryEntry(False, set(), {})) | |||
entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {})) | |||
entry.value.update(value) | |||
entry.known_absent.update(known_absent) | |||
self.cache[key] = entry | |||
def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None: | |||
def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None: | |||
self.cache[key] = DictionaryEntry(True, known_absent, value) |
@@ -35,6 +35,7 @@ from typing import ( | |||
from typing_extensions import Literal | |||
from twisted.internet import reactor | |||
from twisted.internet.interfaces import IReactorTime | |||
from synapse.config import cache as cache_config | |||
from synapse.metrics.background_process_metrics import wrap_as_background_process | |||
@@ -341,7 +342,7 @@ class LruCache(Generic[KT, VT]): | |||
# Default `clock` to something sensible. Note that we rename it to | |||
# `real_clock` so that mypy doesn't think its still `Optional`. | |||
if clock is None: | |||
real_clock = Clock(reactor) | |||
real_clock = Clock(cast(IReactorTime, reactor)) | |||
else: | |||
real_clock = clock | |||
@@ -384,7 +385,7 @@ class LruCache(Generic[KT, VT]): | |||
lock = threading.Lock() | |||
def evict(): | |||
def evict() -> None: | |||
while cache_len() > self.max_size: | |||
# Get the last node in the list (i.e. the oldest node). | |||
todelete = list_root.prev_node | |||
@@ -195,7 +195,7 @@ class StreamChangeCache: | |||
for entity in r: | |||
del self._entity_to_key[entity] | |||
def _evict(self): | |||
def _evict(self) -> None: | |||
while len(self._cache) > self._max_size: | |||
k, r = self._cache.popitem(0) | |||
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) | |||
@@ -35,17 +35,17 @@ class TreeCache: | |||
root = {key_1: {key_2: _value}} | |||
""" | |||
def __init__(self): | |||
self.size = 0 | |||
def __init__(self) -> None: | |||
self.size: int = 0 | |||
self.root = TreeCacheNode() | |||
def __setitem__(self, key, value): | |||
return self.set(key, value) | |||
def __setitem__(self, key, value) -> None: | |||
self.set(key, value) | |||
def __contains__(self, key): | |||
def __contains__(self, key) -> bool: | |||
return self.get(key, SENTINEL) is not SENTINEL | |||
def set(self, key, value): | |||
def set(self, key, value) -> None: | |||
if isinstance(value, TreeCacheNode): | |||
# this would mean we couldn't tell where our tree ended and the value | |||
# started. | |||
@@ -73,7 +73,7 @@ class TreeCache: | |||
return default | |||
return node.get(key[-1], default) | |||
def clear(self): | |||
def clear(self) -> None: | |||
self.size = 0 | |||
self.root = TreeCacheNode() | |||
@@ -128,7 +128,7 @@ class TreeCache: | |||
def values(self): | |||
return iterate_tree_cache_entry(self.root) | |||
def __len__(self): | |||
def __len__(self) -> int: | |||
return self.size | |||
@@ -126,7 +126,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - | |||
signal.signal(signal.SIGTERM, sigterm) | |||
# Cleanup pid file at exit. | |||
def exit(): | |||
def exit() -> None: | |||
logger.warning("Stopping daemon.") | |||
os.remove(pid_file) | |||
sys.exit(0) | |||
@@ -12,6 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import Any, Callable, Dict, List | |||
from twisted.internet import defer | |||
@@ -37,11 +38,11 @@ class Distributor: | |||
model will do for today. | |||
""" | |||
def __init__(self): | |||
self.signals = {} | |||
self.pre_registration = {} | |||
def __init__(self) -> None: | |||
self.signals: Dict[str, Signal] = {} | |||
self.pre_registration: Dict[str, List[Callable]] = {} | |||
def declare(self, name): | |||
def declare(self, name: str) -> None: | |||
if name in self.signals: | |||
raise KeyError("%r already has a signal named %s" % (self, name)) | |||
@@ -52,7 +53,7 @@ class Distributor: | |||
for observer in self.pre_registration[name]: | |||
signal.observe(observer) | |||
def observe(self, name, observer): | |||
def observe(self, name: str, observer: Callable) -> None: | |||
if name in self.signals: | |||
self.signals[name].observe(observer) | |||
else: | |||
@@ -62,7 +63,7 @@ class Distributor: | |||
self.pre_registration[name] = [] | |||
self.pre_registration[name].append(observer) | |||
def fire(self, name, *args, **kwargs): | |||
def fire(self, name: str, *args, **kwargs) -> None: | |||
"""Dispatches the given signal to the registered observers. | |||
Runs the observers as a background process. Does not return a deferred. | |||
@@ -83,18 +84,18 @@ class Signal: | |||
method into all of the observers. | |||
""" | |||
def __init__(self, name): | |||
self.name = name | |||
self.observers = [] | |||
def __init__(self, name: str): | |||
self.name: str = name | |||
self.observers: List[Callable] = [] | |||
def observe(self, observer): | |||
def observe(self, observer: Callable) -> None: | |||
"""Adds a new callable to the observer list which will be invoked by | |||
the 'fire' method. | |||
Each observer callable may return a Deferred.""" | |||
self.observers.append(observer) | |||
def fire(self, *args, **kwargs): | |||
def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]": | |||
"""Invokes every callable in the observer list, passing in the args and | |||
kwargs. Exceptions thrown by observers are logged but ignored. It is | |||
not an error to fire a signal with no observers. | |||
@@ -13,10 +13,14 @@ | |||
# limitations under the License. | |||
import queue | |||
from typing import BinaryIO, Optional, Union, cast | |||
from twisted.internet import threads | |||
from twisted.internet.defer import Deferred | |||
from twisted.internet.interfaces import IPullProducer, IPushProducer | |||
from synapse.logging.context import make_deferred_yieldable, run_in_background | |||
from synapse.types import ISynapseReactor | |||
class BackgroundFileConsumer: | |||
@@ -24,9 +28,9 @@ class BackgroundFileConsumer: | |||
and pull producers | |||
Args: | |||
file_obj (file): The file like object to write to. Closed when | |||
file_obj: The file like object to write to. Closed when | |||
finished. | |||
reactor (twisted.internet.reactor): the Twisted reactor to use | |||
reactor: the Twisted reactor to use | |||
""" | |||
# For PushProducers pause if we have this many unwritten slices | |||
@@ -34,13 +38,13 @@ class BackgroundFileConsumer: | |||
# And resume once the size of the queue is less than this | |||
_RESUME_ON_QUEUE_SIZE = 2 | |||
def __init__(self, file_obj, reactor): | |||
self._file_obj = file_obj | |||
def __init__(self, file_obj: BinaryIO, reactor: ISynapseReactor) -> None: | |||
self._file_obj: BinaryIO = file_obj | |||
self._reactor = reactor | |||
self._reactor: ISynapseReactor = reactor | |||
# Producer we're registered with | |||
self._producer = None | |||
self._producer: Optional[Union[IPushProducer, IPullProducer]] = None | |||
# True if PushProducer, false if PullProducer | |||
self.streaming = False | |||
@@ -51,20 +55,22 @@ class BackgroundFileConsumer: | |||
# Queue of slices of bytes to be written. When producer calls | |||
# unregister a final None is sent. | |||
self._bytes_queue = queue.Queue() | |||
self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue() | |||
# Deferred that is resolved when finished writing | |||
self._finished_deferred = None | |||
self._finished_deferred: Optional[Deferred[None]] = None | |||
# If the _writer thread throws an exception it gets stored here. | |||
self._write_exception = None | |||
self._write_exception: Optional[Exception] = None | |||
def registerProducer(self, producer, streaming): | |||
def registerProducer( | |||
self, producer: Union[IPushProducer, IPullProducer], streaming: bool | |||
) -> None: | |||
"""Part of IConsumer interface | |||
Args: | |||
producer (IProducer) | |||
streaming (bool): True if push based producer, False if pull | |||
producer | |||
streaming: True if push based producer, False if pull | |||
based. | |||
""" | |||
if self._producer: | |||
@@ -81,29 +87,33 @@ class BackgroundFileConsumer: | |||
if not streaming: | |||
self._producer.resumeProducing() | |||
def unregisterProducer(self): | |||
def unregisterProducer(self) -> None: | |||
"""Part of IProducer interface""" | |||
self._producer = None | |||
assert self._finished_deferred is not None | |||
if not self._finished_deferred.called: | |||
self._bytes_queue.put_nowait(None) | |||
def write(self, bytes): | |||
def write(self, write_bytes: bytes) -> None: | |||
"""Part of IProducer interface""" | |||
if self._write_exception: | |||
raise self._write_exception | |||
assert self._finished_deferred is not None | |||
if self._finished_deferred.called: | |||
raise Exception("consumer has closed") | |||
self._bytes_queue.put_nowait(bytes) | |||
self._bytes_queue.put_nowait(write_bytes) | |||
# If this is a PushProducer and the queue is getting behind | |||
# then we pause the producer. | |||
if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: | |||
self._paused_producer = True | |||
self._producer.pauseProducing() | |||
assert self._producer is not None | |||
# cast safe because `streaming` means this is an IPushProducer | |||
cast(IPushProducer, self._producer).pauseProducing() | |||
def _writer(self): | |||
def _writer(self) -> None: | |||
"""This is run in a background thread to write to the file.""" | |||
try: | |||
while self._producer or not self._bytes_queue.empty(): | |||
@@ -130,11 +140,11 @@ class BackgroundFileConsumer: | |||
finally: | |||
self._file_obj.close() | |||
def wait(self): | |||
def wait(self) -> "Deferred[None]": | |||
"""Returns a deferred that resolves when finished writing to file""" | |||
return make_deferred_yieldable(self._finished_deferred) | |||
def _resume_paused_producer(self): | |||
def _resume_paused_producer(self) -> None: | |||
"""Gets called if we should resume producing after being paused""" | |||
if self._paused_producer and self._producer: | |||
self._paused_producer = False | |||
@@ -11,11 +11,12 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any | |||
from frozendict import frozendict | |||
def freeze(o): | |||
def freeze(o: Any) -> Any: | |||
if isinstance(o, dict): | |||
return frozendict({k: freeze(v) for k, v in o.items()}) | |||
@@ -33,7 +34,7 @@ def freeze(o): | |||
return o | |||
def unfreeze(o): | |||
def unfreeze(o: Any) -> Any: | |||
if isinstance(o, (dict, frozendict)): | |||
return {k: unfreeze(v) for k, v in o.items()} | |||
@@ -13,42 +13,43 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Dict | |||
from twisted.web.resource import NoResource | |||
from twisted.web.resource import NoResource, Resource | |||
logger = logging.getLogger(__name__) | |||
def create_resource_tree(desired_tree, root_resource): | |||
def create_resource_tree( | |||
desired_tree: Dict[str, Resource], root_resource: Resource | |||
) -> Resource: | |||
"""Create the resource tree for this homeserver. | |||
This in unduly complicated because Twisted does not support putting | |||
child resources more than 1 level deep at a time. | |||
Args: | |||
web_client (bool): True to enable the web client. | |||
root_resource (twisted.web.resource.Resource): The root | |||
resource to add the tree to. | |||
desired_tree: Dict from desired paths to desired resources. | |||
root_resource: The root resource to add the tree to. | |||
Returns: | |||
twisted.web.resource.Resource: the ``root_resource`` with a tree of | |||
child resources added to it. | |||
The ``root_resource`` with a tree of child resources added to it. | |||
""" | |||
# ideally we'd just use getChild and putChild but getChild doesn't work | |||
# unless you give it a Request object IN ADDITION to the name :/ So | |||
# instead, we'll store a copy of this mapping so we can actually add | |||
# extra resources to existing nodes. See self._resource_id for the key. | |||
resource_mappings = {} | |||
for full_path, res in desired_tree.items(): | |||
resource_mappings: Dict[str, Resource] = {} | |||
for full_path_str, res in desired_tree.items(): | |||
# twisted requires all resources to be bytes | |||
full_path = full_path.encode("utf-8") | |||
full_path = full_path_str.encode("utf-8") | |||
logger.info("Attaching %s to path %s", res, full_path) | |||
last_resource = root_resource | |||
for path_seg in full_path.split(b"/")[1:-1]: | |||
if path_seg not in last_resource.listNames(): | |||
# resource doesn't exist, so make a "dummy resource" | |||
child_resource = NoResource() | |||
child_resource: Resource = NoResource() | |||
last_resource.putChild(path_seg, child_resource) | |||
res_id = _resource_id(last_resource, path_seg) | |||
resource_mappings[res_id] = child_resource | |||
@@ -83,7 +84,7 @@ def create_resource_tree(desired_tree, root_resource): | |||
return root_resource | |||
def _resource_id(resource, path_seg): | |||
def _resource_id(resource: Resource, path_seg: bytes) -> str: | |||
"""Construct an arbitrary resource ID so you can retrieve the mapping | |||
later. | |||
@@ -96,4 +97,4 @@ def _resource_id(resource, path_seg): | |||
Returns: | |||
str: A unique string which can be a key to the child Resource. | |||
""" | |||
return "%s-%s" % (resource, path_seg) | |||
return "%s-%r" % (resource, path_seg) |
@@ -74,7 +74,7 @@ class ListNode(Generic[P]): | |||
new_node._refs_insert_after(node) | |||
return new_node | |||
def remove_from_list(self): | |||
def remove_from_list(self) -> None: | |||
"""Remove this node from the list.""" | |||
with self._LOCK: | |||
self._refs_remove_node_from_list() | |||
@@ -84,7 +84,7 @@ class ListNode(Generic[P]): | |||
# immediately rather than at the next GC. | |||
self.cache_entry = None | |||
def move_after(self, node: "ListNode"): | |||
def move_after(self, node: "ListNode") -> None: | |||
"""Move this node from its current location in the list to after the | |||
given node. | |||
""" | |||
@@ -103,7 +103,7 @@ class ListNode(Generic[P]): | |||
# Insert self back into the list, after target node | |||
self._refs_insert_after(node) | |||
def _refs_remove_node_from_list(self): | |||
def _refs_remove_node_from_list(self) -> None: | |||
"""Internal method to *just* remove the node from the list, without | |||
e.g. clearing out the cache entry. | |||
""" | |||
@@ -122,7 +122,7 @@ class ListNode(Generic[P]): | |||
self.prev_node = None | |||
self.next_node = None | |||
def _refs_insert_after(self, node: "ListNode"): | |||
def _refs_insert_after(self, node: "ListNode") -> None: | |||
"""Internal method to insert the node after the given node.""" | |||
# This method should only be called when we're not already in the list. | |||
@@ -77,7 +77,7 @@ def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> N | |||
should be considered expired. Normally the current time. | |||
""" | |||
def verify_expiry_caveat(caveat: str): | |||
def verify_expiry_caveat(caveat: str) -> bool: | |||
time_msec = get_time_ms() | |||
prefix = "time < " | |||
if not caveat.startswith(prefix): | |||
@@ -15,6 +15,7 @@ | |||
import inspect | |||
import sys | |||
import traceback | |||
from typing import Any, Dict, Optional | |||
from twisted.conch import manhole_ssh | |||
from twisted.conch.insults import insults | |||
@@ -22,6 +23,9 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter | |||
from twisted.conch.ssh.keys import Key | |||
from twisted.cred import checkers, portal | |||
from twisted.internet import defer | |||
from twisted.internet.protocol import Factory | |||
from synapse.config.server import ManholeConfig | |||
PUBLIC_KEY = ( | |||
"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5" | |||
@@ -61,22 +65,22 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs= | |||
-----END RSA PRIVATE KEY-----""" | |||
def manhole(settings, globals): | |||
def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory: | |||
"""Starts a ssh listener with password authentication using | |||
the given username and password. Clients connecting to the ssh | |||
listener will find themselves in a colored python shell with | |||
the supplied globals. | |||
Args: | |||
username(str): The username ssh clients should auth with. | |||
password(str): The password ssh clients should auth with. | |||
globals(dict): The variables to expose in the shell. | |||
username: The username ssh clients should auth with. | |||
password: The password ssh clients should auth with. | |||
globals: The variables to expose in the shell. | |||
Returns: | |||
twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` | |||
A factory to pass to ``listenTCP`` | |||
""" | |||
username = settings.username | |||
password = settings.password | |||
password = settings.password.encode("ascii") | |||
priv_key = settings.priv_key | |||
if priv_key is None: | |||
priv_key = Key.fromString(PRIVATE_KEY) | |||
@@ -84,19 +88,22 @@ def manhole(settings, globals): | |||
if pub_key is None: | |||
pub_key = Key.fromString(PUBLIC_KEY) | |||
if not isinstance(password, bytes): | |||
password = password.encode("ascii") | |||
checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password}) | |||
rlm = manhole_ssh.TerminalRealm() | |||
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( | |||
# mypy ignored here because: | |||
# - can't deduce types of lambdas | |||
# - variable is Type[ServerProtocol], expr is Callable[[], ServerProtocol] | |||
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( # type: ignore[misc,assignment] | |||
SynapseManhole, dict(globals, __name__="__console__") | |||
) | |||
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) | |||
factory.privateKeys[b"ssh-rsa"] = priv_key | |||
factory.publicKeys[b"ssh-rsa"] = pub_key | |||
# conch has the wrong type on these dicts (says bytes to bytes, | |||
# should be bytes to Keys judging by how it's used). | |||
factory.privateKeys[b"ssh-rsa"] = priv_key # type: ignore[assignment] | |||
factory.publicKeys[b"ssh-rsa"] = pub_key # type: ignore[assignment] | |||
return factory | |||
@@ -104,7 +111,7 @@ def manhole(settings, globals): | |||
class SynapseManhole(ColoredManhole): | |||
"""Overrides connectionMade to create our own ManholeInterpreter""" | |||
def connectionMade(self): | |||
def connectionMade(self) -> None: | |||
super().connectionMade() | |||
# replace the manhole interpreter with our own impl | |||
@@ -114,13 +121,14 @@ class SynapseManhole(ColoredManhole): | |||
class SynapseManholeInterpreter(ManholeInterpreter): | |||
def showsyntaxerror(self, filename=None): | |||
def showsyntaxerror(self, filename: Optional[str] = None) -> None: | |||
"""Display the syntax error that just occurred. | |||
Overrides the base implementation, ignoring sys.excepthook. We always want | |||
any syntax errors to be sent to the terminal, rather than sentry. | |||
""" | |||
type, value, tb = sys.exc_info() | |||
assert value is not None | |||
sys.last_type = type | |||
sys.last_value = value | |||
sys.last_traceback = tb | |||
@@ -138,7 +146,7 @@ class SynapseManholeInterpreter(ManholeInterpreter): | |||
lines = traceback.format_exception_only(type, value) | |||
self.write("".join(lines)) | |||
def showtraceback(self): | |||
def showtraceback(self) -> None: | |||
"""Display the exception that just occurred. | |||
Overrides the base implementation, ignoring sys.excepthook. We always want | |||
@@ -146,14 +154,22 @@ class SynapseManholeInterpreter(ManholeInterpreter): | |||
""" | |||
sys.last_type, sys.last_value, last_tb = ei = sys.exc_info() | |||
sys.last_traceback = last_tb | |||
assert last_tb is not None | |||
try: | |||
# We remove the first stack item because it is our own code. | |||
lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) | |||
self.write("".join(lines)) | |||
finally: | |||
last_tb = ei = None | |||
def displayhook(self, obj): | |||
# On the line below, last_tb and ei appear to be dead. | |||
# It's unclear whether there is a reason behind this line. | |||
# It conceivably could be because an exception raised in this block | |||
# will keep the local frame (containing these local variables) around. | |||
# This was adapted taken from CPython's Lib/code.py; see here: | |||
# https://github.com/python/cpython/blob/4dc4300c686f543d504ab6fa9fe600eaf11bb695/Lib/code.py#L131-L150 | |||
last_tb = ei = None # type: ignore | |||
def displayhook(self, obj: Any) -> None: | |||
""" | |||
We override the displayhook so that we automatically convert coroutines | |||
into Deferreds. (Our superclass' displayhook will take care of the rest, | |||
@@ -24,7 +24,7 @@ from twisted.python.failure import Failure | |||
_already_patched = False | |||
def do_patch(): | |||
def do_patch() -> None: | |||
""" | |||
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit | |||
""" | |||
@@ -107,7 +107,7 @@ def do_patch(): | |||
_already_patched = True | |||
def _check_yield_points(f: Callable, changes: List[str]): | |||
def _check_yield_points(f: Callable, changes: List[str]) -> Callable: | |||
"""Wraps a generator that is about to be passed to defer.inlineCallbacks | |||
checking that after every yield the log contexts are correct. | |||
@@ -15,33 +15,36 @@ | |||
import collections | |||
import contextlib | |||
import logging | |||
import typing | |||
from typing import Any, DefaultDict, Iterator, List, Set | |||
from twisted.internet import defer | |||
from synapse.api.errors import LimitExceededError | |||
from synapse.config.ratelimiting import FederationRateLimitConfig | |||
from synapse.logging.context import ( | |||
PreserveLoggingContext, | |||
make_deferred_yieldable, | |||
run_in_background, | |||
) | |||
from synapse.util import Clock | |||
if typing.TYPE_CHECKING: | |||
from contextlib import _GeneratorContextManager | |||
logger = logging.getLogger(__name__) | |||
class FederationRateLimiter: | |||
def __init__(self, clock, config): | |||
""" | |||
Args: | |||
clock (Clock) | |||
config (FederationRateLimitConfig) | |||
""" | |||
def new_limiter(): | |||
def __init__(self, clock: Clock, config: FederationRateLimitConfig): | |||
def new_limiter() -> "_PerHostRatelimiter": | |||
return _PerHostRatelimiter(clock=clock, config=config) | |||
self.ratelimiters = collections.defaultdict(new_limiter) | |||
self.ratelimiters: DefaultDict[ | |||
str, "_PerHostRatelimiter" | |||
] = collections.defaultdict(new_limiter) | |||
def ratelimit(self, host): | |||
def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]": | |||
"""Used to ratelimit an incoming request from a given host | |||
Example usage: | |||
@@ -60,11 +63,11 @@ class FederationRateLimiter: | |||
class _PerHostRatelimiter: | |||
def __init__(self, clock, config): | |||
def __init__(self, clock: Clock, config: FederationRateLimitConfig): | |||
""" | |||
Args: | |||
clock (Clock) | |||
config (FederationRateLimitConfig) | |||
clock | |||
config | |||
""" | |||
self.clock = clock | |||
@@ -75,21 +78,23 @@ class _PerHostRatelimiter: | |||
self.concurrent_requests = config.concurrent | |||
# request_id objects for requests which have been slept | |||
self.sleeping_requests = set() | |||
self.sleeping_requests: Set[object] = set() | |||
# map from request_id object to Deferred for requests which are ready | |||
# for processing but have been queued | |||
self.ready_request_queue = collections.OrderedDict() | |||
self.ready_request_queue: collections.OrderedDict[ | |||
object, defer.Deferred[None] | |||
] = collections.OrderedDict() | |||
# request id objects for requests which are in progress | |||
self.current_processing = set() | |||
self.current_processing: Set[object] = set() | |||
# times at which we have recently (within the last window_size ms) | |||
# received requests. | |||
self.request_times = [] | |||
self.request_times: List[int] = [] | |||
@contextlib.contextmanager | |||
def ratelimit(self): | |||
def ratelimit(self) -> "Iterator[defer.Deferred[None]]": | |||
# `contextlib.contextmanager` takes a generator and turns it into a | |||
# context manager. The generator should only yield once with a value | |||
# to be returned by manager. | |||
@@ -102,7 +107,7 @@ class _PerHostRatelimiter: | |||
finally: | |||
self._on_exit(request_id) | |||
def _on_enter(self, request_id): | |||
def _on_enter(self, request_id: object) -> "defer.Deferred[None]": | |||
time_now = self.clock.time_msec() | |||
# remove any entries from request_times which aren't within the window | |||
@@ -120,9 +125,9 @@ class _PerHostRatelimiter: | |||
self.request_times.append(time_now) | |||
def queue_request(): | |||
def queue_request() -> "defer.Deferred[None]": | |||
if len(self.current_processing) >= self.concurrent_requests: | |||
queue_defer = defer.Deferred() | |||
queue_defer: defer.Deferred[None] = defer.Deferred() | |||
self.ready_request_queue[request_id] = queue_defer | |||
logger.info( | |||
"Ratelimiter: queueing request (queue now %i items)", | |||
@@ -145,7 +150,7 @@ class _PerHostRatelimiter: | |||
self.sleeping_requests.add(request_id) | |||
def on_wait_finished(_): | |||
def on_wait_finished(_: Any) -> "defer.Deferred[None]": | |||
logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id)) | |||
self.sleeping_requests.discard(request_id) | |||
queue_defer = queue_request() | |||
@@ -155,19 +160,19 @@ class _PerHostRatelimiter: | |||
else: | |||
ret_defer = queue_request() | |||
def on_start(r): | |||
def on_start(r: object) -> object: | |||
logger.debug("Ratelimit [%s]: Processing req", id(request_id)) | |||
self.current_processing.add(request_id) | |||
return r | |||
def on_err(r): | |||
def on_err(r: object) -> object: | |||
# XXX: why is this necessary? this is called before we start | |||
# processing the request so why would the request be in | |||
# current_processing? | |||
self.current_processing.discard(request_id) | |||
return r | |||
def on_both(r): | |||
def on_both(r: object) -> object: | |||
# Ensure that we've properly cleaned up. | |||
self.sleeping_requests.discard(request_id) | |||
self.ready_request_queue.pop(request_id, None) | |||
@@ -177,7 +182,7 @@ class _PerHostRatelimiter: | |||
ret_defer.addBoth(on_both) | |||
return make_deferred_yieldable(ret_defer) | |||
def _on_exit(self, request_id): | |||
def _on_exit(self, request_id: object) -> None: | |||
logger.debug("Ratelimit [%s]: Processed req", id(request_id)) | |||
self.current_processing.discard(request_id) | |||
try: | |||
@@ -13,9 +13,13 @@ | |||
# limitations under the License. | |||
import logging | |||
import random | |||
from types import TracebackType | |||
from typing import Any, Optional, Type | |||
import synapse.logging.context | |||
from synapse.api.errors import CodeMessageException | |||
from synapse.storage import DataStore | |||
from synapse.util import Clock | |||
logger = logging.getLogger(__name__) | |||
@@ -30,17 +34,17 @@ MAX_RETRY_INTERVAL = 2 ** 62 | |||
class NotRetryingDestination(Exception): | |||
def __init__(self, retry_last_ts, retry_interval, destination): | |||
def __init__(self, retry_last_ts: int, retry_interval: int, destination: str): | |||
"""Raised by the limiter (and federation client) to indicate that we are | |||
are deliberately not attempting to contact a given server. | |||
Args: | |||
retry_last_ts (int): the unix ts in milliseconds of our last attempt | |||
retry_last_ts: the unix ts in milliseconds of our last attempt | |||
to contact the server. 0 indicates that the last attempt was | |||
successful or that we've never actually attempted to connect. | |||
retry_interval (int): the time in milliseconds to wait until the next | |||
retry_interval: the time in milliseconds to wait until the next | |||
attempt. | |||
destination (str): the domain in question | |||
destination: the domain in question | |||
""" | |||
msg = "Not retrying server %s." % (destination,) | |||
@@ -51,7 +55,13 @@ class NotRetryingDestination(Exception): | |||
self.destination = destination | |||
async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs): | |||
async def get_retry_limiter( | |||
destination: str, | |||
clock: Clock, | |||
store: DataStore, | |||
ignore_backoff: bool = False, | |||
**kwargs: Any, | |||
) -> "RetryDestinationLimiter": | |||
"""For a given destination check if we have previously failed to | |||
send a request there and are waiting before retrying the destination. | |||
If we are not ready to retry the destination, this will raise a | |||
@@ -60,10 +70,10 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k | |||
CodeMessageException with code < 500) | |||
Args: | |||
destination (str): name of homeserver | |||
clock (synapse.util.clock): timing source | |||
store (synapse.storage.transactions.TransactionStore): datastore | |||
ignore_backoff (bool): true to ignore the historical backoff data and | |||
destination: name of homeserver | |||
clock: timing source | |||
store: datastore | |||
ignore_backoff: true to ignore the historical backoff data and | |||
try the request anyway. We will still reset the retry_interval on success. | |||
Example usage: | |||
@@ -114,13 +124,13 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k | |||
class RetryDestinationLimiter: | |||
def __init__( | |||
self, | |||
destination, | |||
clock, | |||
store, | |||
failure_ts, | |||
retry_interval, | |||
backoff_on_404=False, | |||
backoff_on_failure=True, | |||
destination: str, | |||
clock: Clock, | |||
store: DataStore, | |||
failure_ts: Optional[int], | |||
retry_interval: int, | |||
backoff_on_404: bool = False, | |||
backoff_on_failure: bool = True, | |||
): | |||
"""Marks the destination as "down" if an exception is thrown in the | |||
context, except for CodeMessageException with code < 500. | |||
@@ -128,17 +138,17 @@ class RetryDestinationLimiter: | |||
If no exception is raised, marks the destination as "up". | |||
Args: | |||
destination (str) | |||
clock (Clock) | |||
store (DataStore) | |||
failure_ts (int|None): when this destination started failing (in ms since | |||
destination | |||
clock | |||
store | |||
failure_ts: when this destination started failing (in ms since | |||
the epoch), or zero if the last request was successful | |||
retry_interval (int): The next retry interval taken from the | |||
retry_interval: The next retry interval taken from the | |||
database in milliseconds, or zero if the last request was | |||
successful. | |||
backoff_on_404 (bool): Back off if we get a 404 | |||
backoff_on_404: Back off if we get a 404 | |||
backoff_on_failure (bool): set to False if we should not increase the | |||
backoff_on_failure: set to False if we should not increase the | |||
retry interval on a failure. | |||
""" | |||
self.clock = clock | |||
@@ -150,10 +160,15 @@ class RetryDestinationLimiter: | |||
self.backoff_on_404 = backoff_on_404 | |||
self.backoff_on_failure = backoff_on_failure | |||
def __enter__(self): | |||
def __enter__(self) -> None: | |||
pass | |||
def __exit__(self, exc_type, exc_val, exc_tb): | |||
def __exit__( | |||
self, | |||
exc_type: Optional[Type[BaseException]], | |||
exc_val: Optional[BaseException], | |||
exc_tb: Optional[TracebackType], | |||
) -> None: | |||
valid_err_code = False | |||
if exc_type is None: | |||
valid_err_code = True | |||
@@ -161,7 +176,7 @@ class RetryDestinationLimiter: | |||
# avoid treating exceptions which don't derive from Exception as | |||
# failures; this is mostly so as not to catch defer._DefGen. | |||
valid_err_code = True | |||
elif issubclass(exc_type, CodeMessageException): | |||
elif isinstance(exc_val, CodeMessageException): | |||
# Some error codes are perfectly fine for some APIs, whereas other | |||
# APIs may expect to never received e.g. a 404. It's important to | |||
# handle 404 as some remote servers will return a 404 when the HS | |||
@@ -216,7 +231,7 @@ class RetryDestinationLimiter: | |||
if self.failure_ts is None: | |||
self.failure_ts = retry_last_ts | |||
async def store_retry_timings(): | |||
async def store_retry_timings() -> None: | |||
try: | |||
await self.store.set_destination_retry_timings( | |||
self.destination, | |||
@@ -18,7 +18,7 @@ import resource | |||
logger = logging.getLogger("synapse.app.homeserver") | |||
def change_resource_limit(soft_file_no): | |||
def change_resource_limit(soft_file_no: int) -> None: | |||
try: | |||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) | |||
@@ -16,7 +16,7 @@ | |||
import time | |||
import urllib.parse | |||
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union | |||
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union | |||
import jinja2 | |||
@@ -25,9 +25,9 @@ if TYPE_CHECKING: | |||
def build_jinja_env( | |||
template_search_directories: Iterable[str], | |||
template_search_directories: Sequence[str], | |||
config: "HomeServerConfig", | |||
autoescape: Union[bool, Callable[[str], bool], None] = None, | |||
autoescape: Union[bool, Callable[[Optional[str]], bool], None] = None, | |||
) -> jinja2.Environment: | |||
"""Set up a Jinja2 environment to load templates from the given search path | |||
@@ -110,5 +110,5 @@ def _create_mxc_to_http_filter( | |||
return mxc_to_http_filter | |||
def _format_ts_filter(value: int, format: str): | |||
def _format_ts_filter(value: int, format: str) -> str: | |||
return time.strftime(format, time.localtime(value / 1000)) |
@@ -14,6 +14,10 @@ | |||
import logging | |||
import re | |||
import typing | |||
if typing.TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -28,13 +32,13 @@ logger = logging.getLogger(__name__) | |||
MAX_EMAIL_ADDRESS_LENGTH = 500 | |||
def check_3pid_allowed(hs, medium, address): | |||
def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool: | |||
"""Checks whether a given format of 3PID is allowed to be used on this HS | |||
Args: | |||
hs (synapse.server.HomeServer): server | |||
medium (str): 3pid medium - e.g. email, msisdn | |||
address (str): address within that medium (e.g. "wotan@matrix.org") | |||
hs: server | |||
medium: 3pid medium - e.g. email, msisdn | |||
address: address within that medium (e.g. "wotan@matrix.org") | |||
msisdns need to first have been canonicalised | |||
Returns: | |||
bool: whether the 3PID medium/address is allowed to be added to this HS | |||
@@ -19,7 +19,7 @@ import subprocess | |||
logger = logging.getLogger(__name__) | |||
def get_version_string(module): | |||
def get_version_string(module) -> str: | |||
"""Given a module calculate a git-aware version string for it. | |||
If called on a module not in a git checkout will return `__verison__`. | |||
@@ -11,38 +11,41 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Generic, List, TypeVar | |||
T = TypeVar("T") | |||
class _Entry: | |||
class _Entry(Generic[T]): | |||
__slots__ = ["end_key", "queue"] | |||
def __init__(self, end_key): | |||
self.end_key = end_key | |||
self.queue = [] | |||
def __init__(self, end_key: int) -> None: | |||
self.end_key: int = end_key | |||
self.queue: List[T] = [] | |||
class WheelTimer: | |||
class WheelTimer(Generic[T]): | |||
"""Stores arbitrary objects that will be returned after their timers have | |||
expired. | |||
""" | |||
def __init__(self, bucket_size=5000): | |||
def __init__(self, bucket_size: int = 5000) -> None: | |||
""" | |||
Args: | |||
bucket_size (int): Size of buckets in ms. Corresponds roughly to the | |||
bucket_size: Size of buckets in ms. Corresponds roughly to the | |||
accuracy of the timer. | |||
""" | |||
self.bucket_size = bucket_size | |||
self.entries = [] | |||
self.current_tick = 0 | |||
self.bucket_size: int = bucket_size | |||
self.entries: List[_Entry[T]] = [] | |||
self.current_tick: int = 0 | |||
def insert(self, now, obj, then): | |||
def insert(self, now: int, obj: T, then: int) -> None: | |||
"""Inserts object into timer. | |||
Args: | |||
now (int): Current time in msec | |||
obj (object): Object to be inserted | |||
then (int): When to return the object strictly after. | |||
now: Current time in msec | |||
obj: Object to be inserted | |||
then: When to return the object strictly after. | |||
""" | |||
then_key = int(then / self.bucket_size) + 1 | |||
@@ -70,7 +73,7 @@ class WheelTimer: | |||
self.entries[-1].queue.append(obj) | |||
def fetch(self, now): | |||
def fetch(self, now: int) -> List[T]: | |||
"""Fetch any objects that have timed out | |||
Args: | |||
@@ -87,5 +90,5 @@ class WheelTimer: | |||
return ret | |||
def __len__(self): | |||
def __len__(self) -> int: | |||
return sum(len(entry.queue) for entry in self.entries) |
@@ -734,9 +734,9 @@ class TestTransportLayerServer(JsonResource): | |||
FederationRateLimitConfig( | |||
window_size=1, | |||
sleep_limit=1, | |||
sleep_msec=1, | |||
sleep_delay=1, | |||
reject_limit=1000, | |||
concurrent_requests=1000, | |||
concurrent=1000, | |||
), | |||
) | |||