@@ -0,0 +1 @@ | |||
Improve type hints. |
@@ -33,12 +33,6 @@ exclude = (?x) | |||
|synapse/storage/schema/ | |||
)$ | |||
[mypy-synapse.federation.transport.client] | |||
disallow_untyped_defs = False | |||
[mypy-synapse.http.matrixfederationclient] | |||
disallow_untyped_defs = False | |||
[mypy-synapse.metrics._reactor_metrics] | |||
disallow_untyped_defs = False | |||
# This module imports select.epoll. That exists on Linux, but doesn't on macOS. | |||
@@ -280,15 +280,11 @@ class FederationClient(FederationBase): | |||
logger.debug("backfill transaction_data=%r", transaction_data) | |||
if not isinstance(transaction_data, dict): | |||
# TODO we probably want an exception type specific to federation | |||
# client validation. | |||
raise TypeError("Backfill transaction_data is not a dict.") | |||
raise InvalidResponseError("Backfill transaction_data is not a dict.") | |||
transaction_data_pdus = transaction_data.get("pdus") | |||
if not isinstance(transaction_data_pdus, list): | |||
# TODO we probably want an exception type specific to federation | |||
# client validation. | |||
raise TypeError("transaction_data.pdus is not a list.") | |||
raise InvalidResponseError("transaction_data.pdus is not a list.") | |||
room_version = await self.store.get_room_version(room_id) | |||
@@ -16,6 +16,7 @@ | |||
import logging | |||
import urllib | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Any, | |||
Callable, | |||
Collection, | |||
@@ -42,18 +43,21 @@ from synapse.api.urls import ( | |||
) | |||
from synapse.events import EventBase, make_event_from_dict | |||
from synapse.federation.units import Transaction | |||
from synapse.http.matrixfederationclient import ByteParser | |||
from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser | |||
from synapse.http.types import QueryParams | |||
from synapse.types import JsonDict | |||
from synapse.util import ExceptionBundle | |||
if TYPE_CHECKING: | |||
from synapse.app.homeserver import HomeServer | |||
logger = logging.getLogger(__name__) | |||
class TransportLayerClient: | |||
"""Sends federation HTTP requests to other servers""" | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
self.server_name = hs.hostname | |||
self.client = hs.get_federation_http_client() | |||
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled | |||
@@ -133,7 +137,7 @@ class TransportLayerClient: | |||
async def backfill( | |||
self, destination: str, room_id: str, event_tuples: Collection[str], limit: int | |||
) -> Optional[JsonDict]: | |||
) -> Optional[Union[JsonDict, list]]: | |||
"""Requests `limit` previous PDUs in a given context before list of | |||
PDUs. | |||
@@ -388,6 +392,7 @@ class TransportLayerClient: | |||
# server was just having a momentary blip, the room will be out of | |||
# sync. | |||
ignore_backoff=True, | |||
parser=LegacyJsonSendParser(), | |||
) | |||
async def send_leave_v2( | |||
@@ -445,7 +450,11 @@ class TransportLayerClient: | |||
path = _create_v1_path("/invite/%s/%s", room_id, event_id) | |||
return await self.client.put_json( | |||
destination=destination, path=path, data=content, ignore_backoff=True | |||
destination=destination, | |||
path=path, | |||
data=content, | |||
ignore_backoff=True, | |||
parser=LegacyJsonSendParser(), | |||
) | |||
async def send_invite_v2( | |||
@@ -17,7 +17,6 @@ import codecs | |||
import logging | |||
import random | |||
import sys | |||
import typing | |||
import urllib.parse | |||
from http import HTTPStatus | |||
from io import BytesIO, StringIO | |||
@@ -30,9 +29,11 @@ from typing import ( | |||
Generic, | |||
List, | |||
Optional, | |||
TextIO, | |||
Tuple, | |||
TypeVar, | |||
Union, | |||
cast, | |||
overload, | |||
) | |||
@@ -183,20 +184,61 @@ class MatrixFederationRequest: | |||
return self.json | |||
class JsonParser(ByteParser[Union[JsonDict, list]]): | |||
class _BaseJsonParser(ByteParser[T]): | |||
"""A parser that buffers the response and tries to parse it as JSON.""" | |||
CONTENT_TYPE = "application/json" | |||
def __init__(self) -> None: | |||
def __init__( | |||
self, validator: Optional[Callable[[Optional[object]], bool]] = None | |||
) -> None: | |||
""" | |||
Args: | |||
validator: A callable which takes the parsed JSON value and returns | |||
true if the value is valid. | |||
""" | |||
self._buffer = StringIO() | |||
self._binary_wrapper = BinaryIOWrapper(self._buffer) | |||
self._validator = validator | |||
def write(self, data: bytes) -> int: | |||
return self._binary_wrapper.write(data) | |||
def finish(self) -> Union[JsonDict, list]: | |||
return json_decoder.decode(self._buffer.getvalue()) | |||
def finish(self) -> T: | |||
result = json_decoder.decode(self._buffer.getvalue()) | |||
if self._validator is not None and not self._validator(result): | |||
raise ValueError( | |||
f"Received incorrect JSON value: {result.__class__.__name__}" | |||
) | |||
return result | |||
class JsonParser(_BaseJsonParser[JsonDict]): | |||
"""A parser that buffers the response and tries to parse it as a JSON object.""" | |||
def __init__(self) -> None: | |||
super().__init__(self._validate) | |||
@staticmethod | |||
def _validate(v: Any) -> bool: | |||
return isinstance(v, dict) | |||
class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]): | |||
"""Ensure the legacy responses of /send_join & /send_leave are correct.""" | |||
def __init__(self) -> None: | |||
super().__init__(self._validate) | |||
@staticmethod | |||
def _validate(v: Any) -> bool: | |||
# Match [integer, JSON dict] | |||
return ( | |||
isinstance(v, list) | |||
and len(v) == 2 | |||
and type(v[0]) == int | |||
and isinstance(v[1], dict) | |||
) | |||
async def _handle_response( | |||
@@ -313,9 +355,7 @@ async def _handle_response( | |||
class BinaryIOWrapper: | |||
"""A wrapper for a TextIO which converts from bytes on the fly.""" | |||
def __init__( | |||
self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict" | |||
): | |||
def __init__(self, file: TextIO, encoding: str = "utf-8", errors: str = "strict"): | |||
self.decoder = codecs.getincrementaldecoder(encoding)(errors) | |||
self.file = file | |||
@@ -793,7 +833,7 @@ class MatrixFederationHttpClient: | |||
backoff_on_404: bool = False, | |||
try_trailing_slash_on_400: bool = False, | |||
parser: Literal[None] = None, | |||
) -> Union[JsonDict, list]: | |||
) -> JsonDict: | |||
... | |||
@overload | |||
@@ -825,8 +865,8 @@ class MatrixFederationHttpClient: | |||
ignore_backoff: bool = False, | |||
backoff_on_404: bool = False, | |||
try_trailing_slash_on_400: bool = False, | |||
parser: Optional[ByteParser] = None, | |||
): | |||
parser: Optional[ByteParser[T]] = None, | |||
) -> Union[JsonDict, T]: | |||
"""Sends the specified json data using PUT | |||
Args: | |||
@@ -902,7 +942,7 @@ class MatrixFederationHttpClient: | |||
_sec_timeout = self.default_timeout | |||
if parser is None: | |||
parser = JsonParser() | |||
parser = cast(ByteParser[T], JsonParser()) | |||
body = await _handle_response( | |||
self.reactor, | |||
@@ -924,7 +964,7 @@ class MatrixFederationHttpClient: | |||
timeout: Optional[int] = None, | |||
ignore_backoff: bool = False, | |||
args: Optional[QueryParams] = None, | |||
) -> Union[JsonDict, list]: | |||
) -> JsonDict: | |||
"""Sends the specified json data using POST | |||
Args: | |||
@@ -998,7 +1038,7 @@ class MatrixFederationHttpClient: | |||
ignore_backoff: bool = False, | |||
try_trailing_slash_on_400: bool = False, | |||
parser: Literal[None] = None, | |||
) -> Union[JsonDict, list]: | |||
) -> JsonDict: | |||
... | |||
@overload | |||
@@ -1024,8 +1064,8 @@ class MatrixFederationHttpClient: | |||
timeout: Optional[int] = None, | |||
ignore_backoff: bool = False, | |||
try_trailing_slash_on_400: bool = False, | |||
parser: Optional[ByteParser] = None, | |||
): | |||
parser: Optional[ByteParser[T]] = None, | |||
) -> Union[JsonDict, T]: | |||
"""GETs some json from the given host homeserver and path | |||
Args: | |||
@@ -1091,7 +1131,7 @@ class MatrixFederationHttpClient: | |||
_sec_timeout = self.default_timeout | |||
if parser is None: | |||
parser = JsonParser() | |||
parser = cast(ByteParser[T], JsonParser()) | |||
body = await _handle_response( | |||
self.reactor, | |||
@@ -1112,7 +1152,7 @@ class MatrixFederationHttpClient: | |||
timeout: Optional[int] = None, | |||
ignore_backoff: bool = False, | |||
args: Optional[QueryParams] = None, | |||
) -> Union[JsonDict, list]: | |||
) -> JsonDict: | |||
"""Send a DELETE request to the remote expecting some json response | |||
Args: | |||
@@ -75,7 +75,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): | |||
fed_transport = self.hs.get_federation_transport_client() | |||
# Mock out some things, because we don't want to test the whole join | |||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) | |||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] | |||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(("", 1)) | |||
) | |||
@@ -106,7 +106,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): | |||
fed_transport = self.hs.get_federation_transport_client() | |||
# Mock out some things, because we don't want to test the whole join | |||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) | |||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] | |||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(("", 1)) | |||
) | |||
@@ -143,7 +143,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): | |||
fed_transport = self.hs.get_federation_transport_client() | |||
# Mock out some things, because we don't want to test the whole join | |||
fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) | |||
fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] | |||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(("", 1)) | |||
) | |||
@@ -200,7 +200,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): | |||
fed_transport = self.hs.get_federation_transport_client() | |||
# Mock out some things, because we don't want to test the whole join | |||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) | |||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] | |||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(("", 1)) | |||
) | |||
@@ -230,7 +230,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): | |||
fed_transport = self.hs.get_federation_transport_client() | |||
# Mock out some things, because we don't want to test the whole join | |||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) | |||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] | |||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(("", 1)) | |||
) | |||
@@ -26,7 +26,7 @@ from twisted.web.http import HTTPChannel | |||
from synapse.api.errors import RequestSendFailed | |||
from synapse.http.matrixfederationclient import ( | |||
JsonParser, | |||
ByteParser, | |||
MatrixFederationHttpClient, | |||
MatrixFederationRequest, | |||
) | |||
@@ -618,9 +618,9 @@ class FederationClientTests(HomeserverTestCase): | |||
while not test_d.called: | |||
protocol.dataReceived(b"a" * chunk_size) | |||
sent += chunk_size | |||
self.assertLessEqual(sent, JsonParser.MAX_RESPONSE_SIZE) | |||
self.assertLessEqual(sent, ByteParser.MAX_RESPONSE_SIZE) | |||
self.assertEqual(sent, JsonParser.MAX_RESPONSE_SIZE) | |||
self.assertEqual(sent, ByteParser.MAX_RESPONSE_SIZE) | |||
f = self.failureResultOf(test_d) | |||
self.assertIsInstance(f.value, RequestSendFailed) | |||