@@ -0,0 +1 @@ | |||
Add type annotations to increase the number of modules passing `disallow-untyped-defs`. |
@@ -128,15 +128,30 @@ disallow_untyped_defs = True | |||
[mypy-synapse.http.federation.*] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.http.connectproxyclient] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.http.proxyagent] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.http.request_metrics] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.http.server] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.logging._remote] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.logging.context] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.logging.formatter] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.logging.handlers] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.metrics.*] | |||
disallow_untyped_defs = True | |||
@@ -166,6 +181,9 @@ disallow_untyped_defs = True | |||
[mypy-synapse.state.*] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.storage.databases.background_updates] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.storage.databases.main.account_data] | |||
disallow_untyped_defs = True | |||
@@ -232,6 +250,9 @@ disallow_untyped_defs = True | |||
[mypy-synapse.streams.*] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.types] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.util.*] | |||
disallow_untyped_defs = True | |||
@@ -15,7 +15,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple | |||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple | |||
import attr | |||
from canonicaljson import encode_canonical_json | |||
@@ -1105,22 +1105,19 @@ class E2eKeysHandler: | |||
# can request over federation | |||
raise NotFoundError("No %s key found for %s" % (key_type, user_id)) | |||
( | |||
key, | |||
key_id, | |||
verify_key, | |||
) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type) | |||
if key is None: | |||
cross_signing_keys = await self._retrieve_cross_signing_keys_for_remote_user( | |||
user, key_type | |||
) | |||
if cross_signing_keys is None: | |||
raise NotFoundError("No %s key found for %s" % (key_type, user_id)) | |||
return key, key_id, verify_key | |||
return cross_signing_keys | |||
async def _retrieve_cross_signing_keys_for_remote_user( | |||
self, | |||
user: UserID, | |||
desired_key_type: str, | |||
) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]: | |||
) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]: | |||
"""Queries cross-signing keys for a remote user and saves them to the database | |||
Only the key specified by `key_type` will be returned, while all retrieved keys | |||
@@ -1146,12 +1143,10 @@ class E2eKeysHandler: | |||
type(e), | |||
e, | |||
) | |||
return None, None, None | |||
return None | |||
# Process each of the retrieved cross-signing keys | |||
desired_key = None | |||
desired_key_id = None | |||
desired_verify_key = None | |||
desired_key_data = None | |||
retrieved_device_ids = [] | |||
for key_type in ["master", "self_signing"]: | |||
key_content = remote_result.get(key_type + "_key") | |||
@@ -1196,9 +1191,7 @@ class E2eKeysHandler: | |||
# If this is the desired key type, save it and its ID/VerifyKey | |||
if key_type == desired_key_type: | |||
desired_key = key_content | |||
desired_verify_key = verify_key | |||
desired_key_id = key_id | |||
desired_key_data = key_content, key_id, verify_key | |||
# At the same time, store this key in the db for subsequent queries | |||
await self.store.set_e2e_cross_signing_key( | |||
@@ -1212,7 +1205,7 @@ class E2eKeysHandler: | |||
user.to_string(), retrieved_device_ids | |||
) | |||
return desired_key, desired_key_id, desired_verify_key | |||
return desired_key_data | |||
def _check_cross_signing_key( | |||
@@ -14,15 +14,22 @@ | |||
import base64 | |||
import logging | |||
from typing import Optional | |||
from typing import Optional, Union | |||
import attr | |||
from zope.interface import implementer | |||
from twisted.internet import defer, protocol | |||
from twisted.internet.error import ConnectError | |||
from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint | |||
from twisted.internet.interfaces import ( | |||
IAddress, | |||
IConnector, | |||
IProtocol, | |||
IReactorCore, | |||
IStreamClientEndpoint, | |||
) | |||
from twisted.internet.protocol import ClientFactory, Protocol, connectionDone | |||
from twisted.python.failure import Failure | |||
from twisted.web import http | |||
logger = logging.getLogger(__name__) | |||
@@ -81,14 +88,14 @@ class HTTPConnectProxyEndpoint: | |||
self._port = port | |||
self._proxy_creds = proxy_creds | |||
def __repr__(self): | |||
def __repr__(self) -> str: | |||
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,) | |||
# Mypy encounters a false positive here: it complains that ClientFactory | |||
# is incompatible with IProtocolFactory. But ClientFactory inherits from | |||
# Factory, which implements IProtocolFactory. So I think this is a bug | |||
# in mypy-zope. | |||
def connect(self, protocolFactory: ClientFactory): # type: ignore[override] | |||
def connect(self, protocolFactory: ClientFactory) -> "defer.Deferred[IProtocol]": # type: ignore[override] | |||
f = HTTPProxiedClientFactory( | |||
self._host, self._port, protocolFactory, self._proxy_creds | |||
) | |||
@@ -125,10 +132,10 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): | |||
self.proxy_creds = proxy_creds | |||
self.on_connection: "defer.Deferred[None]" = defer.Deferred() | |||
def startedConnecting(self, connector): | |||
def startedConnecting(self, connector: IConnector) -> None: | |||
return self.wrapped_factory.startedConnecting(connector) | |||
def buildProtocol(self, addr): | |||
def buildProtocol(self, addr: IAddress) -> "HTTPConnectProtocol": | |||
wrapped_protocol = self.wrapped_factory.buildProtocol(addr) | |||
if wrapped_protocol is None: | |||
raise TypeError("buildProtocol produced None instead of a Protocol") | |||
@@ -141,13 +148,13 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): | |||
self.proxy_creds, | |||
) | |||
def clientConnectionFailed(self, connector, reason): | |||
def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None: | |||
logger.debug("Connection to proxy failed: %s", reason) | |||
if not self.on_connection.called: | |||
self.on_connection.errback(reason) | |||
return self.wrapped_factory.clientConnectionFailed(connector, reason) | |||
def clientConnectionLost(self, connector, reason): | |||
def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None: | |||
logger.debug("Connection to proxy lost: %s", reason) | |||
if not self.on_connection.called: | |||
self.on_connection.errback(reason) | |||
@@ -191,10 +198,10 @@ class HTTPConnectProtocol(protocol.Protocol): | |||
) | |||
self.http_setup_client.on_connected.addCallback(self.proxyConnected) | |||
def connectionMade(self): | |||
def connectionMade(self) -> None: | |||
self.http_setup_client.makeConnection(self.transport) | |||
def connectionLost(self, reason=connectionDone): | |||
def connectionLost(self, reason: Failure = connectionDone) -> None: | |||
if self.wrapped_protocol.connected: | |||
self.wrapped_protocol.connectionLost(reason) | |||
@@ -203,7 +210,7 @@ class HTTPConnectProtocol(protocol.Protocol): | |||
if not self.connected_deferred.called: | |||
self.connected_deferred.errback(reason) | |||
def proxyConnected(self, _): | |||
def proxyConnected(self, _: Union[None, "defer.Deferred[None]"]) -> None: | |||
self.wrapped_protocol.makeConnection(self.transport) | |||
self.connected_deferred.callback(self.wrapped_protocol) | |||
@@ -213,7 +220,7 @@ class HTTPConnectProtocol(protocol.Protocol): | |||
if buf: | |||
self.wrapped_protocol.dataReceived(buf) | |||
def dataReceived(self, data: bytes): | |||
def dataReceived(self, data: bytes) -> None: | |||
# if we've set up the HTTP protocol, we can send the data there | |||
if self.wrapped_protocol.connected: | |||
return self.wrapped_protocol.dataReceived(data) | |||
@@ -243,7 +250,7 @@ class HTTPConnectSetupClient(http.HTTPClient): | |||
self.proxy_creds = proxy_creds | |||
self.on_connected: "defer.Deferred[None]" = defer.Deferred() | |||
def connectionMade(self): | |||
def connectionMade(self) -> None: | |||
logger.debug("Connected to proxy, sending CONNECT") | |||
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) | |||
@@ -257,14 +264,14 @@ class HTTPConnectSetupClient(http.HTTPClient): | |||
self.endHeaders() | |||
def handleStatus(self, version: bytes, status: bytes, message: bytes): | |||
def handleStatus(self, version: bytes, status: bytes, message: bytes) -> None: | |||
logger.debug("Got Status: %s %s %s", status, message, version) | |||
if status != b"200": | |||
raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}") | |||
def handleEndHeaders(self): | |||
def handleEndHeaders(self) -> None: | |||
logger.debug("End Headers") | |||
self.on_connected.callback(None) | |||
def handleResponse(self, body): | |||
def handleResponse(self, body: bytes) -> None: | |||
pass |
@@ -245,7 +245,7 @@ def http_proxy_endpoint( | |||
proxy: Optional[bytes], | |||
reactor: IReactorCore, | |||
tls_options_factory: Optional[IPolicyForHTTPS], | |||
**kwargs, | |||
**kwargs: object, | |||
) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]: | |||
"""Parses an http proxy setting and returns an endpoint for the proxy | |||
@@ -31,7 +31,11 @@ from twisted.internet.endpoints import ( | |||
TCP4ClientEndpoint, | |||
TCP6ClientEndpoint, | |||
) | |||
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint | |||
from twisted.internet.interfaces import ( | |||
IPushProducer, | |||
IReactorTCP, | |||
IStreamClientEndpoint, | |||
) | |||
from twisted.internet.protocol import Factory, Protocol | |||
from twisted.internet.tcp import Connection | |||
from twisted.python.failure import Failure | |||
@@ -59,14 +63,14 @@ class LogProducer: | |||
_buffer: Deque[logging.LogRecord] | |||
_paused: bool = attr.ib(default=False, init=False) | |||
def pauseProducing(self): | |||
def pauseProducing(self) -> None: | |||
self._paused = True | |||
def stopProducing(self): | |||
def stopProducing(self) -> None: | |||
self._paused = True | |||
self._buffer = deque() | |||
def resumeProducing(self): | |||
def resumeProducing(self) -> None: | |||
# If we're already producing, nothing to do. | |||
self._paused = False | |||
@@ -102,8 +106,8 @@ class RemoteHandler(logging.Handler): | |||
host: str, | |||
port: int, | |||
maximum_buffer: int = 1000, | |||
level=logging.NOTSET, | |||
_reactor=None, | |||
level: int = logging.NOTSET, | |||
_reactor: Optional[IReactorTCP] = None, | |||
): | |||
super().__init__(level=level) | |||
self.host = host | |||
@@ -118,7 +122,7 @@ class RemoteHandler(logging.Handler): | |||
if _reactor is None: | |||
from twisted.internet import reactor | |||
_reactor = reactor | |||
_reactor = reactor # type: ignore[assignment] | |||
try: | |||
ip = ip_address(self.host) | |||
@@ -139,7 +143,7 @@ class RemoteHandler(logging.Handler): | |||
self._stopping = False | |||
self._connect() | |||
def close(self): | |||
def close(self) -> None: | |||
self._stopping = True | |||
self._service.stopService() | |||
@@ -16,6 +16,8 @@ | |||
import logging | |||
import traceback | |||
from io import StringIO | |||
from types import TracebackType | |||
from typing import Optional, Tuple, Type | |||
class LogFormatter(logging.Formatter): | |||
@@ -28,10 +30,14 @@ class LogFormatter(logging.Formatter): | |||
where it was caught are logged). | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
def formatException(self, ei): | |||
def formatException( | |||
self, | |||
ei: Tuple[ | |||
Optional[Type[BaseException]], | |||
Optional[BaseException], | |||
Optional[TracebackType], | |||
], | |||
) -> str: | |||
sio = StringIO() | |||
(typ, val, tb) = ei | |||
@@ -49,7 +49,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler): | |||
) | |||
self._flushing_thread.start() | |||
def on_reactor_running(): | |||
def on_reactor_running() -> None: | |||
self._reactor_started = True | |||
reactor_to_use: IReactorCore | |||
@@ -74,7 +74,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler): | |||
else: | |||
return True | |||
def _flush_periodically(self): | |||
def _flush_periodically(self) -> None: | |||
""" | |||
Whilst this handler is active, flush the handler periodically. | |||
""" | |||
@@ -13,6 +13,8 @@ | |||
# limitations under the License.import logging | |||
import logging | |||
from types import TracebackType | |||
from typing import Optional, Type | |||
from opentracing import Scope, ScopeManager | |||
@@ -107,19 +109,26 @@ class _LogContextScope(Scope): | |||
and - if enter_logcontext was set - the logcontext is finished too. | |||
""" | |||
def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close): | |||
def __init__( | |||
self, | |||
manager: LogContextScopeManager, | |||
span, | |||
logcontext, | |||
enter_logcontext: bool, | |||
finish_on_close: bool, | |||
): | |||
""" | |||
Args: | |||
manager (LogContextScopeManager): | |||
manager: | |||
the manager that is responsible for this scope. | |||
span (Span): | |||
the opentracing span which this scope represents the local | |||
lifetime for. | |||
logcontext (LogContext): | |||
the logcontext to which this scope is attached. | |||
enter_logcontext (Boolean): | |||
enter_logcontext: | |||
if True the logcontext will be exited when the scope is finished | |||
finish_on_close (Boolean): | |||
finish_on_close: | |||
if True finish the span when the scope is closed | |||
""" | |||
super().__init__(manager, span) | |||
@@ -127,16 +136,21 @@ class _LogContextScope(Scope): | |||
self._finish_on_close = finish_on_close | |||
self._enter_logcontext = enter_logcontext | |||
def __exit__(self, exc_type, value, traceback): | |||
def __exit__( | |||
self, | |||
exc_type: Optional[Type[BaseException]], | |||
value: Optional[BaseException], | |||
traceback: Optional[TracebackType], | |||
) -> None: | |||
if exc_type == twisted.internet.defer._DefGen_Return: | |||
# filter out defer.returnValue() calls | |||
exc_type = value = traceback = None | |||
super().__exit__(exc_type, value, traceback) | |||
def __str__(self): | |||
def __str__(self) -> str: | |||
return f"Scope<{self.span}>" | |||
def close(self): | |||
def close(self) -> None: | |||
active_scope = self.manager.active | |||
if active_scope is not self: | |||
logger.error( | |||
@@ -12,20 +12,24 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from types import TracebackType | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Any, | |||
AsyncContextManager, | |||
Awaitable, | |||
Callable, | |||
Dict, | |||
Iterable, | |||
List, | |||
Optional, | |||
Type, | |||
) | |||
import attr | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.storage.types import Connection | |||
from synapse.storage.types import Connection, Cursor | |||
from synapse.types import JsonDict | |||
from synapse.util import Clock, json_encoder | |||
@@ -74,7 +78,12 @@ class _BackgroundUpdateContextManager: | |||
return self._update_duration_ms | |||
async def __aexit__(self, *exc) -> None: | |||
async def __aexit__( | |||
self, | |||
exc_type: Optional[Type[BaseException]], | |||
exc: Optional[BaseException], | |||
tb: Optional[TracebackType], | |||
) -> None: | |||
pass | |||
@@ -352,7 +361,7 @@ class BackgroundUpdater: | |||
True if we have finished running all the background updates, otherwise False | |||
""" | |||
def get_background_updates_txn(txn): | |||
def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]: | |||
txn.execute( | |||
""" | |||
SELECT update_name, depends_on FROM background_updates | |||
@@ -469,7 +478,7 @@ class BackgroundUpdater: | |||
self, | |||
update_name: str, | |||
update_handler: Callable[[JsonDict, int], Awaitable[int]], | |||
): | |||
) -> None: | |||
"""Register a handler for doing a background update. | |||
The handler should take two arguments: | |||
@@ -603,7 +612,7 @@ class BackgroundUpdater: | |||
else: | |||
runner = create_index_sqlite | |||
async def updater(progress, batch_size): | |||
async def updater(progress: JsonDict, batch_size: int) -> int: | |||
if runner is not None: | |||
logger.info("Adding index %s to %s", index_name, table) | |||
await self.db_pool.runWithConnection(runner) | |||
@@ -24,6 +24,7 @@ from typing import ( | |||
Mapping, | |||
Match, | |||
MutableMapping, | |||
NoReturn, | |||
Optional, | |||
Set, | |||
Tuple, | |||
@@ -35,6 +36,7 @@ from typing import ( | |||
import attr | |||
from frozendict import frozendict | |||
from signedjson.key import decode_verify_key_bytes | |||
from signedjson.types import VerifyKey | |||
from typing_extensions import TypedDict | |||
from unpaddedbase64 import decode_base64 | |||
from zope.interface import Interface | |||
@@ -55,6 +57,7 @@ from synapse.util.stringutils import parse_and_validate_server_name | |||
if TYPE_CHECKING: | |||
from synapse.appservice.api import ApplicationService | |||
from synapse.storage.databases.main import DataStore, PurgeEventsStore | |||
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore | |||
# Define a state map type from type/state_key to T (usually an event ID or | |||
# event) | |||
@@ -114,7 +117,7 @@ class Requester: | |||
app_service: Optional["ApplicationService"] | |||
authenticated_entity: str | |||
def serialize(self): | |||
def serialize(self) -> Dict[str, Any]: | |||
"""Converts self to a type that can be serialized as JSON, and then | |||
deserialized by `deserialize` | |||
@@ -132,7 +135,9 @@ class Requester: | |||
} | |||
@staticmethod | |||
def deserialize(store, input): | |||
def deserialize( | |||
store: "ApplicationServiceWorkerStore", input: Dict[str, Any] | |||
) -> "Requester": | |||
"""Converts a dict that was produced by `serialize` back into a | |||
Requester. | |||
@@ -236,10 +241,10 @@ class DomainSpecificString(metaclass=abc.ABCMeta): | |||
domain: str | |||
# Because this is a frozen class, it is deeply immutable. | |||
def __copy__(self): | |||
def __copy__(self: DS) -> DS: | |||
return self | |||
def __deepcopy__(self, memo): | |||
def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS: | |||
return self | |||
@classmethod | |||
@@ -729,12 +734,14 @@ class StreamToken: | |||
) | |||
@property | |||
def room_stream_id(self): | |||
def room_stream_id(self) -> int: | |||
return self.room_key.stream | |||
def copy_and_advance(self, key, new_value) -> "StreamToken": | |||
def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken": | |||
"""Advance the given key in the token to a new value if and only if the | |||
new value is after the old value. | |||
:raises TypeError: if `key` is not the one of the keys tracked by a StreamToken. | |||
""" | |||
if key == "room_key": | |||
new_token = self.copy_and_replace( | |||
@@ -751,7 +758,7 @@ class StreamToken: | |||
else: | |||
return self | |||
def copy_and_replace(self, key, new_value) -> "StreamToken": | |||
def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": | |||
return attr.evolve(self, **{key: new_value}) | |||
@@ -793,14 +800,14 @@ class ThirdPartyInstanceID: | |||
# Deny iteration because it will bite you if you try to create a singleton | |||
# set by: | |||
# users = set(user) | |||
def __iter__(self): | |||
def __iter__(self) -> NoReturn: | |||
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) | |||
# Because this class is a frozen class, it is deeply immutable. | |||
def __copy__(self): | |||
def __copy__(self) -> "ThirdPartyInstanceID": | |||
return self | |||
def __deepcopy__(self, memo): | |||
def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID": | |||
return self | |||
@classmethod | |||
@@ -852,25 +859,28 @@ class DeviceListUpdates: | |||
return bool(self.changed or self.left) | |||
def get_verify_key_from_cross_signing_key(key_info): | |||
def get_verify_key_from_cross_signing_key( | |||
key_info: Mapping[str, Any] | |||
) -> Tuple[str, VerifyKey]: | |||
"""Get the key ID and signedjson verify key from a cross-signing key dict | |||
Args: | |||
key_info (dict): a cross-signing key dict, which must have a "keys" | |||
key_info: a cross-signing key dict, which must have a "keys" | |||
property that has exactly one item in it | |||
Returns: | |||
(str, VerifyKey): the key ID and verify key for the cross-signing key | |||
the key ID and verify key for the cross-signing key | |||
""" | |||
# make sure that exactly one key is provided | |||
# make sure that a `keys` field is provided | |||
if "keys" not in key_info: | |||
raise ValueError("Invalid key") | |||
keys = key_info["keys"] | |||
if len(keys) != 1: | |||
raise ValueError("Invalid key") | |||
# and return that one key | |||
for key_id, key_data in keys.items(): | |||
# and that it contains exactly one key | |||
if len(keys) == 1: | |||
key_id, key_data = next(iter(keys.items())) | |||
return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)) | |||
else: | |||
raise ValueError("Invalid key") | |||
@attr.s(auto_attribs=True, frozen=True, slots=True) | |||