Use Twisted HostnameEndpoint to connect to SMTP servers (instead of connectTCP/connectSSL) which properly supports IPv6-only servers.tags/v1.92.0rc1
@@ -0,0 +1 @@ | |||
Fix IPv6-related bugs on SMTP settings, adding groundwork to fix similar issues. Contributed by @evilham and @telmich (ungleich.ch). |
@@ -23,9 +23,11 @@ from pkg_resources import parse_version | |||
import twisted | |||
from twisted.internet.defer import Deferred | |||
from twisted.internet.interfaces import IOpenSSLContextFactory | |||
from twisted.internet.endpoints import HostnameEndpoint | |||
from twisted.internet.interfaces import IOpenSSLContextFactory, IProtocolFactory | |||
from twisted.internet.ssl import optionsForClientTLS | |||
from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory | |||
from twisted.protocols.tls import TLSMemoryBIOFactory | |||
from synapse.logging.context import make_deferred_yieldable | |||
from synapse.types import ISynapseReactor | |||
@@ -97,6 +99,7 @@ async def _sendmail( | |||
**kwargs, | |||
) | |||
factory: IProtocolFactory | |||
if _is_old_twisted: | |||
# before twisted 21.2, we have to override the ESMTPSender protocol to disable | |||
# TLS | |||
@@ -110,22 +113,13 @@ async def _sendmail( | |||
factory = build_sender_factory(hostname=smtphost if enable_tls else None) | |||
if force_tls: | |||
reactor.connectSSL( | |||
smtphost, | |||
smtpport, | |||
factory, | |||
optionsForClientTLS(smtphost), | |||
timeout=30, | |||
bindAddress=None, | |||
) | |||
else: | |||
reactor.connectTCP( | |||
smtphost, | |||
smtpport, | |||
factory, | |||
timeout=30, | |||
bindAddress=None, | |||
) | |||
factory = TLSMemoryBIOFactory(optionsForClientTLS(smtphost), True, factory) | |||
endpoint = HostnameEndpoint( | |||
reactor, smtphost, smtpport, timeout=30, bindAddress=None | |||
) | |||
await make_deferred_yieldable(endpoint.connect(factory)) | |||
await make_deferred_yieldable(d) | |||
@@ -13,19 +13,40 @@ | |||
# limitations under the License. | |||
from typing import Callable, List, Tuple | |||
from typing import Callable, List, Tuple, Type, Union | |||
from unittest.mock import patch | |||
from zope.interface import implementer | |||
from twisted.internet import defer | |||
from twisted.internet.address import IPv4Address | |||
from twisted.internet._sslverify import ClientTLSOptions | |||
from twisted.internet.address import IPv4Address, IPv6Address | |||
from twisted.internet.defer import ensureDeferred | |||
from twisted.internet.interfaces import IProtocolFactory | |||
from twisted.internet.ssl import ContextFactory | |||
from twisted.mail import interfaces, smtp | |||
from tests.server import FakeTransport | |||
from tests.unittest import HomeserverTestCase, override_config | |||
def TestingESMTPTLSClientFactory( | |||
contextFactory: ContextFactory, | |||
_connectWrapped: bool, | |||
wrappedProtocol: IProtocolFactory, | |||
) -> IProtocolFactory: | |||
"""We use this to pass through in testing without using TLS, but | |||
saving the context information to check that it would have happened. | |||
Note that this is what the MemoryReactor does on connectSSL. | |||
It only saves the contextFactory, but starts the connection with the | |||
underlying Factory. | |||
See: L{twisted.internet.testing.MemoryReactor.connectSSL}""" | |||
wrappedProtocol._testingContextFactory = contextFactory # type: ignore[attr-defined] | |||
return wrappedProtocol | |||
@implementer(interfaces.IMessageDelivery) | |||
class _DummyMessageDelivery: | |||
def __init__(self) -> None: | |||
@@ -75,7 +96,13 @@ class _DummyMessage: | |||
pass | |||
class SendEmailHandlerTestCase(HomeserverTestCase): | |||
class SendEmailHandlerTestCaseIPv4(HomeserverTestCase): | |||
ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address | |||
def setUp(self) -> None: | |||
super().setUp() | |||
self.reactor.lookups["localhost"] = "127.0.0.1" | |||
def test_send_email(self) -> None: | |||
"""Happy-path test that we can send email to a non-TLS server.""" | |||
h = self.hs.get_send_email_handler() | |||
@@ -89,7 +116,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase): | |||
(host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[ | |||
0 | |||
] | |||
self.assertEqual(host, "localhost") | |||
self.assertEqual(host, self.reactor.lookups["localhost"]) | |||
self.assertEqual(port, 25) | |||
# wire it up to an SMTP server | |||
@@ -105,7 +132,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase): | |||
FakeTransport( | |||
client_protocol, | |||
self.reactor, | |||
peer_address=IPv4Address("TCP", "127.0.0.1", 1234), | |||
peer_address=self.ip_class( | |||
"TCP", self.reactor.lookups["localhost"], 1234 | |||
), | |||
) | |||
) | |||
@@ -118,6 +147,10 @@ class SendEmailHandlerTestCase(HomeserverTestCase): | |||
self.assertEqual(str(user), "foo@bar.com") | |||
self.assertIn(b"Subject: test subject", msg) | |||
@patch( | |||
"synapse.handlers.send_email.TLSMemoryBIOFactory", | |||
TestingESMTPTLSClientFactory, | |||
) | |||
@override_config( | |||
{ | |||
"email": { | |||
@@ -135,17 +168,23 @@ class SendEmailHandlerTestCase(HomeserverTestCase): | |||
) | |||
) | |||
# there should be an attempt to connect to localhost:465 | |||
self.assertEqual(len(self.reactor.sslClients), 1) | |||
self.assertEqual(len(self.reactor.tcpClients), 1) | |||
( | |||
host, | |||
port, | |||
client_factory, | |||
contextFactory, | |||
_timeout, | |||
_bindAddress, | |||
) = self.reactor.sslClients[0] | |||
self.assertEqual(host, "localhost") | |||
) = self.reactor.tcpClients[0] | |||
self.assertEqual(host, self.reactor.lookups["localhost"]) | |||
self.assertEqual(port, 465) | |||
# We need to make sure that TLS is happenning | |||
self.assertIsInstance( | |||
client_factory._wrappedFactory._testingContextFactory, | |||
ClientTLSOptions, | |||
) | |||
# And since we use endpoints, they go through reactor.connectTCP | |||
# which works differently to connectSSL on the testing reactor | |||
# wire it up to an SMTP server | |||
message_delivery = _DummyMessageDelivery() | |||
@@ -160,7 +199,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase): | |||
FakeTransport( | |||
client_protocol, | |||
self.reactor, | |||
peer_address=IPv4Address("TCP", "127.0.0.1", 1234), | |||
peer_address=self.ip_class( | |||
"TCP", self.reactor.lookups["localhost"], 1234 | |||
), | |||
) | |||
) | |||
@@ -172,3 +213,11 @@ class SendEmailHandlerTestCase(HomeserverTestCase): | |||
user, msg = message_delivery.messages.pop() | |||
self.assertEqual(str(user), "foo@bar.com") | |||
self.assertIn(b"Subject: test subject", msg) | |||
class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4): | |||
ip_class = IPv6Address | |||
def setUp(self) -> None: | |||
super().setUp() | |||
self.reactor.lookups["localhost"] = "::1" |
@@ -12,6 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import hashlib | |||
import ipaddress | |||
import json | |||
import logging | |||
import os | |||
@@ -45,7 +46,7 @@ import attr | |||
from typing_extensions import ParamSpec | |||
from zope.interface import implementer | |||
from twisted.internet import address, threads, udp | |||
from twisted.internet import address, tcp, threads, udp | |||
from twisted.internet._resolver import SimpleResolverComplexifier | |||
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed | |||
from twisted.internet.error import DNSLookupError | |||
@@ -567,6 +568,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): | |||
conn = super().connectTCP( | |||
host, port, factory, timeout=timeout, bindAddress=None | |||
) | |||
if self.lookups and host in self.lookups: | |||
validate_connector(conn, self.lookups[host]) | |||
callback = self._tcp_callbacks.get((host, port)) | |||
if callback: | |||
@@ -599,6 +602,55 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): | |||
super().advance(0) | |||
def validate_connector(connector: tcp.Connector, expected_ip: str) -> None: | |||
"""Try to validate the obtained connector as it would happen when | |||
synapse is running and the conection will be established. | |||
This method will raise a useful exception when necessary, else it will | |||
just do nothing. | |||
This is in order to help catch quirks related to reactor.connectTCP, | |||
since when called directly, the connector's destination will be of type | |||
IPv4Address, with the hostname as the literal host that was given (which | |||
could be an IPv6-only host or an IPv6 literal). | |||
But when called from reactor.connectTCP *through* e.g. an Endpoint, the | |||
connector's destination will contain the specific IP address with the | |||
correct network stack class. | |||
Note that testing code paths that use connectTCP directly should not be | |||
affected by this check, unless they specifically add a test with a | |||
matching reactor.lookups[HOSTNAME] = "IPv6Literal", where reactor is of | |||
type ThreadedMemoryReactorClock. | |||
For an example of implementing such tests, see test/handlers/send_email.py. | |||
""" | |||
destination = connector.getDestination() | |||
# We use address.IPv{4,6}Address to check what the reactor thinks it is | |||
# is sending but check for validity with ipaddress.IPv{4,6}Address | |||
# because they fail with IPs on the wrong network stack. | |||
cls_mapping = { | |||
address.IPv4Address: ipaddress.IPv4Address, | |||
address.IPv6Address: ipaddress.IPv6Address, | |||
} | |||
cls = cls_mapping.get(destination.__class__) | |||
if cls is not None: | |||
try: | |||
cls(expected_ip) | |||
except Exception as exc: | |||
raise ValueError( | |||
"Invalid IP type and resolution for %s. Expected %s to be %s" | |||
% (destination, expected_ip, cls.__name__) | |||
) from exc | |||
else: | |||
raise ValueError( | |||
"Unknown address type %s for %s" | |||
% (destination.__class__.__name__, destination) | |||
) | |||
class ThreadPool: | |||
""" | |||
Threadless thread pool. | |||
@@ -313,7 +313,7 @@ class HomeserverTestCase(TestCase): | |||
servlets: List of servlet registration function. | |||
user_id (str): The user ID to assume if auth is hijacked. | |||
hijack_auth: Whether to hijack auth to return the user specified | |||
in user_id. | |||
in user_id. | |||
""" | |||
hijack_auth: ClassVar[bool] = True | |||