|
|
@@ -22,6 +22,7 @@ from twisted.internet.protocol import Protocol |
|
|
|
from twisted.internet.task import LoopingCall |
|
|
|
from twisted.web.http import HTTPChannel |
|
|
|
from twisted.web.resource import Resource |
|
|
|
from twisted.web.server import Request, Site |
|
|
|
|
|
|
|
from synapse.app.generic_worker import ( |
|
|
|
GenericWorkerReplicationHandler, |
|
|
@@ -32,7 +33,10 @@ from synapse.http.site import SynapseRequest, SynapseSite |
|
|
|
from synapse.replication.http import ReplicationRestResource |
|
|
|
from synapse.replication.tcp.handler import ReplicationCommandHandler |
|
|
|
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol |
|
|
|
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory |
|
|
|
from synapse.replication.tcp.resource import ( |
|
|
|
ReplicationStreamProtocolFactory, |
|
|
|
ServerReplicationStreamProtocol, |
|
|
|
) |
|
|
|
from synapse.server import HomeServer |
|
|
|
from synapse.util import Clock |
|
|
|
|
|
|
@@ -59,7 +63,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): |
|
|
|
# build a replication server |
|
|
|
server_factory = ReplicationStreamProtocolFactory(hs) |
|
|
|
self.streamer = hs.get_replication_streamer() |
|
|
|
self.server = server_factory.buildProtocol(None) |
|
|
|
self.server = server_factory.buildProtocol( |
|
|
|
None |
|
|
|
) # type: ServerReplicationStreamProtocol |
|
|
|
|
|
|
|
# Make a new HomeServer object for the worker |
|
|
|
self.reactor.lookups["testserv"] = "1.2.3.4" |
|
|
@@ -155,9 +161,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): |
|
|
|
request_factory = OneShotRequestFactory() |
|
|
|
|
|
|
|
# Set up the server side protocol |
|
|
|
channel = _PushHTTPChannel(self.reactor) |
|
|
|
channel.requestFactory = request_factory |
|
|
|
channel.site = self.site |
|
|
|
channel = _PushHTTPChannel(self.reactor, request_factory, self.site) |
|
|
|
|
|
|
|
# Connect client to server and vice versa. |
|
|
|
client_to_server_transport = FakeTransport( |
|
|
@@ -188,8 +192,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): |
|
|
|
fetching updates for given stream. |
|
|
|
""" |
|
|
|
|
|
|
|
path = request.path # type: bytes # type: ignore |
|
|
|
self.assertRegex( |
|
|
|
request.path, |
|
|
|
path, |
|
|
|
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$" |
|
|
|
% (stream_name.encode("ascii"),), |
|
|
|
) |
|
|
@@ -390,9 +395,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): |
|
|
|
request_factory = OneShotRequestFactory() |
|
|
|
|
|
|
|
# Set up the server side protocol |
|
|
|
channel = _PushHTTPChannel(self.reactor) |
|
|
|
channel.requestFactory = request_factory |
|
|
|
channel.site = self._hs_to_site[hs] |
|
|
|
channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs]) |
|
|
|
|
|
|
|
# Connect client to server and vice versa. |
|
|
|
client_to_server_transport = FakeTransport( |
|
|
@@ -475,9 +478,13 @@ class _PushHTTPChannel(HTTPChannel): |
|
|
|
makes it very hard to test. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, reactor: IReactorTime): |
|
|
|
def __init__( |
|
|
|
self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.reactor = reactor |
|
|
|
self.requestFactory = request_factory |
|
|
|
self.site = site |
|
|
|
|
|
|
|
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer] |
|
|
|
|
|
|
|