Twisted trunk makes a change to the `TLSMemoryBIOFactory` where
the underlying protocol is changed from `TLSMemoryBIOProtocol` to
`BufferingTLSTransport` to improve performance of TLS code (see
https://github.com/twisted/twisted/issues/11989).
In order to properly hook this code up in tests we need to pass the test
reactor's clock into `TLSMemoryBIOFactory` to avoid the global (trial)
reactor being used by default.
Twisted does something similar internally for tests:
157cd8e659/src/twisted/web/test/test_agent.py (L871-L874)
tags/v1.96.0rc1
@@ -0,0 +1 @@ | |||
Fix running unit tests on Twisted trunk. |
@@ -15,14 +15,20 @@ import os.path | |||
import subprocess | |||
from typing import List | |||
from incremental import Version | |||
from zope.interface import implementer | |||
import twisted | |||
from OpenSSL import SSL | |||
from OpenSSL.SSL import Connection | |||
from twisted.internet.address import IPv4Address | |||
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator | |||
from twisted.internet.interfaces import ( | |||
IOpenSSLServerConnectionCreator, | |||
IProtocolFactory, | |||
IReactorTime, | |||
) | |||
from twisted.internet.ssl import Certificate, trustRootFromCertificates | |||
from twisted.protocols.tls import TLSMemoryBIOProtocol | |||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol | |||
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401 | |||
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401 | |||
@@ -153,6 +159,33 @@ class TestServerTLSConnectionFactory: | |||
return Connection(ctx, None) | |||
def wrap_server_factory_for_tls( | |||
factory: IProtocolFactory, clock: IReactorTime, sanlist: List[bytes] | |||
) -> TLSMemoryBIOFactory: | |||
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory | |||
The resultant factory will create a TLS server which presents a certificate | |||
signed by our test CA, valid for the domains in `sanlist` | |||
Args: | |||
factory: protocol factory to wrap | |||
sanlist: list of domains the cert should be valid for | |||
Returns: | |||
interfaces.IProtocolFactory | |||
""" | |||
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist) | |||
# Twisted > 23.8.0 has a different API that accepts a clock. | |||
if twisted.version <= Version("Twisted", 23, 8, 0): | |||
return TLSMemoryBIOFactory( | |||
connection_creator, isClient=False, wrappedFactory=factory | |||
) | |||
else: | |||
return TLSMemoryBIOFactory( | |||
connection_creator, isClient=False, wrappedFactory=factory, clock=clock # type: ignore[call-arg] | |||
) | |||
# A dummy address, useful for tests that use FakeTransport and don't care about where | |||
# packets are going to/coming from. | |||
dummy_address = IPv4Address("TCP", "127.0.0.1", 80) |
@@ -31,7 +31,7 @@ from twisted.internet.interfaces import ( | |||
IProtocolFactory, | |||
) | |||
from twisted.internet.protocol import Factory, Protocol | |||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol | |||
from twisted.protocols.tls import TLSMemoryBIOProtocol | |||
from twisted.web._newclient import ResponseNeverReceived | |||
from twisted.web.client import Agent | |||
from twisted.web.http import HTTPChannel, Request | |||
@@ -57,11 +57,7 @@ from synapse.types import ISynapseReactor | |||
from synapse.util.caches.ttlcache import TTLCache | |||
from tests import unittest | |||
from tests.http import ( | |||
TestServerTLSConnectionFactory, | |||
dummy_address, | |||
get_test_ca_cert_file, | |||
) | |||
from tests.http import dummy_address, get_test_ca_cert_file, wrap_server_factory_for_tls | |||
from tests.server import FakeTransport, ThreadedMemoryReactorClock | |||
from tests.utils import checked_cast, default_config | |||
@@ -125,7 +121,18 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
# build the test server | |||
server_factory = _get_test_protocol_factory() | |||
if ssl: | |||
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist) | |||
server_factory = wrap_server_factory_for_tls( | |||
server_factory, | |||
self.reactor, | |||
tls_sanlist | |||
or [ | |||
b"DNS:testserv", | |||
b"DNS:target-server", | |||
b"DNS:xn--bcher-kva.com", | |||
b"IP:1.2.3.4", | |||
b"IP:::1", | |||
], | |||
) | |||
server_protocol = server_factory.buildProtocol(dummy_address) | |||
assert server_protocol is not None | |||
@@ -435,8 +442,16 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
request.finish() | |||
# now we make another test server to act as the upstream HTTP server. | |||
server_ssl_protocol = _wrap_server_factory_for_tls( | |||
_get_test_protocol_factory() | |||
server_ssl_protocol = wrap_server_factory_for_tls( | |||
_get_test_protocol_factory(), | |||
self.reactor, | |||
sanlist=[ | |||
b"DNS:testserv", | |||
b"DNS:target-server", | |||
b"DNS:xn--bcher-kva.com", | |||
b"IP:1.2.3.4", | |||
b"IP:::1", | |||
], | |||
).buildProtocol(dummy_address) | |||
# Tell the HTTP server to send outgoing traffic back via the proxy's transport. | |||
@@ -1786,33 +1801,6 @@ def _check_logcontext(context: LoggingContextOrSentinel) -> None: | |||
raise AssertionError("Expected logcontext %s but was %s" % (context, current)) | |||
def _wrap_server_factory_for_tls( | |||
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None | |||
) -> TLSMemoryBIOFactory: | |||
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory | |||
The resultant factory will create a TLS server which presents a certificate | |||
signed by our test CA, valid for the domains in `sanlist` | |||
Args: | |||
factory: protocol factory to wrap | |||
sanlist: list of domains the cert should be valid for | |||
Returns: | |||
interfaces.IProtocolFactory | |||
""" | |||
if sanlist is None: | |||
sanlist = [ | |||
b"DNS:testserv", | |||
b"DNS:target-server", | |||
b"DNS:xn--bcher-kva.com", | |||
b"IP:1.2.3.4", | |||
b"IP:::1", | |||
] | |||
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist) | |||
return TLSMemoryBIOFactory( | |||
connection_creator, isClient=False, wrappedFactory=factory | |||
) | |||
def _get_test_protocol_factory() -> IProtocolFactory: | |||
"""Get a protocol Factory which will build an HTTPChannel | |||
Returns: | |||
@@ -29,18 +29,14 @@ from twisted.internet.endpoints import ( | |||
) | |||
from twisted.internet.interfaces import IProtocol, IProtocolFactory | |||
from twisted.internet.protocol import Factory, Protocol | |||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol | |||
from twisted.protocols.tls import TLSMemoryBIOProtocol | |||
from twisted.web.http import HTTPChannel | |||
from synapse.http.client import BlocklistingReactorWrapper | |||
from synapse.http.connectproxyclient import BasicProxyCredentials | |||
from synapse.http.proxyagent import ProxyAgent, parse_proxy | |||
from tests.http import ( | |||
TestServerTLSConnectionFactory, | |||
dummy_address, | |||
get_test_https_policy, | |||
) | |||
from tests.http import dummy_address, get_test_https_policy, wrap_server_factory_for_tls | |||
from tests.server import FakeTransport, ThreadedMemoryReactorClock | |||
from tests.unittest import TestCase | |||
from tests.utils import checked_cast | |||
@@ -272,7 +268,9 @@ class MatrixFederationAgentTests(TestCase): | |||
the server Protocol returned by server_factory | |||
""" | |||
if ssl: | |||
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist) | |||
server_factory = wrap_server_factory_for_tls( | |||
server_factory, self.reactor, tls_sanlist or [b"DNS:test.com"] | |||
) | |||
server_protocol = server_factory.buildProtocol(dummy_address) | |||
assert server_protocol is not None | |||
@@ -639,8 +637,8 @@ class MatrixFederationAgentTests(TestCase): | |||
request.finish() | |||
# now we make another test server to act as the upstream HTTP server. | |||
server_ssl_protocol = _wrap_server_factory_for_tls( | |||
_get_test_protocol_factory() | |||
server_ssl_protocol = wrap_server_factory_for_tls( | |||
_get_test_protocol_factory(), self.reactor, sanlist=[b"DNS:test.com"] | |||
).buildProtocol(dummy_address) | |||
# Tell the HTTP server to send outgoing traffic back via the proxy's transport. | |||
@@ -806,7 +804,9 @@ class MatrixFederationAgentTests(TestCase): | |||
request.finish() | |||
# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel | |||
ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory()) | |||
ssl_factory = wrap_server_factory_for_tls( | |||
_get_test_protocol_factory(), self.reactor, sanlist=[b"DNS:test.com"] | |||
) | |||
ssl_protocol = ssl_factory.buildProtocol(dummy_address) | |||
assert isinstance(ssl_protocol, TLSMemoryBIOProtocol) | |||
http_server = ssl_protocol.wrappedProtocol | |||
@@ -870,30 +870,6 @@ class MatrixFederationAgentTests(TestCase): | |||
self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888) | |||
def _wrap_server_factory_for_tls( | |||
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None | |||
) -> TLSMemoryBIOFactory: | |||
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory | |||
The resultant factory will create a TLS server which presents a certificate | |||
signed by our test CA, valid for the domains in `sanlist` | |||
Args: | |||
factory: protocol factory to wrap | |||
sanlist: list of domains the cert should be valid for | |||
Returns: | |||
interfaces.IProtocolFactory | |||
""" | |||
if sanlist is None: | |||
sanlist = [b"DNS:test.com"] | |||
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist) | |||
return TLSMemoryBIOFactory( | |||
connection_creator, isClient=False, wrappedFactory=factory | |||
) | |||
def _get_test_protocol_factory() -> IProtocolFactory: | |||
"""Get a protocol Factory which will build an HTTPChannel | |||
@@ -15,9 +15,7 @@ import logging | |||
import os | |||
from typing import Any, Optional, Tuple | |||
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator | |||
from twisted.internet.protocol import Factory | |||
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from twisted.web.http import HTTPChannel | |||
from twisted.web.server import Request | |||
@@ -27,7 +25,11 @@ from synapse.rest.client import login | |||
from synapse.server import HomeServer | |||
from synapse.util import Clock | |||
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file | |||
from tests.http import ( | |||
TestServerTLSConnectionFactory, | |||
get_test_ca_cert_file, | |||
wrap_server_factory_for_tls, | |||
) | |||
from tests.replication._base import BaseMultiWorkerStreamTestCase | |||
from tests.server import FakeChannel, FakeTransport, make_request | |||
from tests.test_utils import SMALL_PNG | |||
@@ -94,7 +96,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): | |||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop() | |||
# build the test server | |||
server_tls_protocol = _build_test_server(get_connection_factory()) | |||
server_factory = Factory.forProtocol(HTTPChannel) | |||
# Request.finish expects the factory to have a 'log' method. | |||
server_factory.log = _log_request | |||
server_tls_protocol = wrap_server_factory_for_tls( | |||
server_factory, self.reactor, sanlist=[b"DNS:example.com"] | |||
).buildProtocol(None) | |||
# now, tell the client protocol factory to build the client protocol (it will be a | |||
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an | |||
@@ -114,7 +122,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): | |||
) | |||
# fish the test server back out of the server-side TLS protocol. | |||
http_server: HTTPChannel = server_tls_protocol.wrappedProtocol # type: ignore[assignment] | |||
http_server: HTTPChannel = server_tls_protocol.wrappedProtocol | |||
# give the reactor a pump to get the TLS juices flowing. | |||
self.reactor.pump((0.1,)) | |||
@@ -240,40 +248,6 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): | |||
return sum(len(files) for _, _, files in os.walk(path)) | |||
def get_connection_factory() -> TestServerTLSConnectionFactory: | |||
# this needs to happen once, but not until we are ready to run the first test | |||
global test_server_connection_factory | |||
if test_server_connection_factory is None: | |||
test_server_connection_factory = TestServerTLSConnectionFactory( | |||
sanlist=[b"DNS:example.com"] | |||
) | |||
return test_server_connection_factory | |||
def _build_test_server( | |||
connection_creator: IOpenSSLServerConnectionCreator, | |||
) -> TLSMemoryBIOProtocol: | |||
"""Construct a test server | |||
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol | |||
Args: | |||
connection_creator: thing to build SSL connections | |||
Returns: | |||
TLSMemoryBIOProtocol | |||
""" | |||
server_factory = Factory.forProtocol(HTTPChannel) | |||
# Request.finish expects the factory to have a 'log' method. | |||
server_factory.log = _log_request | |||
server_tls_factory = TLSMemoryBIOFactory( | |||
connection_creator, isClient=False, wrappedFactory=server_factory | |||
) | |||
return server_tls_factory.buildProtocol(None) | |||
def _log_request(request: Request) -> None: | |||
"""Implements Factory.log, which is expected by Request.finish""" | |||
logger.info("Completed request %s", request) |
@@ -43,9 +43,11 @@ from typing import ( | |||
from unittest.mock import Mock | |||
import attr | |||
from incremental import Version | |||
from typing_extensions import ParamSpec | |||
from zope.interface import implementer | |||
import twisted | |||
from twisted.internet import address, tcp, threads, udp | |||
from twisted.internet._resolver import SimpleResolverComplexifier | |||
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed | |||
@@ -474,6 +476,16 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): | |||
return fail(DNSLookupError("OH NO: unknown %s" % (name,))) | |||
return succeed(lookups[name]) | |||
# In order for the TLS protocol tests to work, modify _get_default_clock | |||
# on newer Twisted versions to use the test reactor's clock. | |||
# | |||
# This is *super* dirty since it is never undone and relies on the next | |||
# test to overwrite it. | |||
if twisted.version > Version("Twisted", 23, 8, 0): | |||
from twisted.protocols import tls | |||
tls._get_default_clock = lambda: self # type: ignore[attr-defined] | |||
self.nameResolver = SimpleResolverComplexifier(FakeResolver()) | |||
super().__init__() | |||