浏览代码

Precompute joined hosts and store in Redis (#9198)

tags/v1.27.0rc1
Erik Johnston 3 年前
committed by GitHub
父节点
当前提交
dd8da8c5f6
找不到此签名对应的密钥 GPG 密钥 ID: 4AEE18F83AFDEB23
共有 11 个文件被更改,包括 265 次插入49 次删除
  1. +1
    -0
      changelog.d/9198.misc
  2. +11
    -1
      stubs/txredisapi.pyi
  3. +2
    -0
      synapse/config/_base.pyi
  4. +35
    -15
      synapse/federation/sender/__init__.py
  5. +5
    -0
      synapse/handlers/federation.py
  6. +42
    -0
      synapse/handlers/message.py
  7. +105
    -0
      synapse/replication/tcp/external_cache.py
  8. +1
    -14
      synapse/replication/tcp/handler.py
  9. +30
    -0
      synapse/server.py
  10. +8
    -3
      synapse/state/__init__.py
  11. +25
    -16
      tests/replication/_base.py

+ 1
- 0
changelog.d/9198.misc 查看文件

@@ -0,0 +1 @@
Precompute joined hosts and store in Redis.

+ 11
- 1
stubs/txredisapi.pyi 查看文件

@@ -15,11 +15,21 @@

"""Contains *incomplete* type hints for txredisapi.
"""
from typing import List, Optional, Type, Union
from typing import Any, List, Optional, Type, Union

class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
async def set(
self,
key: str,
value: Any,
expire: Optional[int] = None,
pexpire: Optional[int] = None,
only_if_not_exists: bool = False,
only_if_exists: bool = False,
) -> None: ...
async def get(self, key: str) -> Any: ...

class SubscriberProtocol(RedisProtocol):
def __init__(self, *args, **kwargs): ...


+ 2
- 0
synapse/config/_base.pyi 查看文件

@@ -18,6 +18,7 @@ from synapse.config import (
password_auth_providers,
push,
ratelimiting,
redis,
registration,
repository,
room_directory,
@@ -79,6 +80,7 @@ class RootConfig:
roomdirectory: room_directory.RoomDirectoryConfig
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
tracer: tracer.TracerConfig
redis: redis.RedisConfig

config_classes: List = ...
def __init__(self) -> None: ...


+ 35
- 15
synapse/federation/sender/__init__.py 查看文件

@@ -142,6 +142,8 @@ class FederationSender:
self._wake_destinations_needing_catchup,
)

self._external_cache = hs.get_external_cache()

def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination

@@ -197,22 +199,40 @@ class FederationSender:
if not event.internal_metadata.should_proactively_send():
return

try:
# Get the state from before the event.
# We need to make sure that this is the state from before
# the event and not from after it.
# Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't
# be in the room after the ban.
destinations = await self.state.get_hosts_in_room_at_events(
event.room_id, event_ids=event.prev_event_ids()
)
except Exception:
logger.exception(
"Failed to calculate hosts in room for event: %s",
event.event_id,
destinations = None # type: Optional[Set[str]]
if not event.prev_event_ids():
# If there are no prev event IDs then the state is empty
# and so no remote servers in the room
destinations = set()
else:
# We check the external cache for the destinations, which is
# stored per state group.

sg = await self._external_cache.get(
"event_to_prev_state_group", event.event_id
)
return
if sg:
destinations = await self._external_cache.get(
"get_joined_hosts", str(sg)
)

if destinations is None:
try:
# Get the state from before the event.
# We need to make sure that this is the state from before
# the event and not from after it.
# Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't
# be in the room after the ban.
destinations = await self.state.get_hosts_in_room_at_events(
event.room_id, event_ids=event.prev_event_ids()
)
except Exception:
logger.exception(
"Failed to calculate hosts in room for event: %s",
event.event_id,
)
return

destinations = {
d


+ 5
- 0
synapse/handlers/federation.py 查看文件

@@ -2093,6 +2093,11 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)

# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)

return context

async def _check_for_soft_fail(


+ 42
- 0
synapse/handlers/message.py 查看文件

@@ -432,6 +432,8 @@ class EventCreationHandler:

self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages

self._external_cache = hs.get_external_cache()

async def create_event(
self,
requester: Requester,
@@ -939,6 +941,8 @@ class EventCreationHandler:

await self.action_generator.handle_push_actions_for_event(event, context)

await self.cache_joined_hosts_for_event(event)

try:
# If we're a worker we need to hit out to the master.
writer_instance = self._events_shard_config.get_instance(event.room_id)
@@ -978,6 +982,44 @@ class EventCreationHandler:
await self.store.remove_push_actions_from_staging(event.event_id)
raise

async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
"""Precalculate the joined hosts at the event, when using Redis, so that
external federation senders don't have to recalculate it themselves.
"""

if not self._external_cache.is_enabled():
return

# We actually store two mappings, event ID -> prev state group,
# state group -> joined hosts, which is much more space efficient
# than event ID -> joined hosts.
#
# Note: We have to cache event ID -> prev state group, as we don't
# store that in the DB.
#
# Note: We always set the state group -> joined hosts cache, even if
# we already set it, so that the expiry time is reset.

state_entry = await self.state.resolve_state_groups_for_events(
event.room_id, event_ids=event.prev_event_ids()
)

if state_entry.state_group:
joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)

await self._external_cache.set(
"event_to_prev_state_group",
event.event_id,
state_entry.state_group,
expiry_ms=60 * 60 * 1000,
)
await self._external_cache.set(
"get_joined_hosts",
str(state_entry.state_group),
list(joined_hosts),
expiry_ms=60 * 60 * 1000,
)

async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str
) -> None:


+ 105
- 0
synapse/replication/tcp/external_cache.py 查看文件

@@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Any, Optional

from prometheus_client import Counter

from synapse.logging.context import make_deferred_yieldable
from synapse.util import json_decoder, json_encoder

if TYPE_CHECKING:
from synapse.server import HomeServer

set_counter = Counter(
"synapse_external_cache_set",
"Number of times we set a cache",
labelnames=["cache_name"],
)

get_counter = Counter(
"synapse_external_cache_get",
"Number of times we get a cache",
labelnames=["cache_name", "hit"],
)


logger = logging.getLogger(__name__)


class ExternalCache:
"""A cache backed by an external Redis. Does nothing if no Redis is
configured.
"""

def __init__(self, hs: "HomeServer"):
self._redis_connection = hs.get_outbound_redis_connection()

def _get_redis_key(self, cache_name: str, key: str) -> str:
return "cache_v1:%s:%s" % (cache_name, key)

def is_enabled(self) -> bool:
"""Whether the external cache is used or not.

It's safe to use the cache when this returns false, the methods will
just no-op, but the function is useful to avoid doing unnecessary work.
"""
return self._redis_connection is not None

async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
"""Add the key/value to the named cache, with the expiry time given.
"""

if self._redis_connection is None:
return

set_counter.labels(cache_name).inc()

# txredisapi requires the value to be string, bytes or numbers, so we
# encode stuff in JSON.
encoded_value = json_encoder.encode(value)

logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)

return await make_deferred_yieldable(
self._redis_connection.set(
self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
)
)

async def get(self, cache_name: str, key: str) -> Optional[Any]:
"""Look up a key/value in the named cache.
"""

if self._redis_connection is None:
return None

result = await make_deferred_yieldable(
self._redis_connection.get(self._get_redis_key(cache_name, key))
)

logger.debug("Got cache result %s %s: %r", cache_name, key, result)

get_counter.labels(cache_name, result is not None).inc()

if not result:
return None

# For some reason the integers get magically converted back to integers
if isinstance(result, int):
return result

return json_decoder.decode(result)

+ 1
- 14
synapse/replication/tcp/handler.py 查看文件

@@ -286,13 +286,6 @@ class ReplicationCommandHandler:
if hs.config.redis.redis_enabled:
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
lazyConnection,
)

