ソースを参照

Finish type hints for federation client HTTP code. (#15465)

tags/v1.83.0rc1
Patrick Cloke 1年前
committed by GitHub
コミット
ea5c3ede4f
この署名に対応する既知のキーがデータベースに存在しません GPGキーID: 4AEE18F83AFDEB23
7個のファイルの変更82行の追加42行の削除
  1. +1
    -0
      changelog.d/15465.misc
  2. +0
    -6
      mypy.ini
  3. +2
    -6
      synapse/federation/federation_client.py
  4. +13
    -4
      synapse/federation/transport/client.py
  5. +58
    -18
      synapse/http/matrixfederationclient.py
  6. +5
    -5
      tests/federation/test_complexity.py
  7. +3
    -3
      tests/http/test_matrixfederationclient.py

+ 1
- 0
changelog.d/15465.misc ファイルの表示

@@ -0,0 +1 @@
Improve type hints.

+ 0
- 6
mypy.ini ファイルの表示

@@ -33,12 +33,6 @@ exclude = (?x)
|synapse/storage/schema/
)$

[mypy-synapse.federation.transport.client]
disallow_untyped_defs = False

[mypy-synapse.http.matrixfederationclient]
disallow_untyped_defs = False

[mypy-synapse.metrics._reactor_metrics]
disallow_untyped_defs = False
# This module imports select.epoll. That exists on Linux, but doesn't on macOS.


+ 2
- 6
synapse/federation/federation_client.py ファイルの表示

@@ -280,15 +280,11 @@ class FederationClient(FederationBase):
logger.debug("backfill transaction_data=%r", transaction_data)

if not isinstance(transaction_data, dict):
# TODO we probably want an exception type specific to federation
# client validation.
raise TypeError("Backfill transaction_data is not a dict.")
raise InvalidResponseError("Backfill transaction_data is not a dict.")

transaction_data_pdus = transaction_data.get("pdus")
if not isinstance(transaction_data_pdus, list):
# TODO we probably want an exception type specific to federation
# client validation.
raise TypeError("transaction_data.pdus is not a list.")
raise InvalidResponseError("transaction_data.pdus is not a list.")

room_version = await self.store.get_room_version(room_id)



+ 13
- 4
synapse/federation/transport/client.py ファイルの表示

@@ -16,6 +16,7 @@
import logging
import urllib
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
@@ -42,18 +43,21 @@ from synapse.api.urls import (
)
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser
from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
from synapse.http.types import QueryParams
from synapse.types import JsonDict
from synapse.util import ExceptionBundle

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)


class TransportLayerClient:
"""Sends federation HTTP requests to other servers"""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.client = hs.get_federation_http_client()
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
@@ -133,7 +137,7 @@ class TransportLayerClient:

