* Tweak http types in Synapse AFACIS these are correct, and they make mypy happier on tests.http. * Type hints for test_proxyagent * type hints for test_srv_resolver * test_matrix_federation_agent * tests.http.server._base * tests.http.__init__ * tests.http.test_additional_resource * tests.http.test_client * tests.http.test_endpoint * tests.http.test_matrixfederationclient * tests.http.test_servlet * tests.http.test_simple_client * tests.http.test_site * One fixup in tests.server * Untyped defs * Changelog * Fixup syntax for Python 3.7 * Fix olddeps syntax * Use a twisted IPv4 addr for dummy_address * Fix typo, thanks Sean Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> * Remove redundant `Optional` --------- Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com>tags/v1.77.0rc1
@@ -0,0 +1 @@ | |||
Improve type hints. |
@@ -32,9 +32,6 @@ exclude = (?x) | |||
|synapse/storage/databases/main/cache.py | |||
|synapse/storage/schema/ | |||
|tests/http/federation/test_matrix_federation_agent.py | |||
|tests/http/federation/test_srv_resolver.py | |||
|tests/http/test_proxyagent.py | |||
|tests/module_api/test_api.py | |||
|tests/rest/media/v1/test_media_storage.py | |||
|tests/server.py | |||
@@ -92,6 +89,9 @@ disallow_untyped_defs = True | |||
[mypy-tests.handlers.*] | |||
disallow_untyped_defs = True | |||
[mypy-tests.http.*] | |||
disallow_untyped_defs = True | |||
[mypy-tests.logging.*] | |||
disallow_untyped_defs = True | |||
@@ -44,6 +44,7 @@ from twisted.internet.interfaces import ( | |||
IAddress, | |||
IDelayedCall, | |||
IHostResolution, | |||
IReactorCore, | |||
IReactorPluggableNameResolver, | |||
IReactorTime, | |||
IResolutionReceiver, | |||
@@ -226,7 +227,9 @@ class _IPBlacklistingResolver: | |||
return recv | |||
@implementer(ISynapseReactor) | |||
# ISynapseReactor implies IReactorCore, but explicitly marking it this as an implementer | |||
# of IReactorCore seems to keep mypy-zope happier. | |||
@implementer(IReactorCore, ISynapseReactor) | |||
class BlacklistingReactorWrapper: | |||
""" | |||
A Reactor wrapper which will prevent DNS resolution to blacklisted IP | |||
@@ -38,7 +38,6 @@ from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse | |||
from synapse.http import redact_uri | |||
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials | |||
from synapse.types import ISynapseReactor | |||
logger = logging.getLogger(__name__) | |||
@@ -84,7 +83,7 @@ class ProxyAgent(_AgentBase): | |||
def __init__( | |||
self, | |||
reactor: IReactorCore, | |||
proxy_reactor: Optional[ISynapseReactor] = None, | |||
proxy_reactor: Optional[IReactorCore] = None, | |||
contextFactory: Optional[IPolicyForHTTPS] = None, | |||
connectTimeout: Optional[float] = None, | |||
bindAddress: Optional[bytes] = None, | |||
@@ -19,13 +19,15 @@ from zope.interface import implementer | |||
from OpenSSL import SSL | |||
from OpenSSL.SSL import Connection | |||
from twisted.internet.address import IPv4Address | |||
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator | |||
from twisted.internet.ssl import Certificate, trustRootFromCertificates | |||
from twisted.protocols.tls import TLSMemoryBIOProtocol | |||
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401 | |||
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401 | |||
def get_test_https_policy(): | |||
def get_test_https_policy() -> BrowserLikePolicyForHTTPS: | |||
"""Get a test IPolicyForHTTPS which trusts the test CA cert | |||
Returns: | |||
@@ -39,7 +41,7 @@ def get_test_https_policy(): | |||
return BrowserLikePolicyForHTTPS(trustRoot=trust_root) | |||
def get_test_ca_cert_file(): | |||
def get_test_ca_cert_file() -> str: | |||
"""Get the path to the test CA cert | |||
The keypair is generated with: | |||
@@ -51,7 +53,7 @@ def get_test_ca_cert_file(): | |||
return os.path.join(os.path.dirname(__file__), "ca.crt") | |||
def get_test_key_file(): | |||
def get_test_key_file() -> str: | |||
"""get the path to the test key | |||
The key file is made with: | |||
@@ -137,15 +139,20 @@ class TestServerTLSConnectionFactory: | |||
"""An SSL connection creator which returns connections which present a certificate | |||
signed by our test CA.""" | |||
def __init__(self, sanlist): | |||
def __init__(self, sanlist: List[bytes]): | |||
""" | |||
Args: | |||
sanlist: list[bytes]: a list of subjectAltName values for the cert | |||
sanlist: a list of subjectAltName values for the cert | |||
""" | |||
self._cert_file = create_test_cert_file(sanlist) | |||
def serverConnectionForTLS(self, tlsProtocol): | |||
def serverConnectionForTLS(self, tlsProtocol: TLSMemoryBIOProtocol) -> Connection: | |||
ctx = SSL.Context(SSL.SSLv23_METHOD) | |||
ctx.use_certificate_file(self._cert_file) | |||
ctx.use_privatekey_file(get_test_key_file()) | |||
return Connection(ctx, None) | |||
# A dummy address, useful for tests that use FakeTransport and don't care about where | |||
# packets are going to/coming from. | |||
dummy_address = IPv4Address("TCP", "127.0.0.1", 80) |
@@ -14,7 +14,7 @@ | |||
import base64 | |||
import logging | |||
import os | |||
from typing import Iterable, Optional | |||
from typing import Any, Awaitable, Callable, Generator, List, Optional, cast | |||
from unittest.mock import Mock, patch | |||
import treq | |||
@@ -24,14 +24,19 @@ from zope.interface import implementer | |||
from twisted.internet import defer | |||
from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions | |||
from twisted.internet.interfaces import IProtocolFactory | |||
from twisted.internet.defer import Deferred | |||
from twisted.internet.endpoints import _WrappingProtocol | |||
from twisted.internet.interfaces import ( | |||
IOpenSSLClientConnectionCreator, | |||
IProtocolFactory, | |||
) | |||
from twisted.internet.protocol import Factory | |||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol | |||
from twisted.web._newclient import ResponseNeverReceived | |||
from twisted.web.client import Agent | |||
from twisted.web.http import HTTPChannel, Request | |||
from twisted.web.http_headers import Headers | |||
from twisted.web.iweb import IPolicyForHTTPS | |||
from twisted.web.iweb import IPolicyForHTTPS, IResponse | |||
from synapse.config.homeserver import HomeServerConfig | |||
from synapse.crypto.context_factory import FederationPolicyForHTTPS | |||
@@ -42,11 +47,21 @@ from synapse.http.federation.well_known_resolver import ( | |||
WellKnownResolver, | |||
_cache_period_from_headers, | |||
) | |||
from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context | |||
from synapse.logging.context import ( | |||
SENTINEL_CONTEXT, | |||
LoggingContext, | |||
LoggingContextOrSentinel, | |||
current_context, | |||
) | |||
from synapse.types import ISynapseReactor | |||
from synapse.util.caches.ttlcache import TTLCache | |||
from tests import unittest | |||
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file | |||
from tests.http import ( | |||
TestServerTLSConnectionFactory, | |||
dummy_address, | |||
get_test_ca_cert_file, | |||
) | |||
from tests.server import FakeTransport, ThreadedMemoryReactorClock | |||
from tests.utils import default_config | |||
@@ -54,15 +69,17 @@ logger = logging.getLogger(__name__) | |||
# Once Async Mocks or lambdas are supported this can go away. | |||
def generate_resolve_service(result): | |||
async def resolve_service(_): | |||
def generate_resolve_service( | |||
result: List[Server], | |||
) -> Callable[[Any], Awaitable[List[Server]]]: | |||
async def resolve_service(_: Any) -> List[Server]: | |||
return result | |||
return resolve_service | |||
class MatrixFederationAgentTests(unittest.TestCase): | |||
def setUp(self): | |||
def setUp(self) -> None: | |||
self.reactor = ThreadedMemoryReactorClock() | |||
self.mock_resolver = Mock() | |||
@@ -75,8 +92,12 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
self.tls_factory = FederationPolicyForHTTPS(config) | |||
self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) | |||
self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) | |||
self.well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache( | |||
"test_cache", timer=self.reactor.seconds | |||
) | |||
self.had_well_known_cache: TTLCache[bytes, bool] = TTLCache( | |||
"test_cache", timer=self.reactor.seconds | |||
) | |||
self.well_known_resolver = WellKnownResolver( | |||
self.reactor, | |||
Agent(self.reactor, contextFactory=self.tls_factory), | |||
@@ -89,8 +110,8 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
self, | |||
client_factory: IProtocolFactory, | |||
ssl: bool = True, | |||
expected_sni: bytes = None, | |||
tls_sanlist: Optional[Iterable[bytes]] = None, | |||
expected_sni: Optional[bytes] = None, | |||
tls_sanlist: Optional[List[bytes]] = None, | |||
) -> HTTPChannel: | |||
"""Builds a test server, and completes the outgoing client connection | |||
Args: | |||
@@ -116,8 +137,8 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
if ssl: | |||
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist) | |||
server_protocol = server_factory.buildProtocol(None) | |||
server_protocol = server_factory.buildProtocol(dummy_address) | |||
assert server_protocol is not None | |||
# now, tell the client protocol factory to build the client protocol (it will be a | |||
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an | |||
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via | |||
@@ -125,7 +146,8 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
# | |||
# Normally this would be done by the TCP socket code in Twisted, but we are | |||
# stubbing that out here. | |||
client_protocol = client_factory.buildProtocol(None) | |||
client_protocol = client_factory.buildProtocol(dummy_address) | |||
assert isinstance(client_protocol, _WrappingProtocol) | |||
client_protocol.makeConnection( | |||
FakeTransport(server_protocol, self.reactor, client_protocol) | |||
) | |||
@@ -136,6 +158,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
) | |||
if ssl: | |||
assert isinstance(server_protocol, TLSMemoryBIOProtocol) | |||
# fish the test server back out of the server-side TLS protocol. | |||
http_protocol = server_protocol.wrappedProtocol | |||
# grab a hold of the TLS connection, in case it gets torn down | |||
@@ -144,6 +167,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
http_protocol = server_protocol | |||
tls_connection = None | |||
assert isinstance(http_protocol, HTTPChannel) | |||
# give the reactor a pump to get the TLS juices flowing (if needed) | |||
self.reactor.advance(0) | |||
@@ -159,12 +183,14 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
return http_protocol | |||
@defer.inlineCallbacks | |||
def _make_get_request(self, uri: bytes): | |||
def _make_get_request( | |||
self, uri: bytes | |||
) -> Generator["Deferred[object]", object, IResponse]: | |||
""" | |||
Sends a simple GET request via the agent, and checks its logcontext management | |||
""" | |||
with LoggingContext("one") as context: | |||
fetch_d = self.agent.request(b"GET", uri) | |||
fetch_d: Deferred[IResponse] = self.agent.request(b"GET", uri) | |||
# Nothing happened yet | |||
self.assertNoResult(fetch_d) | |||
@@ -172,8 +198,9 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
# should have reset logcontext to the sentinel | |||
_check_logcontext(SENTINEL_CONTEXT) | |||
fetch_res: IResponse | |||
try: | |||
fetch_res = yield fetch_d | |||
fetch_res = yield fetch_d # type: ignore[misc, assignment] | |||
return fetch_res | |||
except Exception as e: | |||
logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e) | |||
@@ -216,7 +243,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
request: Request, | |||
content: bytes, | |||
headers: Optional[dict] = None, | |||
): | |||
) -> None: | |||
"""Check that an incoming request looks like a valid .well-known request, and | |||
send back the response. | |||
""" | |||
@@ -237,16 +264,16 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
because it is created too early during setUp | |||
""" | |||
return MatrixFederationAgent( | |||
reactor=self.reactor, | |||
reactor=cast(ISynapseReactor, self.reactor), | |||
tls_client_options_factory=self.tls_factory, | |||
user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. | |||
user_agent=b"test-agent", # Note that this is unused since _well_known_resolver is provided. | |||
ip_whitelist=IPSet(), | |||
ip_blacklist=IPSet(), | |||
_srv_resolver=self.mock_resolver, | |||
_well_known_resolver=self.well_known_resolver, | |||
) | |||
def test_get(self): | |||
def test_get(self) -> None: | |||
"""happy-path test of a GET request with an explicit port""" | |||
self._do_get() | |||
@@ -254,11 +281,11 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
os.environ, | |||
{"https_proxy": "proxy.com", "no_proxy": "testserv"}, | |||
) | |||
def test_get_bypass_proxy(self): | |||
def test_get_bypass_proxy(self) -> None: | |||
"""test of a GET request with an explicit port and bypass proxy""" | |||
self._do_get() | |||
def _do_get(self): | |||
def _do_get(self) -> None: | |||
"""test of a GET request with an explicit port""" | |||
self.agent = self._make_agent() | |||
@@ -318,7 +345,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
@patch.dict( | |||
os.environ, {"https_proxy": "http://proxy.com", "no_proxy": "unused.com"} | |||
) | |||
def test_get_via_http_proxy(self): | |||
def test_get_via_http_proxy(self) -> None: | |||
"""test for federation request through a http proxy""" | |||
self._do_get_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None) | |||
@@ -326,7 +353,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
os.environ, | |||
{"https_proxy": "http://user:pass@proxy.com", "no_proxy": "unused.com"}, | |||
) | |||
def test_get_via_http_proxy_with_auth(self): | |||
def test_get_via_http_proxy_with_auth(self) -> None: | |||
"""test for federation request through a http proxy with authentication""" | |||
self._do_get_via_proxy( | |||
expect_proxy_ssl=False, expected_auth_credentials=b"user:pass" | |||
@@ -335,7 +362,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
@patch.dict( | |||
os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} | |||
) | |||
def test_get_via_https_proxy(self): | |||
def test_get_via_https_proxy(self) -> None: | |||
"""test for federation request through a https proxy""" | |||
self._do_get_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None) | |||
@@ -343,7 +370,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
os.environ, | |||
{"https_proxy": "https://user:pass@proxy.com", "no_proxy": "unused.com"}, | |||
) | |||
def test_get_via_https_proxy_with_auth(self): | |||
def test_get_via_https_proxy_with_auth(self) -> None: | |||
"""test for federation request through a https proxy with authentication""" | |||
self._do_get_via_proxy( | |||
expect_proxy_ssl=True, expected_auth_credentials=b"user:pass" | |||
@@ -353,7 +380,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
self, | |||
expect_proxy_ssl: bool = False, | |||
expected_auth_credentials: Optional[bytes] = None, | |||
): | |||
) -> None: | |||
"""Send a https federation request via an agent and check that it is correctly | |||
received at the proxy and client. The proxy can use either http or https. | |||
Args: | |||
@@ -418,10 +445,12 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
# now we make another test server to act as the upstream HTTP server. | |||
server_ssl_protocol = _wrap_server_factory_for_tls( | |||
_get_test_protocol_factory() | |||
).buildProtocol(None) | |||
).buildProtocol(dummy_address) | |||
assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol) | |||
# Tell the HTTP server to send outgoing traffic back via the proxy's transport. | |||
proxy_server_transport = proxy_server.transport | |||
assert proxy_server_transport is not None | |||
server_ssl_protocol.makeConnection(proxy_server_transport) | |||
# ... and replace the protocol on the proxy's transport with the | |||
@@ -451,6 +480,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
# now there should be a pending request | |||
http_server = server_ssl_protocol.wrappedProtocol | |||
assert isinstance(http_server, HTTPChannel) | |||
self.assertEqual(len(http_server.requests), 1) | |||
request = http_server.requests[0] | |||
@@ -491,7 +521,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
json = self.successResultOf(treq.json_content(response)) | |||
self.assertEqual(json, {"a": 1}) | |||
def test_get_ip_address(self): | |||
def test_get_ip_address(self) -> None: | |||
""" | |||
Test the behaviour when the server name contains an explicit IP (with no port) | |||
""" | |||
@@ -526,7 +556,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
self.reactor.pump((0.1,)) | |||
self.successResultOf(test_d) | |||
def test_get_ipv6_address(self): | |||
def test_get_ipv6_address(self) -> None: | |||
""" | |||
Test the behaviour when the server name contains an explicit IPv6 address | |||
(with no port) | |||
@@ -562,7 +592,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
self.reactor.pump((0.1,)) | |||
self.successResultOf(test_d) | |||
def test_get_ipv6_address_with_port(self): | |||
def test_get_ipv6_address_with_port(self) -> None: | |||
""" | |||
Test the behaviour when the server name contains an explicit IPv6 address | |||
(with explicit port) | |||
@@ -598,7 +628,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
self.reactor.pump((0.1,)) | |||
self.successResultOf(test_d) | |||
def test_get_hostname_bad_cert(self): | |||
def test_get_hostname_bad_cert(self) -> None: | |||
""" | |||
Test the behaviour when the certificate on the server doesn't match the hostname | |||
""" | |||
@@ -651,7 +681,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
failure_reason = e.value.reasons[0] | |||
self.assertIsInstance(failure_reason.value, VerificationError) | |||
def test_get_ip_address_bad_cert(self): | |||
def test_get_ip_address_bad_cert(self) -> None: | |||
""" | |||
Test the behaviour when the server name contains an explicit IP, but | |||
the server cert doesn't cover it | |||
@@ -684,7 +714,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
failure_reason = e.value.reasons[0] | |||
self.assertIsInstance(failure_reason.value, VerificationError) | |||
def test_get_no_srv_no_well_known(self): | |||
def test_get_no_srv_no_well_known(self) -> None: | |||
""" | |||
Test the behaviour when the server name has no port, no SRV, and no well-known | |||
""" | |||
@@ -740,7 +770,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
self.reactor.pump((0.1,)) | |||
self.successResultOf(test_d) | |||
def test_get_well_known(self): | |||
def test_get_well_known(self) -> None: | |||
"""Test the behaviour when the .well-known delegates elsewhere""" | |||
self.agent = self._make_agent() | |||
@@ -802,7 +832,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
self.well_known_cache.expire() | |||
self.assertNotIn(b"testserv", self.well_known_cache) | |||
def test_get_well_known_redirect(self): | |||
def test_get_well_known_redirect(self) -> None: | |||
"""Test the behaviour when the server name has no port and no SRV record, but | |||
the .well-known has a 300 redirect | |||
""" | |||
@@ -892,7 +922,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
self.well_known_cache.expire() | |||
self.assertNotIn(b"testserv", self.well_known_cache) | |||
def test_get_invalid_well_known(self): | |||
def test_get_invalid_well_known(self) -> None: | |||
""" | |||
Test the behaviour when the server name has an *invalid* well-known (and no SRV) | |||
""" | |||
@@ -945,7 +975,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
self.reactor.pump((0.1,)) | |||
self.successResultOf(test_d) | |||
def test_get_well_known_unsigned_cert(self): | |||
def test_get_well_known_unsigned_cert(self) -> None: | |||
"""Test the behaviour when the .well-known server presents a cert | |||
not signed by a CA | |||
""" | |||
@@ -969,7 +999,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
ip_blacklist=IPSet(), | |||
_srv_resolver=self.mock_resolver, | |||
_well_known_resolver=WellKnownResolver( | |||
self.reactor, | |||
cast(ISynapseReactor, self.reactor), | |||
Agent(self.reactor, contextFactory=tls_factory), | |||
b"test-agent", | |||
well_known_cache=self.well_known_cache, | |||
@@ -999,7 +1029,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
b"_matrix._tcp.testserv" | |||
) | |||
def test_get_hostname_srv(self): | |||
def test_get_hostname_srv(self) -> None: | |||
""" | |||
Test the behaviour when there is a single SRV record | |||
""" | |||
@@ -1041,7 +1071,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
self.reactor.pump((0.1,)) | |||
self.successResultOf(test_d) | |||
def test_get_well_known_srv(self): | |||
def test_get_well_known_srv(self) -> None: | |||
"""Test the behaviour when the .well-known redirects to a place where there | |||
is a SRV. | |||
""" | |||
@@ -1101,7 +1131,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
self.reactor.pump((0.1,)) | |||
self.successResultOf(test_d) | |||
def test_idna_servername(self): | |||
def test_idna_servername(self) -> None: | |||
"""test the behaviour when the server name has idna chars in""" | |||
self.agent = self._make_agent() | |||
@@ -1163,7 +1193,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
self.reactor.pump((0.1,)) | |||
self.successResultOf(test_d) | |||
def test_idna_srv_target(self): | |||
def test_idna_srv_target(self) -> None: | |||
"""test the behaviour when the target of a SRV record has idna chars""" | |||
self.agent = self._make_agent() | |||
@@ -1206,7 +1236,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
self.reactor.pump((0.1,)) | |||
self.successResultOf(test_d) | |||
def test_well_known_cache(self): | |||
def test_well_known_cache(self) -> None: | |||
self.reactor.lookups["testserv"] = "1.2.3.4" | |||
fetch_d = defer.ensureDeferred( | |||
@@ -1262,7 +1292,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
r = self.successResultOf(fetch_d) | |||
self.assertEqual(r.delegated_server, b"other-server") | |||
def test_well_known_cache_with_temp_failure(self): | |||
def test_well_known_cache_with_temp_failure(self) -> None: | |||
"""Test that we refetch well-known before the cache expires, and that | |||
it ignores transient errors. | |||
""" | |||
@@ -1341,7 +1371,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
r = self.successResultOf(fetch_d) | |||
self.assertEqual(r.delegated_server, None) | |||
def test_well_known_too_large(self): | |||
def test_well_known_too_large(self) -> None: | |||
"""A well-known query that returns a result which is too large should be rejected.""" | |||
self.reactor.lookups["testserv"] = "1.2.3.4" | |||
@@ -1367,7 +1397,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
r = self.successResultOf(fetch_d) | |||
self.assertIsNone(r.delegated_server) | |||
def test_srv_fallbacks(self): | |||
def test_srv_fallbacks(self) -> None: | |||
"""Test that other SRV results are tried if the first one fails.""" | |||
self.agent = self._make_agent() | |||
@@ -1427,7 +1457,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
class TestCachePeriodFromHeaders(unittest.TestCase): | |||
def test_cache_control(self): | |||
def test_cache_control(self) -> None: | |||
# uppercase | |||
self.assertEqual( | |||
_cache_period_from_headers( | |||
@@ -1464,7 +1494,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase): | |||
0, | |||
) | |||
def test_expires(self): | |||
def test_expires(self) -> None: | |||
self.assertEqual( | |||
_cache_period_from_headers( | |||
Headers({b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"]}), | |||
@@ -1491,14 +1521,14 @@ class TestCachePeriodFromHeaders(unittest.TestCase): | |||
self.assertEqual(_cache_period_from_headers(Headers({b"Expires": [b"0"]})), 0) | |||
def _check_logcontext(context): | |||
def _check_logcontext(context: LoggingContextOrSentinel) -> None: | |||
current = current_context() | |||
if current is not context: | |||
raise AssertionError("Expected logcontext %s but was %s" % (context, current)) | |||
def _wrap_server_factory_for_tls( | |||
factory: IProtocolFactory, sanlist: Iterable[bytes] = None | |||
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None | |||
) -> IProtocolFactory: | |||
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory | |||
The resultant factory will create a TLS server which presents a certificate | |||
@@ -1537,7 +1567,7 @@ def _get_test_protocol_factory() -> IProtocolFactory: | |||
return server_factory | |||
def _log_request(request: str): | |||
def _log_request(request: str) -> None: | |||
"""Implements Factory.log, which is expected by Request.finish""" | |||
logger.info(f"Completed request {request}") | |||
@@ -1547,6 +1577,8 @@ class TrustingTLSPolicyForHTTPS: | |||
"""An IPolicyForHTTPS which checks that the certificate belongs to the | |||
right server, but doesn't check the certificate chain.""" | |||
def creatorForNetloc(self, hostname, port): | |||
def creatorForNetloc( | |||
self, hostname: bytes, port: int | |||
) -> IOpenSSLClientConnectionCreator: | |||
certificateOptions = OpenSSLCertificateOptions() | |||
return ClientTLSOptions(hostname, certificateOptions.getContext()) |
@@ -12,7 +12,7 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Dict, Generator, List, Tuple, cast | |||
from unittest.mock import Mock | |||
from twisted.internet import defer | |||
@@ -20,7 +20,7 @@ from twisted.internet.defer import Deferred | |||
from twisted.internet.error import ConnectError | |||
from twisted.names import dns, error | |||
from synapse.http.federation.srv_resolver import SrvResolver | |||
from synapse.http.federation.srv_resolver import Server, SrvResolver | |||
from synapse.logging.context import LoggingContext, current_context | |||
from tests import unittest | |||
@@ -28,7 +28,7 @@ from tests.utils import MockClock | |||
class SrvResolverTestCase(unittest.TestCase): | |||
def test_resolve(self): | |||
def test_resolve(self) -> None: | |||
dns_client_mock = Mock() | |||
service_name = b"test_service.example.com" | |||
@@ -38,18 +38,19 @@ class SrvResolverTestCase(unittest.TestCase): | |||
type=dns.SRV, payload=dns.Record_SRV(target=host_name) | |||
) | |||
result_deferred = Deferred() | |||
result_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred() | |||
dns_client_mock.lookupService.return_value = result_deferred | |||
cache = {} | |||
cache: Dict[bytes, List[Server]] = {} | |||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) | |||
@defer.inlineCallbacks | |||
def do_lookup(): | |||
def do_lookup() -> Generator["Deferred[object]", object, List[Server]]: | |||
with LoggingContext("one") as ctx: | |||
resolve_d = resolver.resolve_service(service_name) | |||
result = yield defer.ensureDeferred(resolve_d) | |||
result: List[Server] | |||
result = yield defer.ensureDeferred(resolve_d) # type: ignore[assignment] | |||
# should have restored our context | |||
self.assertIs(current_context(), ctx) | |||
@@ -70,7 +71,9 @@ class SrvResolverTestCase(unittest.TestCase): | |||
self.assertEqual(servers[0].host, host_name) | |||
@defer.inlineCallbacks | |||
def test_from_cache_expired_and_dns_fail(self): | |||
def test_from_cache_expired_and_dns_fail( | |||
self, | |||
) -> Generator["Deferred[object]", object, None]: | |||
dns_client_mock = Mock() | |||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) | |||
@@ -81,10 +84,13 @@ class SrvResolverTestCase(unittest.TestCase): | |||
entry.priority = 0 | |||
entry.weight = 0 | |||
cache = {service_name: [entry]} | |||
cache = {service_name: [cast(Server, entry)]} | |||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) | |||
servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) | |||
servers: List[Server] | |||
servers = yield defer.ensureDeferred( | |||
resolver.resolve_service(service_name) | |||
) # type: ignore[assignment] | |||
dns_client_mock.lookupService.assert_called_once_with(service_name) | |||
@@ -92,7 +98,7 @@ class SrvResolverTestCase(unittest.TestCase): | |||
self.assertEqual(servers, cache[service_name]) | |||
@defer.inlineCallbacks | |||
def test_from_cache(self): | |||
def test_from_cache(self) -> Generator["Deferred[object]", object, None]: | |||
clock = MockClock() | |||
dns_client_mock = Mock(spec_set=["lookupService"]) | |||
@@ -105,12 +111,15 @@ class SrvResolverTestCase(unittest.TestCase): | |||
entry.priority = 0 | |||
entry.weight = 0 | |||
cache = {service_name: [entry]} | |||
cache = {service_name: [cast(Server, entry)]} | |||
resolver = SrvResolver( | |||
dns_client=dns_client_mock, cache=cache, get_time=clock.time | |||
) | |||
servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) | |||
servers: List[Server] | |||
servers = yield defer.ensureDeferred( | |||
resolver.resolve_service(service_name) | |||
) # type: ignore[assignment] | |||
self.assertFalse(dns_client_mock.lookupService.called) | |||
@@ -118,45 +127,48 @@ class SrvResolverTestCase(unittest.TestCase): | |||
self.assertEqual(servers, cache[service_name]) | |||
@defer.inlineCallbacks | |||
def test_empty_cache(self): | |||
def test_empty_cache(self) -> Generator["Deferred[object]", object, None]: | |||
dns_client_mock = Mock() | |||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) | |||
service_name = b"test_service.example.com" | |||
cache = {} | |||
cache: Dict[bytes, List[Server]] = {} | |||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) | |||
with self.assertRaises(error.DNSServerError): | |||
yield defer.ensureDeferred(resolver.resolve_service(service_name)) | |||
@defer.inlineCallbacks | |||
def test_name_error(self): | |||
def test_name_error(self) -> Generator["Deferred[object]", object, None]: | |||
dns_client_mock = Mock() | |||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError()) | |||
service_name = b"test_service.example.com" | |||
cache = {} | |||
cache: Dict[bytes, List[Server]] = {} | |||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) | |||
servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) | |||
servers: List[Server] | |||
servers = yield defer.ensureDeferred( | |||
resolver.resolve_service(service_name) | |||
) # type: ignore[assignment] | |||
self.assertEqual(len(servers), 0) | |||
self.assertEqual(len(cache), 0) | |||
def test_disabled_service(self): | |||
def test_disabled_service(self) -> None: | |||
""" | |||
test the behaviour when there is a single record which is ".". | |||
""" | |||
service_name = b"test_service.example.com" | |||
lookup_deferred = Deferred() | |||
lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred() | |||
dns_client_mock = Mock() | |||
dns_client_mock.lookupService.return_value = lookup_deferred | |||
cache = {} | |||
cache: Dict[bytes, List[Server]] = {} | |||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) | |||
# Old versions of Twisted don't have an ensureDeferred in failureResultOf. | |||
@@ -173,16 +185,16 @@ class SrvResolverTestCase(unittest.TestCase): | |||
self.failureResultOf(resolve_d, ConnectError) | |||
def test_non_srv_answer(self): | |||
def test_non_srv_answer(self) -> None: | |||
""" | |||
test the behaviour when the dns server gives us a spurious non-SRV response | |||
""" | |||
service_name = b"test_service.example.com" | |||
lookup_deferred = Deferred() | |||
lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred() | |||
dns_client_mock = Mock() | |||
dns_client_mock.lookupService.return_value = lookup_deferred | |||
cache = {} | |||
cache: Dict[bytes, List[Server]] = {} | |||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) | |||
# Old versions of Twisted don't have an ensureDeferred in successResultOf. | |||
@@ -556,6 +556,6 @@ def _get_stack_frame_method_name(frame_info: inspect.FrameInfo) -> str: | |||
return method_name | |||
def _hash_stack(stack: List[inspect.FrameInfo]): | |||
def _hash_stack(stack: List[inspect.FrameInfo]) -> Tuple[str, ...]: | |||
"""Turns a stack into a hashable value that can be put into a set.""" | |||
return tuple(_format_stack_frame(frame) for frame in stack) |
@@ -11,28 +11,34 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any | |||
from twisted.web.server import Request | |||
from synapse.http.additional_resource import AdditionalResource | |||
from synapse.http.server import respond_with_json | |||
from synapse.http.site import SynapseRequest | |||
from synapse.types import JsonDict | |||
from tests.server import FakeSite, make_request | |||
from tests.unittest import HomeserverTestCase | |||
class _AsyncTestCustomEndpoint: | |||
def __init__(self, config, module_api): | |||
def __init__(self, config: JsonDict, module_api: Any) -> None: | |||
pass | |||
async def handle_request(self, request): | |||
async def handle_request(self, request: Request) -> None: | |||
assert isinstance(request, SynapseRequest) | |||
respond_with_json(request, 200, {"some_key": "some_value_async"}) | |||
class _SyncTestCustomEndpoint: | |||
def __init__(self, config, module_api): | |||
def __init__(self, config: JsonDict, module_api: Any) -> None: | |||
pass | |||
async def handle_request(self, request): | |||
async def handle_request(self, request: Request) -> None: | |||
assert isinstance(request, SynapseRequest) | |||
respond_with_json(request, 200, {"some_key": "some_value_sync"}) | |||
@@ -41,7 +47,7 @@ class AdditionalResourceTests(HomeserverTestCase): | |||
and async handlers. | |||
""" | |||
def test_async(self): | |||
def test_async(self) -> None: | |||
handler = _AsyncTestCustomEndpoint({}, None).handle_request | |||
resource = AdditionalResource(self.hs, handler) | |||
@@ -52,7 +58,7 @@ class AdditionalResourceTests(HomeserverTestCase): | |||
self.assertEqual(channel.code, 200) | |||
self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) | |||
def test_sync(self): | |||
def test_sync(self) -> None: | |||
handler = _SyncTestCustomEndpoint({}, None).handle_request | |||
resource = AdditionalResource(self.hs, handler) | |||
@@ -13,10 +13,12 @@ | |||
# limitations under the License. | |||
from io import BytesIO | |||
from typing import Tuple, Union | |||
from unittest.mock import Mock | |||
from netaddr import IPSet | |||
from twisted.internet.defer import Deferred | |||
from twisted.internet.error import DNSLookupError | |||
from twisted.python.failure import Failure | |||
from twisted.test.proto_helpers import AccumulatingProtocol | |||
@@ -28,6 +30,7 @@ from synapse.http.client import ( | |||
BlacklistingAgentWrapper, | |||
BlacklistingReactorWrapper, | |||
BodyExceededMaxSize, | |||
_DiscardBodyWithMaxSizeProtocol, | |||
read_body_with_max_size, | |||
) | |||
@@ -36,7 +39,9 @@ from tests.unittest import TestCase | |||
class ReadBodyWithMaxSizeTests(TestCase): | |||
def _build_response(self, length=UNKNOWN_LENGTH): | |||
def _build_response( | |||
self, length: Union[int, str] = UNKNOWN_LENGTH | |||
) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]: | |||
"""Start reading the body, returns the response, result and proto""" | |||
response = Mock(length=length) | |||
result = BytesIO() | |||
@@ -48,23 +53,27 @@ class ReadBodyWithMaxSizeTests(TestCase): | |||
return result, deferred, protocol | |||
def _assert_error(self, deferred, protocol): | |||
def _assert_error( | |||
self, deferred: "Deferred[int]", protocol: _DiscardBodyWithMaxSizeProtocol | |||
) -> None: | |||
"""Ensure that the expected error is received.""" | |||
self.assertIsInstance(deferred.result, Failure) | |||
assert isinstance(deferred.result, Failure) | |||
self.assertIsInstance(deferred.result.value, BodyExceededMaxSize) | |||
protocol.transport.abortConnection.assert_called_once() | |||
assert protocol.transport is not None | |||
# type-ignore: presumably abortConnection has been replaced with a Mock. | |||
protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined] | |||
def _cleanup_error(self, deferred): | |||
def _cleanup_error(self, deferred: "Deferred[int]") -> None: | |||
"""Ensure that the error in the Deferred is handled gracefully.""" | |||
called = [False] | |||
def errback(f): | |||
def errback(f: Failure) -> None: | |||
called[0] = True | |||
deferred.addErrback(errback) | |||
self.assertTrue(called[0]) | |||
def test_no_error(self): | |||
def test_no_error(self) -> None: | |||
"""A response that is NOT too large.""" | |||
result, deferred, protocol = self._build_response() | |||
@@ -76,7 +85,7 @@ class ReadBodyWithMaxSizeTests(TestCase): | |||
self.assertEqual(result.getvalue(), b"12345") | |||
self.assertEqual(deferred.result, 5) | |||
def test_too_large(self): | |||
def test_too_large(self) -> None: | |||
"""A response which is too large raises an exception.""" | |||
result, deferred, protocol = self._build_response() | |||
@@ -87,7 +96,7 @@ class ReadBodyWithMaxSizeTests(TestCase): | |||
self._assert_error(deferred, protocol) | |||
self._cleanup_error(deferred) | |||
def test_multiple_packets(self): | |||
def test_multiple_packets(self) -> None: | |||
"""Data should be accumulated through mutliple packets.""" | |||
result, deferred, protocol = self._build_response() | |||
@@ -100,7 +109,7 @@ class ReadBodyWithMaxSizeTests(TestCase): | |||
self.assertEqual(result.getvalue(), b"1234") | |||
self.assertEqual(deferred.result, 4) | |||
def test_additional_data(self): | |||
def test_additional_data(self) -> None: | |||
"""A connection can receive data after being closed.""" | |||
result, deferred, protocol = self._build_response() | |||
@@ -115,7 +124,7 @@ class ReadBodyWithMaxSizeTests(TestCase): | |||
self._assert_error(deferred, protocol) | |||
self._cleanup_error(deferred) | |||
def test_content_length(self): | |||
def test_content_length(self) -> None: | |||
"""The body shouldn't be read (at all) if the Content-Length header is too large.""" | |||
result, deferred, protocol = self._build_response(length=10) | |||
@@ -132,7 +141,7 @@ class ReadBodyWithMaxSizeTests(TestCase): | |||
class BlacklistingAgentTest(TestCase): | |||
def setUp(self): | |||
def setUp(self) -> None: | |||
self.reactor, self.clock = get_clock() | |||
self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4" | |||
@@ -151,7 +160,7 @@ class BlacklistingAgentTest(TestCase): | |||
self.ip_whitelist = IPSet([self.allowed_ip.decode()]) | |||
self.ip_blacklist = IPSet(["5.0.0.0/8"]) | |||
def test_reactor(self): | |||
def test_reactor(self) -> None: | |||
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs.""" | |||
agent = Agent( | |||
BlacklistingReactorWrapper( | |||
@@ -197,7 +206,7 @@ class BlacklistingAgentTest(TestCase): | |||
response = self.successResultOf(d) | |||
self.assertEqual(response.code, 200) | |||
def test_agent(self): | |||
def test_agent(self) -> None: | |||
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs.""" | |||
agent = BlacklistingAgentWrapper( | |||
Agent(self.reactor), | |||
@@ -17,7 +17,7 @@ from tests import unittest | |||
class ServerNameTestCase(unittest.TestCase): | |||
def test_parse_server_name(self): | |||
def test_parse_server_name(self) -> None: | |||
test_data = { | |||
"localhost": ("localhost", None), | |||
"my-example.com:1234": ("my-example.com", 1234), | |||
@@ -32,7 +32,7 @@ class ServerNameTestCase(unittest.TestCase): | |||
for i, o in test_data.items(): | |||
self.assertEqual(parse_server_name(i), o) | |||
def test_validate_bad_server_names(self): | |||
def test_validate_bad_server_names(self) -> None: | |||
test_data = [ | |||
"", # empty | |||
"localhost:http", # non-numeric port | |||
@@ -11,16 +11,16 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Generator | |||
from unittest.mock import Mock | |||
from netaddr import IPSet | |||
from parameterized import parameterized | |||
from twisted.internet import defer | |||
from twisted.internet.defer import TimeoutError | |||
from twisted.internet.defer import Deferred, TimeoutError | |||
from twisted.internet.error import ConnectingCancelledError, DNSLookupError | |||
from twisted.test.proto_helpers import StringTransport | |||
from twisted.test.proto_helpers import MemoryReactor, StringTransport | |||
from twisted.web.client import ResponseNeverReceived | |||
from twisted.web.http import HTTPChannel | |||
@@ -30,34 +30,43 @@ from synapse.http.matrixfederationclient import ( | |||
MatrixFederationHttpClient, | |||
MatrixFederationRequest, | |||
) | |||
from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context | |||
from synapse.logging.context import ( | |||
SENTINEL_CONTEXT, | |||
LoggingContext, | |||
LoggingContextOrSentinel, | |||
current_context, | |||
) | |||
from synapse.server import HomeServer | |||
from synapse.util import Clock | |||
from tests.server import FakeTransport | |||
from tests.unittest import HomeserverTestCase | |||
def check_logcontext(context): | |||
def check_logcontext(context: LoggingContextOrSentinel) -> None: | |||
current = current_context() | |||
if current is not context: | |||
raise AssertionError("Expected logcontext %s but was %s" % (context, current)) | |||
class FederationClientTests(HomeserverTestCase): | |||
def make_homeserver(self, reactor, clock): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
hs = self.setup_test_homeserver(reactor=reactor, clock=clock) | |||
return hs | |||
def prepare(self, reactor, clock, homeserver): | |||
def prepare( | |||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | |||
) -> None: | |||
self.cl = MatrixFederationHttpClient(self.hs, None) | |||
self.reactor.lookups["testserv"] = "1.2.3.4" | |||
def test_client_get(self): | |||
def test_client_get(self) -> None: | |||
""" | |||
happy-path test of a GET request | |||
""" | |||
@defer.inlineCallbacks | |||
def do_request(): | |||
def do_request() -> Generator["Deferred[object]", object, object]: | |||
with LoggingContext("one") as context: | |||
fetch_d = defer.ensureDeferred( | |||
self.cl.get_json("testserv:8008", "foo/bar") | |||
@@ -119,7 +128,7 @@ class FederationClientTests(HomeserverTestCase): | |||
# check the response is as expected | |||
self.assertEqual(res, {"a": 1}) | |||
def test_dns_error(self): | |||
def test_dns_error(self) -> None: | |||
""" | |||
If the DNS lookup returns an error, it will bubble up. | |||
""" | |||
@@ -132,7 +141,7 @@ class FederationClientTests(HomeserverTestCase): | |||
self.assertIsInstance(f.value, RequestSendFailed) | |||
self.assertIsInstance(f.value.inner_exception, DNSLookupError) | |||
def test_client_connection_refused(self): | |||
def test_client_connection_refused(self) -> None: | |||
d = defer.ensureDeferred( | |||
self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) | |||
) | |||
@@ -156,7 +165,7 @@ class FederationClientTests(HomeserverTestCase): | |||
self.assertIsInstance(f.value, RequestSendFailed) | |||
self.assertIs(f.value.inner_exception, e) | |||
def test_client_never_connect(self): | |||
def test_client_never_connect(self) -> None: | |||
""" | |||
If the HTTP request is not connected and is timed out, it'll give a | |||
ConnectingCancelledError or TimeoutError. | |||
@@ -188,7 +197,7 @@ class FederationClientTests(HomeserverTestCase): | |||
f.value.inner_exception, (ConnectingCancelledError, TimeoutError) | |||
) | |||
def test_client_connect_no_response(self): | |||
def test_client_connect_no_response(self) -> None: | |||
""" | |||
If the HTTP request is connected, but gets no response before being | |||
timed out, it'll give a ResponseNeverReceived. | |||
@@ -222,7 +231,7 @@ class FederationClientTests(HomeserverTestCase): | |||
self.assertIsInstance(f.value, RequestSendFailed) | |||
self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived) | |||
def test_client_ip_range_blacklist(self): | |||
def test_client_ip_range_blacklist(self) -> None: | |||
"""Ensure that Synapse does not try to connect to blacklisted IPs""" | |||
# Set up the ip_range blacklist | |||
@@ -292,7 +301,7 @@ class FederationClientTests(HomeserverTestCase): | |||
f = self.failureResultOf(d, RequestSendFailed) | |||
self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError) | |||
def test_client_gets_headers(self): | |||
def test_client_gets_headers(self) -> None: | |||
""" | |||
Once the client gets the headers, _request returns successfully. | |||
""" | |||
@@ -319,7 +328,7 @@ class FederationClientTests(HomeserverTestCase): | |||
self.assertEqual(r.code, 200) | |||
@parameterized.expand(["get_json", "post_json", "delete_json", "put_json"]) | |||
def test_timeout_reading_body(self, method_name: str): | |||
def test_timeout_reading_body(self, method_name: str) -> None: | |||
""" | |||
If the HTTP request is connected, but gets no response before being | |||
timed out, it'll give a RequestSendFailed with can_retry. | |||
@@ -351,7 +360,7 @@ class FederationClientTests(HomeserverTestCase): | |||
self.assertTrue(f.value.can_retry) | |||
self.assertIsInstance(f.value.inner_exception, defer.TimeoutError) | |||
def test_client_requires_trailing_slashes(self): | |||
def test_client_requires_trailing_slashes(self) -> None: | |||
""" | |||
If a connection is made to a client but the client rejects it due to | |||
requiring a trailing slash. We need to retry the request with a | |||
@@ -405,7 +414,7 @@ class FederationClientTests(HomeserverTestCase): | |||
r = self.successResultOf(d) | |||
self.assertEqual(r, {}) | |||
def test_client_does_not_retry_on_400_plus(self): | |||
def test_client_does_not_retry_on_400_plus(self) -> None: | |||
""" | |||
Another test for trailing slashes but now test that we don't retry on | |||
trailing slashes on a non-400/M_UNRECOGNIZED response. | |||
@@ -450,7 +459,7 @@ class FederationClientTests(HomeserverTestCase): | |||
# We should get a 404 failure response | |||
self.failureResultOf(d) | |||
def test_client_sends_body(self): | |||
def test_client_sends_body(self) -> None: | |||
defer.ensureDeferred( | |||
self.cl.post_json( | |||
"testserv:8008", "foo/bar", timeout=10000, data={"a": "b"} | |||
@@ -474,7 +483,7 @@ class FederationClientTests(HomeserverTestCase): | |||
content = request.content.read() | |||
self.assertEqual(content, b'{"a":"b"}') | |||
def test_closes_connection(self): | |||
def test_closes_connection(self) -> None: | |||
"""Check that the client closes unused HTTP connections""" | |||
d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar")) | |||
@@ -514,7 +523,7 @@ class FederationClientTests(HomeserverTestCase): | |||
self.assertTrue(conn.disconnecting) | |||
@parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)]) | |||
def test_json_error(self, return_value): | |||
def test_json_error(self, return_value: bytes) -> None: | |||
""" | |||
Test what happens if invalid JSON is returned from the remote endpoint. | |||
""" | |||
@@ -560,7 +569,7 @@ class FederationClientTests(HomeserverTestCase): | |||
f = self.failureResultOf(test_d) | |||
self.assertIsInstance(f.value, RequestSendFailed) | |||
def test_too_big(self): | |||
def test_too_big(self) -> None: | |||
""" | |||
Test what happens if a huge response is returned from the remote endpoint. | |||
""" | |||
@@ -14,7 +14,7 @@ | |||
import base64 | |||
import logging | |||
import os | |||
from typing import Iterable, Optional | |||
from typing import List, Optional | |||
from unittest.mock import patch | |||
import treq | |||
@@ -22,7 +22,11 @@ from netaddr import IPSet | |||
from parameterized import parameterized | |||
from twisted.internet import interfaces # noqa: F401 | |||
from twisted.internet.endpoints import HostnameEndpoint, _WrapperEndpoint | |||
from twisted.internet.endpoints import ( | |||
HostnameEndpoint, | |||
_WrapperEndpoint, | |||
_WrappingProtocol, | |||
) | |||
from twisted.internet.interfaces import IProtocol, IProtocolFactory | |||
from twisted.internet.protocol import Factory | |||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol | |||
@@ -32,7 +36,11 @@ from synapse.http.client import BlacklistingReactorWrapper | |||
from synapse.http.connectproxyclient import ProxyCredentials | |||
from synapse.http.proxyagent import ProxyAgent, parse_proxy | |||
from tests.http import TestServerTLSConnectionFactory, get_test_https_policy | |||
from tests.http import ( | |||
TestServerTLSConnectionFactory, | |||
dummy_address, | |||
get_test_https_policy, | |||
) | |||
from tests.server import FakeTransport, ThreadedMemoryReactorClock | |||
from tests.unittest import TestCase | |||
@@ -183,7 +191,7 @@ class ProxyParserTests(TestCase): | |||
expected_hostname: bytes, | |||
expected_port: int, | |||
expected_credentials: Optional[bytes], | |||
): | |||
) -> None: | |||
""" | |||
Tests that a given proxy URL will be broken into the components. | |||
Args: | |||
@@ -209,7 +217,7 @@ class ProxyParserTests(TestCase): | |||
class MatrixFederationAgentTests(TestCase): | |||
def setUp(self): | |||
def setUp(self) -> None: | |||
self.reactor = ThreadedMemoryReactorClock() | |||
def _make_connection( | |||
@@ -218,7 +226,7 @@ class MatrixFederationAgentTests(TestCase): | |||
server_factory: IProtocolFactory, | |||
ssl: bool = False, | |||
expected_sni: Optional[bytes] = None, | |||
tls_sanlist: Optional[Iterable[bytes]] = None, | |||
tls_sanlist: Optional[List[bytes]] = None, | |||
) -> IProtocol: | |||
"""Builds a test server, and completes the outgoing client connection | |||
@@ -244,7 +252,8 @@ class MatrixFederationAgentTests(TestCase): | |||
if ssl: | |||
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist) | |||
server_protocol = server_factory.buildProtocol(None) | |||
server_protocol = server_factory.buildProtocol(dummy_address) | |||
assert server_protocol is not None | |||
# now, tell the client protocol factory to build the client protocol, | |||
# and wire the output of said protocol up to the server via | |||
@@ -252,7 +261,8 @@ class MatrixFederationAgentTests(TestCase): | |||
# | |||
# Normally this would be done by the TCP socket code in Twisted, but we are | |||
# stubbing that out here. | |||
client_protocol = client_factory.buildProtocol(None) | |||
client_protocol = client_factory.buildProtocol(dummy_address) | |||
assert client_protocol is not None | |||
client_protocol.makeConnection( | |||
FakeTransport(server_protocol, self.reactor, client_protocol) | |||
) | |||
@@ -263,6 +273,7 @@ class MatrixFederationAgentTests(TestCase): | |||
) | |||
if ssl: | |||
assert isinstance(server_protocol, TLSMemoryBIOProtocol) | |||
http_protocol = server_protocol.wrappedProtocol | |||
tls_connection = server_protocol._tlsConnection | |||
else: | |||
@@ -288,7 +299,7 @@ class MatrixFederationAgentTests(TestCase): | |||
scheme: bytes, | |||
hostname: bytes, | |||
path: bytes, | |||
): | |||
) -> None: | |||
"""Runs a test case for a direct connection not going through a proxy. | |||
Args: | |||
@@ -319,6 +330,7 @@ class MatrixFederationAgentTests(TestCase): | |||
ssl=is_https, | |||
expected_sni=hostname if is_https else None, | |||
) | |||
assert isinstance(http_server, HTTPChannel) | |||
# the FakeTransport is async, so we need to pump the reactor | |||
self.reactor.advance(0) | |||
@@ -339,34 +351,34 @@ class MatrixFederationAgentTests(TestCase): | |||
body = self.successResultOf(treq.content(resp)) | |||
self.assertEqual(body, b"result") | |||
def test_http_request(self): | |||
def test_http_request(self) -> None: | |||
agent = ProxyAgent(self.reactor) | |||
self._test_request_direct_connection(agent, b"http", b"test.com", b"") | |||
def test_https_request(self): | |||
def test_https_request(self) -> None: | |||
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy()) | |||
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") | |||
def test_http_request_use_proxy_empty_environment(self): | |||
def test_http_request_use_proxy_empty_environment(self) -> None: | |||
agent = ProxyAgent(self.reactor, use_proxy=True) | |||
self._test_request_direct_connection(agent, b"http", b"test.com", b"") | |||
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"}) | |||
def test_http_request_via_uppercase_no_proxy(self): | |||
def test_http_request_via_uppercase_no_proxy(self) -> None: | |||
agent = ProxyAgent(self.reactor, use_proxy=True) | |||
self._test_request_direct_connection(agent, b"http", b"test.com", b"") | |||
@patch.dict( | |||
os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"} | |||
) | |||
def test_http_request_via_no_proxy(self): | |||
def test_http_request_via_no_proxy(self) -> None: | |||
agent = ProxyAgent(self.reactor, use_proxy=True) | |||
self._test_request_direct_connection(agent, b"http", b"test.com", b"") | |||
@patch.dict( | |||
os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"} | |||
) | |||
def test_https_request_via_no_proxy(self): | |||
def test_https_request_via_no_proxy(self) -> None: | |||
agent = ProxyAgent( | |||
self.reactor, | |||
contextFactory=get_test_https_policy(), | |||
@@ -375,12 +387,12 @@ class MatrixFederationAgentTests(TestCase): | |||
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") | |||
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"}) | |||
def test_http_request_via_no_proxy_star(self): | |||
def test_http_request_via_no_proxy_star(self) -> None: | |||
agent = ProxyAgent(self.reactor, use_proxy=True) | |||
self._test_request_direct_connection(agent, b"http", b"test.com", b"") | |||
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"}) | |||
def test_https_request_via_no_proxy_star(self): | |||
def test_https_request_via_no_proxy_star(self) -> None: | |||
agent = ProxyAgent( | |||
self.reactor, | |||
contextFactory=get_test_https_policy(), | |||
@@ -389,7 +401,7 @@ class MatrixFederationAgentTests(TestCase): | |||
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") | |||
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"}) | |||
def test_http_request_via_proxy(self): | |||
def test_http_request_via_proxy(self) -> None: | |||
""" | |||
Tests that requests can be made through a proxy. | |||
""" | |||
@@ -401,7 +413,7 @@ class MatrixFederationAgentTests(TestCase): | |||
os.environ, | |||
{"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"}, | |||
) | |||
def test_http_request_via_proxy_with_auth(self): | |||
def test_http_request_via_proxy_with_auth(self) -> None: | |||
""" | |||
Tests that authenticated requests can be made through a proxy. | |||
""" | |||
@@ -412,7 +424,7 @@ class MatrixFederationAgentTests(TestCase): | |||
@patch.dict( | |||
os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"} | |||
) | |||
def test_http_request_via_https_proxy(self): | |||
def test_http_request_via_https_proxy(self) -> None: | |||
self._do_http_request_via_proxy( | |||
expect_proxy_ssl=True, expected_auth_credentials=None | |||
) | |||
@@ -424,13 +436,13 @@ class MatrixFederationAgentTests(TestCase): | |||
"no_proxy": "unused.com", | |||
}, | |||
) | |||
def test_http_request_via_https_proxy_with_auth(self): | |||
def test_http_request_via_https_proxy_with_auth(self) -> None: | |||
self._do_http_request_via_proxy( | |||
expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" | |||
) | |||
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) | |||
def test_https_request_via_proxy(self): | |||
def test_https_request_via_proxy(self) -> None: | |||
"""Tests that TLS-encrypted requests can be made through a proxy""" | |||
self._do_https_request_via_proxy( | |||
expect_proxy_ssl=False, expected_auth_credentials=None | |||
@@ -440,7 +452,7 @@ class MatrixFederationAgentTests(TestCase): | |||
os.environ, | |||
{"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, | |||
) | |||
def test_https_request_via_proxy_with_auth(self): | |||
def test_https_request_via_proxy_with_auth(self) -> None: | |||
"""Tests that authenticated, TLS-encrypted requests can be made through a proxy""" | |||
self._do_https_request_via_proxy( | |||
expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies" | |||
@@ -449,7 +461,7 @@ class MatrixFederationAgentTests(TestCase): | |||
@patch.dict( | |||
os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} | |||
) | |||
def test_https_request_via_https_proxy(self): | |||
def test_https_request_via_https_proxy(self) -> None: | |||
"""Tests that TLS-encrypted requests can be made through a proxy""" | |||
self._do_https_request_via_proxy( | |||
expect_proxy_ssl=True, expected_auth_credentials=None | |||
@@ -459,7 +471,7 @@ class MatrixFederationAgentTests(TestCase): | |||
os.environ, | |||
{"https_proxy": "https://bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, | |||
) | |||
def test_https_request_via_https_proxy_with_auth(self): | |||
def test_https_request_via_https_proxy_with_auth(self) -> None: | |||
"""Tests that authenticated, TLS-encrypted requests can be made through a proxy""" | |||
self._do_https_request_via_proxy( | |||
expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" | |||
@@ -469,7 +481,7 @@ class MatrixFederationAgentTests(TestCase): | |||
self, | |||
expect_proxy_ssl: bool = False, | |||
expected_auth_credentials: Optional[bytes] = None, | |||
): | |||
) -> None: | |||
"""Send a http request via an agent and check that it is correctly received at | |||
the proxy. The proxy can use either http or https. | |||
Args: | |||
@@ -501,6 +513,7 @@ class MatrixFederationAgentTests(TestCase): | |||
tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, | |||
expected_sni=b"proxy.com" if expect_proxy_ssl else None, | |||
) | |||
assert isinstance(http_server, HTTPChannel) | |||
# the FakeTransport is async, so we need to pump the reactor | |||
self.reactor.advance(0) | |||
@@ -542,7 +555,7 @@ class MatrixFederationAgentTests(TestCase): | |||
self, | |||
expect_proxy_ssl: bool = False, | |||
expected_auth_credentials: Optional[bytes] = None, | |||
): | |||
) -> None: | |||
"""Send a https request via an agent and check that it is correctly received at | |||
the proxy and client. The proxy can use either http or https. | |||
Args: | |||
@@ -606,10 +619,12 @@ class MatrixFederationAgentTests(TestCase): | |||
# now we make another test server to act as the upstream HTTP server. | |||
server_ssl_protocol = _wrap_server_factory_for_tls( | |||
_get_test_protocol_factory() | |||
).buildProtocol(None) | |||
).buildProtocol(dummy_address) | |||
assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol) | |||
# Tell the HTTP server to send outgoing traffic back via the proxy's transport. | |||
proxy_server_transport = proxy_server.transport | |||
assert proxy_server_transport is not None | |||
server_ssl_protocol.makeConnection(proxy_server_transport) | |||
# ... and replace the protocol on the proxy's transport with the | |||
@@ -644,6 +659,7 @@ class MatrixFederationAgentTests(TestCase): | |||
# now there should be a pending request | |||
http_server = server_ssl_protocol.wrappedProtocol | |||
assert isinstance(http_server, HTTPChannel) | |||
self.assertEqual(len(http_server.requests), 1) | |||
request = http_server.requests[0] | |||
@@ -667,7 +683,7 @@ class MatrixFederationAgentTests(TestCase): | |||
self.assertEqual(body, b"result") | |||
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) | |||
def test_http_request_via_proxy_with_blacklist(self): | |||
def test_http_request_via_proxy_with_blacklist(self) -> None: | |||
# The blacklist includes the configured proxy IP. | |||
agent = ProxyAgent( | |||
BlacklistingReactorWrapper( | |||
@@ -691,6 +707,7 @@ class MatrixFederationAgentTests(TestCase): | |||
http_server = self._make_connection( | |||
client_factory, _get_test_protocol_factory() | |||
) | |||
assert isinstance(http_server, HTTPChannel) | |||
# the FakeTransport is async, so we need to pump the reactor | |||
self.reactor.advance(0) | |||
@@ -712,7 +729,7 @@ class MatrixFederationAgentTests(TestCase): | |||
self.assertEqual(body, b"result") | |||
@patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"}) | |||
def test_https_request_via_uppercase_proxy_with_blacklist(self): | |||
def test_https_request_via_uppercase_proxy_with_blacklist(self) -> None: | |||
# The blacklist includes the configured proxy IP. | |||
agent = ProxyAgent( | |||
BlacklistingReactorWrapper( | |||
@@ -737,11 +754,15 @@ class MatrixFederationAgentTests(TestCase): | |||
proxy_server = self._make_connection( | |||
client_factory, _get_test_protocol_factory() | |||
) | |||
assert isinstance(proxy_server, HTTPChannel) | |||
# fish the transports back out so that we can do the old switcheroo | |||
s2c_transport = proxy_server.transport | |||
assert isinstance(s2c_transport, FakeTransport) | |||
client_protocol = s2c_transport.other | |||
assert isinstance(client_protocol, _WrappingProtocol) | |||
c2s_transport = client_protocol.transport | |||
assert isinstance(c2s_transport, FakeTransport) | |||
# the FakeTransport is async, so we need to pump the reactor | |||
self.reactor.advance(0) | |||
@@ -762,8 +783,10 @@ class MatrixFederationAgentTests(TestCase): | |||
# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel | |||
ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory()) | |||
ssl_protocol = ssl_factory.buildProtocol(None) | |||
ssl_protocol = ssl_factory.buildProtocol(dummy_address) | |||
assert isinstance(ssl_protocol, TLSMemoryBIOProtocol) | |||
http_server = ssl_protocol.wrappedProtocol | |||
assert isinstance(http_server, HTTPChannel) | |||
ssl_protocol.makeConnection( | |||
FakeTransport(client_protocol, self.reactor, ssl_protocol) | |||
@@ -797,28 +820,28 @@ class MatrixFederationAgentTests(TestCase): | |||
self.assertEqual(body, b"result") | |||
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) | |||
def test_proxy_with_no_scheme(self): | |||
def test_proxy_with_no_scheme(self) -> None: | |||
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) | |||
self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) | |||
assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) | |||
self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") | |||
self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) | |||
@patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"}) | |||
def test_proxy_with_unsupported_scheme(self): | |||
def test_proxy_with_unsupported_scheme(self) -> None: | |||
with self.assertRaises(ValueError): | |||
ProxyAgent(self.reactor, use_proxy=True) | |||
@patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"}) | |||
def test_proxy_with_http_scheme(self): | |||
def test_proxy_with_http_scheme(self) -> None: | |||
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) | |||
self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) | |||
assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) | |||
self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") | |||
self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) | |||
@patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"}) | |||
def test_proxy_with_https_scheme(self): | |||
def test_proxy_with_https_scheme(self) -> None: | |||
https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) | |||
self.assertIsInstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint) | |||
assert isinstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint) | |||
self.assertEqual( | |||
https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com" | |||
) | |||
@@ -828,7 +851,7 @@ class MatrixFederationAgentTests(TestCase): | |||
def _wrap_server_factory_for_tls( | |||
factory: IProtocolFactory, sanlist: Iterable[bytes] = None | |||
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None | |||
) -> IProtocolFactory: | |||
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory | |||
@@ -865,6 +888,6 @@ def _get_test_protocol_factory() -> IProtocolFactory: | |||
return server_factory | |||
def _log_request(request: str): | |||
def _log_request(request: str) -> None: | |||
"""Implements Factory.log, which is expected by Request.finish""" | |||
logger.info(f"Completed request {request}") |
@@ -14,7 +14,7 @@ | |||
import json | |||
from http import HTTPStatus | |||
from io import BytesIO | |||
from typing import Tuple | |||
from typing import Tuple, Union | |||
from unittest.mock import Mock | |||
from synapse.api.errors import Codes, SynapseError | |||
@@ -33,7 +33,7 @@ from tests import unittest | |||
from tests.http.server._base import test_disconnect | |||
def make_request(content): | |||
def make_request(content: Union[bytes, JsonDict]) -> Mock: | |||
"""Make an object that acts enough like a request.""" | |||
request = Mock(spec=["method", "uri", "content"]) | |||
@@ -47,7 +47,7 @@ def make_request(content): | |||
class TestServletUtils(unittest.TestCase): | |||
def test_parse_json_value(self): | |||
def test_parse_json_value(self) -> None: | |||
"""Basic tests for parse_json_value_from_request.""" | |||
# Test round-tripping. | |||
obj = {"foo": 1} | |||
@@ -78,7 +78,7 @@ class TestServletUtils(unittest.TestCase): | |||
with self.assertRaises(SynapseError): | |||
parse_json_value_from_request(make_request(b'{"foo": Infinity}')) | |||
def test_parse_json_object(self): | |||
def test_parse_json_object(self) -> None: | |||
"""Basic tests for parse_json_object_from_request.""" | |||
# Test empty. | |||
result = parse_json_object_from_request( | |||
@@ -17,22 +17,24 @@ from netaddr import IPSet | |||
from twisted.internet import defer | |||
from twisted.internet.error import DNSLookupError | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.http import RequestTimedOutError | |||
from synapse.http.client import SimpleHttpClient | |||
from synapse.server import HomeServer | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase | |||
class SimpleHttpClientTests(HomeserverTestCase): | |||
def prepare(self, reactor, clock, hs: "HomeServer"): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: "HomeServer") -> None: | |||
# Add a DNS entry for a test server | |||
self.reactor.lookups["testserv"] = "1.2.3.4" | |||
self.cl = hs.get_simple_http_client() | |||
def test_dns_error(self): | |||
def test_dns_error(self) -> None: | |||
""" | |||
If the DNS lookup returns an error, it will bubble up. | |||
""" | |||
@@ -42,7 +44,7 @@ class SimpleHttpClientTests(HomeserverTestCase): | |||
f = self.failureResultOf(d) | |||
self.assertIsInstance(f.value, DNSLookupError) | |||
def test_client_connection_refused(self): | |||
def test_client_connection_refused(self) -> None: | |||
d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar")) | |||
self.pump() | |||
@@ -63,7 +65,7 @@ class SimpleHttpClientTests(HomeserverTestCase): | |||
self.assertIs(f.value, e) | |||
def test_client_never_connect(self): | |||
def test_client_never_connect(self) -> None: | |||
""" | |||
If the HTTP request is not connected and is timed out, it'll give a | |||
ConnectingCancelledError or TimeoutError. | |||
@@ -90,7 +92,7 @@ class SimpleHttpClientTests(HomeserverTestCase): | |||
self.assertIsInstance(f.value, RequestTimedOutError) | |||
def test_client_connect_no_response(self): | |||
def test_client_connect_no_response(self) -> None: | |||
""" | |||
If the HTTP request is connected, but gets no response before being | |||
timed out, it'll give a ResponseNeverReceived. | |||
@@ -121,7 +123,7 @@ class SimpleHttpClientTests(HomeserverTestCase): | |||
self.assertIsInstance(f.value, RequestTimedOutError) | |||
def test_client_ip_range_blacklist(self): | |||
def test_client_ip_range_blacklist(self) -> None: | |||
"""Ensure that Synapse does not try to connect to blacklisted IPs""" | |||
# Add some DNS entries we'll blacklist | |||
@@ -13,18 +13,20 @@ | |||
# limitations under the License. | |||
from twisted.internet.address import IPv6Address | |||
from twisted.test.proto_helpers import StringTransport | |||
from twisted.test.proto_helpers import MemoryReactor, StringTransport | |||
from synapse.app.homeserver import SynapseHomeServer | |||
from synapse.server import HomeServer | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase | |||
class SynapseRequestTestCase(HomeserverTestCase): | |||
def make_homeserver(self, reactor, clock): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer) | |||
def test_large_request(self): | |||
def test_large_request(self) -> None: | |||
"""overlarge HTTP requests should be rejected""" | |||
self.hs.start_listening() | |||
@@ -70,7 +70,7 @@ from synapse.logging.context import ContextResourceUsage | |||
from synapse.server import HomeServer | |||
from synapse.storage import DataStore | |||
from synapse.storage.engines import PostgresEngine, create_engine | |||
from synapse.types import JsonDict | |||
from synapse.types import ISynapseReactor, JsonDict | |||
from synapse.util import Clock | |||
from tests.utils import ( | |||
@@ -401,7 +401,9 @@ def make_request( | |||
return channel | |||
@implementer(IReactorPluggableNameResolver) | |||
# ISynapseReactor implies IReactorPluggableNameResolver, but explicitly | |||
# marking this as an implementer of the latter seems to keep mypy-zope happier. | |||
@implementer(IReactorPluggableNameResolver, ISynapseReactor) | |||
class ThreadedMemoryReactorClock(MemoryReactorClock): | |||
""" | |||
A MemoryReactorClock that supports callFromThread. | |||