Fixes #16417tags/v1.96.0rc1
@@ -0,0 +1 @@ | |||||
Allow multiple workers to write to receipts stream. |
@@ -358,9 +358,9 @@ class WorkerConfig(Config): | |||||
"Must only specify one instance to handle `account_data` messages." | "Must only specify one instance to handle `account_data` messages." | ||||
) | ) | ||||
if len(self.writers.receipts) != 1: | |||||
if len(self.writers.receipts) == 0: | |||||
raise ConfigError( | raise ConfigError( | ||||
"Must only specify one instance to handle `receipts` messages." | |||||
"Must specify at least one instance to handle `receipts` messages." | |||||
) | ) | ||||
if len(self.writers.events) == 0: | if len(self.writers.events) == 0: | ||||
@@ -47,6 +47,7 @@ from synapse.types import ( | |||||
DeviceListUpdates, | DeviceListUpdates, | ||||
JsonDict, | JsonDict, | ||||
JsonMapping, | JsonMapping, | ||||
MultiWriterStreamToken, | |||||
RoomAlias, | RoomAlias, | ||||
RoomStreamToken, | RoomStreamToken, | ||||
StreamKeyType, | StreamKeyType, | ||||
@@ -217,7 +218,7 @@ class ApplicationServicesHandler: | |||||
def notify_interested_services_ephemeral( | def notify_interested_services_ephemeral( | ||||
self, | self, | ||||
stream_key: StreamKeyType, | stream_key: StreamKeyType, | ||||
new_token: Union[int, RoomStreamToken], | |||||
new_token: Union[int, RoomStreamToken, MultiWriterStreamToken], | |||||
users: Collection[Union[str, UserID]], | users: Collection[Union[str, UserID]], | ||||
) -> None: | ) -> None: | ||||
""" | """ | ||||
@@ -259,19 +260,6 @@ class ApplicationServicesHandler: | |||||
): | ): | ||||
return | return | ||||
# Assert that new_token is an integer (and not a RoomStreamToken). | |||||
# All of the supported streams that this function handles use an | |||||
# integer to track progress (rather than a RoomStreamToken - a | |||||
# vector clock implementation) as they don't support multiple | |||||
# stream writers. | |||||
# | |||||
# As a result, we simply assert that new_token is an integer. | |||||
# If we do end up needing to pass a RoomStreamToken down here | |||||
# in the future, using RoomStreamToken.stream (the minimum stream | |||||
# position) to convert to an ascending integer value should work. | |||||
# Additional context: https://github.com/matrix-org/synapse/pull/11137 | |||||
assert isinstance(new_token, int) | |||||
# Ignore to-device messages if the feature flag is not enabled | # Ignore to-device messages if the feature flag is not enabled | ||||
if ( | if ( | ||||
stream_key == StreamKeyType.TO_DEVICE | stream_key == StreamKeyType.TO_DEVICE | ||||
@@ -286,6 +274,9 @@ class ApplicationServicesHandler: | |||||
): | ): | ||||
return | return | ||||
# We know we're not a `RoomStreamToken` at this point. | |||||
assert not isinstance(new_token, RoomStreamToken) | |||||
# Check whether there are any appservices which have registered to receive | # Check whether there are any appservices which have registered to receive | ||||
# ephemeral events. | # ephemeral events. | ||||
# | # | ||||
@@ -327,7 +318,7 @@ class ApplicationServicesHandler: | |||||
self, | self, | ||||
services: List[ApplicationService], | services: List[ApplicationService], | ||||
stream_key: StreamKeyType, | stream_key: StreamKeyType, | ||||
new_token: int, | |||||
new_token: Union[int, MultiWriterStreamToken], | |||||
users: Collection[Union[str, UserID]], | users: Collection[Union[str, UserID]], | ||||
) -> None: | ) -> None: | ||||
logger.debug("Checking interested services for %s", stream_key) | logger.debug("Checking interested services for %s", stream_key) | ||||
@@ -340,6 +331,7 @@ class ApplicationServicesHandler: | |||||
# | # | ||||
# Instead we simply grab the latest typing updates in _handle_typing | # Instead we simply grab the latest typing updates in _handle_typing | ||||
# and, if they apply to this application service, send it off. | # and, if they apply to this application service, send it off. | ||||
assert isinstance(new_token, int) | |||||
events = await self._handle_typing(service, new_token) | events = await self._handle_typing(service, new_token) | ||||
if events: | if events: | ||||
self.scheduler.enqueue_for_appservice(service, ephemeral=events) | self.scheduler.enqueue_for_appservice(service, ephemeral=events) | ||||
@@ -350,15 +342,23 @@ class ApplicationServicesHandler: | |||||
(service.id, stream_key) | (service.id, stream_key) | ||||
): | ): | ||||
if stream_key == StreamKeyType.RECEIPT: | if stream_key == StreamKeyType.RECEIPT: | ||||
assert isinstance(new_token, MultiWriterStreamToken) | |||||
# We store appservice tokens as integers, so we ignore | |||||
# the `instance_map` components and instead simply | |||||
# follow the base stream position. | |||||
new_token = MultiWriterStreamToken(stream=new_token.stream) | |||||
events = await self._handle_receipts(service, new_token) | events = await self._handle_receipts(service, new_token) | ||||
self.scheduler.enqueue_for_appservice(service, ephemeral=events) | self.scheduler.enqueue_for_appservice(service, ephemeral=events) | ||||
# Persist the latest handled stream token for this appservice | # Persist the latest handled stream token for this appservice | ||||
await self.store.set_appservice_stream_type_pos( | await self.store.set_appservice_stream_type_pos( | ||||
service, "read_receipt", new_token | |||||
service, "read_receipt", new_token.stream | |||||
) | ) | ||||
elif stream_key == StreamKeyType.PRESENCE: | elif stream_key == StreamKeyType.PRESENCE: | ||||
assert isinstance(new_token, int) | |||||
events = await self._handle_presence(service, users, new_token) | events = await self._handle_presence(service, users, new_token) | ||||
self.scheduler.enqueue_for_appservice(service, ephemeral=events) | self.scheduler.enqueue_for_appservice(service, ephemeral=events) | ||||
@@ -368,6 +368,7 @@ class ApplicationServicesHandler: | |||||
) | ) | ||||
elif stream_key == StreamKeyType.TO_DEVICE: | elif stream_key == StreamKeyType.TO_DEVICE: | ||||
assert isinstance(new_token, int) | |||||
# Retrieve a list of to-device message events, as well as the | # Retrieve a list of to-device message events, as well as the | ||||
# maximum stream token of the messages we were able to retrieve. | # maximum stream token of the messages we were able to retrieve. | ||||
to_device_messages = await self._get_to_device_messages( | to_device_messages = await self._get_to_device_messages( | ||||
@@ -383,6 +384,7 @@ class ApplicationServicesHandler: | |||||
) | ) | ||||
elif stream_key == StreamKeyType.DEVICE_LIST: | elif stream_key == StreamKeyType.DEVICE_LIST: | ||||
assert isinstance(new_token, int) | |||||
device_list_summary = await self._get_device_list_summary( | device_list_summary = await self._get_device_list_summary( | ||||
service, new_token | service, new_token | ||||
) | ) | ||||
@@ -432,7 +434,7 @@ class ApplicationServicesHandler: | |||||
return typing | return typing | ||||
async def _handle_receipts( | async def _handle_receipts( | ||||
self, service: ApplicationService, new_token: int | |||||
self, service: ApplicationService, new_token: MultiWriterStreamToken | |||||
) -> List[JsonMapping]: | ) -> List[JsonMapping]: | ||||
""" | """ | ||||
Return the latest read receipts that the given application service should receive. | Return the latest read receipts that the given application service should receive. | ||||
@@ -455,15 +457,17 @@ class ApplicationServicesHandler: | |||||
from_key = await self.store.get_type_stream_id_for_appservice( | from_key = await self.store.get_type_stream_id_for_appservice( | ||||
service, "read_receipt" | service, "read_receipt" | ||||
) | ) | ||||
if new_token is not None and new_token <= from_key: | |||||
if new_token is not None and new_token.stream <= from_key: | |||||
logger.debug( | logger.debug( | ||||
"Rejecting token lower than or equal to stored: %s" % (new_token,) | "Rejecting token lower than or equal to stored: %s" % (new_token,) | ||||
) | ) | ||||
return [] | return [] | ||||
from_token = MultiWriterStreamToken(stream=from_key) | |||||
receipts_source = self.event_sources.sources.receipt | receipts_source = self.event_sources.sources.receipt | ||||
receipts, _ = await receipts_source.get_new_events_as( | receipts, _ = await receipts_source.get_new_events_as( | ||||
service=service, from_key=from_key, to_key=new_token | |||||
service=service, from_key=from_token, to_key=new_token | |||||
) | ) | ||||
return receipts | return receipts | ||||
@@ -145,7 +145,7 @@ class InitialSyncHandler: | |||||
joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN] | joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN] | ||||
receipt = await self.store.get_linearized_receipts_for_rooms( | receipt = await self.store.get_linearized_receipts_for_rooms( | ||||
joined_rooms, | joined_rooms, | ||||
to_key=int(now_token.receipt_key), | |||||
to_key=now_token.receipt_key, | |||||
) | ) | ||||
receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id) | receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id) | ||||
@@ -20,6 +20,7 @@ from synapse.streams import EventSource | |||||
from synapse.types import ( | from synapse.types import ( | ||||
JsonDict, | JsonDict, | ||||
JsonMapping, | JsonMapping, | ||||
MultiWriterStreamToken, | |||||
ReadReceipt, | ReadReceipt, | ||||
StreamKeyType, | StreamKeyType, | ||||
UserID, | UserID, | ||||
@@ -200,7 +201,7 @@ class ReceiptsHandler: | |||||
await self.federation_sender.send_read_receipt(receipt) | await self.federation_sender.send_read_receipt(receipt) | ||||
class ReceiptEventSource(EventSource[int, JsonMapping]): | |||||
class ReceiptEventSource(EventSource[MultiWriterStreamToken, JsonMapping]): | |||||
def __init__(self, hs: "HomeServer"): | def __init__(self, hs: "HomeServer"): | ||||
self.store = hs.get_datastores().main | self.store = hs.get_datastores().main | ||||
self.config = hs.config | self.config = hs.config | ||||
@@ -273,13 +274,12 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): | |||||
async def get_new_events( | async def get_new_events( | ||||
self, | self, | ||||
user: UserID, | user: UserID, | ||||
from_key: int, | |||||
from_key: MultiWriterStreamToken, | |||||
limit: int, | limit: int, | ||||
room_ids: Iterable[str], | room_ids: Iterable[str], | ||||
is_guest: bool, | is_guest: bool, | ||||
explicit_room_id: Optional[str] = None, | explicit_room_id: Optional[str] = None, | ||||
) -> Tuple[List[JsonMapping], int]: | |||||
from_key = int(from_key) | |||||
) -> Tuple[List[JsonMapping], MultiWriterStreamToken]: | |||||
to_key = self.get_current_key() | to_key = self.get_current_key() | ||||
if from_key == to_key: | if from_key == to_key: | ||||
@@ -296,8 +296,11 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): | |||||
return events, to_key | return events, to_key | ||||
async def get_new_events_as( | async def get_new_events_as( | ||||
self, from_key: int, to_key: int, service: ApplicationService | |||||
) -> Tuple[List[JsonMapping], int]: | |||||
self, | |||||
from_key: MultiWriterStreamToken, | |||||
to_key: MultiWriterStreamToken, | |||||
service: ApplicationService, | |||||
) -> Tuple[List[JsonMapping], MultiWriterStreamToken]: | |||||
"""Returns a set of new read receipt events that an appservice | """Returns a set of new read receipt events that an appservice | ||||
may be interested in. | may be interested in. | ||||
@@ -312,8 +315,6 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): | |||||
appservice may be interested in. | appservice may be interested in. | ||||
* The current read receipt stream token. | * The current read receipt stream token. | ||||
""" | """ | ||||
from_key = int(from_key) | |||||
if from_key == to_key: | if from_key == to_key: | ||||
return [], to_key | return [], to_key | ||||
@@ -333,5 +334,5 @@ class ReceiptEventSource(EventSource[int, JsonMapping]): | |||||
return events, to_key | return events, to_key | ||||
def get_current_key(self) -> int: | |||||
def get_current_key(self) -> MultiWriterStreamToken: | |||||
return self.store.get_max_receipt_stream_id() | return self.store.get_max_receipt_stream_id() |
@@ -57,6 +57,7 @@ from synapse.types import ( | |||||
DeviceListUpdates, | DeviceListUpdates, | ||||
JsonDict, | JsonDict, | ||||
JsonMapping, | JsonMapping, | ||||
MultiWriterStreamToken, | |||||
MutableStateMap, | MutableStateMap, | ||||
Requester, | Requester, | ||||
RoomStreamToken, | RoomStreamToken, | ||||
@@ -477,7 +478,11 @@ class SyncHandler: | |||||
event_copy = {k: v for (k, v) in event.items() if k != "room_id"} | event_copy = {k: v for (k, v) in event.items() if k != "room_id"} | ||||
ephemeral_by_room.setdefault(room_id, []).append(event_copy) | ephemeral_by_room.setdefault(room_id, []).append(event_copy) | ||||
receipt_key = since_token.receipt_key if since_token else 0 | |||||
receipt_key = ( | |||||
since_token.receipt_key | |||||
if since_token | |||||
else MultiWriterStreamToken(stream=0) | |||||
) | |||||
receipt_source = self.event_sources.sources.receipt | receipt_source = self.event_sources.sources.receipt | ||||
receipts, receipt_key = await receipt_source.get_new_events( | receipts, receipt_key = await receipt_source.get_new_events( | ||||
@@ -21,11 +21,13 @@ from typing import ( | |||||
Dict, | Dict, | ||||
Iterable, | Iterable, | ||||
List, | List, | ||||
Literal, | |||||
Optional, | Optional, | ||||
Set, | Set, | ||||
Tuple, | Tuple, | ||||
TypeVar, | TypeVar, | ||||
Union, | Union, | ||||
overload, | |||||
) | ) | ||||
import attr | import attr | ||||
@@ -44,6 +46,7 @@ from synapse.metrics import LaterGauge | |||||
from synapse.streams.config import PaginationConfig | from synapse.streams.config import PaginationConfig | ||||
from synapse.types import ( | from synapse.types import ( | ||||
JsonDict, | JsonDict, | ||||
MultiWriterStreamToken, | |||||
PersistedEventPosition, | PersistedEventPosition, | ||||
RoomStreamToken, | RoomStreamToken, | ||||
StrCollection, | StrCollection, | ||||
@@ -127,7 +130,7 @@ class _NotifierUserStream: | |||||
def notify( | def notify( | ||||
self, | self, | ||||
stream_key: StreamKeyType, | stream_key: StreamKeyType, | ||||
stream_id: Union[int, RoomStreamToken], | |||||
stream_id: Union[int, RoomStreamToken, MultiWriterStreamToken], | |||||
time_now_ms: int, | time_now_ms: int, | ||||
) -> None: | ) -> None: | ||||
"""Notify any listeners for this user of a new event from an | """Notify any listeners for this user of a new event from an | ||||
@@ -452,10 +455,48 @@ class Notifier: | |||||
except Exception: | except Exception: | ||||
logger.exception("Error pusher pool of event") | logger.exception("Error pusher pool of event") | ||||
@overload | |||||
def on_new_event( | |||||
self, | |||||
stream_key: Literal[StreamKeyType.ROOM], | |||||
new_token: RoomStreamToken, | |||||
users: Optional[Collection[Union[str, UserID]]] = None, | |||||
rooms: Optional[StrCollection] = None, | |||||
) -> None: | |||||
... | |||||
@overload | |||||
def on_new_event( | |||||
self, | |||||
stream_key: Literal[StreamKeyType.RECEIPT], | |||||
new_token: MultiWriterStreamToken, | |||||
users: Optional[Collection[Union[str, UserID]]] = None, | |||||
rooms: Optional[StrCollection] = None, | |||||
) -> None: | |||||
... | |||||
@overload | |||||
def on_new_event( | |||||
self, | |||||
stream_key: Literal[ | |||||
StreamKeyType.ACCOUNT_DATA, | |||||
StreamKeyType.DEVICE_LIST, | |||||
StreamKeyType.PRESENCE, | |||||
StreamKeyType.PUSH_RULES, | |||||
StreamKeyType.TO_DEVICE, | |||||
StreamKeyType.TYPING, | |||||
StreamKeyType.UN_PARTIAL_STATED_ROOMS, | |||||
], | |||||
new_token: int, | |||||
users: Optional[Collection[Union[str, UserID]]] = None, | |||||
rooms: Optional[StrCollection] = None, | |||||
) -> None: | |||||
... | |||||
def on_new_event( | def on_new_event( | ||||
self, | self, | ||||
stream_key: StreamKeyType, | stream_key: StreamKeyType, | ||||
new_token: Union[int, RoomStreamToken], | |||||
new_token: Union[int, RoomStreamToken, MultiWriterStreamToken], | |||||
users: Optional[Collection[Union[str, UserID]]] = None, | users: Optional[Collection[Union[str, UserID]]] = None, | ||||
rooms: Optional[StrCollection] = None, | rooms: Optional[StrCollection] = None, | ||||
) -> None: | ) -> None: | ||||
@@ -126,8 +126,9 @@ class ReplicationDataHandler: | |||||
StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows] | StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows] | ||||
) | ) | ||||
elif stream_name == ReceiptsStream.NAME: | elif stream_name == ReceiptsStream.NAME: | ||||
new_token = self.store.get_max_receipt_stream_id() | |||||
self.notifier.on_new_event( | self.notifier.on_new_event( | ||||
StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows] | |||||
StreamKeyType.RECEIPT, new_token, rooms=[row.room_id for row in rows] | |||||
) | ) | ||||
await self._pusher_pool.on_new_receipts({row.user_id for row in rows}) | await self._pusher_pool.on_new_receipts({row.user_id for row in rows}) | ||||
elif stream_name == ToDeviceStream.NAME: | elif stream_name == ToDeviceStream.NAME: | ||||
@@ -28,6 +28,8 @@ from typing import ( | |||||
cast, | cast, | ||||
) | ) | ||||
from immutabledict import immutabledict | |||||
from synapse.api.constants import EduTypes | from synapse.api.constants import EduTypes | ||||
from synapse.replication.tcp.streams import ReceiptsStream | from synapse.replication.tcp.streams import ReceiptsStream | ||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause | from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause | ||||
@@ -43,7 +45,12 @@ from synapse.storage.util.id_generators import ( | |||||
MultiWriterIdGenerator, | MultiWriterIdGenerator, | ||||
StreamIdGenerator, | StreamIdGenerator, | ||||
) | ) | ||||
from synapse.types import JsonDict, JsonMapping | |||||
from synapse.types import ( | |||||
JsonDict, | |||||
JsonMapping, | |||||
MultiWriterStreamToken, | |||||
PersistedPosition, | |||||
) | |||||
from synapse.util import json_encoder | from synapse.util import json_encoder | ||||
from synapse.util.caches.descriptors import cached, cachedList | from synapse.util.caches.descriptors import cached, cachedList | ||||
from synapse.util.caches.stream_change_cache import StreamChangeCache | from synapse.util.caches.stream_change_cache import StreamChangeCache | ||||
@@ -105,7 +112,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
"receipts_linearized", | "receipts_linearized", | ||||
entity_column="room_id", | entity_column="room_id", | ||||
stream_column="stream_id", | stream_column="stream_id", | ||||
max_value=max_receipts_stream_id, | |||||
max_value=max_receipts_stream_id.stream, | |||||
limit=10000, | limit=10000, | ||||
) | ) | ||||
self._receipts_stream_cache = StreamChangeCache( | self._receipts_stream_cache = StreamChangeCache( | ||||
@@ -114,9 +121,31 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
prefilled_cache=receipts_stream_prefill, | prefilled_cache=receipts_stream_prefill, | ||||
) | ) | ||||
def get_max_receipt_stream_id(self) -> int: | |||||
def get_max_receipt_stream_id(self) -> MultiWriterStreamToken: | |||||
"""Get the current max stream ID for receipts stream""" | """Get the current max stream ID for receipts stream""" | ||||
return self._receipts_id_gen.get_current_token() | |||||
min_pos = self._receipts_id_gen.get_current_token() | |||||
positions = {} | |||||
if isinstance(self._receipts_id_gen, MultiWriterIdGenerator): | |||||
# The `min_pos` is the minimum position that we know all instances | |||||
# have finished persisting to, so we only care about instances whose | |||||
# positions are ahead of that. (Instance positions can be behind the | |||||
# min position as there are times we can work out that the minimum | |||||
# position is ahead of the naive minimum across all current | |||||
# positions. See MultiWriterIdGenerator for details) | |||||
positions = { | |||||
i: p | |||||
for i, p in self._receipts_id_gen.get_positions().items() | |||||
if p > min_pos | |||||
} | |||||
return MultiWriterStreamToken( | |||||
stream=min_pos, instance_map=immutabledict(positions) | |||||
) | |||||
def get_receipt_stream_id_for_instance(self, instance_name: str) -> int: | |||||
return self._receipts_id_gen.get_current_token_for_writer(instance_name) | |||||
def get_last_unthreaded_receipt_for_user_txn( | def get_last_unthreaded_receipt_for_user_txn( | ||||
self, | self, | ||||
@@ -257,7 +286,10 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
} | } | ||||
async def get_linearized_receipts_for_rooms( | async def get_linearized_receipts_for_rooms( | ||||
self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None | |||||
self, | |||||
room_ids: Iterable[str], | |||||
to_key: MultiWriterStreamToken, | |||||
from_key: Optional[MultiWriterStreamToken] = None, | |||||
) -> List[JsonMapping]: | ) -> List[JsonMapping]: | ||||
"""Get receipts for multiple rooms for sending to clients. | """Get receipts for multiple rooms for sending to clients. | ||||
@@ -276,7 +308,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
# Only ask the database about rooms where there have been new | # Only ask the database about rooms where there have been new | ||||
# receipts added since `from_key` | # receipts added since `from_key` | ||||
room_ids = self._receipts_stream_cache.get_entities_changed( | room_ids = self._receipts_stream_cache.get_entities_changed( | ||||
room_ids, from_key | |||||
room_ids, from_key.stream | |||||
) | ) | ||||
results = await self._get_linearized_receipts_for_rooms( | results = await self._get_linearized_receipts_for_rooms( | ||||
@@ -286,7 +318,10 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
return [ev for res in results.values() for ev in res] | return [ev for res in results.values() for ev in res] | ||||
async def get_linearized_receipts_for_room( | async def get_linearized_receipts_for_room( | ||||
self, room_id: str, to_key: int, from_key: Optional[int] = None | |||||
self, | |||||
room_id: str, | |||||
to_key: MultiWriterStreamToken, | |||||
from_key: Optional[MultiWriterStreamToken] = None, | |||||
) -> Sequence[JsonMapping]: | ) -> Sequence[JsonMapping]: | ||||
"""Get receipts for a single room for sending to clients. | """Get receipts for a single room for sending to clients. | ||||
@@ -302,36 +337,49 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
if from_key is not None: | if from_key is not None: | ||||
# Check the cache first to see if any new receipts have been added | # Check the cache first to see if any new receipts have been added | ||||
# since`from_key`. If not we can no-op. | # since`from_key`. If not we can no-op. | ||||
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): | |||||
if not self._receipts_stream_cache.has_entity_changed( | |||||
room_id, from_key.stream | |||||
): | |||||
return [] | return [] | ||||
return await self._get_linearized_receipts_for_room(room_id, to_key, from_key) | return await self._get_linearized_receipts_for_room(room_id, to_key, from_key) | ||||
@cached(tree=True) | @cached(tree=True) | ||||
async def _get_linearized_receipts_for_room( | async def _get_linearized_receipts_for_room( | ||||
self, room_id: str, to_key: int, from_key: Optional[int] = None | |||||
self, | |||||
room_id: str, | |||||
to_key: MultiWriterStreamToken, | |||||
from_key: Optional[MultiWriterStreamToken] = None, | |||||
) -> Sequence[JsonMapping]: | ) -> Sequence[JsonMapping]: | ||||
"""See get_linearized_receipts_for_room""" | """See get_linearized_receipts_for_room""" | ||||
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]: | def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]: | ||||
if from_key: | if from_key: | ||||
sql = ( | |||||
"SELECT receipt_type, user_id, event_id, data" | |||||
" FROM receipts_linearized WHERE" | |||||
" room_id = ? AND stream_id > ? AND stream_id <= ?" | |||||
) | |||||
sql = """ | |||||
SELECT stream_id, instance_name, receipt_type, user_id, event_id, data | |||||
FROM receipts_linearized | |||||
WHERE room_id = ? AND stream_id > ? AND stream_id <= ? | |||||
""" | |||||
txn.execute(sql, (room_id, from_key, to_key)) | |||||
else: | |||||
sql = ( | |||||
"SELECT receipt_type, user_id, event_id, data" | |||||
" FROM receipts_linearized WHERE" | |||||
" room_id = ? AND stream_id <= ?" | |||||
txn.execute( | |||||
sql, (room_id, from_key.stream, to_key.get_max_stream_pos()) | |||||
) | ) | ||||
else: | |||||
sql = """ | |||||
SELECT stream_id, instance_name, receipt_type, user_id, event_id, data | |||||
FROM receipts_linearized WHERE | |||||
room_id = ? AND stream_id <= ? | |||||
""" | |||||
txn.execute(sql, (room_id, to_key)) | |||||
txn.execute(sql, (room_id, to_key.get_max_stream_pos())) | |||||
return cast(List[Tuple[str, str, str, str]], txn.fetchall()) | |||||
return [ | |||||
(receipt_type, user_id, event_id, data) | |||||
for stream_id, instance_name, receipt_type, user_id, event_id, data in txn | |||||
if MultiWriterStreamToken.is_stream_position_in_range( | |||||
from_key, to_key, instance_name, stream_id | |||||
) | |||||
] | |||||
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f) | rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f) | ||||
@@ -352,7 +400,10 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
num_args=3, | num_args=3, | ||||
) | ) | ||||
async def _get_linearized_receipts_for_rooms( | async def _get_linearized_receipts_for_rooms( | ||||
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None | |||||
self, | |||||
room_ids: Collection[str], | |||||
to_key: MultiWriterStreamToken, | |||||
from_key: Optional[MultiWriterStreamToken] = None, | |||||
) -> Mapping[str, Sequence[JsonMapping]]: | ) -> Mapping[str, Sequence[JsonMapping]]: | ||||
if not room_ids: | if not room_ids: | ||||
return {} | return {} | ||||
@@ -362,7 +413,8 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
) -> List[Tuple[str, str, str, str, Optional[str], str]]: | ) -> List[Tuple[str, str, str, str, Optional[str], str]]: | ||||
if from_key: | if from_key: | ||||
sql = """ | sql = """ | ||||
SELECT room_id, receipt_type, user_id, event_id, thread_id, data | |||||
SELECT stream_id, instance_name, room_id, receipt_type, | |||||
user_id, event_id, thread_id, data | |||||
FROM receipts_linearized WHERE | FROM receipts_linearized WHERE | ||||
stream_id > ? AND stream_id <= ? AND | stream_id > ? AND stream_id <= ? AND | ||||
""" | """ | ||||
@@ -370,10 +422,14 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
self.database_engine, "room_id", room_ids | self.database_engine, "room_id", room_ids | ||||
) | ) | ||||
txn.execute(sql + clause, [from_key, to_key] + list(args)) | |||||
txn.execute( | |||||
sql + clause, | |||||
[from_key.stream, to_key.get_max_stream_pos()] + list(args), | |||||
) | |||||
else: | else: | ||||
sql = """ | sql = """ | ||||
SELECT room_id, receipt_type, user_id, event_id, thread_id, data | |||||
SELECT stream_id, instance_name, room_id, receipt_type, | |||||
user_id, event_id, thread_id, data | |||||
FROM receipts_linearized WHERE | FROM receipts_linearized WHERE | ||||
stream_id <= ? AND | stream_id <= ? AND | ||||
""" | """ | ||||
@@ -382,11 +438,15 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
self.database_engine, "room_id", room_ids | self.database_engine, "room_id", room_ids | ||||
) | ) | ||||
txn.execute(sql + clause, [to_key] + list(args)) | |||||
txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args)) | |||||
return cast( | |||||
List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall() | |||||
) | |||||
return [ | |||||
(room_id, receipt_type, user_id, event_id, thread_id, data) | |||||
for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn | |||||
if MultiWriterStreamToken.is_stream_position_in_range( | |||||
from_key, to_key, instance_name, stream_id | |||||
) | |||||
] | |||||
txn_results = await self.db_pool.runInteraction( | txn_results = await self.db_pool.runInteraction( | ||||
"_get_linearized_receipts_for_rooms", f | "_get_linearized_receipts_for_rooms", f | ||||
@@ -420,7 +480,9 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
num_args=2, | num_args=2, | ||||
) | ) | ||||
async def get_linearized_receipts_for_all_rooms( | async def get_linearized_receipts_for_all_rooms( | ||||
self, to_key: int, from_key: Optional[int] = None | |||||
self, | |||||
to_key: MultiWriterStreamToken, | |||||
from_key: Optional[MultiWriterStreamToken] = None, | |||||
) -> Mapping[str, JsonMapping]: | ) -> Mapping[str, JsonMapping]: | ||||
"""Get receipts for all rooms between two stream_ids, up | """Get receipts for all rooms between two stream_ids, up | ||||
to a limit of the latest 100 read receipts. | to a limit of the latest 100 read receipts. | ||||
@@ -437,25 +499,31 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]: | def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]: | ||||
if from_key: | if from_key: | ||||
sql = """ | sql = """ | ||||
SELECT room_id, receipt_type, user_id, event_id, data | |||||
SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data | |||||
FROM receipts_linearized WHERE | FROM receipts_linearized WHERE | ||||
stream_id > ? AND stream_id <= ? | stream_id > ? AND stream_id <= ? | ||||
ORDER BY stream_id DESC | ORDER BY stream_id DESC | ||||
LIMIT 100 | LIMIT 100 | ||||
""" | """ | ||||
txn.execute(sql, [from_key, to_key]) | |||||
txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()]) | |||||
else: | else: | ||||
sql = """ | sql = """ | ||||
SELECT room_id, receipt_type, user_id, event_id, data | |||||
SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data | |||||
FROM receipts_linearized WHERE | FROM receipts_linearized WHERE | ||||
stream_id <= ? | stream_id <= ? | ||||
ORDER BY stream_id DESC | ORDER BY stream_id DESC | ||||
LIMIT 100 | LIMIT 100 | ||||
""" | """ | ||||
txn.execute(sql, [to_key]) | |||||
txn.execute(sql, [to_key.get_max_stream_pos()]) | |||||
return cast(List[Tuple[str, str, str, str, str]], txn.fetchall()) | |||||
return [ | |||||
(room_id, receipt_type, user_id, event_id, data) | |||||
for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn | |||||
if MultiWriterStreamToken.is_stream_position_in_range( | |||||
from_key, to_key, instance_name, stream_id | |||||
) | |||||
] | |||||
txn_results = await self.db_pool.runInteraction( | txn_results = await self.db_pool.runInteraction( | ||||
"get_linearized_receipts_for_all_rooms", f | "get_linearized_receipts_for_all_rooms", f | ||||
@@ -545,10 +613,11 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data | SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data | ||||
FROM receipts_linearized | FROM receipts_linearized | ||||
WHERE ? < stream_id AND stream_id <= ? | WHERE ? < stream_id AND stream_id <= ? | ||||
AND instance_name = ? | |||||
ORDER BY stream_id ASC | ORDER BY stream_id ASC | ||||
LIMIT ? | LIMIT ? | ||||
""" | """ | ||||
txn.execute(sql, (last_id, current_id, limit)) | |||||
txn.execute(sql, (last_id, current_id, instance_name, limit)) | |||||
updates = cast( | updates = cast( | ||||
List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], | List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], | ||||
@@ -695,6 +764,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
keyvalues=keyvalues, | keyvalues=keyvalues, | ||||
values={ | values={ | ||||
"stream_id": stream_id, | "stream_id": stream_id, | ||||
"instance_name": self._instance_name, | |||||
"event_id": event_id, | "event_id": event_id, | ||||
"event_stream_ordering": stream_ordering, | "event_stream_ordering": stream_ordering, | ||||
"data": json_encoder.encode(data), | "data": json_encoder.encode(data), | ||||
@@ -750,7 +820,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
event_ids: List[str], | event_ids: List[str], | ||||
thread_id: Optional[str], | thread_id: Optional[str], | ||||
data: dict, | data: dict, | ||||
) -> Optional[int]: | |||||
) -> Optional[PersistedPosition]: | |||||
"""Insert a receipt, either from local client or remote server. | """Insert a receipt, either from local client or remote server. | ||||
Automatically does conversion between linearized and graph | Automatically does conversion between linearized and graph | ||||
@@ -812,7 +882,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
data, | data, | ||||
) | ) | ||||
return stream_id | |||||
return PersistedPosition(self._instance_name, stream_id) | |||||
async def _insert_graph_receipt( | async def _insert_graph_receipt( | ||||
self, | self, | ||||
@@ -47,7 +47,7 @@ from synapse.storage.databases.main.stream import ( | |||||
generate_pagination_where_clause, | generate_pagination_where_clause, | ||||
) | ) | ||||
from synapse.storage.engines import PostgresEngine | from synapse.storage.engines import PostgresEngine | ||||
from synapse.types import JsonDict, StreamKeyType, StreamToken | |||||
from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken | |||||
from synapse.util.caches.descriptors import cached, cachedList | from synapse.util.caches.descriptors import cached, cachedList | ||||
if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
@@ -314,7 +314,7 @@ class RelationsWorkerStore(SQLBaseStore): | |||||
room_key=next_key, | room_key=next_key, | ||||
presence_key=0, | presence_key=0, | ||||
typing_key=0, | typing_key=0, | ||||
receipt_key=0, | |||||
receipt_key=MultiWriterStreamToken(stream=0), | |||||
account_data_key=0, | account_data_key=0, | ||||
push_rules_key=0, | push_rules_key=0, | ||||
to_device_key=0, | to_device_key=0, | ||||
@@ -0,0 +1,17 @@ | |||||
/* Copyright 2023 The Matrix.org Foundation C.I.C | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
-- This already exists on Postgres. | |||||
ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT; |
@@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource | |||||
from synapse.handlers.typing import TypingNotificationEventSource | from synapse.handlers.typing import TypingNotificationEventSource | ||||
from synapse.logging.opentracing import trace | from synapse.logging.opentracing import trace | ||||
from synapse.streams import EventSource | from synapse.streams import EventSource | ||||
from synapse.types import StreamKeyType, StreamToken | |||||
from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken | |||||
if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
from synapse.server import HomeServer | from synapse.server import HomeServer | ||||
@@ -111,7 +111,7 @@ class EventSources: | |||||
room_key=await self.sources.room.get_current_key_for_room(room_id), | room_key=await self.sources.room.get_current_key_for_room(room_id), | ||||
presence_key=0, | presence_key=0, | ||||
typing_key=0, | typing_key=0, | ||||
receipt_key=0, | |||||
receipt_key=MultiWriterStreamToken(stream=0), | |||||
account_data_key=0, | account_data_key=0, | ||||
push_rules_key=0, | push_rules_key=0, | ||||
to_device_key=0, | to_device_key=0, | ||||
@@ -695,6 +695,90 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): | |||||
return "s%d" % (self.stream,) | return "s%d" % (self.stream,) | ||||
@attr.s(frozen=True, slots=True, order=False) | |||||
class MultiWriterStreamToken(AbstractMultiWriterStreamToken): | |||||
"""A basic stream token class for streams that supports multiple writers.""" | |||||
@classmethod | |||||
async def parse(cls, store: "DataStore", string: str) -> "MultiWriterStreamToken": | |||||
try: | |||||
if string[0].isdigit(): | |||||
return cls(stream=int(string)) | |||||
if string[0] == "m": | |||||
parts = string[1:].split("~") | |||||
stream = int(parts[0]) | |||||
instance_map = {} | |||||
for part in parts[1:]: | |||||
key, value = part.split(".") | |||||
instance_id = int(key) | |||||
pos = int(value) | |||||
instance_name = await store.get_name_from_instance_id(instance_id) | |||||
instance_map[instance_name] = pos | |||||
return cls( | |||||
stream=stream, | |||||
instance_map=immutabledict(instance_map), | |||||
) | |||||
except CancelledError: | |||||
raise | |||||
except Exception: | |||||
pass | |||||
raise SynapseError(400, "Invalid stream token %r" % (string,)) | |||||
async def to_string(self, store: "DataStore") -> str: | |||||
if self.instance_map: | |||||
entries = [] | |||||
for name, pos in self.instance_map.items(): | |||||
if pos <= self.stream: | |||||
# Ignore instances who are below the minimum stream position | |||||
# (we might know they've advanced without seeing a recent | |||||
# write from them). | |||||
continue | |||||
instance_id = await store.get_id_for_instance(name) | |||||
entries.append(f"{instance_id}.{pos}") | |||||
encoded_map = "~".join(entries) | |||||
return f"m{self.stream}~{encoded_map}" | |||||
else: | |||||
return str(self.stream) | |||||
@staticmethod | |||||
def is_stream_position_in_range( | |||||
low: Optional["AbstractMultiWriterStreamToken"], | |||||
high: Optional["AbstractMultiWriterStreamToken"], | |||||
instance_name: Optional[str], | |||||
pos: int, | |||||
) -> bool: | |||||
"""Checks if a given persisted position is between the two given tokens. | |||||
If `instance_name` is None then the row was persisted before multi | |||||
writer support. | |||||
""" | |||||
if low: | |||||
if instance_name: | |||||
low_stream = low.instance_map.get(instance_name, low.stream) | |||||
else: | |||||
low_stream = low.stream | |||||
if pos <= low_stream: | |||||
return False | |||||
if high: | |||||
if instance_name: | |||||
high_stream = high.instance_map.get(instance_name, high.stream) | |||||
else: | |||||
high_stream = high.stream | |||||
if high_stream < pos: | |||||
return False | |||||
return True | |||||
class StreamKeyType(Enum): | class StreamKeyType(Enum): | ||||
"""Known stream types. | """Known stream types. | ||||
@@ -776,7 +860,9 @@ class StreamToken: | |||||
) | ) | ||||
presence_key: int | presence_key: int | ||||
typing_key: int | typing_key: int | ||||
receipt_key: int | |||||
receipt_key: MultiWriterStreamToken = attr.ib( | |||||
validator=attr.validators.instance_of(MultiWriterStreamToken) | |||||
) | |||||
account_data_key: int | account_data_key: int | ||||
push_rules_key: int | push_rules_key: int | ||||
to_device_key: int | to_device_key: int | ||||
@@ -799,8 +885,31 @@ class StreamToken: | |||||
while len(keys) < len(attr.fields(cls)): | while len(keys) < len(attr.fields(cls)): | ||||
# i.e. old token from before receipt_key | # i.e. old token from before receipt_key | ||||
keys.append("0") | keys.append("0") | ||||
( | |||||
room_key, | |||||
presence_key, | |||||
typing_key, | |||||
receipt_key, | |||||
account_data_key, | |||||
push_rules_key, | |||||
to_device_key, | |||||
device_list_key, | |||||
groups_key, | |||||
un_partial_stated_rooms_key, | |||||
) = keys | |||||
return cls( | return cls( | ||||
await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) | |||||
room_key=await RoomStreamToken.parse(store, room_key), | |||||
presence_key=int(presence_key), | |||||
typing_key=int(typing_key), | |||||
receipt_key=await MultiWriterStreamToken.parse(store, receipt_key), | |||||
account_data_key=int(account_data_key), | |||||
push_rules_key=int(push_rules_key), | |||||
to_device_key=int(to_device_key), | |||||
device_list_key=int(device_list_key), | |||||
groups_key=int(groups_key), | |||||
un_partial_stated_rooms_key=int(un_partial_stated_rooms_key), | |||||
) | ) | ||||
except CancelledError: | except CancelledError: | ||||
raise | raise | ||||
@@ -813,7 +922,7 @@ class StreamToken: | |||||
await self.room_key.to_string(store), | await self.room_key.to_string(store), | ||||
str(self.presence_key), | str(self.presence_key), | ||||
str(self.typing_key), | str(self.typing_key), | ||||
str(self.receipt_key), | |||||
await self.receipt_key.to_string(store), | |||||
str(self.account_data_key), | str(self.account_data_key), | ||||
str(self.push_rules_key), | str(self.push_rules_key), | ||||
str(self.to_device_key), | str(self.to_device_key), | ||||
@@ -841,6 +950,11 @@ class StreamToken: | |||||
StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) | StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) | ||||
) | ) | ||||
return new_token | return new_token | ||||
elif key == StreamKeyType.RECEIPT: | |||||
new_token = self.copy_and_replace( | |||||
StreamKeyType.RECEIPT, self.receipt_key.copy_and_advance(new_value) | |||||
) | |||||
return new_token | |||||
new_token = self.copy_and_replace(key, new_value) | new_token = self.copy_and_replace(key, new_value) | ||||
new_id = new_token.get_field(key) | new_id = new_token.get_field(key) | ||||
@@ -858,6 +972,10 @@ class StreamToken: | |||||
def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken: | def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken: | ||||
... | ... | ||||
@overload | |||||
def get_field(self, key: Literal[StreamKeyType.RECEIPT]) -> MultiWriterStreamToken: | |||||
... | |||||
@overload | @overload | ||||
def get_field( | def get_field( | ||||
self, | self, | ||||
@@ -866,7 +984,6 @@ class StreamToken: | |||||
StreamKeyType.DEVICE_LIST, | StreamKeyType.DEVICE_LIST, | ||||
StreamKeyType.PRESENCE, | StreamKeyType.PRESENCE, | ||||
StreamKeyType.PUSH_RULES, | StreamKeyType.PUSH_RULES, | ||||
StreamKeyType.RECEIPT, | |||||
StreamKeyType.TO_DEVICE, | StreamKeyType.TO_DEVICE, | ||||
StreamKeyType.TYPING, | StreamKeyType.TYPING, | ||||
StreamKeyType.UN_PARTIAL_STATED_ROOMS, | StreamKeyType.UN_PARTIAL_STATED_ROOMS, | ||||
@@ -875,15 +992,21 @@ class StreamToken: | |||||
... | ... | ||||
@overload | @overload | ||||
def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]: | |||||
def get_field( | |||||
self, key: StreamKeyType | |||||
) -> Union[int, RoomStreamToken, MultiWriterStreamToken]: | |||||
... | ... | ||||
def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]: | |||||
def get_field( | |||||
self, key: StreamKeyType | |||||
) -> Union[int, RoomStreamToken, MultiWriterStreamToken]: | |||||
"""Returns the stream ID for the given key.""" | """Returns the stream ID for the given key.""" | ||||
return getattr(self, key.value) | return getattr(self, key.value) | ||||
StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0) | |||||
StreamToken.START = StreamToken( | |||||
RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0 | |||||
) | |||||
@attr.s(slots=True, frozen=True, auto_attribs=True) | @attr.s(slots=True, frozen=True, auto_attribs=True) | ||||
@@ -31,7 +31,12 @@ from synapse.appservice import ( | |||||
from synapse.handlers.appservice import ApplicationServicesHandler | from synapse.handlers.appservice import ApplicationServicesHandler | ||||
from synapse.rest.client import login, receipts, register, room, sendtodevice | from synapse.rest.client import login, receipts, register, room, sendtodevice | ||||
from synapse.server import HomeServer | from synapse.server import HomeServer | ||||
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType | |||||
from synapse.types import ( | |||||
JsonDict, | |||||
MultiWriterStreamToken, | |||||
RoomStreamToken, | |||||
StreamKeyType, | |||||
) | |||||
from synapse.util import Clock | from synapse.util import Clock | ||||
from synapse.util.stringutils import random_string | from synapse.util.stringutils import random_string | ||||
@@ -305,7 +310,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||||
) | ) | ||||
self.handler.notify_interested_services_ephemeral( | self.handler.notify_interested_services_ephemeral( | ||||
StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"] | |||||
StreamKeyType.RECEIPT, | |||||
MultiWriterStreamToken(stream=580), | |||||
["@fakerecipient:example.com"], | |||||
) | ) | ||||
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( | self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( | ||||
interested_service, ephemeral=[event] | interested_service, ephemeral=[event] | ||||
@@ -333,7 +340,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||||
) | ) | ||||
self.handler.notify_interested_services_ephemeral( | self.handler.notify_interested_services_ephemeral( | ||||
StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"] | |||||
StreamKeyType.RECEIPT, | |||||
MultiWriterStreamToken(stream=580), | |||||
["@fakerecipient:example.com"], | |||||
) | ) | ||||
# This method will be called, but with an empty list of events | # This method will be called, but with an empty list of events | ||||
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( | self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( | ||||
@@ -636,7 +645,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): | |||||
self.hs.get_application_service_handler()._notify_interested_services_ephemeral( | self.hs.get_application_service_handler()._notify_interested_services_ephemeral( | ||||
services=[interested_appservice], | services=[interested_appservice], | ||||
stream_key=StreamKeyType.RECEIPT, | stream_key=StreamKeyType.RECEIPT, | ||||
new_token=stream_token, | |||||
new_token=MultiWriterStreamToken(stream=stream_token), | |||||
users=[self.exclusive_as_user], | users=[self.exclusive_as_user], | ||||
) | ) | ||||
) | ) | ||||
@@ -0,0 +1,243 @@ | |||||
# Copyright 2020 The Matrix.org Foundation C.I.C. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
import logging | |||||
from twisted.test.proto_helpers import MemoryReactor | |||||
from synapse.api.constants import ReceiptTypes | |||||
from synapse.rest import admin | |||||
from synapse.rest.client import login, receipts, room, sync | |||||
from synapse.server import HomeServer | |||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator | |||||
from synapse.types import StreamToken | |||||
from synapse.util import Clock | |||||
from tests.replication._base import BaseMultiWorkerStreamTestCase | |||||
from tests.server import make_request | |||||
logger = logging.getLogger(__name__) | |||||
class ReceiptsShardTestCase(BaseMultiWorkerStreamTestCase): | |||||
"""Checks receipts sharding works""" | |||||
servlets = [ | |||||
admin.register_servlets_for_client_rest_resource, | |||||
room.register_servlets, | |||||
login.register_servlets, | |||||
sync.register_servlets, | |||||
receipts.register_servlets, | |||||
] | |||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||||
# Register a user who sends a message that we'll get notified about | |||||
self.other_user_id = self.register_user("otheruser", "pass") | |||||
self.other_access_token = self.login("otheruser", "pass") | |||||
self.room_creator = self.hs.get_room_creation_handler() | |||||
self.store = hs.get_datastores().main | |||||
def default_config(self) -> dict: | |||||
conf = super().default_config() | |||||
conf["stream_writers"] = {"receipts": ["worker1", "worker2"]} | |||||
conf["instance_map"] = { | |||||
"main": {"host": "testserv", "port": 8765}, | |||||
"worker1": {"host": "testserv", "port": 1001}, | |||||
"worker2": {"host": "testserv", "port": 1002}, | |||||
} | |||||
return conf | |||||
def test_basic(self) -> None: | |||||
"""Simple test to ensure that receipts can be sent on multiple | |||||
workers. | |||||
""" | |||||
worker1 = self.make_worker_hs( | |||||
"synapse.app.generic_worker", | |||||
{"worker_name": "worker1"}, | |||||
) | |||||
worker1_site = self._hs_to_site[worker1] | |||||
worker2 = self.make_worker_hs( | |||||
"synapse.app.generic_worker", | |||||
{"worker_name": "worker2"}, | |||||
) | |||||
worker2_site = self._hs_to_site[worker2] | |||||
user_id = self.register_user("user", "pass") | |||||
access_token = self.login("user", "pass") | |||||
# Create a room | |||||
room_id = self.helper.create_room_as(user_id, tok=access_token) | |||||
# The other user joins | |||||
self.helper.join( | |||||
room=room_id, user=self.other_user_id, tok=self.other_access_token | |||||
) | |||||
# First user sends a message, the other users sends a receipt. | |||||
response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) | |||||
event_id = response["event_id"] | |||||
channel = make_request( | |||||
reactor=self.reactor, | |||||
site=worker1_site, | |||||
method="POST", | |||||
path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}", | |||||
access_token=access_token, | |||||
content={}, | |||||
) | |||||
self.assertEqual(200, channel.code) | |||||
# Now we do it again using the second worker | |||||
response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) | |||||
event_id = response["event_id"] | |||||
channel = make_request( | |||||
reactor=self.reactor, | |||||
site=worker2_site, | |||||
method="POST", | |||||
path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}", | |||||
access_token=access_token, | |||||
content={}, | |||||
) | |||||
self.assertEqual(200, channel.code) | |||||
def test_vector_clock_token(self) -> None: | |||||
"""Tests that using a stream token with a vector clock component works | |||||
correctly with basic /sync usage. | |||||
""" | |||||
worker_hs1 = self.make_worker_hs( | |||||
"synapse.app.generic_worker", | |||||
{"worker_name": "worker1"}, | |||||
) | |||||
worker1_site = self._hs_to_site[worker_hs1] | |||||
worker_hs2 = self.make_worker_hs( | |||||
"synapse.app.generic_worker", | |||||
{"worker_name": "worker2"}, | |||||
) | |||||
worker2_site = self._hs_to_site[worker_hs2] | |||||
sync_hs = self.make_worker_hs( | |||||
"synapse.app.generic_worker", | |||||
{"worker_name": "sync"}, | |||||
) | |||||
sync_hs_site = self._hs_to_site[sync_hs] | |||||
user_id = self.register_user("user", "pass") | |||||
access_token = self.login("user", "pass") | |||||
store = self.hs.get_datastores().main | |||||
room_id = self.helper.create_room_as(user_id, tok=access_token) | |||||
# The other user joins | |||||
self.helper.join( | |||||
room=room_id, user=self.other_user_id, tok=self.other_access_token | |||||
) | |||||
response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) | |||||
first_event = response["event_id"] | |||||
# Do an initial sync so that we're up to date. | |||||
channel = make_request( | |||||
self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token | |||||
) | |||||
next_batch = channel.json_body["next_batch"] | |||||
# We now gut wrench into the events stream MultiWriterIdGenerator on | |||||
# worker2 to mimic it getting stuck persisting a receipt. This ensures | |||||
# that when we send an event on worker1 we end up in a state where | |||||
# worker2 events stream position lags that on worker1, resulting in a | |||||
# receipts token with a non-empty instance map component. | |||||
# | |||||
# Worker2's receipts stream position will not advance until we call | |||||
# __aexit__ again. | |||||
worker_store2 = worker_hs2.get_datastores().main | |||||
assert isinstance(worker_store2._receipts_id_gen, MultiWriterIdGenerator) | |||||
actx = worker_store2._receipts_id_gen.get_next() | |||||
self.get_success(actx.__aenter__()) | |||||
channel = make_request( | |||||
reactor=self.reactor, | |||||
site=worker1_site, | |||||
method="POST", | |||||
path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{first_event}", | |||||
access_token=access_token, | |||||
content={}, | |||||
) | |||||
self.assertEqual(200, channel.code) | |||||
# Assert that the current stream token has an instance map component, as | |||||
# we are trying to test vector clock tokens. | |||||
receipts_token = store.get_max_receipt_stream_id() | |||||
self.assertGreater(len(receipts_token.instance_map), 0) | |||||
# Check that syncing still gets the new receipt, despite the gap in the | |||||
# stream IDs. | |||||
channel = make_request( | |||||
self.reactor, | |||||
sync_hs_site, | |||||
"GET", | |||||
f"/sync?since={next_batch}", | |||||
access_token=access_token, | |||||
) | |||||
# We should only see the new event and nothing else | |||||
self.assertIn(room_id, channel.json_body["rooms"]["join"]) | |||||
events = channel.json_body["rooms"]["join"][room_id]["ephemeral"]["events"] | |||||
self.assertEqual(len(events), 1) | |||||
self.assertIn(first_event, events[0]["content"]) | |||||
# Get the next batch and makes sure its a vector clock style token. | |||||
vector_clock_token = channel.json_body["next_batch"] | |||||
parsed_token = self.get_success( | |||||
StreamToken.from_string(store, vector_clock_token) | |||||
) | |||||
self.assertGreaterEqual(len(parsed_token.receipt_key.instance_map), 1) | |||||
# Now that we've got a vector clock token we finish the fake persisting | |||||
# a receipt we started above. | |||||
self.get_success(actx.__aexit__(None, None, None)) | |||||
# Now try and send another receipts to the other worker. | |||||
response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token) | |||||
second_event = response["event_id"] | |||||
channel = make_request( | |||||
reactor=self.reactor, | |||||
site=worker2_site, | |||||
method="POST", | |||||
path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{second_event}", | |||||
access_token=access_token, | |||||
content={}, | |||||
) | |||||
channel = make_request( | |||||
self.reactor, | |||||
sync_hs_site, | |||||
"GET", | |||||
f"/sync?since={vector_clock_token}", | |||||
access_token=access_token, | |||||
) | |||||
self.assertIn(room_id, channel.json_body["rooms"]["join"]) | |||||
events = channel.json_body["rooms"]["join"][room_id]["ephemeral"]["events"] | |||||
self.assertEqual(len(events), 1) | |||||
self.assertIn(second_event, events[0]["content"]) |