Explorar el Código

Allow multiple workers to write to receipts stream. (#16432)

Fixes #16417
tags/v1.96.0rc1
Erik Johnston hace 6 meses
committed by GitHub
padre
commit
ba47fea528
No se encontró ninguna clave conocida en la base de datos para esta firma ID de clave GPG: 4AEE18F83AFDEB23
Se han modificado 15 ficheros con 604 adiciones y 89 borrados
  1. +1
    -0
      changelog.d/16432.feature
  2. +2
    -2
      synapse/config/workers.py
  3. +23
    -19
      synapse/handlers/appservice.py
  4. +1
    -1
      synapse/handlers/initial_sync.py
  5. +10
    -9
      synapse/handlers/receipts.py
  6. +6
    -1
      synapse/handlers/sync.py
  7. +43
    -2
      synapse/notifier.py
  8. +2
    -1
      synapse/replication/tcp/client.py
  9. +109
    -39
      synapse/storage/databases/main/receipts.py
  10. +2
    -2
      synapse/storage/databases/main/relations.py
  11. +17
    -0
      synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite
  12. +2
    -2
      synapse/streams/events.py
  13. +130
    -7
      synapse/types/__init__.py
  14. +13
    -4
      tests/handlers/test_appservice.py
  15. +243
    -0
      tests/replication/test_sharded_receipts.py

+ 1
- 0
changelog.d/16432.feature Ver fichero

@@ -0,0 +1 @@
Allow multiple workers to write to receipts stream.

+ 2
- 2
synapse/config/workers.py Ver fichero

@@ -358,9 +358,9 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `account_data` messages." "Must only specify one instance to handle `account_data` messages."
) )


if len(self.writers.receipts) != 1:
if len(self.writers.receipts) == 0:
raise ConfigError( raise ConfigError(
"Must only specify one instance to handle `receipts` messages."
"Must specify at least one instance to handle `receipts` messages."
) )


if len(self.writers.events) == 0: if len(self.writers.events) == 0:


+ 23
- 19
synapse/handlers/appservice.py Ver fichero

@@ -47,6 +47,7 @@ from synapse.types import (
DeviceListUpdates, DeviceListUpdates,
JsonDict, JsonDict,
JsonMapping, JsonMapping,
MultiWriterStreamToken,
RoomAlias, RoomAlias,
RoomStreamToken, RoomStreamToken,
StreamKeyType, StreamKeyType,
@@ -217,7 +218,7 @@ class ApplicationServicesHandler:
def notify_interested_services_ephemeral( def notify_interested_services_ephemeral(
self, self,
stream_key: StreamKeyType, stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken],
new_token: Union[int, RoomStreamToken, MultiWriterStreamToken],
users: Collection[Union[str, UserID]], users: Collection[Union[str, UserID]],
) -> None: ) -> None:
""" """
@@ -259,19 +260,6 @@ class ApplicationServicesHandler:
): ):
return return


# Assert that new_token is an integer (and not a RoomStreamToken).
# All of the supported streams that this function handles use an
# integer to track progress (rather than a RoomStreamToken - a
# vector clock implementation) as they don't support multiple
# stream writers.
#
# As a result, we simply assert that new_token is an integer.
# If we do end up needing to pass a RoomStreamToken down here
# in the future, using RoomStreamToken.stream (the minimum stream
# position) to convert to an ascending integer value should work.
# Additional context: https://github.com/matrix-org/synapse/pull/11137
assert isinstance(new_token, int)

# Ignore to-device messages if the feature flag is not enabled # Ignore to-device messages if the feature flag is not enabled
if ( if (
stream_key == StreamKeyType.TO_DEVICE stream_key == StreamKeyType.TO_DEVICE
@@ -286,6 +274,9 @@ class ApplicationServicesHandler:
): ):
return return


# We know we're not a `RoomStreamToken` at this point.
assert not isinstance(new_token, RoomStreamToken)

# Check whether there are any appservices which have registered to receive # Check whether there are any appservices which have registered to receive
# ephemeral events. # ephemeral events.
# #
@@ -327,7 +318,7 @@ class ApplicationServicesHandler:
self, self,
services: List[ApplicationService], services: List[ApplicationService],
stream_key: StreamKeyType, stream_key: StreamKeyType,
new_token: int,
new_token: Union[int, MultiWriterStreamToken],
users: Collection[Union[str, UserID]], users: Collection[Union[str, UserID]],
) -> None: ) -> None:
logger.debug("Checking interested services for %s", stream_key) logger.debug("Checking interested services for %s", stream_key)
@@ -340,6 +331,7 @@ class ApplicationServicesHandler:
# #
# Instead we simply grab the latest typing updates in _handle_typing # Instead we simply grab the latest typing updates in _handle_typing
# and, if they apply to this application service, send it off. # and, if they apply to this application service, send it off.
assert isinstance(new_token, int)
events = await self._handle_typing(service, new_token) events = await self._handle_typing(service, new_token)
if events: if events:
self.scheduler.enqueue_for_appservice(service, ephemeral=events) self.scheduler.enqueue_for_appservice(service, ephemeral=events)
@@ -350,15 +342,23 @@ class ApplicationServicesHandler:
(service.id, stream_key) (service.id, stream_key)
): ):
if stream_key == StreamKeyType.RECEIPT: if stream_key == StreamKeyType.RECEIPT:
assert isinstance(new_token, MultiWriterStreamToken)

