@@ -0,0 +1 @@ | |||
Add type annotations to increase the number of modules passing `disallow-untyped-defs`. |
@@ -119,9 +119,18 @@ disallow_untyped_defs = True | |||
[mypy-synapse.federation.transport.client] | |||
disallow_untyped_defs = False | |||
[mypy-synapse.groups.*] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.handlers.*] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.http.federation.*] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.http.request_metrics] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.http.server] | |||
disallow_untyped_defs = True | |||
@@ -196,12 +205,27 @@ disallow_untyped_defs = True | |||
[mypy-synapse.storage.databases.main.state_deltas] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.storage.databases.main.stream] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.storage.databases.main.transactions] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.storage.databases.main.user_erasure_store] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.storage.prepare_database] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.storage.persist_events] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.storage.state] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.storage.types] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.storage.util.*] | |||
disallow_untyped_defs = True | |||
@@ -934,7 +934,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
# Before deleting the group lets kick everyone out of it | |||
users = await self.store.get_users_in_group(group_id, include_private=True) | |||
async def _kick_user_from_group(user_id): | |||
async def _kick_user_from_group(user_id: str) -> None: | |||
if self.hs.is_mine_id(user_id): | |||
groups_local = self.hs.get_groups_local_handler() | |||
assert isinstance( | |||
@@ -43,8 +43,10 @@ from twisted.internet import defer, error as twisted_error, protocol, ssl | |||
from twisted.internet.address import IPv4Address, IPv6Address | |||
from twisted.internet.interfaces import ( | |||
IAddress, | |||
IDelayedCall, | |||
IHostResolution, | |||
IReactorPluggableNameResolver, | |||
IReactorTime, | |||
IResolutionReceiver, | |||
ITCPTransport, | |||
) | |||
@@ -121,13 +123,15 @@ def check_against_blacklist( | |||
_EPSILON = 0.00000001 | |||
def _make_scheduler(reactor): | |||
def _make_scheduler( | |||
reactor: IReactorTime, | |||
) -> Callable[[Callable[[], object]], IDelayedCall]: | |||
"""Makes a schedular suitable for a Cooperator using the given reactor. | |||
(This is effectively just a copy from `twisted.internet.task`) | |||
""" | |||
def _scheduler(x): | |||
def _scheduler(x: Callable[[], object]) -> IDelayedCall: | |||
return reactor.callLater(_EPSILON, x) | |||
return _scheduler | |||
@@ -775,7 +779,7 @@ class SimpleHttpClient: | |||
) | |||
def _timeout_to_request_timed_out_error(f: Failure): | |||
def _timeout_to_request_timed_out_error(f: Failure) -> Failure: | |||
if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError): | |||
# The TCP connection has its own timeout (set by the 'connectTimeout' param | |||
# on the Agent), which raises twisted_error.TimeoutError exception. | |||
@@ -809,7 +813,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): | |||
def __init__(self, deferred: defer.Deferred): | |||
self.deferred = deferred | |||
def _maybe_fail(self): | |||
def _maybe_fail(self) -> None: | |||
""" | |||
Report a max size exceed error and disconnect the first time this is called. | |||
""" | |||
@@ -933,12 +937,12 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory): | |||
Do not use this since it allows an attacker to intercept your communications. | |||
""" | |||
def __init__(self): | |||
def __init__(self) -> None: | |||
self._context = SSL.Context(SSL.SSLv23_METHOD) | |||
self._context.set_verify(VERIFY_NONE, lambda *_: False) | |||
def getContext(self, hostname=None, port=None): | |||
return self._context | |||
def creatorForNetloc(self, hostname, port): | |||
def creatorForNetloc(self, hostname: bytes, port: int): | |||
return self |
@@ -239,7 +239,7 @@ class MatrixHostnameEndpointFactory: | |||
self._srv_resolver = srv_resolver | |||
def endpointForURI(self, parsed_uri: URI): | |||
def endpointForURI(self, parsed_uri: URI) -> "MatrixHostnameEndpoint": | |||
return MatrixHostnameEndpoint( | |||
self._reactor, | |||
self._proxy_reactor, | |||
@@ -16,7 +16,7 @@ | |||
import logging | |||
import random | |||
import time | |||
from typing import Callable, Dict, List | |||
from typing import Any, Callable, Dict, List | |||
import attr | |||
@@ -109,7 +109,7 @@ class SrvResolver: | |||
def __init__( | |||
self, | |||
dns_client=client, | |||
dns_client: Any = client, | |||
cache: Dict[bytes, List[Server]] = SERVER_CACHE, | |||
get_time: Callable[[], float] = time.time, | |||
): | |||
@@ -74,9 +74,9 @@ _well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known") | |||
_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known") | |||
@attr.s(slots=True, frozen=True) | |||
@attr.s(slots=True, frozen=True, auto_attribs=True) | |||
class WellKnownLookupResult: | |||
delegated_server = attr.ib() | |||
delegated_server: Optional[bytes] | |||
class WellKnownResolver: | |||
@@ -336,4 +336,4 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]: | |||
class _FetchWellKnownFailure(Exception): | |||
# True if we didn't get a non-5xx HTTP response, i.e. this may or may not be | |||
# a temporary failure. | |||
temporary = attr.ib() | |||
temporary: bool = attr.ib() |
@@ -23,6 +23,8 @@ from http import HTTPStatus | |||
from io import BytesIO, StringIO | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Any, | |||
BinaryIO, | |||
Callable, | |||
Dict, | |||
Generic, | |||
@@ -44,7 +46,7 @@ from typing_extensions import Literal | |||
from twisted.internet import defer | |||
from twisted.internet.error import DNSLookupError | |||
from twisted.internet.interfaces import IReactorTime | |||
from twisted.internet.task import _EPSILON, Cooperator | |||
from twisted.internet.task import Cooperator | |||
from twisted.web.client import ResponseFailed | |||
from twisted.web.http_headers import Headers | |||
from twisted.web.iweb import IBodyProducer, IResponse | |||
@@ -58,11 +60,13 @@ from synapse.api.errors import ( | |||
RequestSendFailed, | |||
SynapseError, | |||
) | |||
from synapse.crypto.context_factory import FederationPolicyForHTTPS | |||
from synapse.http import QuieterFileBodyProducer | |||
from synapse.http.client import ( | |||
BlacklistingAgentWrapper, | |||
BodyExceededMaxSize, | |||
ByteWriteable, | |||
_make_scheduler, | |||
encode_query_args, | |||
read_body_with_max_size, | |||
) | |||
@@ -181,7 +185,7 @@ class JsonParser(ByteParser[Union[JsonDict, list]]): | |||
CONTENT_TYPE = "application/json" | |||
def __init__(self): | |||
def __init__(self) -> None: | |||
self._buffer = StringIO() | |||
self._binary_wrapper = BinaryIOWrapper(self._buffer) | |||
@@ -299,7 +303,9 @@ async def _handle_response( | |||
class BinaryIOWrapper: | |||
"""A wrapper for a TextIO which converts from bytes on the fly.""" | |||
def __init__(self, file: typing.TextIO, encoding="utf-8", errors="strict"): | |||
def __init__( | |||
self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict" | |||
): | |||
self.decoder = codecs.getincrementaldecoder(encoding)(errors) | |||
self.file = file | |||
@@ -317,7 +323,11 @@ class MatrixFederationHttpClient: | |||
requests. | |||
""" | |||
def __init__(self, hs: "HomeServer", tls_client_options_factory): | |||
def __init__( | |||
self, | |||
hs: "HomeServer", | |||
tls_client_options_factory: Optional[FederationPolicyForHTTPS], | |||
): | |||
self.hs = hs | |||
self.signing_key = hs.signing_key | |||
self.server_name = hs.hostname | |||
@@ -348,10 +358,7 @@ class MatrixFederationHttpClient: | |||
self.version_string_bytes = hs.version_string.encode("ascii") | |||
self.default_timeout = 60 | |||
def schedule(x): | |||
self.reactor.callLater(_EPSILON, x) | |||
self._cooperator = Cooperator(scheduler=schedule) | |||
self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor)) | |||
self._sleeper = AwakenableSleeper(self.reactor) | |||
@@ -364,7 +371,7 @@ class MatrixFederationHttpClient: | |||
self, | |||
request: MatrixFederationRequest, | |||
try_trailing_slash_on_400: bool = False, | |||
**send_request_args, | |||
**send_request_args: Any, | |||
) -> IResponse: | |||
"""Wrapper for _send_request which can optionally retry the request | |||
upon receiving a combination of a 400 HTTP response code and a | |||
@@ -1159,7 +1166,7 @@ class MatrixFederationHttpClient: | |||
self, | |||
destination: str, | |||
path: str, | |||
output_stream, | |||
output_stream: BinaryIO, | |||
args: Optional[QueryParams] = None, | |||
retry_on_dns_fail: bool = True, | |||
max_size: Optional[int] = None, | |||
@@ -1250,10 +1257,10 @@ class MatrixFederationHttpClient: | |||
return length, headers | |||
def _flatten_response_never_received(e): | |||
def _flatten_response_never_received(e: BaseException) -> str: | |||
if hasattr(e, "reasons"): | |||
reasons = ", ".join( | |||
_flatten_response_never_received(f.value) for f in e.reasons | |||
_flatten_response_never_received(f.value) for f in e.reasons # type: ignore[attr-defined] | |||
) | |||
return "%s:[%s]" % (type(e).__name__, reasons) | |||
@@ -162,7 +162,7 @@ class RequestMetrics: | |||
with _in_flight_requests_lock: | |||
_in_flight_requests.add(self) | |||
def stop(self, time_sec, response_code, sent_bytes): | |||
def stop(self, time_sec: float, response_code: int, sent_bytes: int) -> None: | |||
with _in_flight_requests_lock: | |||
_in_flight_requests.discard(self) | |||
@@ -186,13 +186,13 @@ class RequestMetrics: | |||
) | |||
return | |||
response_code = str(response_code) | |||
response_code_str = str(response_code) | |||
outgoing_responses_counter.labels(self.method, response_code).inc() | |||
outgoing_responses_counter.labels(self.method, response_code_str).inc() | |||
response_count.labels(self.method, self.name, tag).inc() | |||
response_timer.labels(self.method, self.name, tag, response_code).observe( | |||
response_timer.labels(self.method, self.name, tag, response_code_str).observe( | |||
time_sec - self.start_ts | |||
) | |||
@@ -221,7 +221,7 @@ class RequestMetrics: | |||
# flight. | |||
self.update_metrics() | |||
def update_metrics(self): | |||
def update_metrics(self) -> None: | |||
"""Updates the in flight metrics with values from this request.""" | |||
if not self.start_context: | |||
logger.error( | |||
@@ -31,6 +31,7 @@ from typing import ( | |||
List, | |||
Optional, | |||
Tuple, | |||
Type, | |||
TypeVar, | |||
cast, | |||
overload, | |||
@@ -41,6 +42,7 @@ from prometheus_client import Histogram | |||
from typing_extensions import Concatenate, Literal, ParamSpec | |||
from twisted.enterprise import adbapi | |||
from twisted.internet.interfaces import IReactorCore | |||
from synapse.api.errors import StoreError | |||
from synapse.config.database import DatabaseConnectionConfig | |||
@@ -92,7 +94,9 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { | |||
def make_pool( | |||
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine | |||
reactor: IReactorCore, | |||
db_config: DatabaseConnectionConfig, | |||
engine: BaseDatabaseEngine, | |||
) -> adbapi.ConnectionPool: | |||
"""Get the connection pool for the database.""" | |||
@@ -101,7 +105,7 @@ def make_pool( | |||
db_args = dict(db_config.config.get("args", {})) | |||
db_args.setdefault("cp_reconnect", True) | |||
def _on_new_connection(conn): | |||
def _on_new_connection(conn: Connection) -> None: | |||
# Ensure we have a logging context so we can correctly track queries, | |||
# etc. | |||
with LoggingContext("db.on_new_connection"): | |||
@@ -157,7 +161,11 @@ class LoggingDatabaseConnection: | |||
default_txn_name: str | |||
def cursor( | |||
self, *, txn_name=None, after_callbacks=None, exception_callbacks=None | |||
self, | |||
*, | |||
txn_name: Optional[str] = None, | |||
after_callbacks: Optional[List["_CallbackListEntry"]] = None, | |||
exception_callbacks: Optional[List["_CallbackListEntry"]] = None, | |||
) -> "LoggingTransaction": | |||
if not txn_name: | |||
txn_name = self.default_txn_name | |||
@@ -183,11 +191,16 @@ class LoggingDatabaseConnection: | |||
self.conn.__enter__() | |||
return self | |||
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]: | |||
def __exit__( | |||
self, | |||
exc_type: Optional[Type[BaseException]], | |||
exc_value: Optional[BaseException], | |||
traceback: Optional[types.TracebackType], | |||
) -> Optional[bool]: | |||
return self.conn.__exit__(exc_type, exc_value, traceback) | |||
# Proxy through any unknown lookups to the DB conn class. | |||
def __getattr__(self, name): | |||
def __getattr__(self, name: str) -> Any: | |||
return getattr(self.conn, name) | |||
@@ -391,17 +404,22 @@ class LoggingTransaction: | |||
def __enter__(self) -> "LoggingTransaction": | |||
return self | |||
def __exit__(self, exc_type, exc_value, traceback): | |||
def __exit__( | |||
self, | |||
exc_type: Optional[Type[BaseException]], | |||
exc_value: Optional[BaseException], | |||
traceback: Optional[types.TracebackType], | |||
) -> None: | |||
self.close() | |||
class PerformanceCounters: | |||
def __init__(self): | |||
self.current_counters = {} | |||
self.previous_counters = {} | |||
def __init__(self) -> None: | |||
self.current_counters: Dict[str, Tuple[int, float]] = {} | |||
self.previous_counters: Dict[str, Tuple[int, float]] = {} | |||
def update(self, key: str, duration_secs: float) -> None: | |||
count, cum_time = self.current_counters.get(key, (0, 0)) | |||
count, cum_time = self.current_counters.get(key, (0, 0.0)) | |||
count += 1 | |||
cum_time += duration_secs | |||
self.current_counters[key] = (count, cum_time) | |||
@@ -527,7 +545,7 @@ class DatabasePool: | |||
def start_profiling(self) -> None: | |||
self._previous_loop_ts = monotonic_time() | |||
def loop(): | |||
def loop() -> None: | |||
curr = self._current_txn_total_time | |||
prev = self._previous_txn_total_time | |||
self._previous_txn_total_time = curr | |||
@@ -1186,7 +1204,7 @@ class DatabasePool: | |||
if lock: | |||
self.engine.lock_table(txn, table) | |||
def _getwhere(key): | |||
def _getwhere(key: str) -> str: | |||
# If the value we're passing in is None (aka NULL), we need to use | |||
# IS, not =, as NULL = NULL equals NULL (False). | |||
if keyvalues[key] is None: | |||
@@ -2258,7 +2276,7 @@ class DatabasePool: | |||
term: Optional[str], | |||
col: str, | |||
retcols: Collection[str], | |||
desc="simple_search_list", | |||
desc: str = "simple_search_list", | |||
) -> Optional[List[Dict[str, Any]]]: | |||
"""Executes a SELECT query on the named table, which may return zero or | |||
more rows, returning the result as a list of dicts. | |||
@@ -23,6 +23,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection | |||
from synapse.storage.databases.main.event_push_actions import ( | |||
EventPushActionsWorkerStore, | |||
) | |||
from synapse.storage.types import Cursor | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
@@ -71,7 +72,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
self._last_user_visit_update = self._get_start_of_day() | |||
@wrap_as_background_process("read_forward_extremities") | |||
async def _read_forward_extremities(self): | |||
async def _read_forward_extremities(self) -> None: | |||
def fetch(txn): | |||
txn.execute( | |||
""" | |||
@@ -95,7 +96,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
(x[0] - 1) * x[1] for x in res if x[1] | |||
) | |||
async def count_daily_e2ee_messages(self): | |||
async def count_daily_e2ee_messages(self) -> int: | |||
""" | |||
Returns an estimate of the number of messages sent in the last day. | |||
@@ -115,7 +116,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages) | |||
async def count_daily_sent_e2ee_messages(self): | |||
async def count_daily_sent_e2ee_messages(self) -> int: | |||
def _count_messages(txn): | |||
# This is good enough as if you have silly characters in your own | |||
# hostname then that's your own fault. | |||
@@ -136,7 +137,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
"count_daily_sent_e2ee_messages", _count_messages | |||
) | |||
async def count_daily_active_e2ee_rooms(self): | |||
async def count_daily_active_e2ee_rooms(self) -> int: | |||
def _count(txn): | |||
sql = """ | |||
SELECT COUNT(DISTINCT room_id) FROM events | |||
@@ -151,7 +152,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
"count_daily_active_e2ee_rooms", _count | |||
) | |||
async def count_daily_messages(self): | |||
async def count_daily_messages(self) -> int: | |||
""" | |||
Returns an estimate of the number of messages sent in the last day. | |||
@@ -171,7 +172,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
return await self.db_pool.runInteraction("count_messages", _count_messages) | |||
async def count_daily_sent_messages(self): | |||
async def count_daily_sent_messages(self) -> int: | |||
def _count_messages(txn): | |||
# This is good enough as if you have silly characters in your own | |||
# hostname then that's your own fault. | |||
@@ -192,7 +193,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
"count_daily_sent_messages", _count_messages | |||
) | |||
async def count_daily_active_rooms(self): | |||
async def count_daily_active_rooms(self) -> int: | |||
def _count(txn): | |||
sql = """ | |||
SELECT COUNT(DISTINCT room_id) FROM events | |||
@@ -226,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
"count_monthly_users", self._count_users, thirty_days_ago | |||
) | |||
def _count_users(self, txn, time_from): | |||
def _count_users(self, txn: Cursor, time_from: int) -> int: | |||
""" | |||
Returns number of users seen in the past time_from period | |||
""" | |||
@@ -238,7 +239,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
) u | |||
""" | |||
txn.execute(sql, (time_from,)) | |||
(count,) = txn.fetchone() | |||
# Mypy knows that fetchone() might return None if there are no rows. | |||
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always | |||
# returns exactly one row. | |||
(count,) = txn.fetchone() # type: ignore[misc] | |||
return count | |||
async def count_r30_users(self) -> Dict[str, int]: | |||
@@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
"count_r30v2_users", _count_r30v2_users | |||
) | |||
def _get_start_of_day(self): | |||
def _get_start_of_day(self) -> int: | |||
""" | |||
Returns millisecond unixtime for start of UTC day. | |||
""" | |||
@@ -798,9 +798,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
self, | |||
txn: LoggingTransaction, | |||
event_id: str, | |||
allow_none=False, | |||
) -> int: | |||
return self.db_pool.simple_select_one_onecol_txn( | |||
allow_none: bool = False, | |||
) -> Optional[int]: | |||
# Type ignore: we pass keyvalues a Dict[str, str]; the function wants | |||
# Dict[str, Any]. I think mypy is unhappy because Dict is invariant? | |||
return self.db_pool.simple_select_one_onecol_txn( # type: ignore[call-overload] | |||
txn=txn, | |||
table="events", | |||
keyvalues={"event_id": event_id}, | |||
@@ -25,6 +25,7 @@ from typing import ( | |||
Collection, | |||
Deque, | |||
Dict, | |||
Generator, | |||
Generic, | |||
Iterable, | |||
List, | |||
@@ -207,7 +208,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): | |||
return res | |||
def _handle_queue(self, room_id): | |||
def _handle_queue(self, room_id: str) -> None: | |||
"""Attempts to handle the queue for a room if not already being handled. | |||
The queue's callback will be invoked with for each item in the queue, | |||
@@ -227,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): | |||
self._currently_persisting_rooms.add(room_id) | |||
async def handle_queue_loop(): | |||
async def handle_queue_loop() -> None: | |||
try: | |||
queue = self._get_drainining_queue(room_id) | |||
for item in queue: | |||
@@ -250,15 +251,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]): | |||
with PreserveLoggingContext(): | |||
item.deferred.callback(ret) | |||
finally: | |||
queue = self._event_persist_queues.pop(room_id, None) | |||
if queue: | |||
self._event_persist_queues[room_id] = queue | |||
remaining_queue = self._event_persist_queues.pop(room_id, None) | |||
if remaining_queue: | |||
self._event_persist_queues[room_id] = remaining_queue | |||
self._currently_persisting_rooms.discard(room_id) | |||
# set handle_queue_loop off in the background | |||
run_as_background_process("persist_events", handle_queue_loop) | |||
def _get_drainining_queue(self, room_id): | |||
def _get_drainining_queue( | |||
self, room_id: str | |||
) -> Generator[_EventPersistQueueItem, None, None]: | |||
queue = self._event_persist_queues.setdefault(room_id, deque()) | |||
try: | |||
@@ -317,7 +320,9 @@ class EventsPersistenceStorage: | |||
for event, ctx in events_and_contexts: | |||
partitioned.setdefault(event.room_id, []).append((event, ctx)) | |||
async def enqueue(item): | |||
async def enqueue( | |||
item: Tuple[str, List[Tuple[EventBase, EventContext]]] | |||
) -> Dict[str, str]: | |||
room_id, evs_ctxs = item | |||
return await self._event_persist_queue.add_to_queue( | |||
room_id, evs_ctxs, backfilled=backfilled | |||
@@ -1102,7 +1107,7 @@ class EventsPersistenceStorage: | |||
return False | |||
async def _handle_potentially_left_users(self, user_ids: Set[str]): | |||
async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None: | |||
"""Given a set of remote users check if the server still shares a room with | |||
them. If not then mark those users' device cache as stale. | |||
""" | |||
@@ -85,7 +85,7 @@ def prepare_database( | |||
database_engine: BaseDatabaseEngine, | |||
config: Optional[HomeServerConfig], | |||
databases: Collection[str] = ("main", "state"), | |||
): | |||
) -> None: | |||
"""Prepares a physical database for usage. Will either create all necessary tables | |||
or upgrade from an older schema version. | |||
@@ -62,7 +62,7 @@ class StateFilter: | |||
types: "frozendict[str, Optional[FrozenSet[str]]]" | |||
include_others: bool = False | |||
def __attrs_post_init__(self): | |||
def __attrs_post_init__(self) -> None: | |||
# If `include_others` is set we canonicalise the filter by removing | |||
# wildcards from the types dictionary | |||
if self.include_others: | |||
@@ -138,7 +138,9 @@ class StateFilter: | |||
) | |||
@staticmethod | |||
def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool): | |||
def freeze( | |||
types: Mapping[str, Optional[Collection[str]]], include_others: bool | |||
) -> "StateFilter": | |||
""" | |||
Returns a (frozen) StateFilter with the same contents as the parameters | |||
specified here, which can be made of mutable types. | |||
@@ -11,7 +11,8 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union | |||
from types import TracebackType | |||
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union | |||
from typing_extensions import Protocol | |||
@@ -86,5 +87,10 @@ class Connection(Protocol): | |||
def __enter__(self) -> "Connection": | |||
... | |||
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]: | |||
def __exit__( | |||
self, | |||
exc_type: Optional[Type[BaseException]], | |||
exc_value: Optional[BaseException], | |||
traceback: Optional[TracebackType], | |||
) -> Optional[bool]: | |||
... |