@@ -0,0 +1 @@ | |||
Change `StreamToken.room_key` to be a `RoomStreamToken` instance. |
@@ -46,10 +46,12 @@ files = | |||
synapse/server_notices, | |||
synapse/spam_checker_api, | |||
synapse/state, | |||
synapse/storage/databases/main/events.py, | |||
synapse/storage/databases/main/stream.py, | |||
synapse/storage/databases/main/ui_auth.py, | |||
synapse/storage/database.py, | |||
synapse/storage/engines, | |||
synapse/storage/persist_events.py, | |||
synapse/storage/state.py, | |||
synapse/storage/util, | |||
synapse/streams, | |||
@@ -125,8 +125,8 @@ class AdminHandler(BaseHandler): | |||
else: | |||
stream_ordering = room.stream_ordering | |||
from_key = str(RoomStreamToken(0, 0)) | |||
to_key = str(RoomStreamToken(None, stream_ordering)) | |||
from_key = RoomStreamToken(0, 0) | |||
to_key = RoomStreamToken(None, stream_ordering) | |||
written_events = set() # Events that we've processed in this room | |||
@@ -153,7 +153,7 @@ class AdminHandler(BaseHandler): | |||
if not events: | |||
break | |||
from_key = events[-1].internal_metadata.after | |||
from_key = RoomStreamToken.parse(events[-1].internal_metadata.after) | |||
events = await filter_events_for_client(self.storage, user_id, events) | |||
@@ -29,6 +29,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.types import ( | |||
RoomStreamToken, | |||
StreamToken, | |||
get_domain_from_id, | |||
get_verify_key_from_cross_signing_key, | |||
) | |||
@@ -104,18 +105,15 @@ class DeviceWorkerHandler(BaseHandler): | |||
@trace | |||
@measure_func("device.get_user_ids_changed") | |||
async def get_user_ids_changed(self, user_id, from_token): | |||
async def get_user_ids_changed(self, user_id: str, from_token: StreamToken): | |||
"""Get list of users that have had the devices updated, or have newly | |||
joined a room, that `user_id` may be interested in. | |||
Args: | |||
user_id (str) | |||
from_token (StreamToken) | |||
""" | |||
set_tag("user_id", user_id) | |||
set_tag("from_token", from_token) | |||
now_room_key = await self.store.get_room_events_max_id() | |||
now_room_id = self.store.get_room_max_stream_ordering() | |||
now_room_key = RoomStreamToken(None, now_room_id) | |||
room_ids = await self.store.get_rooms_for_user(user_id) | |||
@@ -142,7 +140,7 @@ class DeviceWorkerHandler(BaseHandler): | |||
) | |||
rooms_changed.update(event.room_id for event in member_events) | |||
stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream | |||
stream_ordering = from_token.room_key.stream | |||
possibly_changed = set(changed) | |||
possibly_left = set() | |||
@@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state | |||
from synapse.logging.context import make_deferred_yieldable, run_in_background | |||
from synapse.storage.roommember import RoomsForUser | |||
from synapse.streams.config import PaginationConfig | |||
from synapse.types import JsonDict, Requester, StreamToken, UserID | |||
from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID | |||
from synapse.util import unwrapFirstError | |||
from synapse.util.async_helpers import concurrently_execute | |||
from synapse.util.caches.response_cache import ResponseCache | |||
@@ -167,7 +167,7 @@ class InitialSyncHandler(BaseHandler): | |||
self.state_handler.get_current_state, event.room_id | |||
) | |||
elif event.membership == Membership.LEAVE: | |||
room_end_token = "s%d" % (event.stream_ordering,) | |||
room_end_token = RoomStreamToken(None, event.stream_ordering,) | |||
deferred_room_state = run_in_background( | |||
self.state_store.get_state_for_events, [event.event_id] | |||
) | |||
@@ -973,6 +973,7 @@ class EventCreationHandler: | |||
This should only be run on the instance in charge of persisting events. | |||
""" | |||
assert self._is_event_writer | |||
assert self.storage.persistence is not None | |||
if ratelimit: | |||
# We check if this is a room admin redacting an event so that we | |||
@@ -344,7 +344,7 @@ class PaginationHandler: | |||
# gets called. | |||
raise Exception("limit not set") | |||
room_token = RoomStreamToken.parse(from_token.room_key) | |||
room_token = from_token.room_key | |||
with await self.pagination_lock.read(room_id): | |||
( | |||
@@ -381,7 +381,7 @@ class PaginationHandler: | |||
if leave_token.topological < max_topo: | |||
from_token = from_token.copy_and_replace( | |||
"room_key", leave_token_str | |||
"room_key", leave_token | |||
) | |||
await self.hs.get_handlers().federation_handler.maybe_backfill( | |||
@@ -1091,20 +1091,19 @@ class RoomEventSource: | |||
async def get_new_events( | |||
self, | |||
user: UserID, | |||
from_key: str, | |||
from_key: RoomStreamToken, | |||
limit: int, | |||
room_ids: List[str], | |||
is_guest: bool, | |||
explicit_room_id: Optional[str] = None, | |||
) -> Tuple[List[EventBase], str]: | |||
) -> Tuple[List[EventBase], RoomStreamToken]: | |||
# We just ignore the key for now. | |||
to_key = self.get_current_key() | |||
from_token = RoomStreamToken.parse(from_key) | |||
if from_token.topological: | |||
if from_key.topological: | |||
logger.warning("Stream has topological part!!!! %r", from_key) | |||
from_key = "s%s" % (from_token.stream,) | |||
from_key = RoomStreamToken(None, from_key.stream) | |||
app_service = self.store.get_app_service_by_user_id(user.to_string()) | |||
if app_service: | |||
@@ -1133,14 +1132,14 @@ class RoomEventSource: | |||
events[:] = events[:limit] | |||
if events: | |||
end_key = events[-1].internal_metadata.after | |||
end_key = RoomStreamToken.parse(events[-1].internal_metadata.after) | |||
else: | |||
end_key = to_key | |||
return (events, end_key) | |||
def get_current_key(self) -> str: | |||
return "s%d" % (self.store.get_room_max_stream_ordering(),) | |||
def get_current_key(self) -> RoomStreamToken: | |||
return RoomStreamToken(None, self.store.get_room_max_stream_ordering()) | |||
def get_current_key_for_room(self, room_id: str) -> Awaitable[str]: | |||
return self.store.get_room_events_max_id(room_id) | |||
@@ -378,7 +378,7 @@ class SyncHandler: | |||
sync_config = sync_result_builder.sync_config | |||
with Measure(self.clock, "ephemeral_by_room"): | |||
typing_key = since_token.typing_key if since_token else "0" | |||
typing_key = since_token.typing_key if since_token else 0 | |||
room_ids = sync_result_builder.joined_room_ids | |||
@@ -402,7 +402,7 @@ class SyncHandler: | |||
event_copy = {k: v for (k, v) in event.items() if k != "room_id"} | |||
ephemeral_by_room.setdefault(room_id, []).append(event_copy) | |||
receipt_key = since_token.receipt_key if since_token else "0" | |||
receipt_key = since_token.receipt_key if since_token else 0 | |||
receipt_source = self.event_sources.sources["receipt"] | |||
receipts, receipt_key = await receipt_source.get_new_events( | |||
@@ -533,7 +533,7 @@ class SyncHandler: | |||
if len(recents) > timeline_limit: | |||
limited = True | |||
recents = recents[-timeline_limit:] | |||
room_key = recents[0].internal_metadata.before | |||
room_key = RoomStreamToken.parse(recents[0].internal_metadata.before) | |||
prev_batch_token = now_token.copy_and_replace("room_key", room_key) | |||
@@ -1322,6 +1322,7 @@ class SyncHandler: | |||
is_guest=sync_config.is_guest, | |||
include_offline=include_offline, | |||
) | |||
assert presence_key | |||
sync_result_builder.now_token = now_token.copy_and_replace( | |||
"presence_key", presence_key | |||
) | |||
@@ -1484,7 +1485,7 @@ class SyncHandler: | |||
if rooms_changed: | |||
return True | |||
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream | |||
stream_id = since_token.room_key.stream | |||
for room_id in sync_result_builder.joined_room_ids: | |||
if self.store.has_room_changed_since(room_id, stream_id): | |||
return True | |||
@@ -1750,7 +1751,7 @@ class SyncHandler: | |||
continue | |||
leave_token = now_token.copy_and_replace( | |||
"room_key", "s%d" % (event.stream_ordering,) | |||
"room_key", RoomStreamToken(None, event.stream_ordering) | |||
) | |||
room_entries.append( | |||
RoomSyncResultBuilder( | |||
@@ -25,6 +25,7 @@ from typing import ( | |||
Set, | |||
Tuple, | |||
TypeVar, | |||
Union, | |||
) | |||
from prometheus_client import Counter | |||
@@ -41,7 +42,7 @@ from synapse.logging.utils import log_function | |||
from synapse.metrics import LaterGauge | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.streams.config import PaginationConfig | |||
from synapse.types import Collection, StreamToken, UserID | |||
from synapse.types import Collection, RoomStreamToken, StreamToken, UserID | |||
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred | |||
from synapse.util.metrics import Measure | |||
from synapse.visibility import filter_events_for_client | |||
@@ -111,7 +112,9 @@ class _NotifierUserStream: | |||
with PreserveLoggingContext(): | |||
self.notify_deferred = ObservableDeferred(defer.Deferred()) | |||
def notify(self, stream_key: str, stream_id: int, time_now_ms: int): | |||
def notify( | |||
self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int, | |||
): | |||
"""Notify any listeners for this user of a new event from an | |||
event source. | |||
Args: | |||
@@ -294,7 +297,12 @@ class Notifier: | |||
rooms.add(event.room_id) | |||
if users or rooms: | |||
self.on_new_event("room_key", max_room_stream_id, users=users, rooms=rooms) | |||
self.on_new_event( | |||
"room_key", | |||
RoomStreamToken(None, max_room_stream_id), | |||
users=users, | |||
rooms=rooms, | |||
) | |||
self._on_updated_room_token(max_room_stream_id) | |||
def _on_updated_room_token(self, max_room_stream_id: int): | |||
@@ -329,7 +337,7 @@ class Notifier: | |||
def on_new_event( | |||
self, | |||
stream_key: str, | |||
new_token: int, | |||
new_token: Union[int, RoomStreamToken], | |||
users: Collection[UserID] = [], | |||
rooms: Collection[str] = [], | |||
): | |||
@@ -47,6 +47,9 @@ class Storage: | |||
# interfaces. | |||
self.main = stores.main | |||
self.persistence = EventsPersistenceStorage(hs, stores) | |||
self.purge_events = PurgeEventsStorage(hs, stores) | |||
self.state = StateGroupStorage(hs, stores) | |||
self.persistence = None | |||
if stores.persist_events: | |||
self.persistence = EventsPersistenceStorage(hs, stores) |
@@ -213,7 +213,7 @@ class PersistEventsStore: | |||
Returns: | |||
Filtered event ids | |||
""" | |||
results = [] | |||
results = [] # type: List[str] | |||
def _get_events_which_are_prevs_txn(txn, batch): | |||
sql = """ | |||
@@ -631,7 +631,9 @@ class PersistEventsStore: | |||
) | |||
@classmethod | |||
def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): | |||
def _filter_events_and_contexts_for_duplicates( | |||
cls, events_and_contexts: List[Tuple[EventBase, EventContext]] | |||
) -> List[Tuple[EventBase, EventContext]]: | |||
"""Ensure that we don't have the same event twice. | |||
Pick the earliest non-outlier if there is one, else the earliest one. | |||
@@ -641,7 +643,9 @@ class PersistEventsStore: | |||
Returns: | |||
list[(EventBase, EventContext)]: filtered list | |||
""" | |||
new_events_and_contexts = OrderedDict() | |||
new_events_and_contexts = ( | |||
OrderedDict() | |||
) # type: OrderedDict[str, Tuple[EventBase, EventContext]] | |||
for event, context in events_and_contexts: | |||
prev_event_context = new_events_and_contexts.get(event.event_id) | |||
if prev_event_context: | |||
@@ -655,7 +659,12 @@ class PersistEventsStore: | |||
new_events_and_contexts[event.event_id] = (event, context) | |||
return list(new_events_and_contexts.values()) | |||
def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): | |||
def _update_room_depths_txn( | |||
self, | |||
txn, | |||
events_and_contexts: List[Tuple[EventBase, EventContext]], | |||
backfilled: bool, | |||
): | |||
"""Update min_depth for each room | |||
Args: | |||
@@ -664,7 +673,7 @@ class PersistEventsStore: | |||
we are persisting | |||
backfilled (bool): True if the events were backfilled | |||
""" | |||
depth_updates = {} | |||
depth_updates = {} # type: Dict[str, int] | |||
for event, context in events_and_contexts: | |||
# Remove the any existing cache entries for the event_ids | |||
txn.call_after(self.store._invalidate_get_event_cache, event.event_id) | |||
@@ -1436,7 +1445,7 @@ class PersistEventsStore: | |||
Forward extremities are handled when we first start persisting the events. | |||
""" | |||
events_by_room = {} | |||
events_by_room = {} # type: Dict[str, List[EventBase]] | |||
for ev in events: | |||
events_by_room.setdefault(ev.room_id, []).append(ev) | |||
@@ -310,11 +310,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
async def get_room_events_stream_for_rooms( | |||
self, | |||
room_ids: Collection[str], | |||
from_key: str, | |||
to_key: str, | |||
from_key: RoomStreamToken, | |||
to_key: RoomStreamToken, | |||
limit: int = 0, | |||
order: str = "DESC", | |||
) -> Dict[str, Tuple[List[EventBase], str]]: | |||
) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]: | |||
"""Get new room events in stream ordering since `from_key`. | |||
Args: | |||
@@ -333,9 +333,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
- list of recent events in the room | |||
- stream ordering key for the start of the chunk of events returned. | |||
""" | |||
from_id = RoomStreamToken.parse_stream_token(from_key).stream | |||
room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id) | |||
room_ids = self._events_stream_cache.get_entities_changed( | |||
room_ids, from_key.stream | |||
) | |||
if not room_ids: | |||
return {} | |||
@@ -364,16 +364,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
return results | |||
def get_rooms_that_changed( | |||
self, room_ids: Collection[str], from_key: str | |||
self, room_ids: Collection[str], from_key: RoomStreamToken | |||
) -> Set[str]: | |||
"""Given a list of rooms and a token, return rooms where there may have | |||
been changes. | |||
Args: | |||
room_ids | |||
from_key: The room_key portion of a StreamToken | |||
""" | |||
from_id = RoomStreamToken.parse_stream_token(from_key).stream | |||
from_id = from_key.stream | |||
return { | |||
room_id | |||
for room_id in room_ids | |||
@@ -383,11 +379,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
async def get_room_events_stream_for_room( | |||
self, | |||
room_id: str, | |||
from_key: str, | |||
to_key: str, | |||
from_key: RoomStreamToken, | |||
to_key: RoomStreamToken, | |||
limit: int = 0, | |||
order: str = "DESC", | |||
) -> Tuple[List[EventBase], str]: | |||
) -> Tuple[List[EventBase], RoomStreamToken]: | |||
"""Get new room events in stream ordering since `from_key`. | |||
Args: | |||
@@ -408,8 +404,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
if from_key == to_key: | |||
return [], from_key | |||
from_id = RoomStreamToken.parse_stream_token(from_key).stream | |||
to_id = RoomStreamToken.parse_stream_token(to_key).stream | |||
from_id = from_key.stream | |||
to_id = to_key.stream | |||
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) | |||
@@ -441,7 +437,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
ret.reverse() | |||
if rows: | |||
key = "s%d" % min(r.stream_ordering for r in rows) | |||
key = RoomStreamToken(None, min(r.stream_ordering for r in rows)) | |||
else: | |||
# Assume we didn't get anything because there was nothing to | |||
# get. | |||
@@ -450,10 +446,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
return ret, key | |||
async def get_membership_changes_for_user( | |||
self, user_id: str, from_key: str, to_key: str | |||
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken | |||
) -> List[EventBase]: | |||
from_id = RoomStreamToken.parse_stream_token(from_key).stream | |||
to_id = RoomStreamToken.parse_stream_token(to_key).stream | |||
from_id = from_key.stream | |||
to_id = to_key.stream | |||
if from_key == to_key: | |||
return [] | |||
@@ -491,8 +487,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
return ret | |||
async def get_recent_events_for_room( | |||
self, room_id: str, limit: int, end_token: str | |||
) -> Tuple[List[EventBase], str]: | |||
self, room_id: str, limit: int, end_token: RoomStreamToken | |||
) -> Tuple[List[EventBase], RoomStreamToken]: | |||
"""Get the most recent events in the room in topological ordering. | |||
Args: | |||
@@ -518,8 +514,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
return (events, token) | |||
async def get_recent_event_ids_for_room( | |||
self, room_id: str, limit: int, end_token: str | |||
) -> Tuple[List[_EventDictReturn], str]: | |||
self, room_id: str, limit: int, end_token: RoomStreamToken | |||
) -> Tuple[List[_EventDictReturn], RoomStreamToken]: | |||
"""Get the most recent events in the room in topological ordering. | |||
Args: | |||
@@ -535,13 +531,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
if limit == 0: | |||
return [], end_token | |||
parsed_end_token = RoomStreamToken.parse(end_token) | |||
rows, token = await self.db_pool.runInteraction( | |||
"get_recent_event_ids_for_room", | |||
self._paginate_room_events_txn, | |||
room_id, | |||
from_token=parsed_end_token, | |||
from_token=end_token, | |||
limit=limit, | |||
) | |||
@@ -619,17 +613,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
allow_none=allow_none, | |||
) | |||
async def get_stream_token_for_event(self, event_id: str) -> str: | |||
async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken: | |||
"""The stream token for an event | |||
Args: | |||
event_id: The id of the event to look up a stream token for. | |||
Raises: | |||
StoreError if the event wasn't in the database. | |||
Returns: | |||
A "s%d" stream token. | |||
A stream token. | |||
""" | |||
stream_id = await self.get_stream_id_for_event(event_id) | |||
return "s%d" % (stream_id,) | |||
return RoomStreamToken(None, stream_id) | |||
async def get_topological_token_for_event(self, event_id: str) -> str: | |||
"""The stream token for an event | |||
@@ -954,7 +948,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
direction: str = "b", | |||
limit: int = -1, | |||
event_filter: Optional[Filter] = None, | |||
) -> Tuple[List[_EventDictReturn], str]: | |||
) -> Tuple[List[_EventDictReturn], RoomStreamToken]: | |||
"""Returns list of events before or after a given token. | |||
Args: | |||
@@ -1054,17 +1048,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
# TODO (erikj): We should work out what to do here instead. | |||
next_token = to_token if to_token else from_token | |||
return rows, str(next_token) | |||
return rows, next_token | |||
async def paginate_room_events( | |||
self, | |||
room_id: str, | |||
from_key: str, | |||
to_key: Optional[str] = None, | |||
from_key: RoomStreamToken, | |||
to_key: Optional[RoomStreamToken] = None, | |||
direction: str = "b", | |||
limit: int = -1, | |||
event_filter: Optional[Filter] = None, | |||
) -> Tuple[List[EventBase], str]: | |||
) -> Tuple[List[EventBase], RoomStreamToken]: | |||
"""Returns list of events before or after a given token. | |||
Args: | |||
@@ -1083,17 +1077,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
and `to_key`). | |||
""" | |||
parsed_from_key = RoomStreamToken.parse(from_key) | |||
parsed_to_key = None | |||
if to_key: | |||
parsed_to_key = RoomStreamToken.parse(to_key) | |||
rows, token = await self.db_pool.runInteraction( | |||
"paginate_room_events", | |||
self._paginate_room_events_txn, | |||
room_id, | |||
parsed_from_key, | |||
parsed_to_key, | |||
from_key, | |||
to_key, | |||
direction, | |||
limit, | |||
event_filter, | |||
@@ -18,7 +18,7 @@ | |||
import itertools | |||
import logging | |||
from collections import deque, namedtuple | |||
from typing import Iterable, List, Optional, Set, Tuple | |||
from typing import Dict, Iterable, List, Optional, Set, Tuple | |||
from prometheus_client import Counter, Histogram | |||
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.storage.databases import Databases | |||
from synapse.storage.databases.main.events import DeltaState | |||
from synapse.types import StateMap | |||
from synapse.types import Collection, StateMap | |||
from synapse.util.async_helpers import ObservableDeferred | |||
from synapse.util.metrics import Measure | |||
@@ -185,6 +185,8 @@ class EventsPersistenceStorage: | |||
# store for now. | |||
self.main_store = stores.main | |||
self.state_store = stores.state | |||
assert stores.persist_events | |||
self.persist_events_store = stores.persist_events | |||
self._clock = hs.get_clock() | |||
@@ -208,7 +210,7 @@ class EventsPersistenceStorage: | |||
Returns: | |||
the stream ordering of the latest persisted event | |||
""" | |||
partitioned = {} | |||
partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]] | |||
for event, ctx in events_and_contexts: | |||
partitioned.setdefault(event.room_id, []).append((event, ctx)) | |||
@@ -305,7 +307,9 @@ class EventsPersistenceStorage: | |||
# Work out the new "current state" for each room. | |||
# We do this by working out what the new extremities are and then | |||
# calculating the state from that. | |||
events_by_room = {} | |||
events_by_room = ( | |||
{} | |||
) # type: Dict[str, List[Tuple[EventBase, EventContext]]] | |||
for event, context in chunk: | |||
events_by_room.setdefault(event.room_id, []).append( | |||
(event, context) | |||
@@ -436,7 +440,7 @@ class EventsPersistenceStorage: | |||
self, | |||
room_id: str, | |||
event_contexts: List[Tuple[EventBase, EventContext]], | |||
latest_event_ids: List[str], | |||
latest_event_ids: Collection[str], | |||
): | |||
"""Calculates the new forward extremities for a room given events to | |||
persist. | |||
@@ -470,7 +474,7 @@ class EventsPersistenceStorage: | |||
# Remove any events which are prev_events of any existing events. | |||
existing_prevs = await self.persist_events_store._get_events_which_are_prevs( | |||
result | |||
) | |||
) # type: Collection[str] | |||
result.difference_update(existing_prevs) | |||
# Finally handle the case where the new events have soft-failed prev | |||
@@ -425,7 +425,9 @@ class RoomStreamToken: | |||
@attr.s(slots=True, frozen=True) | |||
class StreamToken: | |||
room_key = attr.ib(type=str) | |||
room_key = attr.ib( | |||
type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken) | |||
) | |||
presence_key = attr.ib(type=int) | |||
typing_key = attr.ib(type=int) | |||
receipt_key = attr.ib(type=int) | |||
@@ -445,21 +447,16 @@ class StreamToken: | |||
while len(keys) < len(attr.fields(cls)): | |||
# i.e. old token from before receipt_key | |||
keys.append("0") | |||
return cls(keys[0], *(int(k) for k in keys[1:])) | |||
return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:])) | |||
except Exception: | |||
raise SynapseError(400, "Invalid Token") | |||
def to_string(self): | |||
return self._SEPARATOR.join([str(k) for k in attr.astuple(self)]) | |||
return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)]) | |||
@property | |||
def room_stream_id(self): | |||
# TODO(markjh): Awful hack to work around hacks in the presence tests | |||
# which assume that the keys are integers. | |||
if type(self.room_key) is int: | |||
return self.room_key | |||
else: | |||
return int(self.room_key[1:].split("-")[-1]) | |||
return self.room_key.stream | |||
def is_after(self, other): | |||
"""Does this token contain events that the other doesn't?""" | |||
@@ -475,7 +472,7 @@ class StreamToken: | |||
or (int(other.groups_key) < int(self.groups_key)) | |||
) | |||
def copy_and_advance(self, key, new_value): | |||
def copy_and_advance(self, key, new_value) -> "StreamToken": | |||
"""Advance the given key in the token to a new value if and only if the | |||
new value is after the old value. | |||
""" | |||
@@ -491,7 +488,7 @@ class StreamToken: | |||
else: | |||
return self | |||
def copy_and_replace(self, key, new_value): | |||
def copy_and_replace(self, key, new_value) -> "StreamToken": | |||
return attr.evolve(self, **{key: new_value}) | |||
@@ -71,7 +71,10 @@ async def inject_event( | |||
""" | |||
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) | |||
await hs.get_storage().persistence.persist_event(event, context) | |||
persistence = hs.get_storage().persistence | |||
assert persistence is not None | |||
await persistence.persist_event(event, context) | |||
return event | |||