async def backfill(
self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
) -> Optional[JsonDict]:
) -> Optional[Union[JsonDict, list]]:
"""Requests `limit` previous PDUs in a given context before list of
PDUs.

@@ -388,6 +392,7 @@ class TransportLayerClient:
# server was just having a momentary blip, the room will be out of
# sync.
ignore_backoff=True,
parser=LegacyJsonSendParser(),
)

async def send_leave_v2(
@@ -445,7 +450,11 @@ class TransportLayerClient:
path = _create_v1_path("/invite/%s/%s", room_id, event_id)

return await self.client.put_json(
destination=destination, path=path, data=content, ignore_backoff=True
destination=destination,
path=path,
data=content,
ignore_backoff=True,
parser=LegacyJsonSendParser(),
)

async def send_invite_v2(


+ 58
- 18
synapse/http/matrixfederationclient.py ファイルの表示

@@ -17,7 +17,6 @@ import codecs
import logging
import random
import sys
import typing
import urllib.parse
from http import HTTPStatus
from io import BytesIO, StringIO
@@ -30,9 +29,11 @@ from typing import (
Generic,
List,
Optional,
TextIO,
Tuple,
TypeVar,
Union,
cast,
overload,
)

@@ -183,20 +184,61 @@ class MatrixFederationRequest:
return self.json


class JsonParser(ByteParser[Union[JsonDict, list]]):
class _BaseJsonParser(ByteParser[T]):
"""A parser that buffers the response and tries to parse it as JSON."""

CONTENT_TYPE = "application/json"

def __init__(self) -> None:
def __init__(
self, validator: Optional[Callable[[Optional[object]], bool]] = None
) -> None:
"""
Args:
validator: A callable which takes the parsed JSON value and returns
true if the value is valid.
"""
self._buffer = StringIO()
self._binary_wrapper = BinaryIOWrapper(self._buffer)
self._validator = validator

def write(self, data: bytes) -> int:
return self._binary_wrapper.write(data)

def finish(self) -> Union[JsonDict, list]:
return json_decoder.decode(self._buffer.getvalue())
def finish(self) -> T:
result = json_decoder.decode(self._buffer.getvalue())
if self._validator is not None and not self._validator(result):
raise ValueError(
f"Received incorrect JSON value: {result.__class__.__name__}"
)
return result


class JsonParser(_BaseJsonParser[JsonDict]):
"""A parser that buffers the response and tries to parse it as a JSON object."""

def __init__(self) -> None:
super().__init__(self._validate)

@staticmethod
def _validate(v: Any) -> bool:
return isinstance(v, dict)


class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]):
"""Ensure the legacy responses of /send_join & /send_leave are correct."""

def __init__(self) -> None:
super().__init__(self._validate)

@staticmethod
def _validate(v: Any) -> bool:
# Match [integer, JSON dict]
return (
isinstance(v, list)
and len(v) == 2
and type(v[0]) == int
and isinstance(v[1], dict)
)


async def _handle_response(
@@ -313,9 +355,7 @@ async def _handle_response(
class BinaryIOWrapper:
"""A wrapper for a TextIO which converts from bytes on the fly."""

def __init__(
self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
):
def __init__(self, file: TextIO, encoding: str = "utf-8", errors: str = "strict"):
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
self.file = file

@@ -793,7 +833,7 @@ class MatrixFederationHttpClient:
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
) -> Union[JsonDict, list]:
) -> JsonDict:
...

@overload
@@ -825,8 +865,8 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser] = None,
):
parser: Optional[ByteParser[T]] = None,
) -> Union[JsonDict, T]:
"""Sends the specified json data using PUT

Args:
@@ -902,7 +942,7 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout

if parser is None:
parser = JsonParser()
parser = cast(ByteParser[T], JsonParser())

body = await _handle_response(
self.reactor,
@@ -924,7 +964,7 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryParams] = None,
) -> Union[JsonDict, list]:
) -> JsonDict:
"""Sends the specified json data using POST

Args:
@@ -998,7 +1038,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
) -> Union[JsonDict, list]:
) -> JsonDict:
...

@overload
@@ -1024,8 +1064,8 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser] = None,
):
parser: Optional[ByteParser[T]] = None,
) -> Union[JsonDict, T]:
"""GETs some json from the given host homeserver and path

Args:
@@ -1091,7 +1131,7 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout

if parser is None:
parser = JsonParser()
parser = cast(ByteParser[T], JsonParser())

body = await _handle_response(
self.reactor,
@@ -1112,7 +1152,7 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryParams] = None,
) -> Union[JsonDict, list]:
) -> JsonDict:
"""Send a DELETE request to the remote expecting some json response

Args:


+ 5
- 5
tests/federation/test_complexity.py ファイルの表示

@@ -75,7 +75,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()

# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
@@ -106,7 +106,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()

# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
@@ -143,7 +143,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()

# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
@@ -200,7 +200,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()

# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
@@ -230,7 +230,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()

# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)


+ 3
- 3
tests/http/test_matrixfederationclient.py ファイルの表示

@@ -26,7 +26,7 @@ from twisted.web.http import HTTPChannel

from synapse.api.errors import RequestSendFailed
from synapse.http.matrixfederationclient import (
JsonParser,
ByteParser,
MatrixFederationHttpClient,
MatrixFederationRequest,
)
@@ -618,9 +618,9 @@ class FederationClientTests(HomeserverTestCase):
while not test_d.called:
protocol.dataReceived(b"a" * chunk_size)
sent += chunk_size
self.assertLessEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
self.assertLessEqual(sent, ByteParser.MAX_RESPONSE_SIZE)

self.assertEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
self.assertEqual(sent, ByteParser.MAX_RESPONSE_SIZE)

f = self.failureResultOf(test_d)
self.assertIsInstance(f.value, RequestSendFailed)


読み込み中…
キャンセル
保存