Переглянути джерело

Fix remaining mypy issues due to Twisted upgrade. (#9608)

tags/v1.30.0rc1
Patrick Cloke 3 роки тому
committed by GitHub
джерело
коміт
d29b71aa50
Не вдалося знайти GPG ключ що відповідає даному підпису Ідентифікатор GPG ключа: 4AEE18F83AFDEB23
8 змінених файлів з 42 додано та 34 видалено
  1. +1
    -0
      changelog.d/9608.misc
  2. +1
    -1
      stubs/txredisapi.pyi
  3. +10
    -2
      synapse/http/client.py
  4. +2
    -2
      synapse/replication/tcp/handler.py
  5. +9
    -0
      synapse/replication/tcp/protocol.py
  6. +1
    -1
      synapse/replication/tcp/redis.py
  7. +16
    -28
      tests/replication/_base.py
  8. +2
    -0
      tests/server.py

+ 1
- 0
changelog.d/9608.misc Переглянути файл

@@ -0,0 +1 @@
Fix incorrect type hints.

+ 1
- 1
stubs/txredisapi.pyi Переглянути файл

@@ -19,7 +19,7 @@ from typing import Any, List, Optional, Type, Union

from twisted.internet import protocol

class RedisProtocol:
class RedisProtocol(protocol.Protocol):
def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
async def set(


+ 10
- 2
synapse/http/client.py Переглянути файл

@@ -45,7 +45,9 @@ from twisted.internet.interfaces import (
IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
ITCPTransport,
)
from twisted.internet.protocol import connectionDone
from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
@@ -760,6 +762,8 @@ class BodyExceededMaxSize(Exception):
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data."""

transport = None # type: Optional[ITCPTransport]

def __init__(self, deferred: defer.Deferred):
self.deferred = deferred

@@ -771,18 +775,21 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
self.transport.abortConnection()

def dataReceived(self, data: bytes) -> None:
self._maybe_fail()

def connectionLost(self, reason: Failure) -> None:
def connectionLost(self, reason: Failure = connectionDone) -> None:
self._maybe_fail()


class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""

transport = None # type: Optional[ITCPTransport]

def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
@@ -805,9 +812,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
self.transport.abortConnection()

def connectionLost(self, reason: Failure) -> None:
def connectionLost(self, reason: Failure = connectionDone) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return


+ 2
- 2
synapse/replication/tcp/handler.py Переглянути файл

@@ -302,7 +302,7 @@ class ReplicationCommandHandler:
hs, outbound_redis_connection
)
hs.get_reactor().connectTCP(
hs.config.redis.redis_host,
hs.config.redis.redis_host.encode(),
hs.config.redis.redis_port,
self._factory,
)
@@ -311,7 +311,7 @@ class ReplicationCommandHandler:
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
hs.get_reactor().connectTCP(host.encode(), port, self._factory)

def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams."""


+ 9
- 0
synapse/replication/tcp/protocol.py Переглянути файл

@@ -56,6 +56,7 @@ from prometheus_client import Counter
from zope.interface import Interface, implementer

from twisted.internet import task
from twisted.internet.tcp import Connection
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure

@@ -145,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
(if they send a `PING` command)
"""

# The transport is going to be an ITCPTransport, but that doesn't have the
# (un)registerProducer methods, those are only on the implementation.
transport = None # type: Connection

delimiter = b"\n"

# Valid commands we expect to receive
@@ -189,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):

connected_connections.append(self) # Register connection for metrics

assert self.transport is not None
self.transport.registerProducer(self, True) # For the *Producing callbacks

self._send_pending_commands()
@@ -213,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
logger.info(
"[%s] Failed to close connection gracefully, aborting", self.id()
)
assert self.transport is not None
self.transport.abortConnection()
else:
if now - self.last_sent_command >= PING_TIME:
@@ -302,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def close(self):
logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
assert self.transport is not None
self.transport.loseConnection()
self.on_connection_closed()

@@ -399,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def connectionLost(self, reason):
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
assert reason.type is not None
connection_close_counter.labels(reason.type.__name__).inc()
else:
connection_close_counter.labels(reason.__class__.__name__).inc()


+ 1
- 1
synapse/replication/tcp/redis.py Переглянути файл

@@ -365,6 +365,6 @@ def lazyConnection(
factory.continueTrying = reconnect

reactor = hs.get_reactor()
reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None)
reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)

return factory.handler

+ 16
- 28
tests/replication/_base.py Переглянути файл

@@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple

import attr
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
@@ -158,10 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)

request_factory = OneShotRequestFactory()

# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)

# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -183,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()

return request_factory.request
return channel.request

def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
@@ -237,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
"localhost",
b"localhost",
6379,
self.connect_any_redis_attempts,
)
@@ -392,10 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)

request_factory = OneShotRequestFactory()

# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])

# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -421,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
clients = self.reactor.tcpClients
while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost")
self.assertEqual(host, b"localhost")
self.assertEqual(port, 6379)

client_protocol = client_factory.buildProtocol(None)
@@ -453,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
self.received_rdata_rows.append((stream_name, token, r))


@attr.s()
class OneShotRequestFactory:
"""A simple request factory that generates a single `SynapseRequest` and
stores it for future use. Can only be used once.
"""

request = attr.ib(default=None)

def __call__(self, *args, **kwargs):
assert self.request is None

self.request = SynapseRequest(*args, **kwargs)
return self.request


class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers.

@@ -479,7 +458,7 @@ class _PushHTTPChannel(HTTPChannel):
"""

def __init__(
self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
self, reactor: IReactorTime, request_factory: Type[Request], site: Site
):
super().__init__()
self.reactor = reactor
@@ -510,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel):
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False

def requestDone(self, request):
# Store the request for inspection.
self.request = request
super().requestDone(request)


class _PullToPushProducer:
"""A push producer that wraps a pull producer."""
@@ -597,6 +581,8 @@ class FakeRedisPubSubServer:
class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server."""

transport = None # type: Optional[FakeTransport]

def __init__(self, server: FakeRedisPubSubServer):
self._server = server
self._reader = hiredis.Reader()
@@ -641,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol):

def send(self, msg):
"""Send a message back to the client."""
assert self.transport is not None

raw = self.encode(msg).encode("utf-8")

self.transport.write(raw)


+ 2
- 0
tests/server.py Переглянути файл

@@ -16,6 +16,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver,
IReactorTCP,
IResolverSimple,
ITransport,
)
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
@@ -467,6 +468,7 @@ def get_clock():
return clock, hs_clock


@implementer(ITransport)
@attr.s(cmp=False)
class FakeTransport:
"""


Завантаження…
Відмінити
Зберегти