Fixes #16417tags/v1.96.0rc1
@@ -0,0 +1 @@ | |||
Allow multiple workers to write to receipts stream. |
@@ -358,9 +358,9 @@ class WorkerConfig(Config): | |||
"Must only specify one instance to handle `account_data` messages." | |||
) | |||
if len(self.writers.receipts) != 1: | |||
if len(self.writers.receipts) == 0: | |||
raise ConfigError( | |||
"Must only specify one instance to handle `receipts` messages." | |||
"Must specify at least one instance to handle `receipts` messages." | |||
) | |||
if len(self.writers.events) == 0: | |||
@@ -47,6 +47,7 @@ from synapse.types import ( | |||
DeviceListUpdates, | |||
JsonDict, | |||
JsonMapping, | |||
MultiWriterStreamToken, | |||
RoomAlias, | |||
RoomStreamToken, | |||
StreamKeyType, | |||
@@ -217,7 +218,7 @@ class ApplicationServicesHandler: | |||
def notify_interested_services_ephemeral( | |||
self, | |||
stream_key: StreamKeyType, | |||
new_token: Union[int, RoomStreamToken], | |||
new_token: Union[int, RoomStreamToken, MultiWriterStreamToken], | |||
users: Collection[Union[str, UserID]], | |||
) -> None: | |||
""" | |||
@@ -259,19 +260,6 @@ class ApplicationServicesHandler: | |||
): | |||
return | |||
# Assert that new_token is an integer (and not a RoomStreamToken). | |||
# All of the supported streams that this function handles use an | |||
# integer to track progress (rather than a RoomStreamToken - a | |||
# vector clock implementation) as they don't support multiple | |||
# stream writers. | |||
# | |||
# As a result, we simply assert that new_token is an integer. | |||
# If we do end up needing to pass a RoomStreamToken down here | |||
# in the future, using RoomStreamToken.stream (the minimum stream | |||
# position) to convert to an ascending integer value should work. | |||
# Additional context: https://github.com/matrix-org/synapse/pull/11137 | |||
assert isinstance(new_token, int) | |||
# Ignore to-device messages if the feature flag is not enabled | |||
if ( | |||
stream_key == StreamKeyType.TO_DEVICE | |||
@@ -286,6 +274,9 @@ class ApplicationServicesHandler: | |||
): | |||
return | |||
# We know we're not a `RoomStreamToken` at this point. | |||
assert not isinstance(new_token, RoomStreamToken) | |||
# Check whether there are any appservices which have registered to receive | |||
# ephemeral events. | |||
# | |||
@@ -327,7 +318,7 @@ class ApplicationServicesHandler: | |||
self, | |||
services: List[ApplicationService], | |||
stream_key: StreamKeyType, | |||
new_token: int, | |||
new_token: Union[int, MultiWriterStreamToken], | |||
users: Collection[Union[str, UserID]], | |||
) -> None: | |||
logger.debug("Checking interested services for %s", stream_key) | |||
@@ -340,6 +331,7 @@ class ApplicationServicesHandler: | |||
# | |||
# Instead we simply grab the latest typing updates in _handle_typing | |||
# and, if they apply to this application service, send it off. | |||
assert isinstance(new_token, int) | |||
events = await self._handle_typing(service, new_token) | |||
if events: | |||
self.scheduler.enqueue_for_appservice(service, ephemeral=events) | |||
@@ -350,15 +342,23 @@ class ApplicationServicesHandler: | |||
(service.id, stream_key) | |||
): | |||
if stream_key == StreamKeyType.RECEIPT: | |||
assert isinstance(new_token, MultiWriterStreamToken) | |||
# We store appservice tokens as integers, so we ignore | |||
# the `instance_map` components and instead simply | |||
# follow the base stream position. | |||
new_token = MultiWriterStreamToken(stream=new_token.stream) | |||
events = await self._handle_receipts(service, new_token) | |||
self.scheduler.enqueue_for_appservice(service, ephemeral=events) | |||
# Persist the latest handled stream token for this appservice | |||
await self.store.set_appservice_stream_type_pos( | |||
service, "read_receipt", new_token | |||
service, "read_receipt", new_token.stream | |||
) | |||
elif stream_key == StreamKeyType.PRESENCE: | |||
assert isinstance(new_token, int) | |||
events = await self._handle_presence(service, users, new_token) | |||
self.scheduler.enqueue_for_appservice(service, ephemeral=events) | |||
@@ -368,6 +368,7 @@ class ApplicationServicesHandler: | |||
) | |||
elif stream_key == StreamKeyType.TO_DEVICE: | |||
assert isinstance(new_token, int) | |||
# Retrieve a list of to-device message events, as well as the | |||
# maximum stream token of the messages we were able to retrieve. | |||
to_device_messages = await self._get_to_device_messages( | |||
@@ -383,6 +384,7 @@ class ApplicationServicesHandler: | |||
) | |||
elif stream_key == StreamKeyType.DEVICE_LIST: | |||
assert isinstance(new_token, int) | |||
device_list_summary = await self._get_device_list_summary( | |||
service, new_token | |||
) | |||
@@ -432,7 +434,7 @@ class ApplicationServicesHandler: | |||
return typing | |||
async def _handle_receipts( | |||
self, service: ApplicationService, new_token: int | |||
self, service: ApplicationService, new_token: MultiWriterStreamToken | |||
) -> List[JsonMapping]: | |||
""" | |||
Return the latest read receipts that the given application service should receive. | |||
@@ -455,15 +457,17 @@ class ApplicationServicesHandler: | |||
from_key = await self.store.get_type_stream_id_for_appservice( | |||
service, "read_receipt" | |||
) | |||
if new_token is not None and new_token <= from_key: | |||
if new_token is not None and new_token.stream <= from_key: | |||
logger.debug( | |||
"Rejecting token lower than or equal to stored: %s" % (new_token,) | |||
) | |||
return [] | |||
from_token = MultiWriterStreamToken(stream=from_key) | |||
receipts_source = self.event_sources.sources.receipt | |||
receipts, _ = await receipts_source.get_new_events_as( | |||
service=service, from_key=from_key, to_key=new_token | |||
service=service, from_key=from_token, to_key=new_token | |||
) | |||
return receipts | |||
@@ -145,7 +145,7 @@ class InitialSyncHandler: | |||
joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN] | |||
receipt = await self.store.get_linearized_receipts_for_rooms( | |||
joined_rooms, | |||
to_key=int(now_token.receipt_key), | |||
to_key=now_token.receipt_key, | |||
) | |||
receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id) | |||
@@ -20,6 +20,7 @@ from synapse.streams import EventSource | |||
from synapse.types import ( | |||
JsonDict, | |||
JsonMapping, | |||
MultiWriterStreamToken, | |||
ReadReceipt, | |||
StreamKeyType, | |||
UserID, | |||
@@ -200,7 +201,7 @@ class ReceiptsHandler: | |||
await self.federation_sender.send_read_receipt(receipt) | |||
class ReceiptEventSource(EventSource[int, JsonMapping]): | |||
class ReceiptEventSource(EventSource[MultiWriterStreamToken, JsonMapping]): | |||
def __init__(self, hs: "HomeServer"): | |||
self.store = hs.get_datastores().main | |||
self.config = hs.config | |||
@@ -273,13 +274,12 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): | |||
async def get_new_events( | |||
self, | |||
user: UserID, | |||
from_key: int, | |||
from_key: MultiWriterStreamToken, | |||
limit: int, | |||
room_ids: Iterable[str], | |||
is_guest: bool, | |||
explicit_room_id: Optional[str] = None, | |||
) -> Tuple[List[JsonMapping], int]: | |||
from_key = int(from_key) | |||
) -> Tuple[List[JsonMapping], MultiWriterStreamToken]: | |||
to_key = self.get_current_key() | |||
if from_key == to_key: | |||
@@ -296,8 +296,11 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): | |||
return events, to_key | |||
async def get_new_events_as( | |||
self, from_key: int, to_key: int, service: ApplicationService | |||
) -> Tuple[List[JsonMapping], int]: | |||
self, | |||
from_key: MultiWriterStreamToken, | |||
to_key: MultiWriterStreamToken, | |||
service: ApplicationService, | |||
) -> Tuple[List[JsonMapping], MultiWriterStreamToken]: | |||
"""Returns a set of new read receipt events that an appservice | |||
may be interested in. | |||
@@ -312,8 +315,6 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): | |||
appservice may be interested in. | |||
* The current read receipt stream token. | |||
""" | |||
from_key = int(from_key) | |||
if from_key == to_key: | |||
return [], to_key | |||
@@ -333,5 +334,5 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): | |||
return events, to_key | |||
def get_current_key(self) -> int: | |||
def get_current_key(self) -> MultiWriterStreamToken: | |||
return self.store.get_max_receipt_stream_id() |
@@ -57,6 +57,7 @@ from synapse.types import ( | |||
DeviceListUpdates, | |||
JsonDict, | |||
JsonMapping, | |||
MultiWriterStreamToken, | |||
MutableStateMap, | |||
Requester, | |||
RoomStreamToken, | |||
@@ -477,7 +478,11 @@ class SyncHandler: | |||
event_copy = {k: v for (k, v) in event.items() if k != "room_id"} | |||
ephemeral_by_room.setdefault(room_id, []).append(event_copy) | |||
receipt_key = since_token.receipt_key if since_token else 0 | |||
receipt_key = ( | |||
since_token.receipt_key | |||
if since_token | |||
else MultiWriterStreamToken(stream=0) | |||
) | |||
receipt_source = self.event_sources.sources.receipt | |||
receipts, receipt_key = await receipt_source.get_new_events( | |||
@@ -21,11 +21,13 @@ from typing import ( | |||
Dict, | |||
Iterable, | |||
List, | |||
Literal, | |||
Optional, | |||
Set, | |||
Tuple, | |||
TypeVar, | |||
Union, | |||
overload, | |||
) | |||
import attr | |||
@@ -44,6 +46,7 @@ from synapse.metrics import LaterGauge | |||
from synapse.streams.config import PaginationConfig | |||
from synapse.types import ( | |||
JsonDict, | |||
MultiWriterStreamToken, | |||
PersistedEventPosition, | |||
RoomStreamToken, | |||
StrCollection, | |||
@@ -127,7 +130,7 @@ class _NotifierUserStream: | |||
def notify( | |||
self, | |||
stream_key: StreamKeyType, | |||
stream_id: Union[int, RoomStreamToken], | |||
stream_id: Union[int, RoomStreamToken, MultiWriterStreamToken], | |||
time_now_ms: int, | |||
) -> None: | |||
"""Notify any listeners for this user of a new event from an | |||
@@ -452,10 +455,48 @@ class Notifier: | |||
except Exception: | |||
logger.exception("Error pusher pool of event") | |||
@overload | |||
def on_new_event( | |||
self, | |||
stream_key: Literal[StreamKeyType.ROOM], | |||
new_token: RoomStreamToken, | |||
users: Optional[Collection[Union[str, UserID]]] = None, | |||
rooms: Optional[StrCollection] = None, | |||
) -> None: | |||
... | |||
@overload | |||
def on_new_event( | |||
self, | |||
stream_key: Literal[StreamKeyType.RECEIPT], | |||
new_token: MultiWriterStreamToken, | |||
users: Optional[Collection[Union[str, UserID]]] = None, | |||
rooms: Optional[StrCollection] = None, | |||
) -> None: | |||
... | |||
@overload | |||
def on_new_event( | |||
self, | |||
stream_key: Literal[ | |||
StreamKeyType.ACCOUNT_DATA, | |||
StreamKeyType.DEVICE_LIST, | |||
StreamKeyType.PRESENCE, | |||
StreamKeyType.PUSH_RULES, | |||
StreamKeyType.TO_DEVICE, | |||
StreamKeyType.TYPING, | |||
StreamKeyType.UN_PARTIAL_STATED_ROOMS, | |||
], | |||
new_token: int, | |||
users: Optional[Collection[Union[str, UserID]]] = None, | |||
rooms: Optional[StrCollection] = None, | |||
) -> None: | |||
... | |||
def on_new_event( | |||
self, | |||
stream_key: StreamKeyType, | |||
new_token: Union[int, RoomStreamToken], | |||
new_token: Union[int, RoomStreamToken, MultiWriterStreamToken], | |||
users: Optional[Collection[Union[str, UserID]]] = None, | |||
rooms: Optional[StrCollection] = None, | |||
) -> None: | |||
@@ -126,8 +126,9 @@ class ReplicationDataHandler: | |||
StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows] | |||
) | |||
elif stream_name == ReceiptsStream.NAME: | |||
new_token = self.store.get_max_receipt_stream_id() | |||
self.notifier.on_new_event( | |||
StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows] | |||
StreamKeyType.RECEIPT, new_token, rooms=[row.room_id for row in rows] | |||
) | |||
await self._pusher_pool.on_new_receipts({row.user_id for row in rows}) | |||
elif stream_name == ToDeviceStream.NAME: | |||
@@ -28,6 +28,8 @@ from typing import ( | |||
cast, | |||
) | |||
from immutabledict import immutabledict | |||
from synapse.api.constants import EduTypes | |||
from synapse.replication.tcp.streams import ReceiptsStream | |||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause | |||
@@ -43,7 +45,12 @@ from synapse.storage.util.id_generators import ( | |||
MultiWriterIdGenerator, | |||
StreamIdGenerator, | |||
) | |||
from synapse.types import JsonDict, JsonMapping | |||
from synapse.types import ( | |||
JsonDict, | |||
JsonMapping, | |||
MultiWriterStreamToken, | |||
PersistedPosition, | |||
) | |||
from synapse.util import json_encoder | |||
from synapse.util.caches.descriptors import cached, cachedList | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
@@ -105,7 +112,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
"receipts_linearized", | |||
entity_column="room_id", | |||
stream_column="stream_id", | |||
max_value=max_receipts_stream_id, | |||
max_value=max_receipts_stream_id.stream, | |||
limit=10000, | |||
) | |||
self._receipts_stream_cache = StreamChangeCache( | |||
@@ -114,9 +121,31 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
prefilled_cache=receipts_stream_prefill, | |||
) | |||
def get_max_receipt_stream_id(self) -> int: | |||
def get_max_receipt_stream_id(self) -> MultiWriterStreamToken: | |||
"""Get the current max stream ID for receipts stream""" | |||
return self._receipts_id_gen.get_current_token() | |||
min_pos = self._receipts_id_gen.get_current_token() | |||
positions = {} | |||
if isinstance(self._receipts_id_gen, MultiWriterIdGenerator): | |||
# The `min_pos` is the minimum position that we know all instances | |||
# have finished persisting to, so we only care about instances whose | |||
# positions are ahead of that. (Instance positions can be behind the | |||
# min position as there are times we can work out that the minimum | |||
# position is ahead of the naive minimum across all current | |||
# positions. See MultiWriterIdGenerator for details) | |||
positions = { | |||
i: p | |||
for i, p in self._receipts_id_gen.get_positions().items() | |||
if p > min_pos | |||
} | |||
return MultiWriterStreamToken( | |||
stream=min_pos, instance_map=immutabledict(positions) | |||
) | |||
def get_receipt_stream_id_for_instance(self, instance_name: str) -> int: | |||
return self._receipts_id_gen.get_current_token_for_writer(instance_name) | |||
def get_last_unthreaded_receipt_for_user_txn( | |||
self, | |||
@@ -257,7 +286,10 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
} | |||
async def get_linearized_receipts_for_rooms( | |||
self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None | |||
self, | |||
room_ids: Iterable[str], | |||
to_key: MultiWriterStreamToken, | |||
from_key: Optional[MultiWriterStreamToken] = None, | |||
) -> List[JsonMapping]: | |||
"""Get receipts for multiple rooms for sending to clients. | |||
@@ -276,7 +308,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
# Only ask the database about rooms where there have been new | |||
# receipts added since `from_key` | |||
room_ids = self._receipts_stream_cache.get_entities_changed( | |||
room_ids, from_key | |||
room_ids, from_key.stream | |||
) | |||
results = await self._get_linearized_receipts_for_rooms( | |||
@@ -286,7 +318,10 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
return [ev for res in results.values() for ev in res] | |||
async def get_linearized_receipts_for_room( | |||
self, room_id: str, to_key: int, from_key: Optional[int] = None | |||
self, | |||
room_id: str, | |||
to_key: MultiWriterStreamToken, | |||
from_key: Optional[MultiWriterStreamToken] = None, | |||
) -> Sequence[JsonMapping]: | |||
"""Get receipts for a single room for sending to clients. | |||
@@ -302,36 +337,49 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
if from_key is not None: | |||
# Check the cache first to see if any new receipts have been added | |||
# since`from_key`. If not we can no-op. | |||
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): | |||
if not self._receipts_stream_cache.has_entity_changed( | |||
room_id, from_key.stream | |||
): | |||
return [] | |||
return await self._get_linearized_receipts_for_room(room_id, to_key, from_key) | |||
@cached(tree=True) | |||
async def _get_linearized_receipts_for_room( | |||
self, room_id: str, to_key: int, from_key: Optional[int] = None | |||
self, | |||
room_id: str, | |||
to_key: MultiWriterStreamToken, | |||
from_key: Optional[MultiWriterStreamToken] = None, | |||
) -> Sequence[JsonMapping]: | |||
"""See get_linearized_receipts_for_room""" | |||
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]: | |||
if from_key: | |||
sql = ( | |||
"SELECT receipt_type, user_id, event_id, data" | |||
" FROM receipts_linearized WHERE" | |||
" room_id = ? AND stream_id > ? AND stream_id <= ?" | |||
) | |||
sql = """ | |||
SELECT stream_id, instance_name, receipt_type, user_id, event_id, data | |||
FROM receipts_linearized | |||
WHERE room_id = ? AND stream_id > ? AND stream_id <= ? | |||
""" | |||
txn.execute(sql, (room_id, from_key, to_key)) | |||
else: | |||
sql = ( | |||
"SELECT receipt_type, user_id, event_id, data" | |||
" FROM receipts_linearized WHERE" | |||
" room_id = ? AND stream_id <= ?" | |||
txn.execute( | |||
sql, (room_id, from_key.stream, to_key.get_max_stream_pos()) | |||
) | |||
else: | |||
sql = """ | |||
SELECT stream_id, instance_name, receipt_type, user_id, event_id, data | |||
FROM receipts_linearized WHERE | |||
room_id = ? AND stream_id <= ? | |||
""" | |||
txn.execute(sql, (room_id, to_key)) | |||
txn.execute(sql, (room_id, to_key.get_max_stream_pos())) | |||
return cast(List[Tuple[str, str, str, str]], txn.fetchall()) | |||
return [ | |||
(receipt_type, user_id, event_id, data) | |||
for stream_id, instance_name, receipt_type, user_id, event_id, data in txn | |||
if MultiWriterStreamToken.is_stream_position_in_range( | |||
from_key, to_key, instance_name, stream_id | |||
) | |||
] | |||
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f) | |||
@@ -352,7 +400,10 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
num_args=3, | |||
) | |||
async def _get_linearized_receipts_for_rooms( | |||
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None | |||
self, | |||
room_ids: Collection[str], | |||
to_key: MultiWriterStreamToken, | |||
from_key: Optional[MultiWriterStreamToken] = None, | |||
) -> Mapping[str, Sequence[JsonMapping]]: | |||
if not room_ids: | |||
return {} | |||
@@ -362,7 +413,8 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
) -> List[Tuple[str, str, str, str, Optional[str], str]]: | |||
if from_key: | |||
sql = """ | |||
SELECT room_id, receipt_type, user_id, event_id, thread_id, data | |||
SELECT stream_id, instance_name, room_id, receipt_type, | |||
user_id, event_id, thread_id, data | |||
FROM receipts_linearized WHERE | |||
stream_id > ? AND stream_id <= ? AND | |||
""" | |||
@@ -370,10 +422,14 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
self.database_engine, "room_id", room_ids | |||
) | |||
txn.execute(sql + clause, [from_key, to_key] + list(args)) | |||
txn.execute( | |||
sql + clause, | |||
[from_key.stream, to_key.get_max_stream_pos()] + list(args), | |||
) | |||
else: | |||
sql = """ | |||
SELECT room_id, receipt_type, user_id, event_id, thread_id, data | |||
SELECT stream_id, instance_name, room_id, receipt_type, | |||
user_id, event_id, thread_id, data | |||
FROM receipts_linearized WHERE | |||
stream_id <= ? AND | |||
""" | |||
@@ -382,11 +438,15 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
self.database_engine, "room_id", room_ids | |||
) | |||
txn.execute(sql + clause, [to_key] + list(args)) | |||
txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args)) | |||
return cast( | |||
List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall() | |||
) | |||
return [ | |||
(room_id, receipt_type, user_id, event_id, thread_id, data) | |||
for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn | |||
if MultiWriterStreamToken.is_stream_position_in_range( | |||
from_key, to_key, instance_name, stream_id | |||
) | |||
] | |||
txn_results = await self.db_pool.runInteraction( | |||
"_get_linearized_receipts_for_rooms", f | |||
@@ -420,7 +480,9 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
num_args=2, | |||
) | |||
async def get_linearized_receipts_for_all_rooms( | |||
self, to_key: int, from_key: Optional[int] = None | |||
self, | |||
to_key: MultiWriterStreamToken, | |||
from_key: Optional[MultiWriterStreamToken] = None, | |||
) -> Mapping[str, JsonMapping]: | |||
"""Get receipts for all rooms between two stream_ids, up | |||
to a limit of the latest 100 read receipts. | |||
@@ -437,25 +499,31 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]: | |||
if from_key: | |||
sql = """ | |||
SELECT room_id, receipt_type, user_id, event_id, data | |||
SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data | |||
FROM receipts_linearized WHERE | |||
stream_id > ? AND stream_id <= ? | |||
ORDER BY stream_id DESC | |||
LIMIT 100 | |||
""" | |||
txn.execute(sql, [from_key, to_key]) | |||
txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()]) | |||
else: | |||
sql = """ | |||
SELECT room_id, receipt_type, user_id, event_id, data | |||
SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data | |||
FROM receipts_linearized WHERE | |||
stream_id <= ? | |||
ORDER BY stream_id DESC | |||
LIMIT 100 | |||
""" | |||
txn.execute(sql, [to_key]) | |||
txn.execute(sql, [to_key.get_max_stream_pos()]) | |||
return cast(List[Tuple[str, str, str, str, str]], txn.fetchall()) | |||
return [ | |||
(room_id, receipt_type, user_id, event_id, data) | |||
for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn | |||
if MultiWriterStreamToken.is_stream_position_in_range( | |||
from_key, to_key, instance_name, stream_id | |||
) | |||
] | |||
txn_results = await self.db_pool.runInteraction( | |||
"get_linearized_receipts_for_all_rooms", f | |||
@@ -545,10 +613,11 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data | |||
FROM receipts_linearized | |||
WHERE ? < stream_id AND stream_id <= ? | |||
AND instance_name = ? | |||
ORDER BY stream_id ASC | |||
LIMIT ? | |||
""" | |||
txn.execute(sql, (last_id, current_id, limit)) | |||
txn.execute(sql, (last_id, current_id, instance_name, limit)) | |||
updates = cast( | |||
List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], | |||
@@ -695,6 +764,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
keyvalues=keyvalues, | |||
values={ | |||
"stream_id": stream_id, | |||
"instance_name": self._instance_name, | |||
"event_id": event_id, | |||
"event_stream_ordering": stream_ordering, | |||
"data": json_encoder.encode(data), | |||
@@ -750,7 +820,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
event_ids: List[str], | |||
thread_id: Optional[str], | |||
data: dict, | |||
) -> Optional[int]: | |||
) -> Optional[PersistedPosition]: | |||
"""Insert a receipt, either from local client or remote server. | |||
Automatically does conversion between linearized and graph | |||
@@ -812,7 +882,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
data, | |||
) | |||
return stream_id | |||
return PersistedPosition(self._instance_name, stream_id) | |||
async def _insert_graph_receipt( | |||
self, | |||
@@ -47,7 +47,7 @@ from synapse.storage.databases.main.stream import ( | |||
generate_pagination_where_clause, | |||
) | |||
from synapse.storage.engines import PostgresEngine | |||
from synapse.types import JsonDict, StreamKeyType, StreamToken | |||
from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken | |||
from synapse.util.caches.descriptors import cached, cachedList | |||
if TYPE_CHECKING: | |||
@@ -314,7 +314,7 @@ class RelationsWorkerStore(SQLBaseStore): | |||
room_key=next_key, | |||
presence_key=0, | |||
typing_key=0, | |||
receipt_key=0, | |||
receipt_key=MultiWriterStreamToken(stream=0), | |||
account_data_key=0, | |||
push_rules_key=0, | |||
to_device_key=0, | |||
@@ -0,0 +1,17 @@ | |||
/* Copyright 2023 The Matrix.org Foundation C.I.C | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
-- This already exists on Postgres. | |||
ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT; |
@@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource | |||
from synapse.handlers.typing import TypingNotificationEventSource | |||
from synapse.logging.opentracing import trace | |||
from synapse.streams import EventSource | |||
from synapse.types import StreamKeyType, StreamToken | |||
from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
@@ -111,7 +111,7 @@ class EventSources: | |||
room_key=await self.sources.room.get_current_key_for_room(room_id), | |||
presence_key=0, | |||
typing_key=0, | |||
receipt_key=0, | |||
receipt_key=MultiWriterStreamToken(stream=0), | |||
account_data_key=0, | |||
push_rules_key=0, | |||
to_device_key=0, | |||
@@ -695,6 +695,90 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): | |||
return "s%d" % (self.stream,) | |||
@attr.s(frozen=True, slots=True, order=False) | |||
class MultiWriterStreamToken(AbstractMultiWriterStreamToken): | |||
"""A basic stream token class for streams that supports multiple writers.""" | |||
@classmethod | |||
async def parse(cls, store: "DataStore", string: str) -> "MultiWriterStreamToken": | |||
try: | |||
if string[0].isdigit(): | |||
return cls(stream=int(string)) | |||
if string[0] == "m": | |||
parts = string[1:].split("~") | |||
stream = int(parts[0]) | |||
instance_map = {} | |||
for part in parts[1:]: | |||
key, value = part.split(".") | |||
instance_id = int(key) | |||
pos = int(value) | |||
instance_name = await store.get_name_from_instance_id(instance_id) | |||
instance_map[instance_name] = pos | |||
return cls( | |||
stream=stream, | |||
instance_map=immutabledict(instance_map), | |||
) | |||
except CancelledError: | |||
raise | |||
except Exception: | |||
pass | |||
raise SynapseError(400, "Invalid stream token %r" % (string,)) | |||
async def to_string(self, store: "DataStore") -> str: | |||
if self.instance_map: | |||
entries = [] | |||
for name, pos in self.instance_map.items(): | |||
if pos <= self.stream: | |||
# Ignore instances who are below the minimum stream position | |||
# (we might know they've advanced without seeing a recent | |||
# write from them). | |||
continue | |||
instance_id = await store.get_id_for_instance(name) | |||
entries.append(f"{instance_id}.{pos}") | |||
encoded_map = "~".join(entries) | |||
return f"m{self.stream}~{encoded_map}" | |||
else: | |||
return str(self.stream) | |||
@staticmethod | |||
def is_stream_position_in_range( | |||
low: Optional["AbstractMultiWriterStreamToken"], | |||
high: Optional["AbstractMultiWriterStreamToken"], | |||
instance_name: Optional[str], | |||
pos: int, | |||
) -> bool: | |||
"""Checks if a given persisted position is between the two given tokens. | |||
If `instance_name` is None then the row was persisted before multi | |||
writer support. | |||
""" | |||
if low: | |||
if instance_name: | |||
low_stream = low.instance_map.get(instance_name, low.stream) | |||
else: | |||
low_stream = low.stream | |||
if pos <= low_stream: | |||
return False | |||
if high: | |||
if instance_name: | |||
high_stream = high.instance_map.get(instance_name, high.stream) | |||
else: | |||
high_stream = high.stream | |||
if high_stream < pos: | |||
return False | |||
return True | |||
class StreamKeyType(Enum): | |||
"""Known stream types. | |||
@@ -776,7 +860,9 @@ class StreamToken: | |||
) | |||
presence_key: int | |||
typing_key: int | |||
receipt_key: int | |||
receipt_key: MultiWriterStreamToken = attr.ib( | |||
validator=attr.validators.instance_of(MultiWriterStreamToken) | |||
) | |||
account_data_key: int | |||
push_rules_key: int | |||
to_device_key: int | |||
@@ -799,8 +885,31 @@ class StreamToken: | |||
while len(keys) < len(attr.fields(cls)): | |||
# i.e. old token from before receipt_key | |||
keys.append("0") | |||
( | |||
room_key, | |||
presence_key, | |||
typing_key, | |||
receipt_key, | |||
account_data_key, | |||
push_rules_key, | |||
to_device_key, | |||
device_list_key, | |||
groups_key, | |||
un_partial_stated_rooms_key, | |||
) = keys | |||
return cls( | |||
await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) | |||
room_key=await RoomStreamToken.parse(store, room_key), | |||
presence_key=int(presence_key), | |||
typing_key=int(typing_key), | |||
receipt_key=await MultiWriterStreamToken.parse(store, receipt_key), | |||
account_data_key=int(account_data_key), | |||
push_rules_key=int(push_rules_key), | |||
to_device_key=int(to_device_key), | |||
device_list_key=int(device_list_key), | |||
groups_key=int(groups_key), | |||
un_partial_stated_rooms_key=int(un_partial_stated_rooms_key), | |||
) | |||
except CancelledError: | |||
raise | |||
@@ -813,7 +922,7 @@ class StreamToken: | |||
await self.room_key.to_string(store), | |||
str(self.presence_key), | |||
str(self.typing_key), | |||
str(self.receipt_key), | |||
await self.receipt_key.to_string(store), | |||
str(self.account_data_key), | |||
str(self.push_rules_key), | |||
str(self.to_device_key), | |||
@@ -841,6 +950,11 @@ class StreamToken: | |||
StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) | |||
) | |||
return new_token | |||
elif key == StreamKeyType.RECEIPT: | |||
new_token = self.copy_and_replace( | |||
StreamKeyType.RECEIPT, self.receipt_key.copy_and_advance(new_value) | |||
) | |||
return new_token | |||
new_token = self.copy_and_replace(key, new_value) | |||
new_id = new_token.get_field(key) | |||
@@ -858,6 +972,10 @@ class StreamToken: | |||
def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken: | |||
... | |||
@overload | |||
def get_field(self, key: Literal[StreamKeyType.RECEIPT]) -> MultiWriterStreamToken: | |||
... | |||
@overload | |||
def get_field( | |||
self, | |||
@@ -866,7 +984,6 @@ class StreamToken: | |||
StreamKeyType.DEVICE_LIST, | |||
StreamKeyType.PRESENCE, | |||
StreamKeyType.PUSH_RULES, | |||
StreamKeyType.RECEIPT, | |||
StreamKeyType.TO_DEVICE, | |||
StreamKeyType.TYPING, | |||
StreamKeyType.UN_PARTIAL_STATED_ROOMS, | |||
@@ -875,15 +992,21 @@ class StreamToken: | |||
... | |||
@overload | |||
def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]: | |||
def get_field( | |||
self, key: StreamKeyType | |||
) -> Union[int, RoomStreamToken, MultiWriterStreamToken]: | |||
... | |||
def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]: | |||
def get_field( | |||
self, key: StreamKeyType | |||
) -> Union[int, RoomStreamToken, MultiWriterStreamToken]: | |||
"""Returns the stream ID for the given key.""" | |||
return getattr(self, key.value) | |||
StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0) | |||
StreamToken.START = StreamToken( | |||
RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0 | |||
) | |||
@attr.s(slots=True, frozen=True, auto_attribs=True) | |||
@@ -31,7 +31,12 @@ from synapse.appservice import ( | |||
from synapse.handlers.appservice import ApplicationServicesHandler | |||
from synapse.rest.client import login, receipts, register, room, sendtodevice | |||
from synapse.server import HomeServer | |||
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType | |||
from synapse.types import ( | |||
JsonDict, | |||
MultiWriterStreamToken, | |||
RoomStreamToken, | |||
StreamKeyType, | |||
) | |||
from synapse.util import Clock | |||
from synapse.util.stringutils import random_string | |||
@@ -305,7 +310,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
) | |||
self.handler.notify_interested_services_ephemeral( | |||
StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"] | |||
StreamKeyType.RECEIPT, | |||
MultiWriterStreamToken(stream=580), | |||
["@fakerecipient:example.com"], | |||
) | |||
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( | |||
interested_service, ephemeral=[event] | |||
@@ -333,7 +340,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
) | |||
self.handler.notify_interested_services_ephemeral( | |||
StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"] | |||
StreamKeyType.RECEIPT, | |||
MultiWriterStreamToken(stream=580), | |||
["@fakerecipient:example.com"], | |||
) | |||
# This method will be called, but with an empty list of events | |||
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( | |||
@@ -636,7 +645,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): | |||
self.hs.get_application_service_handler()._notify_interested_services_ephemeral( | |||
services=[interested_appservice], | |||
stream_key=StreamKeyType.RECEIPT, | |||
new_token=stream_token, | |||
new_token=MultiWriterStreamToken(stream=stream_token), | |||
users=[self.exclusive_as_user], | |||
) | |||
) | |||
@@ -0,0 +1,243 @@ | |||
# Copyright 2020 The Matrix.org Foundation C.I.C. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.constants import ReceiptTypes | |||
from synapse.rest import admin | |||
from synapse.rest.client import login, receipts, room, sync | |||
from synapse.server import HomeServer | |||
from synapse.storage.util.id_generators import MultiWriterIdGenerator | |||
from synapse.types import StreamToken | |||
from synapse.util import Clock | |||
from tests.replication._base import BaseMultiWorkerStreamTestCase | |||
from tests.server import make_request | |||
logger = logging.getLogger(__name__) | |||
class ReceiptsShardTestCase(BaseMultiWorkerStreamTestCase): | |||
"""Checks receipts sharding works""" | |||
servlets = [ | |||
admin.register_servlets_for_client_rest_resource, | |||
room.register_servlets, | |||
login.register_servlets, | |||
sync.register_servlets, | |||
receipts.register_servlets, | |||
] | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
# Register a user who sends a message that we'll get notified about | |||
self.other_user_id = self.register_user("otheruser", "pass") | |||
self.other_access_token = self.login("otheruser", "pass") | |||
self.room_creator = self.hs.get_room_creation_handler() | |||
self.store = hs.get_datastores().main | |||
def default_config(self) -> dict: | |||
conf = super().default_config() | |||
conf["stream_writers"] = {"receipts": ["worker1", "worker2"]} | |||
conf["instance_map"] = { | |||
"main": {"host": "testserv", "port": 8765}, | |||
"worker1": {"host": "testserv", "port": 1001}, | |||
"worker2": {"host": "testserv", "port": 1002}, | |||
} | |||
return conf | |||
def test_basic(self) -> None: | |||
"""Simple test to ensure that receipts can be sent on multiple | |||
workers. | |||
""" | |||
worker1 = self.make_worker_hs( | |||
"synapse.app.generic_worker", | |||
{"worker_name": "worker1"}, | |||
) | |||
worker1_site = self._hs_to_site[worker1] | |||
worker2 = self.make_worker_hs( | |||
"synapse.app.generic_worker", | |||
{"worker_name": "worker2"}, | |||
) | |||
worker2_site = self._hs_to_site[worker2] | |||
user_id = self.register_user("user", "pass") | |||
access_token = self.login("user", "pass") | |||
# Create a room | |||
room_id = self.helper.create_room_as(user_id, tok=access_token) | |||
# The other user joins | |||
self.helper.join( | |||
room=room_id, user=self.other_user_id, tok=self.other_access_token | |||
) | |||
# First user sends a message, the other users sends a receipt. | |||
response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) | |||
event_id = response["event_id"] | |||
channel = make_request( | |||
reactor=self.reactor, | |||
site=worker1_site, | |||
method="POST", | |||
path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}", | |||
access_token=access_token, | |||
content={}, | |||
) | |||
self.assertEqual(200, channel.code) | |||
# Now we do it again using the second worker | |||
response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) | |||
event_id = response["event_id"] | |||
channel = make_request( | |||
reactor=self.reactor, | |||
site=worker2_site, | |||
method="POST", | |||
path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}", | |||
access_token=access_token, | |||
content={}, | |||
) | |||
self.assertEqual(200, channel.code) | |||
def test_vector_clock_token(self) -> None: | |||
"""Tests that using a stream token with a vector clock component works | |||
correctly with basic /sync usage. | |||
""" | |||
worker_hs1 = self.make_worker_hs( | |||
"synapse.app.generic_worker", | |||
{"worker_name": "worker1"}, | |||
) | |||
worker1_site = self._hs_to_site[worker_hs1] | |||
worker_hs2 = self.make_worker_hs( | |||
"synapse.app.generic_worker", | |||
{"worker_name": "worker2"}, | |||
) | |||
worker2_site = self._hs_to_site[worker_hs2] | |||
sync_hs = self.make_worker_hs( | |||
"synapse.app.generic_worker", | |||
{"worker_name": "sync"}, | |||
) | |||
sync_hs_site = self._hs_to_site[sync_hs] | |||
user_id = self.register_user("user", "pass") | |||
access_token = self.login("user", "pass") | |||
store = self.hs.get_datastores().main | |||
room_id = self.helper.create_room_as(user_id, tok=access_token) | |||
# The other user joins | |||
self.helper.join( | |||
room=room_id, user=self.other_user_id, tok=self.other_access_token | |||
) | |||
response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) | |||
first_event = response["event_id"] | |||
# Do an initial sync so that we're up to date. | |||
channel = make_request( | |||
self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token | |||
) | |||
next_batch = channel.json_body["next_batch"] | |||
# We now gut wrench into the events stream MultiWriterIdGenerator on | |||
# worker2 to mimic it getting stuck persisting a receipt. This ensures | |||
# that when we send an event on worker1 we end up in a state where | |||
# worker2 events stream position lags that on worker1, resulting in a | |||
# receipts token with a non-empty instance map component. | |||
# | |||
# Worker2's receipts stream position will not advance until we call | |||
# __aexit__ again. | |||
worker_store2 = worker_hs2.get_datastores().main | |||
assert isinstance(worker_store2._receipts_id_gen, MultiWriterIdGenerator) | |||
actx = worker_store2._receipts_id_gen.get_next() | |||
self.get_success(actx.__aenter__()) | |||
channel = make_request( | |||
reactor=self.reactor, | |||
site=worker1_site, | |||
method="POST", | |||
path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{first_event}", | |||
access_token=access_token, | |||
content={}, | |||
) | |||
self.assertEqual(200, channel.code) | |||
# Assert that the current stream token has an instance map component, as | |||
# we are trying to test vector clock tokens. | |||
receipts_token = store.get_max_receipt_stream_id() | |||
self.assertGreater(len(receipts_token.instance_map), 0) | |||
# Check that syncing still gets the new receipt, despite the gap in the | |||
# stream IDs. | |||
channel = make_request( | |||
self.reactor, | |||
sync_hs_site, | |||
"GET", | |||
f"/sync?since={next_batch}", | |||
access_token=access_token, | |||
) | |||
# We should only see the new event and nothing else | |||
self.assertIn(room_id, channel.json_body["rooms"]["join"]) | |||
events = channel.json_body["rooms"]["join"][room_id]["ephemeral"]["events"] | |||
self.assertEqual(len(events), 1) | |||
self.assertIn(first_event, events[0]["content"]) | |||
# Get the next batch and makes sure its a vector clock style token. | |||
vector_clock_token = channel.json_body["next_batch"] | |||
parsed_token = self.get_success( | |||
StreamToken.from_string(store, vector_clock_token) | |||
) | |||
self.assertGreaterEqual(len(parsed_token.receipt_key.instance_map), 1) | |||
# Now that we've got a vector clock token we finish the fake persisting | |||
# a receipt we started above. | |||
self.get_success(actx.__aexit__(None, None, None)) | |||
# Now try and send another receipts to the other worker. | |||
response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) | |||
second_event = response["event_id"] | |||
channel = make_request( | |||
reactor=self.reactor, | |||
site=worker2_site, | |||
method="POST", | |||
path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{second_event}", | |||
access_token=access_token, | |||
content={}, | |||
) | |||
channel = make_request( | |||
self.reactor, | |||
sync_hs_site, | |||
"GET", | |||
f"/sync?since={vector_clock_token}", | |||
access_token=access_token, | |||
) | |||
self.assertIn(room_id, channel.json_body["rooms"]["join"]) | |||
events = channel.json_body["rooms"]["join"][room_id]["ephemeral"]["events"] | |||
self.assertEqual(len(events), 1) | |||
self.assertIn(second_event, events[0]["content"]) |