@@ -0,0 +1 @@ | |||
Precompute joined hosts and store in Redis. |
@@ -15,11 +15,21 @@ | |||
"""Contains *incomplete* type hints for txredisapi. | |||
""" | |||
from typing import List, Optional, Type, Union | |||
from typing import Any, List, Optional, Type, Union | |||
class RedisProtocol: | |||
def publish(self, channel: str, message: bytes): ... | |||
async def ping(self) -> None: ... | |||
async def set( | |||
self, | |||
key: str, | |||
value: Any, | |||
expire: Optional[int] = None, | |||
pexpire: Optional[int] = None, | |||
only_if_not_exists: bool = False, | |||
only_if_exists: bool = False, | |||
) -> None: ... | |||
async def get(self, key: str) -> Any: ... | |||
class SubscriberProtocol(RedisProtocol): | |||
def __init__(self, *args, **kwargs): ... | |||
@@ -18,6 +18,7 @@ from synapse.config import ( | |||
password_auth_providers, | |||
push, | |||
ratelimiting, | |||
redis, | |||
registration, | |||
repository, | |||
room_directory, | |||
@@ -79,6 +80,7 @@ class RootConfig: | |||
roomdirectory: room_directory.RoomDirectoryConfig | |||
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig | |||
tracer: tracer.TracerConfig | |||
redis: redis.RedisConfig | |||
config_classes: List = ... | |||
def __init__(self) -> None: ... | |||
@@ -142,6 +142,8 @@ class FederationSender: | |||
self._wake_destinations_needing_catchup, | |||
) | |||
self._external_cache = hs.get_external_cache() | |||
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: | |||
"""Get or create a PerDestinationQueue for the given destination | |||
@@ -197,22 +199,40 @@ class FederationSender: | |||
if not event.internal_metadata.should_proactively_send(): | |||
return | |||
try: | |||
# Get the state from before the event. | |||
# We need to make sure that this is the state from before | |||
# the event and not from after it. | |||
# Otherwise if the last member on a server in a room is | |||
# banned then it won't receive the event because it won't | |||
# be in the room after the ban. | |||
destinations = await self.state.get_hosts_in_room_at_events( | |||
event.room_id, event_ids=event.prev_event_ids() | |||
) | |||
except Exception: | |||
logger.exception( | |||
"Failed to calculate hosts in room for event: %s", | |||
event.event_id, | |||
destinations = None # type: Optional[Set[str]] | |||
if not event.prev_event_ids(): | |||
# If there are no prev event IDs then the state is empty | |||
# and so no remote servers in the room | |||
destinations = set() | |||
else: | |||
# We check the external cache for the destinations, which is | |||
# stored per state group. | |||
sg = await self._external_cache.get( | |||
"event_to_prev_state_group", event.event_id | |||
) | |||
return | |||
if sg: | |||
destinations = await self._external_cache.get( | |||
"get_joined_hosts", str(sg) | |||
) | |||
if destinations is None: | |||
try: | |||
# Get the state from before the event. | |||
# We need to make sure that this is the state from before | |||
# the event and not from after it. | |||
# Otherwise if the last member on a server in a room is | |||
# banned then it won't receive the event because it won't | |||
# be in the room after the ban. | |||
destinations = await self.state.get_hosts_in_room_at_events( | |||
event.room_id, event_ids=event.prev_event_ids() | |||
) | |||
except Exception: | |||
logger.exception( | |||
"Failed to calculate hosts in room for event: %s", | |||
event.event_id, | |||
) | |||
return | |||
destinations = { | |||
d | |||
@@ -2093,6 +2093,11 @@ class FederationHandler(BaseHandler): | |||
if event.type == EventTypes.GuestAccess and not context.rejected: | |||
await self.maybe_kick_guest_users(event) | |||
# If we are going to send this event over federation we precaclculate | |||
# the joined hosts. | |||
if event.internal_metadata.get_send_on_behalf_of(): | |||
await self.event_creation_handler.cache_joined_hosts_for_event(event) | |||
return context | |||
async def _check_for_soft_fail( | |||
@@ -432,6 +432,8 @@ class EventCreationHandler: | |||
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages | |||
self._external_cache = hs.get_external_cache() | |||
async def create_event( | |||
self, | |||
requester: Requester, | |||
@@ -939,6 +941,8 @@ class EventCreationHandler: | |||
await self.action_generator.handle_push_actions_for_event(event, context) | |||
await self.cache_joined_hosts_for_event(event) | |||
try: | |||
# If we're a worker we need to hit out to the master. | |||
writer_instance = self._events_shard_config.get_instance(event.room_id) | |||
@@ -978,6 +982,44 @@ class EventCreationHandler: | |||
await self.store.remove_push_actions_from_staging(event.event_id) | |||
raise | |||
async def cache_joined_hosts_for_event(self, event: EventBase) -> None: | |||
"""Precalculate the joined hosts at the event, when using Redis, so that | |||
external federation senders don't have to recalculate it themselves. | |||
""" | |||
if not self._external_cache.is_enabled(): | |||
return | |||
# We actually store two mappings, event ID -> prev state group, | |||
# state group -> joined hosts, which is much more space efficient | |||
# than event ID -> joined hosts. | |||
# | |||
# Note: We have to cache event ID -> prev state group, as we don't | |||
# store that in the DB. | |||
# | |||
# Note: We always set the state group -> joined hosts cache, even if | |||
# we already set it, so that the expiry time is reset. | |||
state_entry = await self.state.resolve_state_groups_for_events( | |||
event.room_id, event_ids=event.prev_event_ids() | |||
) | |||
if state_entry.state_group: | |||
joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry) | |||
await self._external_cache.set( | |||
"event_to_prev_state_group", | |||
event.event_id, | |||
state_entry.state_group, | |||
expiry_ms=60 * 60 * 1000, | |||
) | |||
await self._external_cache.set( | |||
"get_joined_hosts", | |||
str(state_entry.state_group), | |||
list(joined_hosts), | |||
expiry_ms=60 * 60 * 1000, | |||
) | |||
async def _validate_canonical_alias( | |||
self, directory_handler, room_alias_str: str, expected_room_id: str | |||
) -> None: | |||
@@ -0,0 +1,105 @@ | |||
# -*- coding: utf-8 -*- | |||
# Copyright 2021 The Matrix.org Foundation C.I.C. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING, Any, Optional | |||
from prometheus_client import Counter | |||
from synapse.logging.context import make_deferred_yieldable | |||
from synapse.util import json_decoder, json_encoder | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
set_counter = Counter( | |||
"synapse_external_cache_set", | |||
"Number of times we set a cache", | |||
labelnames=["cache_name"], | |||
) | |||
get_counter = Counter( | |||
"synapse_external_cache_get", | |||
"Number of times we get a cache", | |||
labelnames=["cache_name", "hit"], | |||
) | |||
logger = logging.getLogger(__name__) | |||
class ExternalCache: | |||
"""A cache backed by an external Redis. Does nothing if no Redis is | |||
configured. | |||
""" | |||
def __init__(self, hs: "HomeServer"): | |||
self._redis_connection = hs.get_outbound_redis_connection() | |||
def _get_redis_key(self, cache_name: str, key: str) -> str: | |||
return "cache_v1:%s:%s" % (cache_name, key) | |||
def is_enabled(self) -> bool: | |||
"""Whether the external cache is used or not. | |||
It's safe to use the cache when this returns false, the methods will | |||
just no-op, but the function is useful to avoid doing unnecessary work. | |||
""" | |||
return self._redis_connection is not None | |||
async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None: | |||
"""Add the key/value to the named cache, with the expiry time given. | |||
""" | |||
if self._redis_connection is None: | |||
return | |||
set_counter.labels(cache_name).inc() | |||
# txredisapi requires the value to be string, bytes or numbers, so we | |||
# encode stuff in JSON. | |||
encoded_value = json_encoder.encode(value) | |||
logger.debug("Caching %s %s: %r", cache_name, key, encoded_value) | |||
return await make_deferred_yieldable( | |||
self._redis_connection.set( | |||
self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms, | |||
) | |||
) | |||
async def get(self, cache_name: str, key: str) -> Optional[Any]: | |||
"""Look up a key/value in the named cache. | |||
""" | |||
if self._redis_connection is None: | |||
return None | |||
result = await make_deferred_yieldable( | |||
self._redis_connection.get(self._get_redis_key(cache_name, key)) | |||
) | |||
logger.debug("Got cache result %s %s: %r", cache_name, key, result) | |||
get_counter.labels(cache_name, result is not None).inc() | |||
if not result: | |||
return None | |||
# For some reason the integers get magically converted back to integers | |||
if isinstance(result, int): | |||
return result | |||
return json_decoder.decode(result) |
@@ -286,13 +286,6 @@ class ReplicationCommandHandler: | |||
if hs.config.redis.redis_enabled: | |||
from synapse.replication.tcp.redis import ( | |||
RedisDirectTcpReplicationClientFactory, | |||
lazyConnection, | |||
) | |||
logger.info( | |||
"Connecting to redis (host=%r port=%r)", | |||
hs.config.redis_host, | |||
hs.config.redis_port, | |||
) | |||
# First let's ensure that we have a ReplicationStreamer started. | |||
@@ -303,13 +296,7 @@ class ReplicationCommandHandler: | |||
# connection after SUBSCRIBE is called). | |||
# First create the connection for sending commands. | |||
outbound_redis_connection = lazyConnection( | |||
hs=hs, | |||
host=hs.config.redis_host, | |||
port=hs.config.redis_port, | |||
password=hs.config.redis.redis_password, | |||
reconnect=True, | |||
) | |||
outbound_redis_connection = hs.get_outbound_redis_connection() | |||
# Now create the factory/connection for the subscription stream. | |||
self._factory = RedisDirectTcpReplicationClientFactory( | |||
@@ -103,6 +103,7 @@ from synapse.notifier import Notifier | |||
from synapse.push.action_generator import ActionGenerator | |||
from synapse.push.pusherpool import PusherPool | |||
from synapse.replication.tcp.client import ReplicationDataHandler | |||
from synapse.replication.tcp.external_cache import ExternalCache | |||
from synapse.replication.tcp.handler import ReplicationCommandHandler | |||
from synapse.replication.tcp.resource import ReplicationStreamer | |||
from synapse.replication.tcp.streams import STREAMS_MAP, Stream | |||
@@ -128,6 +129,8 @@ from synapse.util.stringutils import random_string | |||
logger = logging.getLogger(__name__) | |||
if TYPE_CHECKING: | |||
from txredisapi import RedisProtocol | |||
from synapse.handlers.oidc_handler import OidcHandler | |||
from synapse.handlers.saml_handler import SamlHandler | |||
@@ -716,6 +719,33 @@ class HomeServer(metaclass=abc.ABCMeta): | |||
def get_account_data_handler(self) -> AccountDataHandler: | |||
return AccountDataHandler(self) | |||
@cache_in_self | |||
def get_external_cache(self) -> ExternalCache: | |||
return ExternalCache(self) | |||
@cache_in_self | |||
def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]: | |||
if not self.config.redis.redis_enabled: | |||
return None | |||
# We only want to import redis module if we're using it, as we have | |||
# `txredisapi` as an optional dependency. | |||
from synapse.replication.tcp.redis import lazyConnection | |||
logger.info( | |||
"Connecting to redis (host=%r port=%r) for external cache", | |||
self.config.redis_host, | |||
self.config.redis_port, | |||
) | |||
return lazyConnection( | |||
hs=self, | |||
host=self.config.redis_host, | |||
port=self.config.redis_port, | |||
password=self.config.redis.redis_password, | |||
reconnect=True, | |||
) | |||
async def remove_pusher(self, app_id: str, push_key: str, user_id: str): | |||
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) | |||
@@ -310,6 +310,7 @@ class StateHandler: | |||
state_group_before_event = None | |||
state_group_before_event_prev_group = None | |||
deltas_to_state_group_before_event = None | |||
entry = None | |||
else: | |||
# otherwise, we'll need to resolve the state across the prev_events. | |||
@@ -340,9 +341,13 @@ class StateHandler: | |||
current_state_ids=state_ids_before_event, | |||
) | |||
# XXX: can we update the state cache entry for the new state group? or | |||
# could we set a flag on resolve_state_groups_for_events to tell it to | |||
# always make a state group? | |||
# Assign the new state group to the cached state entry. | |||
# | |||
# Note that this can race in that we could generate multiple state | |||
# groups for the same state entry, but that is just inefficient | |||
# rather than dangerous. | |||
if entry and entry.state_group is None: | |||
entry.state_group = state_group_before_event | |||
# | |||
# now if it's not a state event, we're done | |||
@@ -212,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): | |||
# Fake in memory Redis server that servers can connect to. | |||
self._redis_server = FakeRedisPubSubServer() | |||
# We may have an attempt to connect to redis for the external cache already. | |||
self.connect_any_redis_attempts() | |||
store = self.hs.get_datastore() | |||
self.database_pool = store.db_pool | |||
@@ -401,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): | |||
fake one. | |||
""" | |||
clients = self.reactor.tcpClients | |||
self.assertEqual(len(clients), 1) | |||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) | |||
self.assertEqual(host, "localhost") | |||
self.assertEqual(port, 6379) | |||
while clients: | |||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) | |||
self.assertEqual(host, "localhost") | |||
self.assertEqual(port, 6379) | |||
client_protocol = client_factory.buildProtocol(None) | |||
server_protocol = self._redis_server.buildProtocol(None) | |||
client_protocol = client_factory.buildProtocol(None) | |||
server_protocol = self._redis_server.buildProtocol(None) | |||
client_to_server_transport = FakeTransport( | |||
server_protocol, self.reactor, client_protocol | |||
) | |||
client_protocol.makeConnection(client_to_server_transport) | |||
server_to_client_transport = FakeTransport( | |||
client_protocol, self.reactor, server_protocol | |||
) | |||
server_protocol.makeConnection(server_to_client_transport) | |||
client_to_server_transport = FakeTransport( | |||
server_protocol, self.reactor, client_protocol | |||
) | |||
client_protocol.makeConnection(client_to_server_transport) | |||
return client_to_server_transport, server_to_client_transport | |||
server_to_client_transport = FakeTransport( | |||
client_protocol, self.reactor, server_protocol | |||
) | |||
server_protocol.makeConnection(server_to_client_transport) | |||
class TestReplicationDataHandler(GenericWorkerReplicationHandler): | |||
@@ -624,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol): | |||
(channel,) = args | |||
self._server.add_subscriber(self) | |||
self.send(["subscribe", channel, 1]) | |||
# Since we use SET/GET to cache things we can safely no-op them. | |||
elif command == b"SET": | |||
self.send("OK") | |||
elif command == b"GET": | |||
self.send(None) | |||
else: | |||
raise Exception("Unknown command") | |||
@@ -645,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol): | |||
# We assume bytes are just unicode strings. | |||
obj = obj.decode("utf-8") | |||
if obj is None: | |||
return "$-1\r\n" | |||
if isinstance(obj, str): | |||
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) | |||
if isinstance(obj, int): | |||