logger.info(
"Connecting to redis (host=%r port=%r)",
hs.config.redis_host,
hs.config.redis_port,
)

# First let's ensure that we have a ReplicationStreamer started.
@@ -303,13 +296,7 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called).

# First create the connection for sending commands.
outbound_redis_connection = lazyConnection(
hs=hs,
host=hs.config.redis_host,
port=hs.config.redis_port,
password=hs.config.redis.redis_password,
reconnect=True,
)
outbound_redis_connection = hs.get_outbound_redis_connection()

# Now create the factory/connection for the subscription stream.
self._factory = RedisDirectTcpReplicationClientFactory(


+ 30
- 0
synapse/server.py 查看文件

@@ -103,6 +103,7 @@ from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.external_cache import ExternalCache
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
@@ -128,6 +129,8 @@ from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from txredisapi import RedisProtocol

from synapse.handlers.oidc_handler import OidcHandler
from synapse.handlers.saml_handler import SamlHandler

@@ -716,6 +719,33 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_account_data_handler(self) -> AccountDataHandler:
return AccountDataHandler(self)

@cache_in_self
def get_external_cache(self) -> ExternalCache:
return ExternalCache(self)

@cache_in_self
def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
if not self.config.redis.redis_enabled:
return None

# We only want to import redis module if we're using it, as we have
# `txredisapi` as an optional dependency.
from synapse.replication.tcp.redis import lazyConnection

logger.info(
"Connecting to redis (host=%r port=%r) for external cache",
self.config.redis_host,
self.config.redis_port,
)

return lazyConnection(
hs=self,
host=self.config.redis_host,
port=self.config.redis_port,
password=self.config.redis.redis_password,
reconnect=True,
)

async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)



