@@ -0,0 +1 @@ | |||
Fix incorrect type hints. |
@@ -19,7 +19,7 @@ from typing import Any, List, Optional, Type, Union | |||
from twisted.internet import protocol | |||
class RedisProtocol: | |||
class RedisProtocol(protocol.Protocol): | |||
def publish(self, channel: str, message: bytes): ... | |||
async def ping(self) -> None: ... | |||
async def set( | |||
@@ -45,7 +45,9 @@ from twisted.internet.interfaces import ( | |||
IHostResolution, | |||
IReactorPluggableNameResolver, | |||
IResolutionReceiver, | |||
ITCPTransport, | |||
) | |||
from twisted.internet.protocol import connectionDone | |||
from twisted.internet.task import Cooperator | |||
from twisted.python.failure import Failure | |||
from twisted.web._newclient import ResponseDone | |||
@@ -760,6 +762,8 @@ class BodyExceededMaxSize(Exception): | |||
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): | |||
"""A protocol which immediately errors upon receiving data.""" | |||
transport = None # type: Optional[ITCPTransport] | |||
def __init__(self, deferred: defer.Deferred): | |||
self.deferred = deferred | |||
@@ -771,18 +775,21 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): | |||
self.deferred.errback(BodyExceededMaxSize()) | |||
# Close the connection (forcefully) since all the data will get | |||
# discarded anyway. | |||
assert self.transport is not None | |||
self.transport.abortConnection() | |||
def dataReceived(self, data: bytes) -> None: | |||
self._maybe_fail() | |||
def connectionLost(self, reason: Failure) -> None: | |||
def connectionLost(self, reason: Failure = connectionDone) -> None: | |||
self._maybe_fail() | |||
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): | |||
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size.""" | |||
transport = None # type: Optional[ITCPTransport] | |||
def __init__( | |||
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int] | |||
): | |||
@@ -805,9 +812,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): | |||
self.deferred.errback(BodyExceededMaxSize()) | |||
# Close the connection (forcefully) since all the data will get | |||
# discarded anyway. | |||
assert self.transport is not None | |||
self.transport.abortConnection() | |||
def connectionLost(self, reason: Failure) -> None: | |||
def connectionLost(self, reason: Failure = connectionDone) -> None: | |||
# If the maximum size was already exceeded, there's nothing to do. | |||
if self.deferred.called: | |||
return | |||
@@ -302,7 +302,7 @@ class ReplicationCommandHandler: | |||
hs, outbound_redis_connection | |||
) | |||
hs.get_reactor().connectTCP( | |||
hs.config.redis.redis_host, | |||
hs.config.redis.redis_host.encode(), | |||
hs.config.redis.redis_port, | |||
self._factory, | |||
) | |||
@@ -311,7 +311,7 @@ class ReplicationCommandHandler: | |||
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) | |||
host = hs.config.worker_replication_host | |||
port = hs.config.worker_replication_port | |||
hs.get_reactor().connectTCP(host, port, self._factory) | |||
hs.get_reactor().connectTCP(host.encode(), port, self._factory) | |||
def get_streams(self) -> Dict[str, Stream]: | |||
"""Get a map from stream name to all streams.""" | |||
@@ -56,6 +56,7 @@ from prometheus_client import Counter | |||
from zope.interface import Interface, implementer | |||
from twisted.internet import task | |||
from twisted.internet.tcp import Connection | |||
from twisted.protocols.basic import LineOnlyReceiver | |||
from twisted.python.failure import Failure | |||
@@ -145,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): | |||
(if they send a `PING` command) | |||
""" | |||
# The transport is going to be an ITCPTransport, but that doesn't have the | |||
# (un)registerProducer methods, those are only on the implementation. | |||
transport = None # type: Connection | |||
delimiter = b"\n" | |||
# Valid commands we expect to receive | |||
@@ -189,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): | |||
connected_connections.append(self) # Register connection for metrics | |||
assert self.transport is not None | |||
self.transport.registerProducer(self, True) # For the *Producing callbacks | |||
self._send_pending_commands() | |||
@@ -213,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): | |||
logger.info( | |||
"[%s] Failed to close connection gracefully, aborting", self.id() | |||
) | |||
assert self.transport is not None | |||
self.transport.abortConnection() | |||
else: | |||
if now - self.last_sent_command >= PING_TIME: | |||
@@ -302,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): | |||
def close(self): | |||
logger.warning("[%s] Closing connection", self.id()) | |||
self.time_we_closed = self.clock.time_msec() | |||
assert self.transport is not None | |||
self.transport.loseConnection() | |||
self.on_connection_closed() | |||
@@ -399,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): | |||
def connectionLost(self, reason): | |||
logger.info("[%s] Replication connection closed: %r", self.id(), reason) | |||
if isinstance(reason, Failure): | |||
assert reason.type is not None | |||
connection_close_counter.labels(reason.type.__name__).inc() | |||
else: | |||
connection_close_counter.labels(reason.__class__.__name__).inc() | |||
@@ -365,6 +365,6 @@ def lazyConnection( | |||
factory.continueTrying = reconnect | |||
reactor = hs.get_reactor() | |||
reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None) | |||
reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None) | |||
return factory.handler |
@@ -13,9 +13,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import Any, Callable, Dict, List, Optional, Tuple | |||
import attr | |||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type | |||
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime | |||
from twisted.internet.protocol import Protocol | |||
@@ -158,10 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): | |||
# Set up client side protocol | |||
client_protocol = client_factory.buildProtocol(None) | |||
request_factory = OneShotRequestFactory() | |||
# Set up the server side protocol | |||
channel = _PushHTTPChannel(self.reactor, request_factory, self.site) | |||
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site) | |||
# Connect client to server and vice versa. | |||
client_to_server_transport = FakeTransport( | |||
@@ -183,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): | |||
server_to_client_transport.loseConnection() | |||
client_to_server_transport.loseConnection() | |||
return request_factory.request | |||
return channel.request | |||
def assert_request_is_get_repl_stream_updates( | |||
self, request: SynapseRequest, stream_name: str | |||
@@ -237,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): | |||
if self.hs.config.redis.redis_enabled: | |||
# Handle attempts to connect to fake redis server. | |||
self.reactor.add_tcp_client_callback( | |||
"localhost", | |||
b"localhost", | |||
6379, | |||
self.connect_any_redis_attempts, | |||
) | |||
@@ -392,10 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): | |||
# Set up client side protocol | |||
client_protocol = client_factory.buildProtocol(None) | |||
request_factory = OneShotRequestFactory() | |||
# Set up the server side protocol | |||
channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs]) | |||
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs]) | |||
# Connect client to server and vice versa. | |||
client_to_server_transport = FakeTransport( | |||
@@ -421,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): | |||
clients = self.reactor.tcpClients | |||
while clients: | |||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) | |||
self.assertEqual(host, "localhost") | |||
self.assertEqual(host, b"localhost") | |||
self.assertEqual(port, 6379) | |||
client_protocol = client_factory.buildProtocol(None) | |||
@@ -453,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler): | |||
self.received_rdata_rows.append((stream_name, token, r)) | |||
@attr.s() | |||
class OneShotRequestFactory: | |||
"""A simple request factory that generates a single `SynapseRequest` and | |||
stores it for future use. Can only be used once. | |||
""" | |||
request = attr.ib(default=None) | |||
def __call__(self, *args, **kwargs): | |||
assert self.request is None | |||
self.request = SynapseRequest(*args, **kwargs) | |||
return self.request | |||
class _PushHTTPChannel(HTTPChannel): | |||
"""A HTTPChannel that wraps pull producers to push producers. | |||
@@ -479,7 +458,7 @@ class _PushHTTPChannel(HTTPChannel): | |||
""" | |||
def __init__( | |||
self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site | |||
self, reactor: IReactorTime, request_factory: Type[Request], site: Site | |||
): | |||
super().__init__() | |||
self.reactor = reactor | |||
@@ -510,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel): | |||
request.responseHeaders.setRawHeaders(b"connection", [b"close"]) | |||
return False | |||
def requestDone(self, request): | |||
# Store the request for inspection. | |||
self.request = request | |||
super().requestDone(request) | |||
class _PullToPushProducer: | |||
"""A push producer that wraps a pull producer.""" | |||
@@ -597,6 +581,8 @@ class FakeRedisPubSubServer: | |||
class FakeRedisPubSubProtocol(Protocol): | |||
"""A connection from a client talking to the fake Redis server.""" | |||
transport = None # type: Optional[FakeTransport] | |||
def __init__(self, server: FakeRedisPubSubServer): | |||
self._server = server | |||
self._reader = hiredis.Reader() | |||
@@ -641,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol): | |||
def send(self, msg): | |||
"""Send a message back to the client.""" | |||
assert self.transport is not None | |||
raw = self.encode(msg).encode("utf-8") | |||
self.transport.write(raw) | |||
@@ -16,6 +16,7 @@ from twisted.internet.interfaces import ( | |||
IReactorPluggableNameResolver, | |||
IReactorTCP, | |||
IResolverSimple, | |||
ITransport, | |||
) | |||
from twisted.python.failure import Failure | |||
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock | |||
@@ -467,6 +468,7 @@ def get_clock(): | |||
return clock, hs_clock | |||
@implementer(ITransport) | |||
@attr.s(cmp=False) | |||
class FakeTransport: | |||
""" | |||