# We store appservice tokens as integers, so we ignore
# the `instance_map` components and instead simply
# follow the base stream position.
new_token = MultiWriterStreamToken(stream=new_token.stream)

events = await self._handle_receipts(service, new_token) events = await self._handle_receipts(service, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events) self.scheduler.enqueue_for_appservice(service, ephemeral=events)


# Persist the latest handled stream token for this appservice # Persist the latest handled stream token for this appservice
await self.store.set_appservice_stream_type_pos( await self.store.set_appservice_stream_type_pos(
service, "read_receipt", new_token
service, "read_receipt", new_token.stream
) )


elif stream_key == StreamKeyType.PRESENCE: elif stream_key == StreamKeyType.PRESENCE:
assert isinstance(new_token, int)
events = await self._handle_presence(service, users, new_token) events = await self._handle_presence(service, users, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events) self.scheduler.enqueue_for_appservice(service, ephemeral=events)


@@ -368,6 +368,7 @@ class ApplicationServicesHandler:
) )


elif stream_key == StreamKeyType.TO_DEVICE: elif stream_key == StreamKeyType.TO_DEVICE:
assert isinstance(new_token, int)
# Retrieve a list of to-device message events, as well as the # Retrieve a list of to-device message events, as well as the
# maximum stream token of the messages we were able to retrieve. # maximum stream token of the messages we were able to retrieve.
to_device_messages = await self._get_to_device_messages( to_device_messages = await self._get_to_device_messages(
@@ -383,6 +384,7 @@ class ApplicationServicesHandler:
) )


elif stream_key == StreamKeyType.DEVICE_LIST: elif stream_key == StreamKeyType.DEVICE_LIST:
assert isinstance(new_token, int)
device_list_summary = await self._get_device_list_summary( device_list_summary = await self._get_device_list_summary(
service, new_token service, new_token
) )
@@ -432,7 +434,7 @@ class ApplicationServicesHandler:
return typing return typing


async def _handle_receipts( async def _handle_receipts(
self, service: ApplicationService, new_token: int
self, service: ApplicationService, new_token: MultiWriterStreamToken
) -> List[JsonMapping]: ) -> List[JsonMapping]:
""" """
Return the latest read receipts that the given application service should receive. Return the latest read receipts that the given application service should receive.
@@ -455,15 +457,17 @@ class ApplicationServicesHandler:
from_key = await self.store.get_type_stream_id_for_appservice( from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt" service, "read_receipt"
) )
if new_token is not None and new_token <= from_key:
if new_token is not None and new_token.stream <= from_key:
logger.debug( logger.debug(
"Rejecting token lower than or equal to stored: %s" % (new_token,) "Rejecting token lower than or equal to stored: %s" % (new_token,)
) )
return [] return []


from_token = MultiWriterStreamToken(stream=from_key)

receipts_source = self.event_sources.sources.receipt receipts_source = self.event_sources.sources.receipt
receipts, _ = await receipts_source.get_new_events_as( receipts, _ = await receipts_source.get_new_events_as(
service=service, from_key=from_key, to_key=new_token
service=service, from_key=from_token, to_key=new_token
) )
return receipts return receipts




+ 1
- 1
synapse/handlers/initial_sync.py Ver fichero

@@ -145,7 +145,7 @@ class InitialSyncHandler:
joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN] joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
receipt = await self.store.get_linearized_receipts_for_rooms( receipt = await self.store.get_linearized_receipts_for_rooms(
joined_rooms, joined_rooms,
to_key=int(now_token.receipt_key),
to_key=now_token.receipt_key,
) )


receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id) receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)


+ 10
- 9
synapse/handlers/receipts.py Ver fichero

@@ -20,6 +20,7 @@ from synapse.streams import EventSource
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
JsonMapping, JsonMapping,
MultiWriterStreamToken,
ReadReceipt, ReadReceipt,
StreamKeyType, StreamKeyType,
UserID, UserID,
@@ -200,7 +201,7 @@ class ReceiptsHandler:
await self.federation_sender.send_read_receipt(receipt) await self.federation_sender.send_read_receipt(receipt)




class ReceiptEventSource(EventSource[int, JsonMapping]):
class ReceiptEventSource(EventSource[MultiWriterStreamToken, JsonMapping]):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.config = hs.config self.config = hs.config
@@ -273,13 +274,12 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
async def get_new_events( async def get_new_events(
self, self,
user: UserID, user: UserID,
from_key: int,
from_key: MultiWriterStreamToken,
limit: int, limit: int,
room_ids: Iterable[str], room_ids: Iterable[str],
is_guest: bool, is_guest: bool,
explicit_room_id: Optional[str] = None, explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonMapping], int]:
from_key = int(from_key)
) -> Tuple[List[JsonMapping], MultiWriterStreamToken]:
to_key = self.get_current_key() to_key = self.get_current_key()


