You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

224 lines
7.8 KiB

  1. # Copyright 2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Callable, List, Tuple, Type, Union
  15. from unittest.mock import patch
  16. from zope.interface import implementer
  17. from twisted.internet import defer
  18. from twisted.internet._sslverify import ClientTLSOptions
  19. from twisted.internet.address import IPv4Address, IPv6Address
  20. from twisted.internet.defer import ensureDeferred
  21. from twisted.internet.interfaces import IProtocolFactory
  22. from twisted.internet.ssl import ContextFactory
  23. from twisted.mail import interfaces, smtp
  24. from tests.server import FakeTransport
  25. from tests.unittest import HomeserverTestCase, override_config
  26. def TestingESMTPTLSClientFactory(
  27. contextFactory: ContextFactory,
  28. _connectWrapped: bool,
  29. wrappedProtocol: IProtocolFactory,
  30. ) -> IProtocolFactory:
  31. """We use this to pass through in testing without using TLS, but
  32. saving the context information to check that it would have happened.
  33. Note that this is what the MemoryReactor does on connectSSL.
  34. It only saves the contextFactory, but starts the connection with the
  35. underlying Factory.
  36. See: L{twisted.internet.testing.MemoryReactor.connectSSL}"""
  37. wrappedProtocol._testingContextFactory = contextFactory # type: ignore[attr-defined]
  38. return wrappedProtocol
  39. @implementer(interfaces.IMessageDelivery)
  40. class _DummyMessageDelivery:
  41. def __init__(self) -> None:
  42. # (recipient, message) tuples
  43. self.messages: List[Tuple[smtp.Address, bytes]] = []
  44. def receivedHeader(
  45. self,
  46. helo: Tuple[bytes, bytes],
  47. origin: smtp.Address,
  48. recipients: List[smtp.User],
  49. ) -> None:
  50. return None
  51. def validateFrom(
  52. self, helo: Tuple[bytes, bytes], origin: smtp.Address
  53. ) -> smtp.Address:
  54. return origin
  55. def record_message(self, recipient: smtp.Address, message: bytes) -> None:
  56. self.messages.append((recipient, message))
  57. def validateTo(self, user: smtp.User) -> Callable[[], interfaces.IMessageSMTP]:
  58. return lambda: _DummyMessage(self, user)
  59. @implementer(interfaces.IMessageSMTP)
  60. class _DummyMessage:
  61. """IMessageSMTP implementation which saves the message delivered to it
  62. to the _DummyMessageDelivery object.
  63. """
  64. def __init__(self, delivery: _DummyMessageDelivery, user: smtp.User):
  65. self._delivery = delivery
  66. self._user = user
  67. self._buffer: List[bytes] = []
  68. def lineReceived(self, line: bytes) -> None:
  69. self._buffer.append(line)
  70. def eomReceived(self) -> "defer.Deferred[bytes]":
  71. message = b"\n".join(self._buffer) + b"\n"
  72. self._delivery.record_message(self._user.dest, message)
  73. return defer.succeed(b"saved")
  74. def connectionLost(self) -> None:
  75. pass
  76. class SendEmailHandlerTestCaseIPv4(HomeserverTestCase):
  77. ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address
  78. def setUp(self) -> None:
  79. super().setUp()
  80. self.reactor.lookups["localhost"] = "127.0.0.1"
  81. def test_send_email(self) -> None:
  82. """Happy-path test that we can send email to a non-TLS server."""
  83. h = self.hs.get_send_email_handler()
  84. d = ensureDeferred(
  85. h.send_email(
  86. "foo@bar.com", "test subject", "Tests", "HTML content", "Text content"
  87. )
  88. )
  89. # there should be an attempt to connect to localhost:25
  90. self.assertEqual(len(self.reactor.tcpClients), 1)
  91. (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
  92. 0
  93. ]
  94. self.assertEqual(host, self.reactor.lookups["localhost"])
  95. self.assertEqual(port, 25)
  96. # wire it up to an SMTP server
  97. message_delivery = _DummyMessageDelivery()
  98. server_protocol = smtp.ESMTP()
  99. server_protocol.delivery = message_delivery
  100. # make sure that the server uses the test reactor to set timeouts
  101. server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]
  102. client_protocol = client_factory.buildProtocol(None)
  103. client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
  104. server_protocol.makeConnection(
  105. FakeTransport(
  106. client_protocol,
  107. self.reactor,
  108. peer_address=self.ip_class(
  109. "TCP", self.reactor.lookups["localhost"], 1234
  110. ),
  111. )
  112. )
  113. # the message should now get delivered
  114. self.get_success(d, by=0.1)
  115. # check it arrived
  116. self.assertEqual(len(message_delivery.messages), 1)
  117. user, msg = message_delivery.messages.pop()
  118. self.assertEqual(str(user), "foo@bar.com")
  119. self.assertIn(b"Subject: test subject", msg)
  120. @patch(
  121. "synapse.handlers.send_email.TLSMemoryBIOFactory",
  122. TestingESMTPTLSClientFactory,
  123. )
  124. @override_config(
  125. {
  126. "email": {
  127. "notif_from": "noreply@test",
  128. "force_tls": True,
  129. },
  130. }
  131. )
  132. def test_send_email_force_tls(self) -> None:
  133. """Happy-path test that we can send email to an Implicit TLS server."""
  134. h = self.hs.get_send_email_handler()
  135. d = ensureDeferred(
  136. h.send_email(
  137. "foo@bar.com", "test subject", "Tests", "HTML content", "Text content"
  138. )
  139. )
  140. # there should be an attempt to connect to localhost:465
  141. self.assertEqual(len(self.reactor.tcpClients), 1)
  142. (
  143. host,
  144. port,
  145. client_factory,
  146. _timeout,
  147. _bindAddress,
  148. ) = self.reactor.tcpClients[0]
  149. self.assertEqual(host, self.reactor.lookups["localhost"])
  150. self.assertEqual(port, 465)
  151. # We need to make sure that TLS is happenning
  152. self.assertIsInstance(
  153. client_factory._wrappedFactory._testingContextFactory,
  154. ClientTLSOptions,
  155. )
  156. # And since we use endpoints, they go through reactor.connectTCP
  157. # which works differently to connectSSL on the testing reactor
  158. # wire it up to an SMTP server
  159. message_delivery = _DummyMessageDelivery()
  160. server_protocol = smtp.ESMTP()
  161. server_protocol.delivery = message_delivery
  162. # make sure that the server uses the test reactor to set timeouts
  163. server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]
  164. client_protocol = client_factory.buildProtocol(None)
  165. client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
  166. server_protocol.makeConnection(
  167. FakeTransport(
  168. client_protocol,
  169. self.reactor,
  170. peer_address=self.ip_class(
  171. "TCP", self.reactor.lookups["localhost"], 1234
  172. ),
  173. )
  174. )
  175. # the message should now get delivered
  176. self.get_success(d, by=0.1)
  177. # check it arrived
  178. self.assertEqual(len(message_delivery.messages), 1)
  179. user, msg = message_delivery.messages.pop()
  180. self.assertEqual(str(user), "foo@bar.com")
  181. self.assertIn(b"Subject: test subject", msg)
  182. class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4):
  183. ip_class = IPv6Address
  184. def setUp(self) -> None:
  185. super().setUp()
  186. self.reactor.lookups["localhost"] = "::1"