You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1147 lines
37 KiB

  1. # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import hashlib
  15. import ipaddress
  16. import json
  17. import logging
  18. import os
  19. import os.path
  20. import sqlite3
  21. import time
  22. import uuid
  23. import warnings
  24. from collections import deque
  25. from io import SEEK_END, BytesIO
  26. from typing import (
  27. Any,
  28. Awaitable,
  29. Callable,
  30. Deque,
  31. Dict,
  32. Iterable,
  33. List,
  34. MutableMapping,
  35. Optional,
  36. Sequence,
  37. Tuple,
  38. Type,
  39. TypeVar,
  40. Union,
  41. cast,
  42. )
  43. from unittest.mock import Mock
  44. import attr
  45. from incremental import Version
  46. from typing_extensions import ParamSpec
  47. from zope.interface import implementer
  48. import twisted
  49. from twisted.internet import address, tcp, threads, udp
  50. from twisted.internet._resolver import SimpleResolverComplexifier
  51. from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
  52. from twisted.internet.error import DNSLookupError
  53. from twisted.internet.interfaces import (
  54. IAddress,
  55. IConnector,
  56. IConsumer,
  57. IHostnameResolver,
  58. IListeningPort,
  59. IProducer,
  60. IProtocol,
  61. IPullProducer,
  62. IPushProducer,
  63. IReactorPluggableNameResolver,
  64. IReactorTime,
  65. IResolverSimple,
  66. ITransport,
  67. )
  68. from twisted.internet.protocol import ClientFactory, DatagramProtocol, Factory
  69. from twisted.python import threadpool
  70. from twisted.python.failure import Failure
  71. from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
  72. from twisted.web.http_headers import Headers
  73. from twisted.web.resource import IResource
  74. from twisted.web.server import Request, Site
  75. from synapse.config.database import DatabaseConnectionConfig
  76. from synapse.config.homeserver import HomeServerConfig
  77. from synapse.events.presence_router import load_legacy_presence_router
  78. from synapse.handlers.auth import load_legacy_password_auth_providers
  79. from synapse.http.site import SynapseRequest
  80. from synapse.logging.context import ContextResourceUsage
  81. from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
  82. from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
  83. load_legacy_third_party_event_rules,
  84. )
  85. from synapse.server import HomeServer
  86. from synapse.storage import DataStore
  87. from synapse.storage.database import LoggingDatabaseConnection
  88. from synapse.storage.engines import create_engine
  89. from synapse.storage.prepare_database import prepare_database
  90. from synapse.types import ISynapseReactor, JsonDict
  91. from synapse.util import Clock
  92. from tests.utils import (
  93. LEAVE_DB,
  94. POSTGRES_BASE_DB,
  95. POSTGRES_HOST,
  96. POSTGRES_PASSWORD,
  97. POSTGRES_PORT,
  98. POSTGRES_USER,
  99. SQLITE_PERSIST_DB,
  100. USE_POSTGRES_FOR_TESTS,
  101. MockClock,
  102. default_config,
  103. )
  104. logger = logging.getLogger(__name__)
  105. R = TypeVar("R")
  106. P = ParamSpec("P")
  107. # the type of thing that can be passed into `make_request` in the headers list
  108. CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
  109. # A pre-prepared SQLite DB that is used as a template when creating new SQLite
  110. # DB each test run. This dramatically speeds up test set up when using SQLite.
  111. PREPPED_SQLITE_DB_CONN: Optional[LoggingDatabaseConnection] = None
  112. class TimedOutException(Exception):
  113. """
  114. A web query timed out.
  115. """
  116. @implementer(ITransport, IPushProducer, IConsumer)
  117. @attr.s(auto_attribs=True)
  118. class FakeChannel:
  119. """
  120. A fake Twisted Web Channel (the part that interfaces with the
  121. wire).
  122. See twisted.web.http.HTTPChannel.
  123. """
  124. site: Union[Site, "FakeSite"]
  125. _reactor: MemoryReactorClock
  126. result: dict = attr.Factory(dict)
  127. _ip: str = "127.0.0.1"
  128. _producer: Optional[Union[IPullProducer, IPushProducer]] = None
  129. resource_usage: Optional[ContextResourceUsage] = None
  130. _request: Optional[Request] = None
  131. @property
  132. def request(self) -> Request:
  133. assert self._request is not None
  134. return self._request
  135. @request.setter
  136. def request(self, request: Request) -> None:
  137. assert self._request is None
  138. self._request = request
  139. @property
  140. def json_body(self) -> JsonDict:
  141. body = json.loads(self.text_body)
  142. assert isinstance(body, dict)
  143. return body
  144. @property
  145. def json_list(self) -> List[JsonDict]:
  146. body = json.loads(self.text_body)
  147. assert isinstance(body, list)
  148. return body
  149. @property
  150. def text_body(self) -> str:
  151. """The body of the result, utf-8-decoded.
  152. Raises an exception if the request has not yet completed.
  153. """
  154. if not self.is_finished():
  155. raise Exception("Request not yet completed")
  156. return self.result["body"].decode("utf8")
  157. def is_finished(self) -> bool:
  158. """check if the response has been completely received"""
  159. return self.result.get("done", False)
  160. @property
  161. def code(self) -> int:
  162. if not self.result:
  163. raise Exception("No result yet.")
  164. return int(self.result["code"])
  165. @property
  166. def headers(self) -> Headers:
  167. if not self.result:
  168. raise Exception("No result yet.")
  169. h = Headers()
  170. for i in self.result["headers"]:
  171. h.addRawHeader(*i)
  172. return h
  173. def writeHeaders(
  174. self, version: bytes, code: bytes, reason: bytes, headers: Headers
  175. ) -> None:
  176. self.result["version"] = version
  177. self.result["code"] = code
  178. self.result["reason"] = reason
  179. self.result["headers"] = headers
  180. def write(self, data: bytes) -> None:
  181. assert isinstance(data, bytes), "Should be bytes! " + repr(data)
  182. if "body" not in self.result:
  183. self.result["body"] = b""
  184. self.result["body"] += data
  185. def writeSequence(self, data: Iterable[bytes]) -> None:
  186. for x in data:
  187. self.write(x)
  188. def loseConnection(self) -> None:
  189. self.unregisterProducer()
  190. self.transport.loseConnection()
  191. # Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
  192. def registerProducer(self, producer: IProducer, streaming: bool) -> None:
  193. # TODO This should ensure that the IProducer is an IPushProducer or
  194. # IPullProducer, unfortunately twisted.protocols.basic.FileSender does
  195. # implement those, but doesn't declare it.
  196. self._producer = cast(Union[IPushProducer, IPullProducer], producer)
  197. self.producerStreaming = streaming
  198. def _produce() -> None:
  199. if self._producer:
  200. self._producer.resumeProducing()
  201. self._reactor.callLater(0.1, _produce)
  202. if not streaming:
  203. self._reactor.callLater(0.0, _produce)
  204. def unregisterProducer(self) -> None:
  205. if self._producer is None:
  206. return
  207. self._producer = None
  208. def stopProducing(self) -> None:
  209. if self._producer is not None:
  210. self._producer.stopProducing()
  211. def pauseProducing(self) -> None:
  212. raise NotImplementedError()
  213. def resumeProducing(self) -> None:
  214. raise NotImplementedError()
  215. def requestDone(self, _self: Request) -> None:
  216. self.result["done"] = True
  217. if isinstance(_self, SynapseRequest):
  218. assert _self.logcontext is not None
  219. self.resource_usage = _self.logcontext.get_resource_usage()
  220. def getPeer(self) -> IAddress:
  221. # We give an address so that getClientAddress/getClientIP returns a non null entry,
  222. # causing us to record the MAU
  223. return address.IPv4Address("TCP", self._ip, 3423)
  224. def getHost(self) -> IAddress:
  225. # this is called by Request.__init__ to configure Request.host.
  226. return address.IPv4Address("TCP", "127.0.0.1", 8888)
  227. def isSecure(self) -> bool:
  228. return False
  229. @property
  230. def transport(self) -> "FakeChannel":
  231. return self
  232. def await_result(self, timeout_ms: int = 1000) -> None:
  233. """
  234. Wait until the request is finished.
  235. """
  236. end_time = self._reactor.seconds() + timeout_ms / 1000.0
  237. self._reactor.run()
  238. while not self.is_finished():
  239. # If there's a producer, tell it to resume producing so we get content
  240. if self._producer:
  241. self._producer.resumeProducing()
  242. if self._reactor.seconds() > end_time:
  243. raise TimedOutException("Timed out waiting for request to finish.")
  244. self._reactor.advance(0.1)
  245. def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
  246. """Process the contents of any Set-Cookie headers in the response
  247. Any cookines found are added to the given dict
  248. """
  249. headers = self.headers.getRawHeaders("Set-Cookie")
  250. if not headers:
  251. return
  252. for h in headers:
  253. parts = h.split(";")
  254. k, v = parts[0].split("=", maxsplit=1)
  255. cookies[k] = v
  256. class FakeSite:
  257. """
  258. A fake Twisted Web Site, with mocks of the extra things that
  259. Synapse adds.
  260. """
  261. server_version_string = b"1"
  262. site_tag = "test"
  263. access_logger = logging.getLogger("synapse.access.http.fake")
  264. def __init__(
  265. self,
  266. resource: IResource,
  267. reactor: IReactorTime,
  268. experimental_cors_msc3886: bool = False,
  269. ):
  270. """
  271. Args:
  272. resource: the resource to be used for rendering all requests
  273. """
  274. self._resource = resource
  275. self.reactor = reactor
  276. self.experimental_cors_msc3886 = experimental_cors_msc3886
  277. def getResourceFor(self, request: Request) -> IResource:
  278. return self._resource
  279. def make_request(
  280. reactor: MemoryReactorClock,
  281. site: Union[Site, FakeSite],
  282. method: Union[bytes, str],
  283. path: Union[bytes, str],
  284. content: Union[bytes, str, JsonDict] = b"",
  285. access_token: Optional[str] = None,
  286. request: Type[Request] = SynapseRequest,
  287. shorthand: bool = True,
  288. federation_auth_origin: Optional[bytes] = None,
  289. content_is_form: bool = False,
  290. await_result: bool = True,
  291. custom_headers: Optional[Iterable[CustomHeaderType]] = None,
  292. client_ip: str = "127.0.0.1",
  293. ) -> FakeChannel:
  294. """
  295. Make a web request using the given method, path and content, and render it
  296. Returns the fake Channel object which records the response to the request.
  297. Args:
  298. reactor:
  299. site: The twisted Site to use to render the request
  300. method: The HTTP request method ("verb").
  301. path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such).
  302. content: The body of the request. JSON-encoded, if a str of bytes.
  303. access_token: The access token to add as authorization for the request.
  304. request: The request class to create.
  305. shorthand: Whether to try and be helpful and prefix the given URL
  306. with the usual REST API path, if it doesn't contain it.
  307. federation_auth_origin: if set to not-None, we will add a fake
  308. Authorization header pretenting to be the given server name.
  309. content_is_form: Whether the content is URL encoded form data. Adds the
  310. 'Content-Type': 'application/x-www-form-urlencoded' header.
  311. await_result: whether to wait for the request to complete rendering. If true,
  312. will pump the reactor until the the renderer tells the channel the request
  313. is finished.
  314. custom_headers: (name, value) pairs to add as request headers
  315. client_ip: The IP to use as the requesting IP. Useful for testing
  316. ratelimiting.
  317. Returns:
  318. channel
  319. """
  320. if not isinstance(method, bytes):
  321. method = method.encode("ascii")
  322. if not isinstance(path, bytes):
  323. path = path.encode("ascii")
  324. # Decorate it to be the full path, if we're using shorthand
  325. if (
  326. shorthand
  327. and not path.startswith(b"/_matrix")
  328. and not path.startswith(b"/_synapse")
  329. ):
  330. if path.startswith(b"/"):
  331. path = path[1:]
  332. path = b"/_matrix/client/r0/" + path
  333. if not path.startswith(b"/"):
  334. path = b"/" + path
  335. if isinstance(content, dict):
  336. content = json.dumps(content).encode("utf8")
  337. if isinstance(content, str):
  338. content = content.encode("utf8")
  339. channel = FakeChannel(site, reactor, ip=client_ip)
  340. req = request(channel, site)
  341. channel.request = req
  342. req.content = BytesIO(content)
  343. # Twisted expects to be at the end of the content when parsing the request.
  344. req.content.seek(0, SEEK_END)
  345. # Old version of Twisted (<20.3.0) have issues with parsing x-www-form-urlencoded
  346. # bodies if the Content-Length header is missing
  347. req.requestHeaders.addRawHeader(
  348. b"Content-Length", str(len(content)).encode("ascii")
  349. )
  350. if access_token:
  351. req.requestHeaders.addRawHeader(
  352. b"Authorization", b"Bearer " + access_token.encode("ascii")
  353. )
  354. if federation_auth_origin is not None:
  355. req.requestHeaders.addRawHeader(
  356. b"Authorization",
  357. b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
  358. )
  359. if content:
  360. if content_is_form:
  361. req.requestHeaders.addRawHeader(
  362. b"Content-Type", b"application/x-www-form-urlencoded"
  363. )
  364. else:
  365. # Assume the body is JSON
  366. req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
  367. if custom_headers:
  368. for k, v in custom_headers:
  369. req.requestHeaders.addRawHeader(k, v)
  370. req.parseCookies()
  371. req.requestReceived(method, path, b"1.1")
  372. if await_result:
  373. channel.await_result()
  374. return channel
  375. # ISynapseReactor implies IReactorPluggableNameResolver, but explicitly
  376. # marking this as an implementer of the latter seems to keep mypy-zope happier.
  377. @implementer(IReactorPluggableNameResolver, ISynapseReactor)
  378. class ThreadedMemoryReactorClock(MemoryReactorClock):
  379. """
  380. A MemoryReactorClock that supports callFromThread.
  381. """
  382. def __init__(self) -> None:
  383. self.threadpool = ThreadPool(self)
  384. self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
  385. self._udp: List[udp.Port] = []
  386. self.lookups: Dict[str, str] = {}
  387. self._thread_callbacks: Deque[Callable[..., R]] = deque()
  388. lookups = self.lookups
  389. @implementer(IResolverSimple)
  390. class FakeResolver:
  391. def getHostByName(
  392. self, name: str, timeout: Optional[Sequence[int]] = None
  393. ) -> "Deferred[str]":
  394. if name not in lookups:
  395. return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
  396. return succeed(lookups[name])
  397. # In order for the TLS protocol tests to work, modify _get_default_clock
  398. # on newer Twisted versions to use the test reactor's clock.
  399. #
  400. # This is *super* dirty since it is never undone and relies on the next
  401. # test to overwrite it.
  402. if twisted.version > Version("Twisted", 23, 8, 0):
  403. from twisted.protocols import tls
  404. tls._get_default_clock = lambda: self
  405. self.nameResolver = SimpleResolverComplexifier(FakeResolver())
  406. super().__init__()
  407. def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
  408. raise NotImplementedError()
  409. def listenUDP(
  410. self,
  411. port: int,
  412. protocol: DatagramProtocol,
  413. interface: str = "",
  414. maxPacketSize: int = 8196,
  415. ) -> udp.Port:
  416. p = udp.Port(port, protocol, interface, maxPacketSize, self)
  417. p.startListening()
  418. self._udp.append(p)
  419. return p
  420. def callFromThread(
  421. self, callable: Callable[..., Any], *args: object, **kwargs: object
  422. ) -> None:
  423. """
  424. Make the callback fire in the next reactor iteration.
  425. """
  426. cb = lambda: callable(*args, **kwargs)
  427. # it's not safe to call callLater() here, so we append the callback to a
  428. # separate queue.
  429. self._thread_callbacks.append(cb)
  430. def callInThread(
  431. self, callable: Callable[..., Any], *args: object, **kwargs: object
  432. ) -> None:
  433. raise NotImplementedError()
  434. def suggestThreadPoolSize(self, size: int) -> None:
  435. raise NotImplementedError()
  436. def getThreadPool(self) -> "threadpool.ThreadPool":
  437. # Cast to match super-class.
  438. return cast(threadpool.ThreadPool, self.threadpool)
  439. def add_tcp_client_callback(
  440. self, host: str, port: int, callback: Callable[[], None]
  441. ) -> None:
  442. """Add a callback that will be invoked when we receive a connection
  443. attempt to the given IP/port using `connectTCP`.
  444. Note that the callback gets run before we return the connection to the
  445. client, which means callbacks cannot block while waiting for writes.
  446. """
  447. self._tcp_callbacks[(host, port)] = callback
  448. def connectUNIX(
  449. self,
  450. address: str,
  451. factory: ClientFactory,
  452. timeout: float = 30,
  453. checkPID: int = 0,
  454. ) -> IConnector:
  455. """
  456. Unix sockets aren't supported for unit tests yet. Make it obvious to any
  457. developer trying it out that they will need to do some work before being able
  458. to use it in tests.
  459. """
  460. raise Exception("Unix sockets are not implemented for tests yet, sorry.")
  461. def listenUNIX(
  462. self,
  463. address: str,
  464. factory: Factory,
  465. backlog: int = 50,
  466. mode: int = 0o666,
  467. wantPID: int = 0,
  468. ) -> IListeningPort:
  469. """
  470. Unix sockets aren't supported for unit tests yet. Make it obvious to any
  471. developer trying it out that they will need to do some work before being able
  472. to use it in tests.
  473. """
  474. raise Exception("Unix sockets are not implemented for tests, sorry")
  475. def connectTCP(
  476. self,
  477. host: str,
  478. port: int,
  479. factory: ClientFactory,
  480. timeout: float = 30,
  481. bindAddress: Optional[Tuple[str, int]] = None,
  482. ) -> IConnector:
  483. """Fake L{IReactorTCP.connectTCP}."""
  484. conn = super().connectTCP(
  485. host, port, factory, timeout=timeout, bindAddress=None
  486. )
  487. if self.lookups and host in self.lookups:
  488. validate_connector(conn, self.lookups[host])
  489. callback = self._tcp_callbacks.get((host, port))
  490. if callback:
  491. callback()
  492. return conn
  493. def advance(self, amount: float) -> None:
  494. # first advance our reactor's time, and run any "callLater" callbacks that
  495. # makes ready
  496. super().advance(amount)
  497. # now run any "callFromThread" callbacks
  498. while True:
  499. try:
  500. callback = self._thread_callbacks.popleft()
  501. except IndexError:
  502. break
  503. callback()
  504. # check for more "callLater" callbacks added by the thread callback
  505. # This isn't required in a regular reactor, but it ends up meaning that
  506. # our database queries can complete in a single call to `advance` [1] which
  507. # simplifies tests.
  508. #
  509. # [1]: we replace the threadpool backing the db connection pool with a
  510. # mock ThreadPool which doesn't really use threads; but we still use
  511. # reactor.callFromThread to feed results back from the db functions to the
  512. # main thread.
  513. super().advance(0)
  514. def validate_connector(connector: tcp.Connector, expected_ip: str) -> None:
  515. """Try to validate the obtained connector as it would happen when
  516. synapse is running and the conection will be established.
  517. This method will raise a useful exception when necessary, else it will
  518. just do nothing.
  519. This is in order to help catch quirks related to reactor.connectTCP,
  520. since when called directly, the connector's destination will be of type
  521. IPv4Address, with the hostname as the literal host that was given (which
  522. could be an IPv6-only host or an IPv6 literal).
  523. But when called from reactor.connectTCP *through* e.g. an Endpoint, the
  524. connector's destination will contain the specific IP address with the
  525. correct network stack class.
  526. Note that testing code paths that use connectTCP directly should not be
  527. affected by this check, unless they specifically add a test with a
  528. matching reactor.lookups[HOSTNAME] = "IPv6Literal", where reactor is of
  529. type ThreadedMemoryReactorClock.
  530. For an example of implementing such tests, see test/handlers/send_email.py.
  531. """
  532. destination = connector.getDestination()
  533. # We use address.IPv{4,6}Address to check what the reactor thinks it is
  534. # is sending but check for validity with ipaddress.IPv{4,6}Address
  535. # because they fail with IPs on the wrong network stack.
  536. cls_mapping = {
  537. address.IPv4Address: ipaddress.IPv4Address,
  538. address.IPv6Address: ipaddress.IPv6Address,
  539. }
  540. cls = cls_mapping.get(destination.__class__)
  541. if cls is not None:
  542. try:
  543. cls(expected_ip)
  544. except Exception as exc:
  545. raise ValueError(
  546. "Invalid IP type and resolution for %s. Expected %s to be %s"
  547. % (destination, expected_ip, cls.__name__)
  548. ) from exc
  549. else:
  550. raise ValueError(
  551. "Unknown address type %s for %s"
  552. % (destination.__class__.__name__, destination)
  553. )
  554. class ThreadPool:
  555. """
  556. Threadless thread pool.
  557. See twisted.python.threadpool.ThreadPool
  558. """
  559. def __init__(self, reactor: IReactorTime):
  560. self._reactor = reactor
  561. def start(self) -> None:
  562. pass
  563. def stop(self) -> None:
  564. pass
  565. def callInThreadWithCallback(
  566. self,
  567. onResult: Callable[[bool, Union[Failure, R]], None],
  568. function: Callable[P, R],
  569. *args: P.args,
  570. **kwargs: P.kwargs,
  571. ) -> "Deferred[None]":
  572. def _(res: Any) -> None:
  573. if isinstance(res, Failure):
  574. onResult(False, res)
  575. else:
  576. onResult(True, res)
  577. d: "Deferred[None]" = Deferred()
  578. d.addCallback(lambda x: function(*args, **kwargs))
  579. d.addBoth(_)
  580. self._reactor.callLater(0, d.callback, True)
  581. return d
  582. def _make_test_homeserver_synchronous(server: HomeServer) -> None:
  583. """
  584. Make the given test homeserver's database interactions synchronous.
  585. """
  586. clock = server.get_clock()
  587. for database in server.get_datastores().databases:
  588. pool = database._db_pool
  589. def runWithConnection(
  590. func: Callable[..., R], *args: Any, **kwargs: Any
  591. ) -> Awaitable[R]:
  592. return threads.deferToThreadPool(
  593. pool._reactor,
  594. pool.threadpool,
  595. pool._runWithConnection,
  596. func,
  597. *args,
  598. **kwargs,
  599. )
  600. def runInteraction(
  601. desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
  602. ) -> Awaitable[R]:
  603. return threads.deferToThreadPool(
  604. pool._reactor,
  605. pool.threadpool,
  606. pool._runInteraction,
  607. desc,
  608. func,
  609. *args,
  610. **kwargs,
  611. )
  612. pool.runWithConnection = runWithConnection # type: ignore[method-assign]
  613. pool.runInteraction = runInteraction # type: ignore[assignment]
  614. # Replace the thread pool with a threadless 'thread' pool
  615. pool.threadpool = ThreadPool(clock._reactor)
  616. pool.running = True
  617. # We've just changed the Databases to run DB transactions on the same
  618. # thread, so we need to disable the dedicated thread behaviour.
  619. server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
  620. def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
  621. clock = ThreadedMemoryReactorClock()
  622. hs_clock = Clock(clock)
  623. return clock, hs_clock
  624. @implementer(ITransport)
  625. @attr.s(cmp=False, auto_attribs=True)
  626. class FakeTransport:
  627. """
  628. A twisted.internet.interfaces.ITransport implementation which sends all its data
  629. straight into an IProtocol object: it exists to connect two IProtocols together.
  630. To use it, instantiate it with the receiving IProtocol, and then pass it to the
  631. sending IProtocol's makeConnection method:
  632. server = HTTPChannel()
  633. client.makeConnection(FakeTransport(server, self.reactor))
  634. If you want bidirectional communication, you'll need two instances.
  635. """
  636. other: IProtocol
  637. """The Protocol object which will receive any data written to this transport.
  638. """
  639. _reactor: IReactorTime
  640. """Test reactor
  641. """
  642. _protocol: Optional[IProtocol] = None
  643. """The Protocol which is producing data for this transport. Optional, but if set
  644. will get called back for connectionLost() notifications etc.
  645. """
  646. _peer_address: IAddress = attr.Factory(
  647. lambda: address.IPv4Address("TCP", "127.0.0.1", 5678)
  648. )
  649. """The value to be returned by getPeer"""
  650. _host_address: IAddress = attr.Factory(
  651. lambda: address.IPv4Address("TCP", "127.0.0.1", 1234)
  652. )
  653. """The value to be returned by getHost"""
  654. disconnecting = False
  655. disconnected = False
  656. connected = True
  657. buffer: bytes = b""
  658. producer: Optional[IPushProducer] = None
  659. autoflush: bool = True
  660. def getPeer(self) -> IAddress:
  661. return self._peer_address
  662. def getHost(self) -> IAddress:
  663. return self._host_address
  664. def loseConnection(self) -> None:
  665. if not self.disconnecting:
  666. logger.info("FakeTransport: loseConnection()")
  667. self.disconnecting = True
  668. if self._protocol:
  669. self._protocol.connectionLost(
  670. Failure(RuntimeError("FakeTransport.loseConnection()"))
  671. )
  672. # if we still have data to write, delay until that is done
  673. if self.buffer:
  674. logger.info(
  675. "FakeTransport: Delaying disconnect until buffer is flushed"
  676. )
  677. else:
  678. self.connected = False
  679. self.disconnected = True
  680. def abortConnection(self) -> None:
  681. logger.info("FakeTransport: abortConnection()")
  682. if not self.disconnecting:
  683. self.disconnecting = True
  684. if self._protocol:
  685. self._protocol.connectionLost(None) # type: ignore[arg-type]
  686. self.disconnected = True
  687. def pauseProducing(self) -> None:
  688. if not self.producer:
  689. return
  690. self.producer.pauseProducing()
  691. def resumeProducing(self) -> None:
  692. if not self.producer:
  693. return
  694. self.producer.resumeProducing()
  695. def unregisterProducer(self) -> None:
  696. if not self.producer:
  697. return
  698. self.producer = None
  699. def registerProducer(self, producer: IPushProducer, streaming: bool) -> None:
  700. self.producer = producer
  701. self.producerStreaming = streaming
  702. def _produce() -> None:
  703. if not self.producer:
  704. # we've been unregistered
  705. return
  706. # some implementations of IProducer (for example, FileSender)
  707. # don't return a deferred.
  708. d = maybeDeferred(self.producer.resumeProducing)
  709. d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
  710. if not streaming:
  711. self._reactor.callLater(0.0, _produce)
  712. def write(self, byt: bytes) -> None:
  713. if self.disconnecting:
  714. raise Exception("Writing to disconnecting FakeTransport")
  715. self.buffer = self.buffer + byt
  716. # always actually do the write asynchronously. Some protocols (notably the
  717. # TLSMemoryBIOProtocol) get very confused if a read comes back while they are
  718. # still doing a write. Doing a callLater here breaks the cycle.
  719. if self.autoflush:
  720. self._reactor.callLater(0.0, self.flush)
  721. def writeSequence(self, seq: Iterable[bytes]) -> None:
  722. for x in seq:
  723. self.write(x)
  724. def flush(self, maxbytes: Optional[int] = None) -> None:
  725. if not self.buffer:
  726. # nothing to do. Don't write empty buffers: it upsets the
  727. # TLSMemoryBIOProtocol
  728. return
  729. if self.disconnected:
  730. return
  731. if maxbytes is not None:
  732. to_write = self.buffer[:maxbytes]
  733. else:
  734. to_write = self.buffer
  735. logger.info("%s->%s: %s", self._protocol, self.other, to_write)
  736. try:
  737. self.other.dataReceived(to_write)
  738. except Exception as e:
  739. logger.exception("Exception writing to protocol: %s", e)
  740. return
  741. self.buffer = self.buffer[len(to_write) :]
  742. if self.buffer and self.autoflush:
  743. self._reactor.callLater(0.0, self.flush)
  744. if not self.buffer and self.disconnecting:
  745. logger.info("FakeTransport: Buffer now empty, completing disconnect")
  746. self.disconnected = True
  747. def connect_client(
  748. reactor: ThreadedMemoryReactorClock, client_id: int
  749. ) -> Tuple[IProtocol, AccumulatingProtocol]:
  750. """
  751. Connect a client to a fake TCP transport.
  752. Args:
  753. reactor
  754. factory: The connecting factory to build.
  755. """
  756. factory = reactor.tcpClients.pop(client_id)[2]
  757. client = factory.buildProtocol(None)
  758. server = AccumulatingProtocol()
  759. server.makeConnection(FakeTransport(client, reactor))
  760. client.makeConnection(FakeTransport(server, reactor))
  761. return client, server
  762. class TestHomeServer(HomeServer):
  763. DATASTORE_CLASS = DataStore # type: ignore[assignment]
  764. def setup_test_homeserver(
  765. cleanup_func: Callable[[Callable[[], None]], None],
  766. name: str = "test",
  767. config: Optional[HomeServerConfig] = None,
  768. reactor: Optional[ISynapseReactor] = None,
  769. homeserver_to_use: Type[HomeServer] = TestHomeServer,
  770. **kwargs: Any,
  771. ) -> HomeServer:
  772. """
  773. Setup a homeserver suitable for running tests against. Keyword arguments
  774. are passed to the Homeserver constructor.
  775. If no datastore is supplied, one is created and given to the homeserver.
  776. Args:
  777. cleanup_func : The function used to register a cleanup routine for
  778. after the test.
  779. Calling this method directly is deprecated: you should instead derive from
  780. HomeserverTestCase.
  781. """
  782. if reactor is None:
  783. from twisted.internet import reactor as _reactor
  784. reactor = cast(ISynapseReactor, _reactor)
  785. if config is None:
  786. config = default_config(name, parse=True)
  787. config.caches.resize_all_caches()
  788. if "clock" not in kwargs:
  789. kwargs["clock"] = MockClock()
  790. if USE_POSTGRES_FOR_TESTS:
  791. test_db = "synapse_test_%s" % uuid.uuid4().hex
  792. database_config = {
  793. "name": "psycopg2",
  794. "args": {
  795. "dbname": test_db,
  796. "host": POSTGRES_HOST,
  797. "password": POSTGRES_PASSWORD,
  798. "user": POSTGRES_USER,
  799. "port": POSTGRES_PORT,
  800. "cp_min": 1,
  801. "cp_max": 5,
  802. },
  803. }
  804. else:
  805. if SQLITE_PERSIST_DB:
  806. # The current working directory is in _trial_temp, so this gets created within that directory.
  807. test_db_location = os.path.abspath("test.db")
  808. logger.debug("Will persist db to %s", test_db_location)
  809. # Ensure each test gets a clean database.
  810. try:
  811. os.remove(test_db_location)
  812. except FileNotFoundError:
  813. pass
  814. else:
  815. logger.debug("Removed existing DB at %s", test_db_location)
  816. else:
  817. test_db_location = ":memory:"
  818. database_config = {
  819. "name": "sqlite3",
  820. "args": {"database": test_db_location, "cp_min": 1, "cp_max": 1},
  821. }
  822. # Check if we have set up a DB that we can use as a template.
  823. global PREPPED_SQLITE_DB_CONN
  824. if PREPPED_SQLITE_DB_CONN is None:
  825. temp_engine = create_engine(database_config)
  826. PREPPED_SQLITE_DB_CONN = LoggingDatabaseConnection(
  827. sqlite3.connect(":memory:"), temp_engine, "PREPPED_CONN"
  828. )
  829. database = DatabaseConnectionConfig("master", database_config)
  830. config.database.databases = [database]
  831. prepare_database(
  832. PREPPED_SQLITE_DB_CONN, create_engine(database_config), config
  833. )
  834. database_config["_TEST_PREPPED_CONN"] = PREPPED_SQLITE_DB_CONN
  835. if "db_txn_limit" in kwargs:
  836. database_config["txn_limit"] = kwargs["db_txn_limit"]
  837. database = DatabaseConnectionConfig("master", database_config)
  838. config.database.databases = [database]
  839. db_engine = create_engine(database.config)
  840. # Create the database before we actually try and connect to it, based off
  841. # the template database we generate in setupdb()
  842. if USE_POSTGRES_FOR_TESTS:
  843. db_conn = db_engine.module.connect(
  844. dbname=POSTGRES_BASE_DB,
  845. user=POSTGRES_USER,
  846. host=POSTGRES_HOST,
  847. port=POSTGRES_PORT,
  848. password=POSTGRES_PASSWORD,
  849. )
  850. db_engine.attempt_to_set_autocommit(db_conn, True)
  851. cur = db_conn.cursor()
  852. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  853. cur.execute(
  854. "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
  855. )
  856. cur.close()
  857. db_conn.close()
  858. hs = homeserver_to_use(
  859. name,
  860. config=config,
  861. version_string="Synapse/tests",
  862. reactor=reactor,
  863. )
  864. # Install @cache_in_self attributes
  865. for key, val in kwargs.items():
  866. setattr(hs, "_" + key, val)
  867. # Mock TLS
  868. hs.tls_server_context_factory = Mock()
  869. hs.setup()
  870. if USE_POSTGRES_FOR_TESTS:
  871. database_pool = hs.get_datastores().databases[0]
  872. # We need to do cleanup on PostgreSQL
  873. def cleanup() -> None:
  874. import psycopg2
  875. # Close all the db pools
  876. database_pool._db_pool.close()
  877. dropped = False
  878. # Drop the test database
  879. db_conn = db_engine.module.connect(
  880. dbname=POSTGRES_BASE_DB,
  881. user=POSTGRES_USER,
  882. host=POSTGRES_HOST,
  883. port=POSTGRES_PORT,
  884. password=POSTGRES_PASSWORD,
  885. )
  886. db_engine.attempt_to_set_autocommit(db_conn, True)
  887. cur = db_conn.cursor()
  888. # Try a few times to drop the DB. Some things may hold on to the
  889. # database for a few more seconds due to flakiness, preventing
  890. # us from dropping it when the test is over. If we can't drop
  891. # it, warn and move on.
  892. for _ in range(5):
  893. try:
  894. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  895. db_conn.commit()
  896. dropped = True
  897. except psycopg2.OperationalError as e:
  898. warnings.warn(
  899. "Couldn't drop old db: " + str(e),
  900. category=UserWarning,
  901. stacklevel=2,
  902. )
  903. time.sleep(0.5)
  904. cur.close()
  905. db_conn.close()
  906. if not dropped:
  907. warnings.warn(
  908. "Failed to drop old DB.",
  909. category=UserWarning,
  910. stacklevel=2,
  911. )
  912. if not LEAVE_DB:
  913. # Register the cleanup hook
  914. cleanup_func(cleanup)
  915. # bcrypt is far too slow to be doing in unit tests
  916. # Need to let the HS build an auth handler and then mess with it
  917. # because AuthHandler's constructor requires the HS, so we can't make one
  918. # beforehand and pass it in to the HS's constructor (chicken / egg)
  919. async def hash(p: str) -> str:
  920. return hashlib.md5(p.encode("utf8")).hexdigest()
  921. hs.get_auth_handler().hash = hash # type: ignore[assignment]
  922. async def validate_hash(p: str, h: str) -> bool:
  923. return hashlib.md5(p.encode("utf8")).hexdigest() == h
  924. hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment]
  925. # Make the threadpool and database transactions synchronous for testing.
  926. _make_test_homeserver_synchronous(hs)
  927. # Load any configured modules into the homeserver
  928. module_api = hs.get_module_api()
  929. for module, module_config in hs.config.modules.loaded_modules:
  930. module(config=module_config, api=module_api)
  931. load_legacy_spam_checkers(hs)
  932. load_legacy_third_party_event_rules(hs)
  933. load_legacy_presence_router(hs)
  934. load_legacy_password_auth_providers(hs)
  935. return hs