if from_key == to_key: if from_key == to_key:
@@ -296,8 +296,11 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
return events, to_key return events, to_key


async def get_new_events_as( async def get_new_events_as(
self, from_key: int, to_key: int, service: ApplicationService
) -> Tuple[List[JsonMapping], int]:
self,
from_key: MultiWriterStreamToken,
to_key: MultiWriterStreamToken,
service: ApplicationService,
) -> Tuple[List[JsonMapping], MultiWriterStreamToken]:
"""Returns a set of new read receipt events that an appservice """Returns a set of new read receipt events that an appservice
may be interested in. may be interested in.


@@ -312,8 +315,6 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
appservice may be interested in. appservice may be interested in.
* The current read receipt stream token. * The current read receipt stream token.
""" """
from_key = int(from_key)

if from_key == to_key: if from_key == to_key:
return [], to_key return [], to_key


@@ -333,5 +334,5 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):


return events, to_key return events, to_key


def get_current_key(self) -> int:
def get_current_key(self) -> MultiWriterStreamToken:
return self.store.get_max_receipt_stream_id() return self.store.get_max_receipt_stream_id()

+ 6
- 1
synapse/handlers/sync.py Ver fichero

@@ -57,6 +57,7 @@ from synapse.types import (
DeviceListUpdates, DeviceListUpdates,
JsonDict, JsonDict,
JsonMapping, JsonMapping,
MultiWriterStreamToken,
MutableStateMap, MutableStateMap,
Requester, Requester,
RoomStreamToken, RoomStreamToken,
@@ -477,7 +478,11 @@ class SyncHandler:
event_copy = {k: v for (k, v) in event.items() if k != "room_id"} event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy) ephemeral_by_room.setdefault(room_id, []).append(event_copy)


receipt_key = since_token.receipt_key if since_token else 0
receipt_key = (
since_token.receipt_key
if since_token
else MultiWriterStreamToken(stream=0)
)


receipt_source = self.event_sources.sources.receipt receipt_source = self.event_sources.sources.receipt
receipts, receipt_key = await receipt_source.get_new_events( receipts, receipt_key = await receipt_source.get_new_events(


+ 43
- 2
synapse/notifier.py Ver fichero

@@ -21,11 +21,13 @@ from typing import (
Dict, Dict,
Iterable, Iterable,
List, List,
Literal,
Optional, Optional,
Set, Set,
Tuple, Tuple,
TypeVar, TypeVar,
Union, Union,
overload,
) )


import attr import attr
@@ -44,6 +46,7 @@ from synapse.metrics import LaterGauge
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
MultiWriterStreamToken,
PersistedEventPosition, PersistedEventPosition,
RoomStreamToken, RoomStreamToken,
StrCollection, StrCollection,
@@ -127,7 +130,7 @@ class _NotifierUserStream:
def notify( def notify(
self, self,
stream_key: StreamKeyType, stream_key: StreamKeyType,
stream_id: Union[int, RoomStreamToken],
stream_id: Union[int, RoomStreamToken, MultiWriterStreamToken],
time_now_ms: int, time_now_ms: int,
) -> None: ) -> None:
"""Notify any listeners for this user of a new event from an """Notify any listeners for this user of a new event from an
@@ -452,10 +455,48 @@ class Notifier:
except Exception: except Exception:
logger.exception("Error pusher pool of event") logger.exception("Error pusher pool of event")


@overload
def on_new_event(
self,
stream_key: Literal[StreamKeyType.ROOM],
new_token: RoomStreamToken,
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None,
) -> None:
...

@overload
def on_new_event(
self,
stream_key: Literal[StreamKeyType.RECEIPT],
new_token: MultiWriterStreamToken,
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None,
) -> None:
...

@overload
def on_new_event(
self,
stream_key: Literal[
StreamKeyType.ACCOUNT_DATA,
StreamKeyType.DEVICE_LIST,
StreamKeyType.PRESENCE,
StreamKeyType.PUSH_RULES,
StreamKeyType.TO_DEVICE,
StreamKeyType.TYPING,
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
],
new_token: int,
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None,
) -> None:
...

def on_new_event( def on_new_event(
self, self,
stream_key: StreamKeyType, stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken],
new_token: Union[int, RoomStreamToken, MultiWriterStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None, users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None, rooms: Optional[StrCollection] = None,
) -> None: ) -> None:


+ 2
- 1
synapse/replication/tcp/client.py Ver fichero

@@ -126,8 +126,9 @@ class ReplicationDataHandler:
StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows] StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows]
) )
elif stream_name == ReceiptsStream.NAME: elif stream_name == ReceiptsStream.NAME:
new_token = self.store.get_max_receipt_stream_id()
self.notifier.on_new_event( self.notifier.on_new_event(
StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows]
StreamKeyType.RECEIPT, new_token, rooms=[row.room_id for row in rows]
) )
await self._pusher_pool.on_new_receipts({row.user_id for row in rows}) await self._pusher_pool.on_new_receipts({row.user_id for row in rows})
elif stream_name == ToDeviceStream.NAME: elif stream_name == ToDeviceStream.NAME:


+ 109
- 39
synapse/storage/databases/main/receipts.py Ver fichero