+ 8
- 3
synapse/state/__init__.py 查看文件

@@ -310,6 +310,7 @@ class StateHandler:
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
entry = None

else:
# otherwise, we'll need to resolve the state across the prev_events.
@@ -340,9 +341,13 @@ class StateHandler:
current_state_ids=state_ids_before_event,
)

# XXX: can we update the state cache entry for the new state group? or
# could we set a flag on resolve_state_groups_for_events to tell it to
# always make a state group?
# Assign the new state group to the cached state entry.
#
# Note that this can race in that we could generate multiple state
# groups for the same state entry, but that is just inefficient
# rather than dangerous.
if entry and entry.state_group is None:
entry.state_group = state_group_before_event

#
# now if it's not a state event, we're done


+ 25
- 16
tests/replication/_base.py 查看文件

@@ -212,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Fake in memory Redis server that servers can connect to.
self._redis_server = FakeRedisPubSubServer()

# We may have an attempt to connect to redis for the external cache already.
self.connect_any_redis_attempts()

store = self.hs.get_datastore()
self.database_pool = store.db_pool

@@ -401,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
fake one.
"""
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)
while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)

client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None)
client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None)

client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)

server_to_client_transport = FakeTransport(
client_protocol, self.reactor, server_protocol
)
server_protocol.makeConnection(server_to_client_transport)
client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)

return client_to_server_transport, server_to_client_transport
server_to_client_transport = FakeTransport(
client_protocol, self.reactor, server_protocol
)
server_protocol.makeConnection(server_to_client_transport)


class TestReplicationDataHandler(GenericWorkerReplicationHandler):
@@ -624,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol):
(channel,) = args
self._server.add_subscriber(self)
self.send(["subscribe", channel, 1])

# Since we use SET/GET to cache things we can safely no-op them.
elif command == b"SET":
self.send("OK")
elif command == b"GET":
self.send(None)
else:
raise Exception("Unknown command")

@@ -645,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol):
# We assume bytes are just unicode strings.
obj = obj.decode("utf-8")

if obj is None:
return "$-1\r\n"
if isinstance(obj, str):
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
if isinstance(obj, int):


正在加载...
取消
保存