|
|
@@ -22,20 +22,25 @@ import warnings |
|
|
|
from collections import deque |
|
|
|
from io import SEEK_END, BytesIO |
|
|
|
from typing import ( |
|
|
|
Any, |
|
|
|
Awaitable, |
|
|
|
Callable, |
|
|
|
Dict, |
|
|
|
Iterable, |
|
|
|
List, |
|
|
|
MutableMapping, |
|
|
|
Optional, |
|
|
|
Sequence, |
|
|
|
Tuple, |
|
|
|
Type, |
|
|
|
TypeVar, |
|
|
|
Union, |
|
|
|
cast, |
|
|
|
) |
|
|
|
from unittest.mock import Mock |
|
|
|
|
|
|
|
import attr |
|
|
|
from typing_extensions import Deque |
|
|
|
from typing_extensions import Deque, ParamSpec |
|
|
|
from zope.interface import implementer |
|
|
|
|
|
|
|
from twisted.internet import address, threads, udp |
|
|
@@ -44,8 +49,10 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed |
|
|
|
from twisted.internet.error import DNSLookupError |
|
|
|
from twisted.internet.interfaces import ( |
|
|
|
IAddress, |
|
|
|
IConnector, |
|
|
|
IConsumer, |
|
|
|
IHostnameResolver, |
|
|
|
IProducer, |
|
|
|
IProtocol, |
|
|
|
IPullProducer, |
|
|
|
IPushProducer, |
|
|
@@ -54,6 +61,8 @@ from twisted.internet.interfaces import ( |
|
|
|
IResolverSimple, |
|
|
|
ITransport, |
|
|
|
) |
|
|
|
from twisted.internet.protocol import ClientFactory, DatagramProtocol |
|
|
|
from twisted.python import threadpool |
|
|
|
from twisted.python.failure import Failure |
|
|
|
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock |
|
|
|
from twisted.web.http_headers import Headers |
|
|
@@ -61,6 +70,7 @@ from twisted.web.resource import IResource |
|
|
|
from twisted.web.server import Request, Site |
|
|
|
|
|
|
|
from synapse.config.database import DatabaseConnectionConfig |
|
|
|
from synapse.config.homeserver import HomeServerConfig |
|
|
|
from synapse.events.presence_router import load_legacy_presence_router |
|
|
|
from synapse.events.spamcheck import load_legacy_spam_checkers |
|
|
|
from synapse.events.third_party_rules import load_legacy_third_party_event_rules |
|
|
@@ -88,6 +98,9 @@ from tests.utils import ( |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
R = TypeVar("R") |
|
|
|
P = ParamSpec("P") |
|
|
|
|
|
|
|
# the type of thing that can be passed into `make_request` in the headers list |
|
|
|
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]] |
|
|
|
|
|
|
@@ -98,12 +111,14 @@ class TimedOutException(Exception): |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
@implementer(IConsumer) |
|
|
|
@implementer(ITransport, IPushProducer, IConsumer) |
|
|
|
@attr.s(auto_attribs=True) |
|
|
|
class FakeChannel: |
|
|
|
""" |
|
|
|
A fake Twisted Web Channel (the part that interfaces with the |
|
|
|
wire). |
|
|
|
|
|
|
|
See twisted.web.http.HTTPChannel. |
|
|
|
""" |
|
|
|
|
|
|
|
site: Union[Site, "FakeSite"] |
|
|
@@ -142,7 +157,7 @@ class FakeChannel: |
|
|
|
|
|
|
|
Raises an exception if the request has not yet completed. |
|
|
|
""" |
|
|
|
if not self.is_finished: |
|
|
|
if not self.is_finished(): |
|
|
|
raise Exception("Request not yet completed") |
|
|
|
return self.result["body"].decode("utf8") |
|
|
|
|
|
|
@@ -165,27 +180,36 @@ class FakeChannel: |
|
|
|
h.addRawHeader(*i) |
|
|
|
return h |
|
|
|
|
|
|
|
def writeHeaders(self, version, code, reason, headers): |
|
|
|
def writeHeaders( |
|
|
|
self, version: bytes, code: bytes, reason: bytes, headers: Headers |
|
|
|
) -> None: |
|
|
|
self.result["version"] = version |
|
|
|
self.result["code"] = code |
|
|
|
self.result["reason"] = reason |
|
|
|
self.result["headers"] = headers |
|
|
|
|
|
|
|
def write(self, content: bytes) -> None: |
|
|
|
assert isinstance(content, bytes), "Should be bytes! " + repr(content) |
|
|
|
def write(self, data: bytes) -> None: |
|
|
|
assert isinstance(data, bytes), "Should be bytes! " + repr(data) |
|
|
|
|
|
|
|
if "body" not in self.result: |
|
|
|
self.result["body"] = b"" |
|
|
|
|
|
|
|
self.result["body"] += content |
|
|
|
self.result["body"] += data |
|
|
|
|
|
|
|
def writeSequence(self, data: Iterable[bytes]) -> None: |
|
|
|
for x in data: |
|
|
|
self.write(x) |
|
|
|
|
|
|
|
def loseConnection(self) -> None: |
|
|
|
self.unregisterProducer() |
|
|
|
self.transport.loseConnection() |
|
|
|
|
|
|
|
# Type ignore: mypy doesn't like the fact that producer isn't an IProducer. |
|
|
|
def registerProducer( # type: ignore[override] |
|
|
|
self, |
|
|
|
producer: Union[IPullProducer, IPushProducer], |
|
|
|
streaming: bool, |
|
|
|
) -> None: |
|
|
|
self._producer = producer |
|
|
|
def registerProducer(self, producer: IProducer, streaming: bool) -> None: |
|
|
|
# TODO This should ensure that the IProducer is an IPushProducer or |
|
|
|
# IPullProducer, unfortunately twisted.protocols.basic.FileSender does |
|
|
|
# implement those, but doesn't declare it. |
|
|
|
self._producer = cast(Union[IPushProducer, IPullProducer], producer) |
|
|
|
self.producerStreaming = streaming |
|
|
|
|
|
|
|
def _produce() -> None: |
|
|
@@ -202,6 +226,16 @@ class FakeChannel: |
|
|
|
|
|
|
|
self._producer = None |
|
|
|
|
|
|
|
def stopProducing(self) -> None: |
|
|
|
if self._producer is not None: |
|
|
|
self._producer.stopProducing() |
|
|
|
|
|
|
|
def pauseProducing(self) -> None: |
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
def resumeProducing(self) -> None: |
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
def requestDone(self, _self: Request) -> None: |
|
|
|
self.result["done"] = True |
|
|
|
if isinstance(_self, SynapseRequest): |
|
|
@@ -281,12 +315,12 @@ class FakeSite: |
|
|
|
self.reactor = reactor |
|
|
|
self.experimental_cors_msc3886 = experimental_cors_msc3886 |
|
|
|
|
|
|
|
def getResourceFor(self, request): |
|
|
|
def getResourceFor(self, request: Request) -> IResource: |
|
|
|
return self._resource |
|
|
|
|
|
|
|
|
|
|
|
def make_request( |
|
|
|
reactor, |
|
|
|
reactor: MemoryReactorClock, |
|
|
|
site: Union[Site, FakeSite], |
|
|
|
method: Union[bytes, str], |
|
|
|
path: Union[bytes, str], |
|
|
@@ -409,19 +443,21 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): |
|
|
|
A MemoryReactorClock that supports callFromThread. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
def __init__(self) -> None: |
|
|
|
self.threadpool = ThreadPool(self) |
|
|
|
|
|
|
|
self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {} |
|
|
|
self._udp = [] |
|
|
|
self._udp: List[udp.Port] = [] |
|
|
|
self.lookups: Dict[str, str] = {} |
|
|
|
self._thread_callbacks: Deque[Callable[[], None]] = deque() |
|
|
|
self._thread_callbacks: Deque[Callable[..., R]] = deque() |
|
|
|
|
|
|
|
lookups = self.lookups |
|
|
|
|
|
|
|
@implementer(IResolverSimple) |
|
|
|
class FakeResolver: |
|
|
|
def getHostByName(self, name, timeout=None): |
|
|
|
def getHostByName( |
|
|
|
self, name: str, timeout: Optional[Sequence[int]] = None |
|
|
|
) -> "Deferred[str]": |
|
|
|
if name not in lookups: |
|
|
|
return fail(DNSLookupError("OH NO: unknown %s" % (name,))) |
|
|
|
return succeed(lookups[name]) |
|
|
@@ -432,25 +468,44 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): |
|
|
|
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver: |
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): |
|
|
|
def listenUDP( |
|
|
|
self, |
|
|
|
port: int, |
|
|
|
protocol: DatagramProtocol, |
|
|
|
interface: str = "", |
|
|
|
maxPacketSize: int = 8196, |
|
|
|
) -> udp.Port: |
|
|
|
p = udp.Port(port, protocol, interface, maxPacketSize, self) |
|
|
|
p.startListening() |
|
|
|
self._udp.append(p) |
|
|
|
return p |
|
|
|
|
|
|
|
def callFromThread(self, callback, *args, **kwargs): |
|
|
|
def callFromThread( |
|
|
|
self, callable: Callable[..., Any], *args: object, **kwargs: object |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Make the callback fire in the next reactor iteration. |
|
|
|
""" |
|
|
|
cb = lambda: callback(*args, **kwargs) |
|
|
|
cb = lambda: callable(*args, **kwargs) |
|
|
|
# it's not safe to call callLater() here, so we append the callback to a |
|
|
|
# separate queue. |
|
|
|
self._thread_callbacks.append(cb) |
|
|
|
|
|
|
|
def getThreadPool(self): |
|
|
|
return self.threadpool |
|
|
|
def callInThread( |
|
|
|
self, callable: Callable[..., Any], *args: object, **kwargs: object |
|
|
|
) -> None: |
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
def suggestThreadPoolSize(self, size: int) -> None: |
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
def getThreadPool(self) -> "threadpool.ThreadPool": |
|
|
|
# Cast to match super-class. |
|
|
|
return cast(threadpool.ThreadPool, self.threadpool) |
|
|
|
|
|
|
|
def add_tcp_client_callback(self, host: str, port: int, callback: Callable): |
|
|
|
def add_tcp_client_callback( |
|
|
|
self, host: str, port: int, callback: Callable[[], None] |
|
|
|
) -> None: |
|
|
|
"""Add a callback that will be invoked when we receive a connection |
|
|
|
attempt to the given IP/port using `connectTCP`. |
|
|
|
|
|
|
@@ -459,7 +514,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): |
|
|
|
""" |
|
|
|
self._tcp_callbacks[(host, port)] = callback |
|
|
|
|
|
|
|
def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None): |
|
|
|
def connectTCP( |
|
|
|
self, |
|
|
|
host: str, |
|
|
|
port: int, |
|
|
|
factory: ClientFactory, |
|
|
|
timeout: float = 30, |
|
|
|
bindAddress: Optional[Tuple[str, int]] = None, |
|
|
|
) -> IConnector: |
|
|
|
"""Fake L{IReactorTCP.connectTCP}.""" |
|
|
|
|
|
|
|
conn = super().connectTCP( |
|
|
@@ -472,7 +534,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): |
|
|
|
|
|
|
|
return conn |
|
|
|
|
|
|
|
def advance(self, amount): |
|
|
|
def advance(self, amount: float) -> None: |
|
|
|
# first advance our reactor's time, and run any "callLater" callbacks that |
|
|
|
# makes ready |
|
|
|
super().advance(amount) |
|
|
@@ -500,25 +562,33 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): |
|
|
|
class ThreadPool: |
|
|
|
""" |
|
|
|
Threadless thread pool. |
|
|
|
|
|
|
|
See twisted.python.threadpool.ThreadPool |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, reactor): |
|
|
|
def __init__(self, reactor: IReactorTime): |
|
|
|
self._reactor = reactor |
|
|
|
|
|
|
|
def start(self): |
|
|
|
def start(self) -> None: |
|
|
|
pass |
|
|
|
|
|
|
|
def stop(self): |
|
|
|
def stop(self) -> None: |
|
|
|
pass |
|
|
|
|
|
|
|
def callInThreadWithCallback(self, onResult, function, *args, **kwargs): |
|
|
|
def _(res): |
|
|
|
def callInThreadWithCallback( |
|
|
|
self, |
|
|
|
onResult: Callable[[bool, Union[Failure, R]], None], |
|
|
|
function: Callable[P, R], |
|
|
|
*args: P.args, |
|
|
|
**kwargs: P.kwargs, |
|
|
|
) -> "Deferred[None]": |
|
|
|
def _(res: Any) -> None: |
|
|
|
if isinstance(res, Failure): |
|
|
|
onResult(False, res) |
|
|
|
else: |
|
|
|
onResult(True, res) |
|
|
|
|
|
|
|
d = Deferred() |
|
|
|
d: "Deferred[None]" = Deferred() |
|
|
|
d.addCallback(lambda x: function(*args, **kwargs)) |
|
|
|
d.addBoth(_) |
|
|
|
self._reactor.callLater(0, d.callback, True) |
|
|
@@ -535,7 +605,9 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None: |
|
|
|
for database in server.get_datastores().databases: |
|
|
|
pool = database._db_pool |
|
|
|
|
|
|
|
def runWithConnection(func, *args, **kwargs): |
|
|
|
def runWithConnection( |
|
|
|
func: Callable[..., R], *args: Any, **kwargs: Any |
|
|
|
) -> Awaitable[R]: |
|
|
|
return threads.deferToThreadPool( |
|
|
|
pool._reactor, |
|
|
|
pool.threadpool, |
|
|
@@ -545,20 +617,23 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None: |
|
|
|
**kwargs, |
|
|
|
) |
|
|
|
|
|
|
|
def runInteraction(interaction, *args, **kwargs): |
|
|
|
def runInteraction( |
|
|
|
desc: str, func: Callable[..., R], *args: Any, **kwargs: Any |
|
|
|
) -> Awaitable[R]: |
|
|
|
return threads.deferToThreadPool( |
|
|
|
pool._reactor, |
|
|
|
pool.threadpool, |
|
|
|
pool._runInteraction, |
|
|
|
interaction, |
|
|
|
desc, |
|
|
|
func, |
|
|
|
*args, |
|
|
|
**kwargs, |
|
|
|
) |
|
|
|
|
|
|
|
pool.runWithConnection = runWithConnection |
|
|
|
pool.runInteraction = runInteraction |
|
|
|
pool.runWithConnection = runWithConnection # type: ignore[assignment] |
|
|
|
pool.runInteraction = runInteraction # type: ignore[assignment] |
|
|
|
# Replace the thread pool with a threadless 'thread' pool |
|
|
|
pool.threadpool = ThreadPool(clock._reactor) |
|
|
|
pool.threadpool = ThreadPool(clock._reactor) # type: ignore[assignment] |
|
|
|
pool.running = True |
|
|
|
|
|
|
|
# We've just changed the Databases to run DB transactions on the same |
|
|
@@ -573,7 +648,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: |
|
|
|
|
|
|
|
|
|
|
|
@implementer(ITransport) |
|
|
|
@attr.s(cmp=False) |
|
|
|
@attr.s(cmp=False, auto_attribs=True) |
|
|
|
class FakeTransport: |
|
|
|
""" |
|
|
|
A twisted.internet.interfaces.ITransport implementation which sends all its data |
|
|
@@ -588,48 +663,50 @@ class FakeTransport: |
|
|
|
If you want bidirectional communication, you'll need two instances. |
|
|
|
""" |
|
|
|
|
|
|
|
other = attr.ib() |
|
|
|
other: IProtocol |
|
|
|
"""The Protocol object which will receive any data written to this transport. |
|
|
|
|
|
|
|
:type: twisted.internet.interfaces.IProtocol |
|
|
|
""" |
|
|
|
|
|
|
|
_reactor = attr.ib() |
|
|
|
_reactor: IReactorTime |
|
|
|
"""Test reactor |
|
|
|
|
|
|
|
:type: twisted.internet.interfaces.IReactorTime |
|
|
|
""" |
|
|
|
|
|
|
|
_protocol = attr.ib(default=None) |
|
|
|
_protocol: Optional[IProtocol] = None |
|
|
|
"""The Protocol which is producing data for this transport. Optional, but if set |
|
|
|
will get called back for connectionLost() notifications etc. |
|
|
|
""" |
|
|
|
|
|
|
|
_peer_address: Optional[IAddress] = attr.ib(default=None) |
|
|
|
_peer_address: IAddress = attr.Factory( |
|
|
|
lambda: address.IPv4Address("TCP", "127.0.0.1", 5678) |
|
|
|
) |
|
|
|
"""The value to be returned by getPeer""" |
|
|
|
|
|
|
|
_host_address: Optional[IAddress] = attr.ib(default=None) |
|
|
|
_host_address: IAddress = attr.Factory( |
|
|
|
lambda: address.IPv4Address("TCP", "127.0.0.1", 1234) |
|
|
|
) |
|
|
|
"""The value to be returned by getHost""" |
|
|
|
|
|
|
|
disconnecting = False |
|
|
|
disconnected = False |
|
|
|
connected = True |
|
|
|
buffer = attr.ib(default=b"") |
|
|
|
producer = attr.ib(default=None) |
|
|
|
autoflush = attr.ib(default=True) |
|
|
|
buffer: bytes = b"" |
|
|
|
producer: Optional[IPushProducer] = None |
|
|
|
autoflush: bool = True |
|
|
|
|
|
|
|
def getPeer(self) -> Optional[IAddress]: |
|
|
|
def getPeer(self) -> IAddress: |
|
|
|
return self._peer_address |
|
|
|
|
|
|
|
def getHost(self) -> Optional[IAddress]: |
|
|
|
def getHost(self) -> IAddress: |
|
|
|
return self._host_address |
|
|
|
|
|
|
|
def loseConnection(self, reason=None): |
|
|
|
def loseConnection(self) -> None: |
|
|
|
if not self.disconnecting: |
|
|
|
logger.info("FakeTransport: loseConnection(%s)", reason) |
|
|
|
logger.info("FakeTransport: loseConnection()") |
|
|
|
self.disconnecting = True |
|
|
|
if self._protocol: |
|
|
|
self._protocol.connectionLost(reason) |
|
|
|
self._protocol.connectionLost( |
|
|
|
Failure(RuntimeError("FakeTransport.loseConnection()")) |
|
|
|
) |
|
|
|
|
|
|
|
# if we still have data to write, delay until that is done |
|
|
|
if self.buffer: |
|
|
@@ -640,38 +717,38 @@ class FakeTransport: |
|
|
|
self.connected = False |
|
|
|
self.disconnected = True |
|
|
|
|
|
|
|
def abortConnection(self): |
|
|
|
def abortConnection(self) -> None: |
|
|
|
logger.info("FakeTransport: abortConnection()") |
|
|
|
|
|
|
|
if not self.disconnecting: |
|
|
|
self.disconnecting = True |
|
|
|
if self._protocol: |
|
|
|
self._protocol.connectionLost(None) |
|
|
|
self._protocol.connectionLost(None) # type: ignore[arg-type] |
|
|
|
|
|
|
|
self.disconnected = True |
|
|
|
|
|
|
|
def pauseProducing(self): |
|
|
|
def pauseProducing(self) -> None: |
|
|
|
if not self.producer: |
|
|
|
return |
|
|
|
|
|
|
|
self.producer.pauseProducing() |
|
|
|
|
|
|
|
def resumeProducing(self): |
|
|
|
def resumeProducing(self) -> None: |
|
|
|
if not self.producer: |
|
|
|
return |
|
|
|
self.producer.resumeProducing() |
|
|
|
|
|
|
|
def unregisterProducer(self): |
|
|
|
def unregisterProducer(self) -> None: |
|
|
|
if not self.producer: |
|
|
|
return |
|
|
|
|
|
|
|
self.producer = None |
|
|
|
|
|
|
|
def registerProducer(self, producer, streaming): |
|
|
|
def registerProducer(self, producer: IPushProducer, streaming: bool) -> None: |
|
|
|
self.producer = producer |
|
|
|
self.producerStreaming = streaming |
|
|
|
|
|
|
|
def _produce(): |
|
|
|
def _produce() -> None: |
|
|
|
if not self.producer: |
|
|
|
# we've been unregistered |
|
|
|
return |
|
|
@@ -683,7 +760,7 @@ class FakeTransport: |
|
|
|
if not streaming: |
|
|
|
self._reactor.callLater(0.0, _produce) |
|
|
|
|
|
|
|
def write(self, byt): |
|
|
|
def write(self, byt: bytes) -> None: |
|
|
|
if self.disconnecting: |
|
|
|
raise Exception("Writing to disconnecting FakeTransport") |
|
|
|
|
|
|
@@ -695,11 +772,11 @@ class FakeTransport: |
|
|
|
if self.autoflush: |
|
|
|
self._reactor.callLater(0.0, self.flush) |
|
|
|
|
|
|
|
def writeSequence(self, seq): |
|
|
|
def writeSequence(self, seq: Iterable[bytes]) -> None: |
|
|
|
for x in seq: |
|
|
|
self.write(x) |
|
|
|
|
|
|
|
def flush(self, maxbytes=None): |
|
|
|
def flush(self, maxbytes: Optional[int] = None) -> None: |
|
|
|
if not self.buffer: |
|
|
|
# nothing to do. Don't write empty buffers: it upsets the |
|
|
|
# TLSMemoryBIOProtocol |
|
|
@@ -750,17 +827,17 @@ def connect_client( |
|
|
|
|
|
|
|
|
|
|
|
class TestHomeServer(HomeServer): |
|
|
|
DATASTORE_CLASS = DataStore |
|
|
|
DATASTORE_CLASS = DataStore # type: ignore[assignment] |
|
|
|
|
|
|
|
|
|
|
|
def setup_test_homeserver( |
|
|
|
cleanup_func, |
|
|
|
name="test", |
|
|
|
config=None, |
|
|
|
reactor=None, |
|
|
|
cleanup_func: Callable[[Callable[[], None]], None], |
|
|
|
name: str = "test", |
|
|
|
config: Optional[HomeServerConfig] = None, |
|
|
|
reactor: Optional[ISynapseReactor] = None, |
|
|
|
homeserver_to_use: Type[HomeServer] = TestHomeServer, |
|
|
|
**kwargs, |
|
|
|
): |
|
|
|
**kwargs: Any, |
|
|
|
) -> HomeServer: |
|
|
|
""" |
|
|
|
Setup a homeserver suitable for running tests against. Keyword arguments |
|
|
|
are passed to the Homeserver constructor. |
|
|
@@ -775,13 +852,14 @@ def setup_test_homeserver( |
|
|
|
HomeserverTestCase. |
|
|
|
""" |
|
|
|
if reactor is None: |
|
|
|
from twisted.internet import reactor |
|
|
|
from twisted.internet import reactor as _reactor |
|
|
|
|
|
|
|
reactor = cast(ISynapseReactor, _reactor) |
|
|
|
|
|
|
|
if config is None: |
|
|
|
config = default_config(name, parse=True) |
|
|
|
|
|
|
|
config.caches.resize_all_caches() |
|
|
|
config.ldap_enabled = False |
|
|
|
|
|
|
|
if "clock" not in kwargs: |
|
|
|
kwargs["clock"] = MockClock() |
|
|
@@ -832,6 +910,8 @@ def setup_test_homeserver( |
|
|
|
# Create the database before we actually try and connect to it, based off |
|
|
|
# the template database we generate in setupdb() |
|
|
|
if isinstance(db_engine, PostgresEngine): |
|
|
|
import psycopg2.extensions |
|
|
|
|
|
|
|
db_conn = db_engine.module.connect( |
|
|
|
database=POSTGRES_BASE_DB, |
|
|
|
user=POSTGRES_USER, |
|
|
@@ -839,6 +919,7 @@ def setup_test_homeserver( |
|
|
|
port=POSTGRES_PORT, |
|
|
|
password=POSTGRES_PASSWORD, |
|
|
|
) |
|
|
|
assert isinstance(db_conn, psycopg2.extensions.connection) |
|
|
|
db_conn.autocommit = True |
|
|
|
cur = db_conn.cursor() |
|
|
|
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) |
|
|
@@ -867,14 +948,15 @@ def setup_test_homeserver( |
|
|
|
hs.setup_background_tasks() |
|
|
|
|
|
|
|
if isinstance(db_engine, PostgresEngine): |
|
|
|
database = hs.get_datastores().databases[0] |
|
|
|
database_pool = hs.get_datastores().databases[0] |
|
|
|
|
|
|
|
# We need to do cleanup on PostgreSQL |
|
|
|
def cleanup(): |
|
|
|
def cleanup() -> None: |
|
|
|
import psycopg2 |
|
|
|
import psycopg2.extensions |
|
|
|
|
|
|
|
# Close all the db pools |
|
|
|
database._db_pool.close() |
|
|
|
database_pool._db_pool.close() |
|
|
|
|
|
|
|
dropped = False |
|
|
|
|
|
|
@@ -886,6 +968,7 @@ def setup_test_homeserver( |
|
|
|
port=POSTGRES_PORT, |
|
|
|
password=POSTGRES_PASSWORD, |
|
|
|
) |
|
|
|
assert isinstance(db_conn, psycopg2.extensions.connection) |
|
|
|
db_conn.autocommit = True |
|
|
|
cur = db_conn.cursor() |
|
|
|
|
|
|
@@ -918,23 +1001,23 @@ def setup_test_homeserver( |
|
|
|
# Need to let the HS build an auth handler and then mess with it |
|
|
|
# because AuthHandler's constructor requires the HS, so we can't make one |
|
|
|
# beforehand and pass it in to the HS's constructor (chicken / egg) |
|
|
|
async def hash(p): |
|
|
|
async def hash(p: str) -> str: |
|
|
|
return hashlib.md5(p.encode("utf8")).hexdigest() |
|
|
|
|
|
|
|
hs.get_auth_handler().hash = hash |
|
|
|
hs.get_auth_handler().hash = hash # type: ignore[assignment] |
|
|
|
|
|
|
|
async def validate_hash(p, h): |
|
|
|
async def validate_hash(p: str, h: str) -> bool: |
|
|
|
return hashlib.md5(p.encode("utf8")).hexdigest() == h |
|
|
|
|
|
|
|
hs.get_auth_handler().validate_hash = validate_hash |
|
|
|
hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment] |
|
|
|
|
|
|
|
# Make the threadpool and database transactions synchronous for testing. |
|
|
|
_make_test_homeserver_synchronous(hs) |
|
|
|
|
|
|
|
# Load any configured modules into the homeserver |
|
|
|
module_api = hs.get_module_api() |
|
|
|
for module, config in hs.config.modules.loaded_modules: |
|
|
|
module(config=config, api=module_api) |
|
|
|
for module, module_config in hs.config.modules.loaded_modules: |
|
|
|
module(config=module_config, api=module_api) |
|
|
|
|
|
|
|
load_legacy_spam_checkers(hs) |
|
|
|
load_legacy_third_party_event_rules(hs) |
|
|
|