@@ -28,6 +28,8 @@ from typing import (
cast, cast,
) )


from immutabledict import immutabledict

from synapse.api.constants import EduTypes from synapse.api.constants import EduTypes
from synapse.replication.tcp.streams import ReceiptsStream from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -43,7 +45,12 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator, MultiWriterIdGenerator,
StreamIdGenerator, StreamIdGenerator,
) )
from synapse.types import JsonDict, JsonMapping
from synapse.types import (
JsonDict,
JsonMapping,
MultiWriterStreamToken,
PersistedPosition,
)
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -105,7 +112,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"receipts_linearized", "receipts_linearized",
entity_column="room_id", entity_column="room_id",
stream_column="stream_id", stream_column="stream_id",
max_value=max_receipts_stream_id,
max_value=max_receipts_stream_id.stream,
limit=10000, limit=10000,
) )
self._receipts_stream_cache = StreamChangeCache( self._receipts_stream_cache = StreamChangeCache(
@@ -114,9 +121,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
prefilled_cache=receipts_stream_prefill, prefilled_cache=receipts_stream_prefill,
) )


def get_max_receipt_stream_id(self) -> int:
def get_max_receipt_stream_id(self) -> MultiWriterStreamToken:
"""Get the current max stream ID for receipts stream""" """Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token()

min_pos = self._receipts_id_gen.get_current_token()

positions = {}
if isinstance(self._receipts_id_gen, MultiWriterIdGenerator):
# The `min_pos` is the minimum position that we know all instances
# have finished persisting to, so we only care about instances whose
# positions are ahead of that. (Instance positions can be behind the
# min position as there are times we can work out that the minimum
# position is ahead of the naive minimum across all current
# positions. See MultiWriterIdGenerator for details)
positions = {
i: p
for i, p in self._receipts_id_gen.get_positions().items()
if p > min_pos
}

return MultiWriterStreamToken(
stream=min_pos, instance_map=immutabledict(positions)
)

def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
return self._receipts_id_gen.get_current_token_for_writer(instance_name)


def get_last_unthreaded_receipt_for_user_txn( def get_last_unthreaded_receipt_for_user_txn(
self, self,
@@ -257,7 +286,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
} }


async def get_linearized_receipts_for_rooms( async def get_linearized_receipts_for_rooms(
self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
self,
room_ids: Iterable[str],
to_key: MultiWriterStreamToken,
from_key: Optional[MultiWriterStreamToken] = None,
) -> List[JsonMapping]: ) -> List[JsonMapping]:
"""Get receipts for multiple rooms for sending to clients. """Get receipts for multiple rooms for sending to clients.


@@ -276,7 +308,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# Only ask the database about rooms where there have been new # Only ask the database about rooms where there have been new
# receipts added since `from_key` # receipts added since `from_key`
room_ids = self._receipts_stream_cache.get_entities_changed( room_ids = self._receipts_stream_cache.get_entities_changed(
room_ids, from_key
room_ids, from_key.stream
) )


results = await self._get_linearized_receipts_for_rooms( results = await self._get_linearized_receipts_for_rooms(
@@ -286,7 +318,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return [ev for res in results.values() for ev in res] return [ev for res in results.values() for ev in res]


async def get_linearized_receipts_for_room( async def get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
self,
room_id: str,
to_key: MultiWriterStreamToken,
from_key: Optional[MultiWriterStreamToken] = None,
) -> Sequence[JsonMapping]: ) -> Sequence[JsonMapping]:
"""Get receipts for a single room for sending to clients. """Get receipts for a single room for sending to clients.


@@ -302,36 +337,49 @@ class ReceiptsWorkerStore(SQLBaseStore):
if from_key is not None: if from_key is not None:
# Check the cache first to see if any new receipts have been added # Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op. # since`from_key`. If not we can no-op.
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
if not self._receipts_stream_cache.has_entity_changed(
room_id, from_key.stream
):
return [] return []


return await self._get_linearized_receipts_for_room(room_id, to_key, from_key) return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)


@cached(tree=True) @cached(tree=True)
async def _get_linearized_receipts_for_room( async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
self,
room_id: str,
to_key: MultiWriterStreamToken,
from_key: Optional[MultiWriterStreamToken] = None,
) -> Sequence[JsonMapping]: ) -> Sequence[JsonMapping]:
"""See get_linearized_receipts_for_room""" """See get_linearized_receipts_for_room"""


def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]: def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
if from_key: if from_key:
sql = (
"SELECT receipt_type, user_id, event_id, data"
" FROM receipts_linearized WHERE"
" room_id = ? AND stream_id > ? AND stream_id <= ?"
)
sql = """
SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
FROM receipts_linearized
WHERE room_id = ? AND stream_id > ? AND stream_id <= ?
"""


txn.execute(sql, (room_id, from_key, to_key))
else:
sql = (
"SELECT receipt_type, user_id, event_id, data"
" FROM receipts_linearized WHERE"
" room_id = ? AND stream_id <= ?"
txn.execute(
sql, (room_id, from_key.stream, to_key.get_max_stream_pos())
) )
else:
sql = """
SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE
room_id = ? AND stream_id <= ?
"""


txn.execute(sql, (room_id, to_key))
txn.execute(sql, (room_id, to_key.get_max_stream_pos()))


return cast(List[Tuple[str, str, str, str]], txn.fetchall())
return [
(receipt_type, user_id, event_id, data)
for stream_id, instance_name, receipt_type, user_id, event_id, data in txn
if MultiWriterStreamToken.is_stream_position_in_range(
from_key, to_key, instance_name, stream_id
)
]


rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f) rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)


@@ -352,7 +400,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
num_args=3, num_args=3,
) )
async def _get_linearized_receipts_for_rooms( async def _get_linearized_receipts_for_rooms(
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
self,
room_ids: Collection[str],
to_key: MultiWriterStreamToken,
from_key: Optional[MultiWriterStreamToken] = None,
) -> Mapping[str, Sequence[JsonMapping]]: ) -> Mapping[str, Sequence[JsonMapping]]:
if not room_ids: if not room_ids:
return {} return {}
@@ -362,7 +413,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
) -> List[Tuple[str, str, str, str, Optional[str], str]]: ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
if from_key: if from_key:
sql = """ sql = """
SELECT room_id, receipt_type, user_id, event_id, thread_id, data
SELECT stream_id, instance_name, room_id, receipt_type,
user_id, event_id, thread_id, data
FROM receipts_linearized WHERE FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND stream_id > ? AND stream_id <= ? AND
""" """
@@ -370,10 +422,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
self.database_engine, "room_id", room_ids self.database_engine, "room_id", room_ids
) )


txn.execute(sql + clause, [from_key, to_key] + list(args))
txn.execute(
sql + clause,
[from_key.stream, to_key.get_max_stream_pos()] + list(args),
)
else: else:
sql = """ sql = """
SELECT room_id, receipt_type, user_id, event_id, thread_id, data
SELECT stream_id, instance_name, room_id, receipt_type,
user_id, event_id, thread_id, data
FROM receipts_linearized WHERE FROM receipts_linearized WHERE
stream_id <= ? AND stream_id <= ? AND
""" """
@@ -382,11 +438,15 @@ class ReceiptsWorkerStore(SQLBaseStore):
self.database_engine, "room_id", room_ids self.database_engine, "room_id", room_ids
) )


txn.execute(sql + clause, [to_key] + list(args))
txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args))


return cast(
List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall()
)
return [
(room_id, receipt_type, user_id, event_id, thread_id, data)
for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn
if MultiWriterStreamToken.is_stream_position_in_range(
from_key, to_key, instance_name, stream_id
)
]


txn_results = await self.db_pool.runInteraction( txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f "_get_linearized_receipts_for_rooms", f
@@ -420,7 +480,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
num_args=2, num_args=2,
) )
async def get_linearized_receipts_for_all_rooms( async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
self,
to_key: MultiWriterStreamToken,
from_key: Optional[MultiWriterStreamToken] = None,
) -> Mapping[str, JsonMapping]: ) -> Mapping[str, JsonMapping]:
"""Get receipts for all rooms between two stream_ids, up """Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts. to a limit of the latest 100 read receipts.
@@ -437,25 +499,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]: def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
if from_key: if from_key:
sql = """ sql = """
SELECT room_id, receipt_type, user_id, event_id, data
SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? stream_id > ? AND stream_id <= ?
ORDER BY stream_id DESC ORDER BY stream_id DESC
LIMIT 100 LIMIT 100
""" """
txn.execute(sql, [from_key, to_key])
txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()])
else: else:
sql = """ sql = """
SELECT room_id, receipt_type, user_id, event_id, data
SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE FROM receipts_linearized WHERE
stream_id <= ? stream_id <= ?
ORDER BY stream_id DESC ORDER BY stream_id DESC
LIMIT 100 LIMIT 100
""" """


txn.execute(sql, [to_key])
txn.execute(sql, [to_key.get_max_stream_pos()])


return cast(List[Tuple[str, str, str, str, str]], txn.fetchall())
return [
(room_id, receipt_type, user_id, event_id, data)
for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn
if MultiWriterStreamToken.is_stream_position_in_range(
from_key, to_key, instance_name, stream_id
)
]


txn_results = await self.db_pool.runInteraction( txn_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_all_rooms", f "get_linearized_receipts_for_all_rooms", f
@@ -545,10 +613,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ? WHERE ? < stream_id AND stream_id <= ?
AND instance_name = ?
ORDER BY stream_id ASC ORDER BY stream_id ASC
LIMIT ? LIMIT ?
""" """
txn.execute(sql, (last_id, current_id, limit))
txn.execute(sql, (last_id, current_id, instance_name, limit))


updates = cast( updates = cast(
List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
@@ -695,6 +764,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
keyvalues=keyvalues, keyvalues=keyvalues,
values={ values={
"stream_id": stream_id, "stream_id": stream_id,
"instance_name": self._instance_name,
"event_id": event_id, "event_id": event_id,
"event_stream_ordering": stream_ordering, "event_stream_ordering": stream_ordering,
"data": json_encoder.encode(data), "data": json_encoder.encode(data),
@@ -750,7 +820,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_ids: List[str], event_ids: List[str],
thread_id: Optional[str], thread_id: Optional[str],
data: dict, data: dict,
) -> Optional[int]:
) -> Optional[PersistedPosition]:
"""Insert a receipt, either from local client or remote server. """Insert a receipt, either from local client or remote server.


Automatically does conversion between linearized and graph Automatically does conversion between linearized and graph
@@ -812,7 +882,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
data, data,
) )


return stream_id
return PersistedPosition(self._instance_name, stream_id)


async def _insert_graph_receipt( async def _insert_graph_receipt(
self, self,


+ 2
- 2
synapse/storage/databases/main/relations.py Ver fichero

@@ -47,7 +47,7 @@ from synapse.storage.databases.main.stream import (
generate_pagination_where_clause, generate_pagination_where_clause,
) )
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, StreamKeyType, StreamToken
from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList


if TYPE_CHECKING: if TYPE_CHECKING:
@@ -314,7 +314,7 @@ class RelationsWorkerStore(SQLBaseStore):
room_key=next_key, room_key=next_key,
presence_key=0, presence_key=0,
typing_key=0, typing_key=0,
receipt_key=0,
receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0, account_data_key=0,
push_rules_key=0, push_rules_key=0,
to_device_key=0, to_device_key=0,


+ 17
- 0
synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite Ver fichero

@@ -0,0 +1,17 @@
/* Copyright 2023 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

-- This already exists on Postgres.
ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;

+ 2
- 2
synapse/streams/events.py Ver fichero

@@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.typing import TypingNotificationEventSource
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace
from synapse.streams import EventSource from synapse.streams import EventSource
from synapse.types import StreamKeyType, StreamToken
from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken


if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@@ -111,7 +111,7 @@ class EventSources:
room_key=await self.sources.room.get_current_key_for_room(room_id), room_key=await self.sources.room.get_current_key_for_room(room_id),
presence_key=0, presence_key=0,
typing_key=0, typing_key=0,
receipt_key=0,
receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0, account_data_key=0,
push_rules_key=0, push_rules_key=0,
to_device_key=0, to_device_key=0,


+ 130
- 7
synapse/types/__init__.py Ver fichero

@@ -695,6 +695,90 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
return "s%d" % (self.stream,) return "s%d" % (self.stream,)




@attr.s(frozen=True, slots=True, order=False)
class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
"""A basic stream token class for streams that supports multiple writers."""

@classmethod
async def parse(cls, store: "DataStore", string: str) -> "MultiWriterStreamToken":
try:
if string[0].isdigit():
return cls(stream=int(string))
if string[0] == "m":
parts = string[1:].split("~")
stream = int(parts[0])

instance_map = {}
for part in parts[1:]:
key, value = part.split(".")
instance_id = int(key)
pos = int(value)

instance_name = await store.get_name_from_instance_id(instance_id)
instance_map[instance_name] = pos

return cls(
stream=stream,
instance_map=immutabledict(instance_map),
)
except CancelledError:
raise
except Exception:
pass
raise SynapseError(400, "Invalid stream token %r" % (string,))

async def to_string(self, store: "DataStore") -> str:
if self.instance_map:
entries = []
for name, pos in self.instance_map.items():
if pos <= self.stream:
# Ignore instances who are below the minimum stream position
# (we might know they've advanced without seeing a recent
# write from them).
continue

instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")

encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
else:
return str(self.stream)

@staticmethod
def is_stream_position_in_range(
low: Optional["AbstractMultiWriterStreamToken"],
high: Optional["AbstractMultiWriterStreamToken"],
instance_name: Optional[str],
pos: int,
) -> bool:
"""Checks if a given persisted position is between the two given tokens.

If `instance_name` is None then the row was persisted before multi
writer support.
"""

if low:
if instance_name:
low_stream = low.instance_map.get(instance_name, low.stream)
else:
low_stream = low.stream

if pos <= low_stream:
return False

if high:
if instance_name:
high_stream = high.instance_map.get(instance_name, high.stream)
else:
high_stream = high.stream

if high_stream < pos:
return False

return True


class StreamKeyType(Enum): class StreamKeyType(Enum):
"""Known stream types. """Known stream types.


@@ -776,7 +860,9 @@ class StreamToken:
) )
presence_key: int presence_key: int
typing_key: int typing_key: int
receipt_key: int
receipt_key: MultiWriterStreamToken = attr.ib(
validator=attr.validators.instance_of(MultiWriterStreamToken)
)
account_data_key: int account_data_key: int
push_rules_key: int push_rules_key: int
to_device_key: int to_device_key: int
@@ -799,8 +885,31 @@ class StreamToken:
while len(keys) < len(attr.fields(cls)): while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key # i.e. old token from before receipt_key
keys.append("0") keys.append("0")

(
room_key,
presence_key,
typing_key,
receipt_key,
account_data_key,
push_rules_key,
to_device_key,
device_list_key,
groups_key,
un_partial_stated_rooms_key,
) = keys

return cls( return cls(
await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
room_key=await RoomStreamToken.parse(store, room_key),
presence_key=int(presence_key),
typing_key=int(typing_key),
receipt_key=await MultiWriterStreamToken.parse(store, receipt_key),
account_data_key=int(account_data_key),
push_rules_key=int(push_rules_key),
to_device_key=int(to_device_key),
device_list_key=int(device_list_key),
groups_key=int(groups_key),
un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
) )
except CancelledError: except CancelledError:
raise raise
@@ -813,7 +922,7 @@ class StreamToken:
await self.room_key.to_string(store), await self.room_key.to_string(store),
str(self.presence_key), str(self.presence_key),
str(self.typing_key), str(self.typing_key),
str(self.receipt_key),
await self.receipt_key.to_string(store),
str(self.account_data_key), str(self.account_data_key),
str(self.push_rules_key), str(self.push_rules_key),
str(self.to_device_key), str(self.to_device_key),
@@ -841,6 +950,11 @@ class StreamToken:
StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value)
) )
return new_token return new_token
elif key == StreamKeyType.RECEIPT:
new_token = self.copy_and_replace(
StreamKeyType.RECEIPT, self.receipt_key.copy_and_advance(new_value)
)
return new_token


new_token = self.copy_and_replace(key, new_value) new_token = self.copy_and_replace(key, new_value)
new_id = new_token.get_field(key) new_id = new_token.get_field(key)
@@ -858,6 +972,10 @@ class StreamToken:
def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken: def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken:
... ...


@overload
def get_field(self, key: Literal[StreamKeyType.RECEIPT]) -> MultiWriterStreamToken:
...

@overload @overload
def get_field( def get_field(
self, self,
@@ -866,7 +984,6 @@ class StreamToken:
StreamKeyType.DEVICE_LIST, StreamKeyType.DEVICE_LIST,
StreamKeyType.PRESENCE, StreamKeyType.PRESENCE,
StreamKeyType.PUSH_RULES, StreamKeyType.PUSH_RULES,
StreamKeyType.RECEIPT,
StreamKeyType.TO_DEVICE, StreamKeyType.TO_DEVICE,
StreamKeyType.TYPING, StreamKeyType.TYPING,
StreamKeyType.UN_PARTIAL_STATED_ROOMS, StreamKeyType.UN_PARTIAL_STATED_ROOMS,
@@ -875,15 +992,21 @@ class StreamToken:
... ...


@overload @overload
def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
def get_field(
self, key: StreamKeyType
) -> Union[int, RoomStreamToken, MultiWriterStreamToken]:
... ...


def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
def get_field(
self, key: StreamKeyType
) -> Union[int, RoomStreamToken, MultiWriterStreamToken]:
"""Returns the stream ID for the given key.""" """Returns the stream ID for the given key."""
return getattr(self, key.value) return getattr(self, key.value)




StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
StreamToken.START = StreamToken(
RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
)




@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)


+ 13
- 4
tests/handlers/test_appservice.py Ver fichero

@@ -31,7 +31,12 @@ from synapse.appservice import (
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, register, room, sendtodevice from synapse.rest.client import login, receipts, register, room, sendtodevice
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType
from synapse.types import (
JsonDict,
MultiWriterStreamToken,
RoomStreamToken,
StreamKeyType,
)
from synapse.util import Clock from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string


@@ -305,7 +310,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
) )


self.handler.notify_interested_services_ephemeral( self.handler.notify_interested_services_ephemeral(
StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
StreamKeyType.RECEIPT,
MultiWriterStreamToken(stream=580),
["@fakerecipient:example.com"],
) )
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, ephemeral=[event] interested_service, ephemeral=[event]
@@ -333,7 +340,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
) )


self.handler.notify_interested_services_ephemeral( self.handler.notify_interested_services_ephemeral(
StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
StreamKeyType.RECEIPT,
MultiWriterStreamToken(stream=580),
["@fakerecipient:example.com"],
) )
# This method will be called, but with an empty list of events # This method will be called, but with an empty list of events
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
@@ -636,7 +645,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.hs.get_application_service_handler()._notify_interested_services_ephemeral( self.hs.get_application_service_handler()._notify_interested_services_ephemeral(
services=[interested_appservice], services=[interested_appservice],
stream_key=StreamKeyType.RECEIPT, stream_key=StreamKeyType.RECEIPT,
new_token=stream_token,
new_token=MultiWriterStreamToken(stream=stream_token),
users=[self.exclusive_as_user], users=[self.exclusive_as_user],
) )
) )


+ 243
- 0
tests/replication/test_sharded_receipts.py Ver fichero

@@ -0,0 +1,243 @@
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import ReceiptTypes
from synapse.rest import admin
from synapse.rest.client import login, receipts, room, sync
from synapse.server import HomeServer
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import StreamToken
from synapse.util import Clock

from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request

logger = logging.getLogger(__name__)


class ReceiptsShardTestCase(BaseMultiWorkerStreamTestCase):
"""Checks receipts sharding works"""

servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
sync.register_servlets,
receipts.register_servlets,
]

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Register a user who sends a message that we'll get notified about
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")

self.room_creator = self.hs.get_room_creation_handler()
self.store = hs.get_datastores().main

def default_config(self) -> dict:
conf = super().default_config()
conf["stream_writers"] = {"receipts": ["worker1", "worker2"]}
conf["instance_map"] = {
"main": {"host": "testserv", "port": 8765},
"worker1": {"host": "testserv", "port": 1001},
"worker2": {"host": "testserv", "port": 1002},
}
return conf

def test_basic(self) -> None:
"""Simple test to ensure that receipts can be sent on multiple
workers.
"""

worker1 = self.make_worker_hs(
"synapse.app.generic_worker",
{"worker_name": "worker1"},
)
worker1_site = self._hs_to_site[worker1]

worker2 = self.make_worker_hs(
"synapse.app.generic_worker",
{"worker_name": "worker2"},
)
worker2_site = self._hs_to_site[worker2]

user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")

# Create a room
room_id = self.helper.create_room_as(user_id, tok=access_token)

# The other user joins
self.helper.join(
room=room_id, user=self.other_user_id, tok=self.other_access_token
)

# First user sends a message, the other users sends a receipt.
response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token)
event_id = response["event_id"]

channel = make_request(
reactor=self.reactor,
site=worker1_site,
method="POST",
path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}",
access_token=access_token,
content={},
)
self.assertEqual(200, channel.code)

# Now we do it again using the second worker
response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token)
event_id = response["event_id"]

channel = make_request(
reactor=self.reactor,
site=worker2_site,
method="POST",
path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}",
access_token=access_token,
content={},
)
self.assertEqual(200, channel.code)

def test_vector_clock_token(self) -> None:
"""Tests that using a stream token with a vector clock component works
correctly with basic /sync usage.
"""

worker_hs1 = self.make_worker_hs(
"synapse.app.generic_worker",
{"worker_name": "worker1"},
)
worker1_site = self._hs_to_site[worker_hs1]

worker_hs2 = self.make_worker_hs(
"synapse.app.generic_worker",
{"worker_name": "worker2"},
)
worker2_site = self._hs_to_site[worker_hs2]

sync_hs = self.make_worker_hs(
"synapse.app.generic_worker",
{"worker_name": "sync"},
)
sync_hs_site = self._hs_to_site[sync_hs]

user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")

store = self.hs.get_datastores().main

room_id = self.helper.create_room_as(user_id, tok=access_token)

# The other user joins
self.helper.join(
room=room_id, user=self.other_user_id, tok=self.other_access_token
)

response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token)
first_event = response["event_id"]

# Do an initial sync so that we're up to date.
channel = make_request(
self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
)
next_batch = channel.json_body["next_batch"]

# We now gut wrench into the events stream MultiWriterIdGenerator on
# worker2 to mimic it getting stuck persisting a receipt. This ensures
# that when we send an event on worker1 we end up in a state where
# worker2 events stream position lags that on worker1, resulting in a
# receipts token with a non-empty instance map component.
#
# Worker2's receipts stream position will not advance until we call
# __aexit__ again.
worker_store2 = worker_hs2.get_datastores().main
assert isinstance(worker_store2._receipts_id_gen, MultiWriterIdGenerator)

actx = worker_store2._receipts_id_gen.get_next()
self.get_success(actx.__aenter__())

channel = make_request(
reactor=self.reactor,
site=worker1_site,
method="POST",
path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{first_event}",
access_token=access_token,
content={},
)
self.assertEqual(200, channel.code)

# Assert that the current stream token has an instance map component, as
# we are trying to test vector clock tokens.
receipts_token = store.get_max_receipt_stream_id()
self.assertGreater(len(receipts_token.instance_map), 0)

# Check that syncing still gets the new receipt, despite the gap in the
# stream IDs.
channel = make_request(
self.reactor,
sync_hs_site,
"GET",
f"/sync?since={next_batch}",
access_token=access_token,
)

# We should only see the new event and nothing else
self.assertIn(room_id, channel.json_body["rooms"]["join"])

events = channel.json_body["rooms"]["join"][room_id]["ephemeral"]["events"]
self.assertEqual(len(events), 1)
self.assertIn(first_event, events[0]["content"])

# Get the next batch and makes sure its a vector clock style token.
vector_clock_token = channel.json_body["next_batch"]
parsed_token = self.get_success(
StreamToken.from_string(store, vector_clock_token)
)
self.assertGreaterEqual(len(parsed_token.receipt_key.instance_map), 1)

# Now that we've got a vector clock token we finish the fake persisting
# a receipt we started above.
self.get_success(actx.__aexit__(None, None, None))

# Now try and send another receipts to the other worker.
response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token)
second_event = response["event_id"]

channel = make_request(
reactor=self.reactor,
site=worker2_site,
method="POST",
path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{second_event}",
access_token=access_token,
content={},
)

channel = make_request(
self.reactor,
sync_hs_site,
"GET",
f"/sync?since={vector_clock_token}",
access_token=access_token,
)

self.assertIn(room_id, channel.json_body["rooms"]["join"])

events = channel.json_body["rooms"]["join"][room_id]["ephemeral"]["events"]
self.assertEqual(len(events), 1)
self.assertIn(second_event, events[0]["content"])

Cargando…
Cancelar
Guardar