@@ -0,0 +1 @@ | |||
Rename storage classes. |
@@ -22,7 +22,7 @@ from synapse.events import EventBase | |||
from synapse.types import JsonDict, StateMap | |||
if TYPE_CHECKING: | |||
from synapse.storage import Storage | |||
from synapse.storage.controllers import StorageControllers | |||
from synapse.storage.databases.main import DataStore | |||
from synapse.storage.state import StateFilter | |||
@@ -84,7 +84,7 @@ class EventContext: | |||
incomplete state. | |||
""" | |||
_storage: "Storage" | |||
_storage: "StorageControllers" | |||
rejected: Union[Literal[False], str] = False | |||
_state_group: Optional[int] = None | |||
state_group_before_event: Optional[int] = None | |||
@@ -97,7 +97,7 @@ class EventContext: | |||
@staticmethod | |||
def with_state( | |||
storage: "Storage", | |||
storage: "StorageControllers", | |||
state_group: Optional[int], | |||
state_group_before_event: Optional[int], | |||
state_delta_due_to_event: Optional[StateMap[str]], | |||
@@ -117,7 +117,7 @@ class EventContext: | |||
@staticmethod | |||
def for_outlier( | |||
storage: "Storage", | |||
storage: "StorageControllers", | |||
) -> "EventContext": | |||
"""Return an EventContext instance suitable for persisting an outlier event""" | |||
return EventContext(storage=storage) | |||
@@ -147,7 +147,7 @@ class EventContext: | |||
} | |||
@staticmethod | |||
def deserialize(storage: "Storage", input: JsonDict) -> "EventContext": | |||
def deserialize(storage: "StorageControllers", input: JsonDict) -> "EventContext": | |||
"""Converts a dict that was produced by `serialize` back into a | |||
EventContext. | |||
@@ -109,7 +109,6 @@ class FederationServer(FederationBase): | |||
super().__init__(hs) | |||
self.handler = hs.get_federation_handler() | |||
self.storage = hs.get_storage() | |||
self._spam_checker = hs.get_spam_checker() | |||
self._federation_event_handler = hs.get_federation_event_handler() | |||
self.state = hs.get_state_handler() | |||
@@ -30,8 +30,8 @@ logger = logging.getLogger(__name__) | |||
class AdminHandler: | |||
def __init__(self, hs: "HomeServer"): | |||
self.store = hs.get_datastores().main | |||
self.storage = hs.get_storage() | |||
self.state_storage = self.storage.state | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self._state_storage_controller = self._storage_controllers.state | |||
async def get_whois(self, user: UserID) -> JsonDict: | |||
connections = [] | |||
@@ -197,7 +197,9 @@ class AdminHandler: | |||
from_key = events[-1].internal_metadata.after | |||
events = await filter_events_for_client(self.storage, user_id, events) | |||
events = await filter_events_for_client( | |||
self._storage_controllers, user_id, events | |||
) | |||
writer.write_events(room_id, events) | |||
@@ -233,7 +235,9 @@ class AdminHandler: | |||
for event_id in extremities: | |||
if not event_to_unseen_prevs[event_id]: | |||
continue | |||
state = await self.state_storage.get_state_for_event(event_id) | |||
state = await self._state_storage_controller.get_state_for_event( | |||
event_id | |||
) | |||
writer.write_state(room_id, event_id, state) | |||
return writer.finished() | |||
@@ -71,7 +71,7 @@ class DeviceWorkerHandler: | |||
self.store = hs.get_datastores().main | |||
self.notifier = hs.get_notifier() | |||
self.state = hs.get_state_handler() | |||
self.state_storage = hs.get_storage().state | |||
self._state_storage = hs.get_storage_controllers().state | |||
self._auth_handler = hs.get_auth_handler() | |||
self.server_name = hs.hostname | |||
@@ -204,7 +204,7 @@ class DeviceWorkerHandler: | |||
continue | |||
# mapping from event_id -> state_dict | |||
prev_state_ids = await self.state_storage.get_state_ids_for_events( | |||
prev_state_ids = await self._state_storage.get_state_ids_for_events( | |||
event_ids | |||
) | |||
@@ -139,7 +139,7 @@ class EventStreamHandler: | |||
class EventHandler: | |||
def __init__(self, hs: "HomeServer"): | |||
self.store = hs.get_datastores().main | |||
self.storage = hs.get_storage() | |||
self._storage_controllers = hs.get_storage_controllers() | |||
async def get_event( | |||
self, | |||
@@ -177,7 +177,7 @@ class EventHandler: | |||
is_peeking = user.to_string() not in users | |||
filtered = await filter_events_for_client( | |||
self.storage, user.to_string(), [event], is_peeking=is_peeking | |||
self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking | |||
) | |||
if not filtered: | |||
@@ -125,8 +125,8 @@ class FederationHandler: | |||
self.hs = hs | |||
self.store = hs.get_datastores().main | |||
self.storage = hs.get_storage() | |||
self.state_storage = self.storage.state | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self._state_storage_controller = self._storage_controllers.state | |||
self.federation_client = hs.get_federation_client() | |||
self.state_handler = hs.get_state_handler() | |||
self.server_name = hs.hostname | |||
@@ -324,7 +324,7 @@ class FederationHandler: | |||
# We set `check_history_visibility_only` as we might otherwise get false | |||
# positives from users having been erased. | |||
filtered_extremities = await filter_events_for_server( | |||
self.storage, | |||
self._storage_controllers, | |||
self.server_name, | |||
events_to_check, | |||
redact=False, | |||
@@ -660,7 +660,7 @@ class FederationHandler: | |||
# in the invitee's sync stream. It is stripped out for all other local users. | |||
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] | |||
context = EventContext.for_outlier(self.storage) | |||
context = EventContext.for_outlier(self._storage_controllers) | |||
stream_id = await self._federation_event_handler.persist_events_and_notify( | |||
event.room_id, [(event, context)] | |||
) | |||
@@ -849,7 +849,7 @@ class FederationHandler: | |||
) | |||
) | |||
context = EventContext.for_outlier(self.storage) | |||
context = EventContext.for_outlier(self._storage_controllers) | |||
await self._federation_event_handler.persist_events_and_notify( | |||
event.room_id, [(event, context)] | |||
) | |||
@@ -878,7 +878,7 @@ class FederationHandler: | |||
await self.federation_client.send_leave(host_list, event) | |||
context = EventContext.for_outlier(self.storage) | |||
context = EventContext.for_outlier(self._storage_controllers) | |||
stream_id = await self._federation_event_handler.persist_events_and_notify( | |||
event.room_id, [(event, context)] | |||
) | |||
@@ -1027,7 +1027,7 @@ class FederationHandler: | |||
if event.internal_metadata.outlier: | |||
raise NotFoundError("State not known at event %s" % (event_id,)) | |||
state_groups = await self.state_storage.get_state_groups_ids( | |||
state_groups = await self._state_storage_controller.get_state_groups_ids( | |||
room_id, [event_id] | |||
) | |||
@@ -1078,7 +1078,9 @@ class FederationHandler: | |||
], | |||
) | |||
events = await filter_events_for_server(self.storage, origin, events) | |||
events = await filter_events_for_server( | |||
self._storage_controllers, origin, events | |||
) | |||
return events | |||
@@ -1109,7 +1111,9 @@ class FederationHandler: | |||
if not in_room: | |||
raise AuthError(403, "Host not in room.") | |||
events = await filter_events_for_server(self.storage, origin, [event]) | |||
events = await filter_events_for_server( | |||
self._storage_controllers, origin, [event] | |||
) | |||
event = events[0] | |||
return event | |||
else: | |||
@@ -1138,7 +1142,7 @@ class FederationHandler: | |||
) | |||
missing_events = await filter_events_for_server( | |||
self.storage, origin, missing_events | |||
self._storage_controllers, origin, missing_events | |||
) | |||
return missing_events | |||
@@ -1480,9 +1484,11 @@ class FederationHandler: | |||
# clear the lazy-loading flag. | |||
logger.info("Updating current state for %s", room_id) | |||
assert ( | |||
self.storage.persistence is not None | |||
self._storage_controllers.persistence is not None | |||
), "TODO(faster_joins): support for workers" | |||
await self.storage.persistence.update_current_state(room_id) | |||
await self._storage_controllers.persistence.update_current_state( | |||
room_id | |||
) | |||
logger.info("Clearing partial-state flag for %s", room_id) | |||
success = await self.store.clear_partial_state_room(room_id) | |||
@@ -98,8 +98,8 @@ class FederationEventHandler: | |||
def __init__(self, hs: "HomeServer"): | |||
self._store = hs.get_datastores().main | |||
self._storage = hs.get_storage() | |||
self._state_storage = self._storage.state | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self._state_storage_controller = self._storage_controllers.state | |||
self._state_handler = hs.get_state_handler() | |||
self._event_creation_handler = hs.get_event_creation_handler() | |||
@@ -535,7 +535,9 @@ class FederationEventHandler: | |||
) | |||
return | |||
await self._store.update_state_for_partial_state_event(event, context) | |||
self._state_storage.notify_event_un_partial_stated(event.event_id) | |||
self._state_storage_controller.notify_event_un_partial_stated( | |||
event.event_id | |||
) | |||
async def backfill( | |||
self, dest: str, room_id: str, limit: int, extremities: Collection[str] | |||
@@ -835,7 +837,9 @@ class FederationEventHandler: | |||
try: | |||
# Get the state of the events we know about | |||
ours = await self._state_storage.get_state_groups_ids(room_id, seen) | |||
ours = await self._state_storage_controller.get_state_groups_ids( | |||
room_id, seen | |||
) | |||
# state_maps is a list of mappings from (type, state_key) to event_id | |||
state_maps: List[StateMap[str]] = list(ours.values()) | |||
@@ -1436,7 +1440,7 @@ class FederationEventHandler: | |||
# we're not bothering about room state, so flag the event as an outlier. | |||
event.internal_metadata.outlier = True | |||
context = EventContext.for_outlier(self._storage) | |||
context = EventContext.for_outlier(self._storage_controllers) | |||
try: | |||
validate_event_for_room_version(room_version_obj, event) | |||
check_auth_rules_for_event(room_version_obj, event, auth) | |||
@@ -1613,7 +1617,7 @@ class FederationEventHandler: | |||
# given state at the event. This should correctly handle cases | |||
# like bans, especially with state res v2. | |||
state_sets_d = await self._state_storage.get_state_groups_ids( | |||
state_sets_d = await self._state_storage_controller.get_state_groups_ids( | |||
event.room_id, extrem_ids | |||
) | |||
state_sets: List[StateMap[str]] = list(state_sets_d.values()) | |||
@@ -1885,7 +1889,7 @@ class FederationEventHandler: | |||
# create a new state group as a delta from the existing one. | |||
prev_group = context.state_group | |||
state_group = await self._state_storage.store_state_group( | |||
state_group = await self._state_storage_controller.store_state_group( | |||
event.event_id, | |||
event.room_id, | |||
prev_group=prev_group, | |||
@@ -1894,7 +1898,7 @@ class FederationEventHandler: | |||
) | |||
return EventContext.with_state( | |||
storage=self._storage, | |||
storage=self._storage_controllers, | |||
state_group=state_group, | |||
state_group_before_event=context.state_group_before_event, | |||
state_delta_due_to_event=state_updates, | |||
@@ -1984,11 +1988,14 @@ class FederationEventHandler: | |||
) | |||
return result["max_stream_id"] | |||
else: | |||
assert self._storage.persistence | |||
assert self._storage_controllers.persistence | |||
# Note that this returns the events that were persisted, which may not be | |||
# the same as were passed in if some were deduplicated due to transaction IDs. | |||
events, max_stream_token = await self._storage.persistence.persist_events( | |||
( | |||
events, | |||
max_stream_token, | |||
) = await self._storage_controllers.persistence.persist_events( | |||
event_and_contexts, backfilled=backfilled | |||
) | |||
@@ -67,8 +67,8 @@ class InitialSyncHandler: | |||
] | |||
] = ResponseCache(hs.get_clock(), "initial_sync_cache") | |||
self._event_serializer = hs.get_event_client_serializer() | |||
self.storage = hs.get_storage() | |||
self.state_storage = self.storage.state | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self._state_storage_controller = self._storage_controllers.state | |||
async def snapshot_all_rooms( | |||
self, | |||
@@ -198,7 +198,8 @@ class InitialSyncHandler: | |||
event.stream_ordering, | |||
) | |||
deferred_room_state = run_in_background( | |||
self.state_storage.get_state_for_events, [event.event_id] | |||
self._state_storage_controller.get_state_for_events, | |||
[event.event_id], | |||
).addCallback( | |||
lambda states: cast(StateMap[EventBase], states[event.event_id]) | |||
) | |||
@@ -218,7 +219,7 @@ class InitialSyncHandler: | |||
).addErrback(unwrapFirstError) | |||
messages = await filter_events_for_client( | |||
self.storage, user_id, messages | |||
self._storage_controllers, user_id, messages | |||
) | |||
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) | |||
@@ -355,7 +356,9 @@ class InitialSyncHandler: | |||
member_event_id: str, | |||
is_peeking: bool, | |||
) -> JsonDict: | |||
room_state = await self.state_storage.get_state_for_event(member_event_id) | |||
room_state = await self._state_storage_controller.get_state_for_event( | |||
member_event_id | |||
) | |||
limit = pagin_config.limit if pagin_config else None | |||
if limit is None: | |||
@@ -369,7 +372,7 @@ class InitialSyncHandler: | |||
) | |||
messages = await filter_events_for_client( | |||
self.storage, user_id, messages, is_peeking=is_peeking | |||
self._storage_controllers, user_id, messages, is_peeking=is_peeking | |||
) | |||
start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token) | |||
@@ -474,7 +477,7 @@ class InitialSyncHandler: | |||
) | |||
messages = await filter_events_for_client( | |||
self.storage, user_id, messages, is_peeking=is_peeking | |||
self._storage_controllers, user_id, messages, is_peeking=is_peeking | |||
) | |||
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) | |||
@@ -84,8 +84,8 @@ class MessageHandler: | |||
self.clock = hs.get_clock() | |||
self.state = hs.get_state_handler() | |||
self.store = hs.get_datastores().main | |||
self.storage = hs.get_storage() | |||
self.state_storage = self.storage.state | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self._state_storage_controller = self._storage_controllers.state | |||
self._event_serializer = hs.get_event_client_serializer() | |||
self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages | |||
@@ -132,7 +132,7 @@ class MessageHandler: | |||
assert ( | |||
membership_event_id is not None | |||
), "check_user_in_room_or_world_readable returned invalid data" | |||
room_state = await self.state_storage.get_state_for_events( | |||
room_state = await self._state_storage_controller.get_state_for_events( | |||
[membership_event_id], StateFilter.from_types([key]) | |||
) | |||
data = room_state[membership_event_id].get(key) | |||
@@ -193,7 +193,7 @@ class MessageHandler: | |||
# check whether the user is in the room at that time to determine | |||
# whether they should be treated as peeking. | |||
state_map = await self.state_storage.get_state_for_event( | |||
state_map = await self._state_storage_controller.get_state_for_event( | |||
last_event.event_id, | |||
StateFilter.from_types([(EventTypes.Member, user_id)]), | |||
) | |||
@@ -206,7 +206,7 @@ class MessageHandler: | |||
is_peeking = not joined | |||
visible_events = await filter_events_for_client( | |||
self.storage, | |||
self._storage_controllers, | |||
user_id, | |||
[last_event], | |||
filter_send_to_client=False, | |||
@@ -214,8 +214,10 @@ class MessageHandler: | |||
) | |||
if visible_events: | |||
room_state_events = await self.state_storage.get_state_for_events( | |||
[last_event.event_id], state_filter=state_filter | |||
room_state_events = ( | |||
await self._state_storage_controller.get_state_for_events( | |||
[last_event.event_id], state_filter=state_filter | |||
) | |||
) | |||
room_state: Mapping[Any, EventBase] = room_state_events[ | |||
last_event.event_id | |||
@@ -244,8 +246,10 @@ class MessageHandler: | |||
assert ( | |||
membership_event_id is not None | |||
), "check_user_in_room_or_world_readable returned invalid data" | |||
room_state_events = await self.state_storage.get_state_for_events( | |||
[membership_event_id], state_filter=state_filter | |||
room_state_events = ( | |||
await self._state_storage_controller.get_state_for_events( | |||
[membership_event_id], state_filter=state_filter | |||
) | |||
) | |||
room_state = room_state_events[membership_event_id] | |||
@@ -402,7 +406,7 @@ class EventCreationHandler: | |||
self.auth = hs.get_auth() | |||
self._event_auth_handler = hs.get_event_auth_handler() | |||
self.store = hs.get_datastores().main | |||
self.storage = hs.get_storage() | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self.state = hs.get_state_handler() | |||
self.clock = hs.get_clock() | |||
self.validator = EventValidator() | |||
@@ -1032,7 +1036,7 @@ class EventCreationHandler: | |||
# after it is created | |||
if builder.internal_metadata.outlier: | |||
event.internal_metadata.outlier = True | |||
context = EventContext.for_outlier(self.storage) | |||
context = EventContext.for_outlier(self._storage_controllers) | |||
elif ( | |||
event.type == EventTypes.MSC2716_INSERTION | |||
and state_event_ids | |||
@@ -1445,7 +1449,7 @@ class EventCreationHandler: | |||
""" | |||
extra_users = extra_users or [] | |||
assert self.storage.persistence is not None | |||
assert self._storage_controllers.persistence is not None | |||
assert self._events_shard_config.should_handle( | |||
self._instance_name, event.room_id | |||
) | |||
@@ -1679,7 +1683,7 @@ class EventCreationHandler: | |||
event, | |||
event_pos, | |||
max_stream_token, | |||
) = await self.storage.persistence.persist_event( | |||
) = await self._storage_controllers.persistence.persist_event( | |||
event, context=context, backfilled=backfilled | |||
) | |||
@@ -129,8 +129,8 @@ class PaginationHandler: | |||
self.hs = hs | |||
self.auth = hs.get_auth() | |||
self.store = hs.get_datastores().main | |||
self.storage = hs.get_storage() | |||
self.state_storage = self.storage.state | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self._state_storage_controller = self._storage_controllers.state | |||
self.clock = hs.get_clock() | |||
self._server_name = hs.hostname | |||
self._room_shutdown_handler = hs.get_room_shutdown_handler() | |||
@@ -352,7 +352,7 @@ class PaginationHandler: | |||
self._purges_in_progress_by_room.add(room_id) | |||
try: | |||
async with self.pagination_lock.write(room_id): | |||
await self.storage.purge_events.purge_history( | |||
await self._storage_controllers.purge_events.purge_history( | |||
room_id, token, delete_local_events | |||
) | |||
logger.info("[purge] complete") | |||
@@ -414,7 +414,7 @@ class PaginationHandler: | |||
if joined: | |||
raise SynapseError(400, "Users are still joined to this room") | |||
await self.storage.purge_events.purge_room(room_id) | |||
await self._storage_controllers.purge_events.purge_room(room_id) | |||
async def get_messages( | |||
self, | |||
@@ -529,7 +529,10 @@ class PaginationHandler: | |||
events = await event_filter.filter(events) | |||
events = await filter_events_for_client( | |||
self.storage, user_id, events, is_peeking=(member_event_id is None) | |||
self._storage_controllers, | |||
user_id, | |||
events, | |||
is_peeking=(member_event_id is None), | |||
) | |||
# if after the filter applied there are no more events | |||
@@ -550,7 +553,7 @@ class PaginationHandler: | |||
(EventTypes.Member, event.sender) for event in events | |||
) | |||
state_ids = await self.state_storage.get_state_ids_for_event( | |||
state_ids = await self._state_storage_controller.get_state_ids_for_event( | |||
events[0].event_id, state_filter=state_filter | |||
) | |||
@@ -664,7 +667,7 @@ class PaginationHandler: | |||
400, "Users are still joined to this room" | |||
) | |||
await self.storage.purge_events.purge_room(room_id) | |||
await self._storage_controllers.purge_events.purge_room(room_id) | |||
logger.info("complete") | |||
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE | |||
@@ -69,7 +69,7 @@ class BundledAggregations: | |||
class RelationsHandler: | |||
def __init__(self, hs: "HomeServer"): | |||
self._main_store = hs.get_datastores().main | |||
self._storage = hs.get_storage() | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self._auth = hs.get_auth() | |||
self._clock = hs.get_clock() | |||
self._event_handler = hs.get_event_handler() | |||
@@ -143,7 +143,10 @@ class RelationsHandler: | |||
) | |||
events = await filter_events_for_client( | |||
self._storage, user_id, events, is_peeking=(member_event_id is None) | |||
self._storage_controllers, | |||
user_id, | |||
events, | |||
is_peeking=(member_event_id is None), | |||
) | |||
now = self._clock.time_msec() | |||
@@ -1192,8 +1192,8 @@ class RoomContextHandler: | |||
self.hs = hs | |||
self.auth = hs.get_auth() | |||
self.store = hs.get_datastores().main | |||
self.storage = hs.get_storage() | |||
self.state_storage = self.storage.state | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self._state_storage_controller = self._storage_controllers.state | |||
self._relations_handler = hs.get_relations_handler() | |||
async def get_event_context( | |||
@@ -1236,7 +1236,10 @@ class RoomContextHandler: | |||
if use_admin_priviledge: | |||
return events | |||
return await filter_events_for_client( | |||
self.storage, user.to_string(), events, is_peeking=is_peeking | |||
self._storage_controllers, | |||
user.to_string(), | |||
events, | |||
is_peeking=is_peeking, | |||
) | |||
event = await self.store.get_event( | |||
@@ -1293,7 +1296,7 @@ class RoomContextHandler: | |||
# first? Shouldn't we be consistent with /sync? | |||
# https://github.com/matrix-org/matrix-doc/issues/687 | |||
state = await self.state_storage.get_state_for_events( | |||
state = await self._state_storage_controller.get_state_for_events( | |||
[last_event_id], state_filter=state_filter | |||
) | |||
@@ -17,7 +17,7 @@ class RoomBatchHandler: | |||
def __init__(self, hs: "HomeServer"): | |||
self.hs = hs | |||
self.store = hs.get_datastores().main | |||
self.state_storage = hs.get_storage().state | |||
self._state_storage_controller = hs.get_storage_controllers().state | |||
self.event_creation_handler = hs.get_event_creation_handler() | |||
self.room_member_handler = hs.get_room_member_handler() | |||
self.auth = hs.get_auth() | |||
@@ -141,7 +141,7 @@ class RoomBatchHandler: | |||
) = await self.store.get_max_depth_of(event_ids) | |||
# mapping from (type, state_key) -> state_event_id | |||
assert most_recent_event_id is not None | |||
prev_state_map = await self.state_storage.get_state_ids_for_event( | |||
prev_state_map = await self._state_storage_controller.get_state_ids_for_event( | |||
most_recent_event_id | |||
) | |||
# List of state event ID's | |||
@@ -55,8 +55,8 @@ class SearchHandler: | |||
self.hs = hs | |||
self._event_serializer = hs.get_event_client_serializer() | |||
self._relations_handler = hs.get_relations_handler() | |||
self.storage = hs.get_storage() | |||
self.state_storage = self.storage.state | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self._state_storage_controller = self._storage_controllers.state | |||
self.auth = hs.get_auth() | |||
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]: | |||
@@ -460,7 +460,7 @@ class SearchHandler: | |||
filtered_events = await search_filter.filter([r["event"] for r in results]) | |||
events = await filter_events_for_client( | |||
self.storage, user.to_string(), filtered_events | |||
self._storage_controllers, user.to_string(), filtered_events | |||
) | |||
events.sort(key=lambda e: -rank_map[e.event_id]) | |||
@@ -559,7 +559,7 @@ class SearchHandler: | |||
filtered_events = await search_filter.filter([r["event"] for r in results]) | |||
events = await filter_events_for_client( | |||
self.storage, user.to_string(), filtered_events | |||
self._storage_controllers, user.to_string(), filtered_events | |||
) | |||
room_events.extend(events) | |||
@@ -644,11 +644,11 @@ class SearchHandler: | |||
) | |||
events_before = await filter_events_for_client( | |||
self.storage, user.to_string(), res.events_before | |||
self._storage_controllers, user.to_string(), res.events_before | |||
) | |||
events_after = await filter_events_for_client( | |||
self.storage, user.to_string(), res.events_after | |||
self._storage_controllers, user.to_string(), res.events_after | |||
) | |||
context: JsonDict = { | |||
@@ -677,7 +677,7 @@ class SearchHandler: | |||
[(EventTypes.Member, sender) for sender in senders] | |||
) | |||
state = await self.state_storage.get_state_for_event( | |||
state = await self._state_storage_controller.get_state_for_event( | |||
last_event_id, state_filter | |||
) | |||
@@ -238,8 +238,8 @@ class SyncHandler: | |||
self.clock = hs.get_clock() | |||
self.state = hs.get_state_handler() | |||
self.auth = hs.get_auth() | |||
self.storage = hs.get_storage() | |||
self.state_storage = self.storage.state | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self._state_storage_controller = self._storage_controllers.state | |||
# TODO: flush cache entries on subsequent sync request. | |||
# Once we get the next /sync request (ie, one with the same access token | |||
@@ -512,7 +512,7 @@ class SyncHandler: | |||
current_state_ids = frozenset(current_state_ids_map.values()) | |||
recents = await filter_events_for_client( | |||
self.storage, | |||
self._storage_controllers, | |||
sync_config.user.to_string(), | |||
recents, | |||
always_include_ids=current_state_ids, | |||
@@ -580,7 +580,7 @@ class SyncHandler: | |||
current_state_ids = frozenset(current_state_ids_map.values()) | |||
loaded_recents = await filter_events_for_client( | |||
self.storage, | |||
self._storage_controllers, | |||
sync_config.user.to_string(), | |||
loaded_recents, | |||
always_include_ids=current_state_ids, | |||
@@ -630,7 +630,7 @@ class SyncHandler: | |||
event: event of interest | |||
state_filter: The state filter used to fetch state from the database. | |||
""" | |||
state_ids = await self.state_storage.get_state_ids_for_event( | |||
state_ids = await self._state_storage_controller.get_state_ids_for_event( | |||
event.event_id, state_filter=state_filter or StateFilter.all() | |||
) | |||
if event.is_state(): | |||
@@ -710,7 +710,7 @@ class SyncHandler: | |||
return None | |||
last_event = last_events[-1] | |||
state_ids = await self.state_storage.get_state_ids_for_event( | |||
state_ids = await self._state_storage_controller.get_state_ids_for_event( | |||
last_event.event_id, | |||
state_filter=StateFilter.from_types( | |||
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] | |||
@@ -889,13 +889,15 @@ class SyncHandler: | |||
if full_state: | |||
if batch: | |||
current_state_ids = ( | |||
await self.state_storage.get_state_ids_for_event( | |||
await self._state_storage_controller.get_state_ids_for_event( | |||
batch.events[-1].event_id, state_filter=state_filter | |||
) | |||
) | |||
state_ids = await self.state_storage.get_state_ids_for_event( | |||
batch.events[0].event_id, state_filter=state_filter | |||
state_ids = ( | |||
await self._state_storage_controller.get_state_ids_for_event( | |||
batch.events[0].event_id, state_filter=state_filter | |||
) | |||
) | |||
else: | |||
@@ -915,7 +917,7 @@ class SyncHandler: | |||
elif batch.limited: | |||
if batch: | |||
state_at_timeline_start = ( | |||
await self.state_storage.get_state_ids_for_event( | |||
await self._state_storage_controller.get_state_ids_for_event( | |||
batch.events[0].event_id, state_filter=state_filter | |||
) | |||
) | |||
@@ -950,7 +952,7 @@ class SyncHandler: | |||
if batch: | |||
current_state_ids = ( | |||
await self.state_storage.get_state_ids_for_event( | |||
await self._state_storage_controller.get_state_ids_for_event( | |||
batch.events[-1].event_id, state_filter=state_filter | |||
) | |||
) | |||
@@ -982,7 +984,7 @@ class SyncHandler: | |||
# So we fish out all the member events corresponding to the | |||
# timeline here, and then dedupe any redundant ones below. | |||
state_ids = await self.state_storage.get_state_ids_for_event( | |||
state_ids = await self._state_storage_controller.get_state_ids_for_event( | |||
batch.events[0].event_id, | |||
# we only want members! | |||
state_filter=StateFilter.from_types( | |||
@@ -221,7 +221,7 @@ class Notifier: | |||
self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {} | |||
self.hs = hs | |||
self.storage = hs.get_storage() | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self.event_sources = hs.get_event_sources() | |||
self.store = hs.get_datastores().main | |||
self.pending_new_room_events: List[_PendingRoomEventEntry] = [] | |||
@@ -623,7 +623,7 @@ class Notifier: | |||
if name == "room": | |||
new_events = await filter_events_for_client( | |||
self.storage, | |||
self._storage_controllers, | |||
user.to_string(), | |||
new_events, | |||
is_peeking=is_peeking, | |||
@@ -65,7 +65,7 @@ class HttpPusher(Pusher): | |||
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): | |||
super().__init__(hs, pusher_config) | |||
self.storage = self.hs.get_storage() | |||
self._storage_controllers = self.hs.get_storage_controllers() | |||
self.app_display_name = pusher_config.app_display_name | |||
self.device_display_name = pusher_config.device_display_name | |||
self.pushkey_ts = pusher_config.ts | |||
@@ -343,7 +343,9 @@ class HttpPusher(Pusher): | |||
} | |||
return d | |||
ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id) | |||
ctx = await push_tools.get_context_for_event( | |||
self._storage_controllers, event, self.user_id | |||
) | |||
d = { | |||
"notification": { | |||
@@ -114,10 +114,10 @@ class Mailer: | |||
self.send_email_handler = hs.get_send_email_handler() | |||
self.store = self.hs.get_datastores().main | |||
self.state_storage = self.hs.get_storage().state | |||
self._state_storage_controller = self.hs.get_storage_controllers().state | |||
self.macaroon_gen = self.hs.get_macaroon_generator() | |||
self.state_handler = self.hs.get_state_handler() | |||
self.storage = hs.get_storage() | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self.app_name = app_name | |||
self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects | |||
@@ -456,7 +456,7 @@ class Mailer: | |||
} | |||
the_events = await filter_events_for_client( | |||
self.storage, user_id, results.events_before | |||
self._storage_controllers, user_id, results.events_before | |||
) | |||
the_events.append(notif_event) | |||
@@ -494,7 +494,7 @@ class Mailer: | |||
) | |||
else: | |||
# Attempt to check the historical state for the room. | |||
historical_state = await self.state_storage.get_state_for_event( | |||
historical_state = await self._state_storage_controller.get_state_for_event( | |||
event.event_id, StateFilter.from_types((type_state_key,)) | |||
) | |||
sender_state_event = historical_state.get(type_state_key) | |||
@@ -767,8 +767,10 @@ class Mailer: | |||
member_event_ids.append(sender_state_event_id) | |||
else: | |||
# Attempt to check the historical state for the room. | |||
historical_state = await self.state_storage.get_state_for_event( | |||
event_id, StateFilter.from_types((type_state_key,)) | |||
historical_state = ( | |||
await self._state_storage_controller.get_state_for_event( | |||
event_id, StateFilter.from_types((type_state_key,)) | |||
) | |||
) | |||
sender_state_event = historical_state.get(type_state_key) | |||
if sender_state_event: | |||
@@ -16,7 +16,7 @@ from typing import Dict | |||
from synapse.api.constants import ReceiptTypes | |||
from synapse.events import EventBase | |||
from synapse.push.presentable_names import calculate_room_name, name_from_member_event | |||
from synapse.storage import Storage | |||
from synapse.storage.controllers import StorageControllers | |||
from synapse.storage.databases.main import DataStore | |||
@@ -52,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - | |||
async def get_context_for_event( | |||
storage: Storage, ev: EventBase, user_id: str | |||
storage: StorageControllers, ev: EventBase, user_id: str | |||
) -> Dict[str, str]: | |||
ctx = {} | |||
@@ -69,7 +69,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): | |||
super().__init__(hs) | |||
self.store = hs.get_datastores().main | |||
self.storage = hs.get_storage() | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self.clock = hs.get_clock() | |||
self.federation_event_handler = hs.get_federation_event_handler() | |||
@@ -133,7 +133,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): | |||
event.internal_metadata.outlier = event_payload["outlier"] | |||
context = EventContext.deserialize( | |||
self.storage, event_payload["context"] | |||
self._storage_controllers, event_payload["context"] | |||
) | |||
event_and_contexts.append((event, context)) | |||
@@ -70,7 +70,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): | |||
self.event_creation_handler = hs.get_event_creation_handler() | |||
self.store = hs.get_datastores().main | |||
self.storage = hs.get_storage() | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self.clock = hs.get_clock() | |||
@staticmethod | |||
@@ -127,7 +127,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): | |||
event.internal_metadata.outlier = content["outlier"] | |||
requester = Requester.deserialize(self.store, content["requester"]) | |||
context = EventContext.deserialize(self.storage, content["context"]) | |||
context = EventContext.deserialize( | |||
self._storage_controllers, content["context"] | |||
) | |||
ratelimit = content["ratelimit"] | |||
extra_users = [UserID.from_string(u) for u in content["extra_users"]] | |||
@@ -123,7 +123,8 @@ from synapse.server_notices.worker_server_notices_sender import ( | |||
WorkerServerNoticesSender, | |||
) | |||
from synapse.state import StateHandler, StateResolutionHandler | |||
from synapse.storage import Databases, Storage | |||
from synapse.storage import Databases | |||
from synapse.storage.controllers import StorageControllers | |||
from synapse.streams.events import EventSources | |||
from synapse.types import DomainSpecificString, ISynapseReactor | |||
from synapse.util import Clock | |||
@@ -729,8 +730,8 @@ class HomeServer(metaclass=abc.ABCMeta): | |||
return PasswordPolicyHandler(self) | |||
@cache_in_self | |||
def get_storage(self) -> Storage: | |||
return Storage(self, self.get_datastores()) | |||
def get_storage_controllers(self) -> StorageControllers: | |||
return StorageControllers(self, self.get_datastores()) | |||
@cache_in_self | |||
def get_replication_streamer(self) -> ReplicationStreamer: | |||
@@ -127,10 +127,10 @@ class StateHandler: | |||
def __init__(self, hs: "HomeServer"): | |||
self.clock = hs.get_clock() | |||
self.store = hs.get_datastores().main | |||
self.state_storage = hs.get_storage().state | |||
self._state_storage_controller = hs.get_storage_controllers().state | |||
self.hs = hs | |||
self._state_resolution_handler = hs.get_state_resolution_handler() | |||
self._storage = hs.get_storage() | |||
self._storage_controllers = hs.get_storage_controllers() | |||
@overload | |||
async def get_current_state( | |||
@@ -337,12 +337,14 @@ class StateHandler: | |||
# | |||
if not state_group_before_event: | |||
state_group_before_event = await self.state_storage.store_state_group( | |||
event.event_id, | |||
event.room_id, | |||
prev_group=state_group_before_event_prev_group, | |||
delta_ids=deltas_to_state_group_before_event, | |||
current_state_ids=state_ids_before_event, | |||
state_group_before_event = ( | |||
await self._state_storage_controller.store_state_group( | |||
event.event_id, | |||
event.room_id, | |||
prev_group=state_group_before_event_prev_group, | |||
delta_ids=deltas_to_state_group_before_event, | |||
current_state_ids=state_ids_before_event, | |||
) | |||
) | |||
# Assign the new state group to the cached state entry. | |||
@@ -359,7 +361,7 @@ class StateHandler: | |||
if not event.is_state(): | |||
return EventContext.with_state( | |||
storage=self._storage, | |||
storage=self._storage_controllers, | |||
state_group_before_event=state_group_before_event, | |||
state_group=state_group_before_event, | |||
state_delta_due_to_event={}, | |||
@@ -382,16 +384,18 @@ class StateHandler: | |||
state_ids_after_event[key] = event.event_id | |||
delta_ids = {key: event.event_id} | |||
state_group_after_event = await self.state_storage.store_state_group( | |||
event.event_id, | |||
event.room_id, | |||
prev_group=state_group_before_event, | |||
delta_ids=delta_ids, | |||
current_state_ids=state_ids_after_event, | |||
state_group_after_event = ( | |||
await self._state_storage_controller.store_state_group( | |||
event.event_id, | |||
event.room_id, | |||
prev_group=state_group_before_event, | |||
delta_ids=delta_ids, | |||
current_state_ids=state_ids_after_event, | |||
) | |||
) | |||
return EventContext.with_state( | |||
storage=self._storage, | |||
storage=self._storage_controllers, | |||
state_group=state_group_after_event, | |||
state_group_before_event=state_group_before_event, | |||
state_delta_due_to_event=delta_ids, | |||
@@ -416,7 +420,9 @@ class StateHandler: | |||
""" | |||
logger.debug("resolve_state_groups event_ids %s", event_ids) | |||
state_groups = await self.state_storage.get_state_group_for_events(event_ids) | |||
state_groups = await self._state_storage_controller.get_state_group_for_events( | |||
event_ids | |||
) | |||
state_group_ids = state_groups.values() | |||
@@ -424,8 +430,13 @@ class StateHandler: | |||
state_group_ids_set = set(state_group_ids) | |||
if len(state_group_ids_set) == 1: | |||
(state_group_id,) = state_group_ids_set | |||
state = await self.state_storage.get_state_for_groups(state_group_ids_set) | |||
prev_group, delta_ids = await self.state_storage.get_state_group_delta( | |||
state = await self._state_storage_controller.get_state_for_groups( | |||
state_group_ids_set | |||
) | |||
( | |||
prev_group, | |||
delta_ids, | |||
) = await self._state_storage_controller.get_state_group_delta( | |||
state_group_id | |||
) | |||
return _StateCacheEntry( | |||
@@ -439,7 +450,7 @@ class StateHandler: | |||
room_version = await self.store.get_room_version_id(room_id) | |||
state_to_resolve = await self.state_storage.get_state_for_groups( | |||
state_to_resolve = await self._state_storage_controller.get_state_for_groups( | |||
state_group_ids_set | |||
) | |||
@@ -18,41 +18,20 @@ The storage layer is split up into multiple parts to allow Synapse to run | |||
against different configurations of databases (e.g. single or multiple | |||
databases). The `DatabasePool` class represents connections to a single physical | |||
database. The `databases` are classes that talk directly to a `DatabasePool` | |||
instance and have associated schemas, background updates, etc. On top of those | |||
there are classes that provide high level interfaces that combine calls to | |||
multiple `databases`. | |||
instance and have associated schemas, background updates, etc. | |||
On top of the databases are the StorageControllers, located in the | |||
`synapse.storage.controllers` module. These classes provide high level | |||
interfaces that combine calls to multiple `databases`. They are bundled into the | |||
`StorageControllers` singleton for ease of use, and exposed via | |||
`HomeServer.get_storage_controllers()`. | |||
There are also schemas that get applied to every database, regardless of the | |||
data stores associated with them (e.g. the schema version tables), which are | |||
stored in `synapse.storage.schema`. | |||
""" | |||
from typing import TYPE_CHECKING | |||
from synapse.storage.databases import Databases | |||
from synapse.storage.databases.main import DataStore | |||
from synapse.storage.persist_events import EventsPersistenceStorage | |||
from synapse.storage.purge_events import PurgeEventsStorage | |||
from synapse.storage.state import StateGroupStorage | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
__all__ = ["Databases", "DataStore"] | |||
class Storage: | |||
"""The high level interfaces for talking to various storage layers.""" | |||
def __init__(self, hs: "HomeServer", stores: Databases): | |||
# We include the main data store here mainly so that we don't have to | |||
# rewrite all the existing code to split it into high vs low level | |||
# interfaces. | |||
self.main = stores.main | |||
self.purge_events = PurgeEventsStorage(hs, stores) | |||
self.state = StateGroupStorage(hs, stores) | |||
self.persistence = None | |||
if stores.persist_events: | |||
self.persistence = EventsPersistenceStorage(hs, stores) |
@@ -0,0 +1,46 @@ | |||
# Copyright 2022 The Matrix.org Foundation C.I.C. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import TYPE_CHECKING | |||
from synapse.storage.controllers.persist_events import ( | |||
EventsPersistenceStorageController, | |||
) | |||
from synapse.storage.controllers.purge_events import PurgeEventsStorageController | |||
from synapse.storage.controllers.state import StateGroupStorageController | |||
from synapse.storage.databases import Databases | |||
from synapse.storage.databases.main import DataStore | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
__all__ = ["Databases", "DataStore"] | |||
class StorageControllers: | |||
"""The high level interfaces for talking to various storage controller layers.""" | |||
def __init__(self, hs: "HomeServer", stores: Databases): | |||
# We include the main data store here mainly so that we don't have to | |||
# rewrite all the existing code to split it into high vs low level | |||
# interfaces. | |||
self.main = stores.main | |||
self.purge_events = PurgeEventsStorageController(hs, stores) | |||
self.state = StateGroupStorageController(hs, stores) | |||
self.persistence = None | |||
if stores.persist_events: | |||
self.persistence = EventsPersistenceStorageController(hs, stores) |
@@ -272,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): | |||
pass | |||
class EventsPersistenceStorage: | |||
class EventsPersistenceStorageController: | |||
"""High level interface for handling persisting newly received events. | |||
Takes care of batching up events by room, and calculating the necessary |
@@ -24,7 +24,7 @@ if TYPE_CHECKING: | |||
logger = logging.getLogger(__name__) | |||
class PurgeEventsStorage: | |||
class PurgeEventsStorageController: | |||
"""High level interface for purging rooms and event history.""" | |||
def __init__(self, hs: "HomeServer", stores: Databases): |
@@ -0,0 +1,351 @@ | |||
# Copyright 2022 The Matrix.org Foundation C.I.C. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Awaitable, | |||
Collection, | |||
Dict, | |||
Iterable, | |||
List, | |||
Mapping, | |||
Optional, | |||
Tuple, | |||
) | |||
from synapse.events import EventBase | |||
from synapse.storage.state import StateFilter | |||
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker | |||
from synapse.types import MutableStateMap, StateMap | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
from synapse.storage.databases import Databases | |||
logger = logging.getLogger(__name__) | |||
class StateGroupStorageController: | |||
"""High level interface to fetching state for event.""" | |||
def __init__(self, hs: "HomeServer", stores: "Databases"): | |||
self._is_mine_id = hs.is_mine_id | |||
self.stores = stores | |||
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) | |||
def notify_event_un_partial_stated(self, event_id: str) -> None: | |||
self._partial_state_events_tracker.notify_un_partial_stated(event_id) | |||
async def get_state_group_delta( | |||
self, state_group: int | |||
) -> Tuple[Optional[int], Optional[StateMap[str]]]: | |||
"""Given a state group try to return a previous group and a delta between | |||
the old and the new. | |||
Args: | |||
state_group: The state group used to retrieve state deltas. | |||
Returns: | |||
A tuple of the previous group and a state map of the event IDs which | |||
make up the delta between the old and new state groups. | |||
""" | |||
state_group_delta = await self.stores.state.get_state_group_delta(state_group) | |||
return state_group_delta.prev_group, state_group_delta.delta_ids | |||
async def get_state_groups_ids( | |||
self, _room_id: str, event_ids: Collection[str] | |||
) -> Dict[int, MutableStateMap[str]]: | |||
"""Get the event IDs of all the state for the state groups for the given events | |||
Args: | |||
_room_id: id of the room for these events | |||
event_ids: ids of the events | |||
Returns: | |||
dict of state_group_id -> (dict of (type, state_key) -> event id) | |||
Raises: | |||
RuntimeError if we don't have a state group for one or more of the events | |||
(ie they are outliers or unknown) | |||
""" | |||
if not event_ids: | |||
return {} | |||
event_to_groups = await self.get_state_group_for_events(event_ids) | |||
groups = set(event_to_groups.values()) | |||
group_to_state = await self.stores.state._get_state_for_groups(groups) | |||
return group_to_state | |||
async def get_state_ids_for_group( | |||
self, state_group: int, state_filter: Optional[StateFilter] = None | |||
) -> StateMap[str]: | |||
"""Get the event IDs of all the state in the given state group | |||
Args: | |||
state_group: A state group for which we want to get the state IDs. | |||
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules | |||
Returns: | |||
Resolves to a map of (type, state_key) -> event_id | |||
""" | |||
group_to_state = await self.get_state_for_groups((state_group,), state_filter) | |||
return group_to_state[state_group] | |||
async def get_state_groups( | |||
self, room_id: str, event_ids: Collection[str] | |||
) -> Dict[int, List[EventBase]]: | |||
"""Get the state groups for the given list of event_ids | |||
Args: | |||
room_id: ID of the room for these events. | |||
event_ids: The event IDs to retrieve state for. | |||
Returns: | |||
dict of state_group_id -> list of state events. | |||
""" | |||
if not event_ids: | |||
return {} | |||
group_to_ids = await self.get_state_groups_ids(room_id, event_ids) | |||
state_event_map = await self.stores.main.get_events( | |||
[ | |||
ev_id | |||
for group_ids in group_to_ids.values() | |||
for ev_id in group_ids.values() | |||
], | |||
get_prev_content=False, | |||
) | |||
return { | |||
group: [ | |||
state_event_map[v] | |||
for v in event_id_map.values() | |||
if v in state_event_map | |||
] | |||
for group, event_id_map in group_to_ids.items() | |||
} | |||
def _get_state_groups_from_groups( | |||
self, groups: List[int], state_filter: StateFilter | |||
) -> Awaitable[Dict[int, StateMap[str]]]: | |||
"""Returns the state groups for a given set of groups, filtering on | |||
types of state events. | |||
Args: | |||
groups: list of state group IDs to query | |||
state_filter: The state filter used to fetch state | |||
from the database. | |||
Returns: | |||
Dict of state group to state map. | |||
""" | |||
return self.stores.state._get_state_groups_from_groups(groups, state_filter) | |||
async def get_state_for_events( | |||
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None | |||
) -> Dict[str, StateMap[EventBase]]: | |||
"""Given a list of event_ids and type tuples, return a list of state | |||
dicts for each event. | |||
Args: | |||
event_ids: The events to fetch the state of. | |||
state_filter: The state filter used to fetch state. | |||
Returns: | |||
A dict of (event_id) -> (type, state_key) -> [state_events] | |||
Raises: | |||
RuntimeError if we don't have a state group for one or more of the events | |||
(ie they are outliers or unknown) | |||
""" | |||
await_full_state = True | |||
if state_filter and not state_filter.must_await_full_state(self._is_mine_id): | |||
await_full_state = False | |||
event_to_groups = await self.get_state_group_for_events( | |||
event_ids, await_full_state=await_full_state | |||
) | |||
groups = set(event_to_groups.values()) | |||
group_to_state = await self.stores.state._get_state_for_groups( | |||
groups, state_filter or StateFilter.all() | |||
) | |||
state_event_map = await self.stores.main.get_events( | |||
[ev_id for sd in group_to_state.values() for ev_id in sd.values()], | |||
get_prev_content=False, | |||
) | |||
event_to_state = { | |||
event_id: { | |||
k: state_event_map[v] | |||
for k, v in group_to_state[group].items() | |||
if v in state_event_map | |||
} | |||
for event_id, group in event_to_groups.items() | |||
} | |||
return {event: event_to_state[event] for event in event_ids} | |||
async def get_state_ids_for_events( | |||
self, | |||
event_ids: Collection[str], | |||
state_filter: Optional[StateFilter] = None, | |||
) -> Dict[str, StateMap[str]]: | |||
""" | |||
Get the state dicts corresponding to a list of events, containing the event_ids | |||
of the state events (as opposed to the events themselves) | |||
Args: | |||
event_ids: events whose state should be returned | |||
state_filter: The state filter used to fetch state from the database. | |||
Returns: | |||
A dict from event_id -> (type, state_key) -> event_id | |||
Raises: | |||
RuntimeError if we don't have a state group for one or more of the events | |||
(ie they are outliers or unknown) | |||
""" | |||
await_full_state = True | |||
if state_filter and not state_filter.must_await_full_state(self._is_mine_id): | |||
await_full_state = False | |||
event_to_groups = await self.get_state_group_for_events( | |||
event_ids, await_full_state=await_full_state | |||
) | |||
groups = set(event_to_groups.values()) | |||
group_to_state = await self.stores.state._get_state_for_groups( | |||
groups, state_filter or StateFilter.all() | |||
) | |||
event_to_state = { | |||
event_id: group_to_state[group] | |||
for event_id, group in event_to_groups.items() | |||
} | |||
return {event: event_to_state[event] for event in event_ids} | |||
async def get_state_for_event( | |||
self, event_id: str, state_filter: Optional[StateFilter] = None | |||
) -> StateMap[EventBase]: | |||
""" | |||
Get the state dict corresponding to a particular event | |||
Args: | |||
event_id: event whose state should be returned | |||
state_filter: The state filter used to fetch state from the database. | |||
Returns: | |||
A dict from (type, state_key) -> state_event | |||
Raises: | |||
RuntimeError if we don't have a state group for the event (ie it is an | |||
outlier or is unknown) | |||
""" | |||
state_map = await self.get_state_for_events( | |||
[event_id], state_filter or StateFilter.all() | |||
) | |||
return state_map[event_id] | |||
async def get_state_ids_for_event( | |||
self, event_id: str, state_filter: Optional[StateFilter] = None | |||
) -> StateMap[str]: | |||
""" | |||
Get the state dict corresponding to a particular event | |||
Args: | |||
event_id: event whose state should be returned | |||
state_filter: The state filter used to fetch state from the database. | |||
Returns: | |||
A dict from (type, state_key) -> state_event_id | |||
Raises: | |||
RuntimeError if we don't have a state group for the event (ie it is an | |||
outlier or is unknown) | |||
""" | |||
state_map = await self.get_state_ids_for_events( | |||
[event_id], state_filter or StateFilter.all() | |||
) | |||
return state_map[event_id] | |||
def get_state_for_groups( | |||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None | |||
) -> Awaitable[Dict[int, MutableStateMap[str]]]: | |||
"""Gets the state at each of a list of state groups, optionally | |||
filtering by type/state_key | |||
Args: | |||
groups: list of state groups for which we want to get the state. | |||
state_filter: The state filter used to fetch state. | |||
from the database. | |||
Returns: | |||
Dict of state group to state map. | |||
""" | |||
return self.stores.state._get_state_for_groups( | |||
groups, state_filter or StateFilter.all() | |||
) | |||
async def get_state_group_for_events( | |||
self, | |||
event_ids: Collection[str], | |||
await_full_state: bool = True, | |||
) -> Mapping[str, int]: | |||
"""Returns mapping event_id -> state_group | |||
Args: | |||
event_ids: events to get state groups for | |||
await_full_state: if true, will block if we do not yet have complete | |||
state at these events. | |||
""" | |||
if await_full_state: | |||
await self._partial_state_events_tracker.await_full_state(event_ids) | |||
return await self.stores.main._get_state_group_for_events(event_ids) | |||
async def store_state_group( | |||
self, | |||
event_id: str, | |||
room_id: str, | |||
prev_group: Optional[int], | |||
delta_ids: Optional[StateMap[str]], | |||
current_state_ids: StateMap[str], | |||
) -> int: | |||
"""Store a new set of state, returning a newly assigned state group. | |||
Args: | |||
event_id: The event ID for which the state was calculated. | |||
room_id: ID of the room for which the state was calculated. | |||
prev_group: A previous state group for the room, optional. | |||
delta_ids: The delta between state at `prev_group` and | |||
`current_state_ids`, if `prev_group` was given. Same format as | |||
`current_state_ids`. | |||
current_state_ids: The state to store. Map of (type, state_key) | |||
to event_id. | |||
Returns: | |||
The state group ID | |||
""" | |||
return await self.stores.state.store_state_group( | |||
event_id, room_id, prev_group, delta_ids, current_state_ids | |||
) |
@@ -15,7 +15,6 @@ | |||
import logging | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Awaitable, | |||
Callable, | |||
Collection, | |||
Dict, | |||
@@ -32,15 +31,11 @@ import attr | |||
from frozendict import frozendict | |||
from synapse.api.constants import EventTypes | |||
from synapse.events import EventBase | |||
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker | |||
from synapse.types import MutableStateMap, StateKey, StateMap | |||
if TYPE_CHECKING: | |||
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad | |||
from synapse.server import HomeServer | |||
from synapse.storage.databases import Databases | |||
logger = logging.getLogger(__name__) | |||
@@ -578,318 +573,3 @@ _ALL_NON_MEMBER_STATE_FILTER = StateFilter( | |||
types=frozendict({EventTypes.Member: frozenset()}), include_others=True | |||
) | |||
_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) | |||
class StateGroupStorage: | |||
"""High level interface to fetching state for event.""" | |||
def __init__(self, hs: "HomeServer", stores: "Databases"): | |||
self._is_mine_id = hs.is_mine_id | |||
self.stores = stores | |||
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) | |||
def notify_event_un_partial_stated(self, event_id: str) -> None: | |||
self._partial_state_events_tracker.notify_un_partial_stated(event_id) | |||
async def get_state_group_delta( | |||
self, state_group: int | |||
) -> Tuple[Optional[int], Optional[StateMap[str]]]: | |||
"""Given a state group try to return a previous group and a delta between | |||
the old and the new. | |||
Args: | |||
state_group: The state group used to retrieve state deltas. | |||
Returns: | |||
A tuple of the previous group and a state map of the event IDs which | |||
make up the delta between the old and new state groups. | |||
""" | |||
state_group_delta = await self.stores.state.get_state_group_delta(state_group) | |||
return state_group_delta.prev_group, state_group_delta.delta_ids | |||
async def get_state_groups_ids( | |||
self, _room_id: str, event_ids: Collection[str] | |||
) -> Dict[int, MutableStateMap[str]]: | |||
"""Get the event IDs of all the state for the state groups for the given events | |||
Args: | |||
_room_id: id of the room for these events | |||
event_ids: ids of the events | |||
Returns: | |||
dict of state_group_id -> (dict of (type, state_key) -> event id) | |||
Raises: | |||
RuntimeError if we don't have a state group for one or more of the events | |||
(ie they are outliers or unknown) | |||
""" | |||
if not event_ids: | |||
return {} | |||
event_to_groups = await self.get_state_group_for_events(event_ids) | |||
groups = set(event_to_groups.values()) | |||
group_to_state = await self.stores.state._get_state_for_groups(groups) | |||
return group_to_state | |||
async def get_state_ids_for_group( | |||
self, state_group: int, state_filter: Optional[StateFilter] = None | |||
) -> StateMap[str]: | |||
"""Get the event IDs of all the state in the given state group | |||
Args: | |||
state_group: A state group for which we want to get the state IDs. | |||
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules | |||
Returns: | |||
Resolves to a map of (type, state_key) -> event_id | |||
""" | |||
group_to_state = await self.get_state_for_groups((state_group,), state_filter) | |||
return group_to_state[state_group] | |||
async def get_state_groups( | |||
self, room_id: str, event_ids: Collection[str] | |||
) -> Dict[int, List[EventBase]]: | |||
"""Get the state groups for the given list of event_ids | |||
Args: | |||
room_id: ID of the room for these events. | |||
event_ids: The event IDs to retrieve state for. | |||
Returns: | |||
dict of state_group_id -> list of state events. | |||
""" | |||
if not event_ids: | |||
return {} | |||
group_to_ids = await self.get_state_groups_ids(room_id, event_ids) | |||
state_event_map = await self.stores.main.get_events( | |||
[ | |||
ev_id | |||
for group_ids in group_to_ids.values() | |||
for ev_id in group_ids.values() | |||
], | |||
get_prev_content=False, | |||
) | |||
return { | |||
group: [ | |||
state_event_map[v] | |||
for v in event_id_map.values() | |||
if v in state_event_map | |||
] | |||
for group, event_id_map in group_to_ids.items() | |||
} | |||
def _get_state_groups_from_groups( | |||
self, groups: List[int], state_filter: StateFilter | |||
) -> Awaitable[Dict[int, StateMap[str]]]: | |||
"""Returns the state groups for a given set of groups, filtering on | |||
types of state events. | |||
Args: | |||
groups: list of state group IDs to query | |||
state_filter: The state filter used to fetch state | |||
from the database. | |||
Returns: | |||
Dict of state group to state map. | |||
""" | |||
return self.stores.state._get_state_groups_from_groups(groups, state_filter) | |||
async def get_state_for_events( | |||
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None | |||
) -> Dict[str, StateMap[EventBase]]: | |||
"""Given a list of event_ids and type tuples, return a list of state | |||
dicts for each event. | |||
Args: | |||
event_ids: The events to fetch the state of. | |||
state_filter: The state filter used to fetch state. | |||
Returns: | |||
A dict of (event_id) -> (type, state_key) -> [state_events] | |||
Raises: | |||
RuntimeError if we don't have a state group for one or more of the events | |||
(ie they are outliers or unknown) | |||
""" | |||
await_full_state = True | |||
if state_filter and not state_filter.must_await_full_state(self._is_mine_id): | |||
await_full_state = False | |||
event_to_groups = await self.get_state_group_for_events( | |||
event_ids, await_full_state=await_full_state | |||
) | |||
groups = set(event_to_groups.values()) | |||
group_to_state = await self.stores.state._get_state_for_groups( | |||
groups, state_filter or StateFilter.all() | |||
) | |||
state_event_map = await self.stores.main.get_events( | |||
[ev_id for sd in group_to_state.values() for ev_id in sd.values()], | |||
get_prev_content=False, | |||
) | |||
event_to_state = { | |||
event_id: { | |||
k: state_event_map[v] | |||
for k, v in group_to_state[group].items() | |||
if v in state_event_map | |||
} | |||
for event_id, group in event_to_groups.items() | |||
} | |||
return {event: event_to_state[event] for event in event_ids} | |||
async def get_state_ids_for_events( | |||
self, | |||
event_ids: Collection[str], | |||
state_filter: Optional[StateFilter] = None, | |||
) -> Dict[str, StateMap[str]]: | |||
""" | |||
Get the state dicts corresponding to a list of events, containing the event_ids | |||
of the state events (as opposed to the events themselves) | |||
Args: | |||
event_ids: events whose state should be returned | |||
state_filter: The state filter used to fetch state from the database. | |||
Returns: | |||
A dict from event_id -> (type, state_key) -> event_id | |||
Raises: | |||
RuntimeError if we don't have a state group for one or more of the events | |||
(ie they are outliers or unknown) | |||
""" | |||
await_full_state = True | |||
if state_filter and not state_filter.must_await_full_state(self._is_mine_id): | |||
await_full_state = False | |||
event_to_groups = await self.get_state_group_for_events( | |||
event_ids, await_full_state=await_full_state | |||
) | |||
groups = set(event_to_groups.values()) | |||
group_to_state = await self.stores.state._get_state_for_groups( | |||
groups, state_filter or StateFilter.all() | |||
) | |||
event_to_state = { | |||
event_id: group_to_state[group] | |||
for event_id, group in event_to_groups.items() | |||
} | |||
return {event: event_to_state[event] for event in event_ids} | |||
async def get_state_for_event( | |||
self, event_id: str, state_filter: Optional[StateFilter] = None | |||
) -> StateMap[EventBase]: | |||
""" | |||
Get the state dict corresponding to a particular event | |||
Args: | |||
event_id: event whose state should be returned | |||
state_filter: The state filter used to fetch state from the database. | |||
Returns: | |||
A dict from (type, state_key) -> state_event | |||
Raises: | |||
RuntimeError if we don't have a state group for the event (ie it is an | |||
outlier or is unknown) | |||
""" | |||
state_map = await self.get_state_for_events( | |||
[event_id], state_filter or StateFilter.all() | |||
) | |||
return state_map[event_id] | |||
async def get_state_ids_for_event( | |||
self, event_id: str, state_filter: Optional[StateFilter] = None | |||
) -> StateMap[str]: | |||
""" | |||
Get the state dict corresponding to a particular event | |||
Args: | |||
event_id: event whose state should be returned | |||
state_filter: The state filter used to fetch state from the database. | |||
Returns: | |||
A dict from (type, state_key) -> state_event_id | |||
Raises: | |||
RuntimeError if we don't have a state group for the event (ie it is an | |||
outlier or is unknown) | |||
""" | |||
state_map = await self.get_state_ids_for_events( | |||
[event_id], state_filter or StateFilter.all() | |||
) | |||
return state_map[event_id] | |||
def get_state_for_groups( | |||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None | |||
) -> Awaitable[Dict[int, MutableStateMap[str]]]: | |||
"""Gets the state at each of a list of state groups, optionally | |||
filtering by type/state_key | |||
Args: | |||
groups: list of state groups for which we want to get the state. | |||
state_filter: The state filter used to fetch state. | |||
from the database. | |||
Returns: | |||
Dict of state group to state map. | |||
""" | |||
return self.stores.state._get_state_for_groups( | |||
groups, state_filter or StateFilter.all() | |||
) | |||
async def get_state_group_for_events( | |||
self, | |||
event_ids: Collection[str], | |||
await_full_state: bool = True, | |||
) -> Mapping[str, int]: | |||
"""Returns mapping event_id -> state_group | |||
Args: | |||
event_ids: events to get state groups for | |||
await_full_state: if true, will block if we do not yet have complete | |||
state at these events. | |||
""" | |||
if await_full_state: | |||
await self._partial_state_events_tracker.await_full_state(event_ids) | |||
return await self.stores.main._get_state_group_for_events(event_ids) | |||
async def store_state_group( | |||
self, | |||
event_id: str, | |||
room_id: str, | |||
prev_group: Optional[int], | |||
delta_ids: Optional[StateMap[str]], | |||
current_state_ids: StateMap[str], | |||
) -> int: | |||
"""Store a new set of state, returning a newly assigned state group. | |||
Args: | |||
event_id: The event ID for which the state was calculated. | |||
room_id: ID of the room for which the state was calculated. | |||
prev_group: A previous state group for the room, optional. | |||
delta_ids: The delta between state at `prev_group` and | |||
`current_state_ids`, if `prev_group` was given. Same format as | |||
`current_state_ids`. | |||
current_state_ids: The state to store. Map of (type, state_key) | |||
to event_id. | |||
Returns: | |||
The state group ID | |||
""" | |||
return await self.stores.state.store_state_group( | |||
event_id, room_id, prev_group, delta_ids, current_state_ids | |||
) |
@@ -20,7 +20,7 @@ from typing_extensions import Final | |||
from synapse.api.constants import EventTypes, HistoryVisibility, Membership | |||
from synapse.events import EventBase | |||
from synapse.events.utils import prune_event | |||
from synapse.storage import Storage | |||
from synapse.storage.controllers import StorageControllers | |||
from synapse.storage.state import StateFilter | |||
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id | |||
@@ -47,7 +47,7 @@ _HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, "" | |||
async def filter_events_for_client( | |||
storage: Storage, | |||
storage: StorageControllers, | |||
user_id: str, | |||
events: List[EventBase], | |||
is_peeking: bool = False, | |||
@@ -268,7 +268,7 @@ async def filter_events_for_client( | |||
async def filter_events_for_server( | |||
storage: Storage, | |||
storage: StorageControllers, | |||
server_name: str, | |||
events: List[EventBase], | |||
redact: bool = True, | |||
@@ -360,7 +360,7 @@ async def filter_events_for_server( | |||
async def _event_to_history_vis( | |||
storage: Storage, events: Collection[EventBase] | |||
storage: StorageControllers, events: Collection[EventBase] | |||
) -> Dict[str, str]: | |||
"""Get the history visibility at each of the given events | |||
@@ -407,7 +407,7 @@ async def _event_to_history_vis( | |||
async def _event_to_memberships( | |||
storage: Storage, events: Collection[EventBase], server_name: str | |||
storage: StorageControllers, events: Collection[EventBase], server_name: str | |||
) -> Dict[str, StateMap[EventBase]]: | |||
"""Get the remote membership list at each of the given events | |||
@@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase): | |||
def prepare(self, reactor, clock, hs): | |||
self.store = hs.get_datastores().main | |||
self.storage = hs.get_storage() | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self.user_id = self.register_user("u1", "pass") | |||
self.user_tok = self.login("u1", "pass") | |||
@@ -87,7 +87,7 @@ class TestEventContext(unittest.HomeserverTestCase): | |||
def _check_serialize_deserialize(self, event, context): | |||
serialized = self.get_success(context.serialize(event, self.store)) | |||
d_context = EventContext.deserialize(self.storage, serialized) | |||
d_context = EventContext.deserialize(self._storage_controllers, serialized) | |||
self.assertEqual(context.state_group, d_context.state_group) | |||
self.assertEqual(context.rejected, d_context.rejected) | |||
@@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): | |||
hs = self.setup_test_homeserver(federation_http_client=None) | |||
self.handler = hs.get_federation_handler() | |||
self.store = hs.get_datastores().main | |||
self.state_storage = hs.get_storage().state | |||
self.state_storage_controller = hs.get_storage_controllers().state | |||
self._event_auth_handler = hs.get_event_auth_handler() | |||
return hs | |||
@@ -338,7 +338,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): | |||
# mapping from (type, state_key) -> state_event_id | |||
assert most_recent_prev_event_id is not None | |||
prev_state_map = self.get_success( | |||
self.state_storage.get_state_ids_for_event(most_recent_prev_event_id) | |||
self.state_storage_controller.get_state_ids_for_event( | |||
most_recent_prev_event_id | |||
) | |||
) | |||
# List of state event ID's | |||
prev_state_ids = list(prev_state_map.values()) | |||
@@ -70,7 +70,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): | |||
) -> None: | |||
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" | |||
main_store = self.hs.get_datastores().main | |||
state_storage = self.hs.get_storage().state | |||
state_storage_controller = self.hs.get_storage_controllers().state | |||
# create the room | |||
user_id = self.register_user("kermit", "test") | |||
@@ -146,10 +146,11 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): | |||
) | |||
if prev_exists_as_outlier: | |||
prev_event.internal_metadata.outlier = True | |||
persistence = self.hs.get_storage().persistence | |||
persistence = self.hs.get_storage_controllers().persistence | |||
self.get_success( | |||
persistence.persist_event( | |||
prev_event, EventContext.for_outlier(self.hs.get_storage()) | |||
prev_event, | |||
EventContext.for_outlier(self.hs.get_storage_controllers()), | |||
) | |||
) | |||
else: | |||
@@ -216,7 +217,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): | |||
# check that the state at that event is as expected | |||
state = self.get_success( | |||
state_storage.get_state_ids_for_event(pulled_event.event_id) | |||
state_storage_controller.get_state_ids_for_event(pulled_event.event_id) | |||
) | |||
expected_state = { | |||
(e.type, e.state_key): e.event_id for e in state_at_prev_event | |||
@@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): | |||
def prepare(self, reactor, clock, hs): | |||
self.handler = self.hs.get_event_creation_handler() | |||
self.persist_event_storage = self.hs.get_storage().persistence | |||
self._persist_event_storage_controller = ( | |||
self.hs.get_storage_controllers().persistence | |||
) | |||
self.user_id = self.register_user("tester", "foobar") | |||
self.access_token = self.login("tester", "foobar") | |||
@@ -65,7 +67,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): | |||
) | |||
) | |||
self.get_success( | |||
self.persist_event_storage.persist_event(memberEvent, memberEventContext) | |||
self._persist_event_storage_controller.persist_event( | |||
memberEvent, memberEventContext | |||
) | |||
) | |||
return memberEvent, memberEventContext | |||
@@ -129,7 +133,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): | |||
self.assertNotEqual(event1.event_id, event3.event_id) | |||
ret_event3, event_pos3, _ = self.get_success( | |||
self.persist_event_storage.persist_event(event3, context) | |||
self._persist_event_storage_controller.persist_event(event3, context) | |||
) | |||
# Assert that the returned values match those from the initial event | |||
@@ -143,7 +147,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): | |||
self.assertNotEqual(event1.event_id, event3.event_id) | |||
events, _ = self.get_success( | |||
self.persist_event_storage.persist_events([(event3, context)]) | |||
self._persist_event_storage_controller.persist_events([(event3, context)]) | |||
) | |||
ret_event4 = events[0] | |||
@@ -166,7 +170,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): | |||
self.assertNotEqual(event1.event_id, event2.event_id) | |||
events, _ = self.get_success( | |||
self.persist_event_storage.persist_events( | |||
self._persist_event_storage_controller.persist_events( | |||
[(event1, context1), (event2, context2)] | |||
) | |||
) | |||
@@ -954,7 +954,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): | |||
) | |||
self.get_success( | |||
self.hs.get_storage().persistence.persist_event(event, context) | |||
self.hs.get_storage_controllers().persistence.persist_event(event, context) | |||
) | |||
def test_local_user_leaving_room_remains_in_user_directory(self) -> None: | |||
@@ -32,7 +32,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase): | |||
self.master_store = hs.get_datastores().main | |||
self.slaved_store = self.worker_hs.get_datastores().main | |||
self.storage = hs.get_storage() | |||
self._storage_controllers = hs.get_storage_controllers() | |||
def replicate(self): | |||
"""Tell the master side of replication that something has happened, and then | |||
@@ -262,7 +262,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): | |||
) | |||
msg, msgctx = self.build_event() | |||
self.get_success( | |||
self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)]) | |||
self._storage_controllers.persistence.persist_events( | |||
[(j2, j2ctx), (msg, msgctx)] | |||
) | |||
) | |||
self.replicate() | |||
@@ -323,12 +325,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): | |||
if backfill: | |||
self.get_success( | |||
self.storage.persistence.persist_events( | |||
self._storage_controllers.persistence.persist_events( | |||
[(event, context)], backfilled=True | |||
) | |||
) | |||
else: | |||
self.get_success(self.storage.persistence.persist_event(event, context)) | |||
self.get_success( | |||
self._storage_controllers.persistence.persist_event(event, context) | |||
) | |||
return event | |||
@@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): | |||
def prepare(self, reactor, clock, homeserver): | |||
super().prepare(reactor, clock, homeserver) | |||
self.room_creator = homeserver.get_room_creation_handler() | |||
self.persist_event_storage = self.hs.get_storage().persistence | |||
self.persist_event_storage_controller = ( | |||
self.hs.get_storage_controllers().persistence | |||
) | |||
# Create a test user | |||
self.ourUser = UserID.from_string(OUR_USER_ID) | |||
@@ -61,7 +63,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): | |||
) | |||
) | |||
self.get_success( | |||
self.persist_event_storage.persist_event(memberEvent, memberEventContext) | |||
self.persist_event_storage_controller.persist_event( | |||
memberEvent, memberEventContext | |||
) | |||
) | |||
# Join the second user to the second room | |||
@@ -76,7 +80,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): | |||
) | |||
) | |||
self.get_success( | |||
self.persist_event_storage.persist_event(memberEvent, memberEventContext) | |||
self.persist_event_storage_controller.persist_event( | |||
memberEvent, memberEventContext | |||
) | |||
) | |||
def test_return_empty_with_no_data(self): | |||
@@ -2579,7 +2579,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): | |||
other_user_tok = self.login("user", "pass") | |||
event_builder_factory = self.hs.get_event_builder_factory() | |||
event_creation_handler = self.hs.get_event_creation_handler() | |||
storage = self.hs.get_storage() | |||
storage_controllers = self.hs.get_storage_controllers() | |||
# Create two rooms, one with a local user only and one with both a local | |||
# and remote user. | |||
@@ -2604,7 +2604,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): | |||
event_creation_handler.create_new_client_event(builder) | |||
) | |||
self.get_success(storage.persistence.persist_event(event, context)) | |||
self.get_success(storage_controllers.persistence.persist_event(event, context)) | |||
# Now get rooms | |||
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" | |||
@@ -130,7 +130,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): | |||
We do this by setting a very long time between purge jobs. | |||
""" | |||
store = self.hs.get_datastores().main | |||
storage = self.hs.get_storage() | |||
storage_controllers = self.hs.get_storage_controllers() | |||
room_id = self.helper.create_room_as(self.user_id, tok=self.token) | |||
# Send a first event, which should be filtered out at the end of the test. | |||
@@ -155,7 +155,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): | |||
) | |||
self.assertEqual(2, len(events), "events retrieved from database") | |||
filtered_events = self.get_success( | |||
filter_events_for_client(storage, self.user_id, events) | |||
filter_events_for_client(storage_controllers, self.user_id, events) | |||
) | |||
# We should only get one event back. | |||
@@ -88,7 +88,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.clock = clock | |||
self.storage = hs.get_storage() | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self.virtual_user_id, _ = self.register_appservice_user( | |||
"as_user_potato", self.appservice.token | |||
@@ -168,7 +168,9 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): | |||
# Fetch the state_groups | |||
state_group_map = self.get_success( | |||
self.storage.state.get_state_groups_ids(room_id, historical_event_ids) | |||
self._storage_controllers.state.get_state_groups_ids( | |||
room_id, historical_event_ids | |||
) | |||
) | |||
# We expect all of the historical events to be using the same state_group | |||
@@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase): | |||
# We need to persist the events to the events and state_events | |||
# tables. | |||
persist_events_store._store_event_txn( | |||
txn, [(e, EventContext(self.hs.get_storage())) for e in events] | |||
txn, | |||
[(e, EventContext(self.hs.get_storage_controllers())) for e in events], | |||
) | |||
# Actually call the function that calculates the auth chain stuff. | |||
@@ -31,7 +31,7 @@ class ExtremPruneTestCase(HomeserverTestCase): | |||
def prepare(self, reactor, clock, homeserver): | |||
self.state = self.hs.get_state_handler() | |||
self.persistence = self.hs.get_storage().persistence | |||
self._persistence = self.hs.get_storage_controllers().persistence | |||
self.store = self.hs.get_datastores().main | |||
self.register_user("user", "pass") | |||
@@ -71,7 +71,7 @@ class ExtremPruneTestCase(HomeserverTestCase): | |||
context = self.get_success( | |||
self.state.compute_event_context(event, state_ids_before_event=state) | |||
) | |||
self.get_success(self.persistence.persist_event(event, context)) | |||
self.get_success(self._persistence.persist_event(event, context)) | |||
def assert_extremities(self, expected_extremities): | |||
"""Assert the current extremities for the room""" | |||
@@ -148,7 +148,7 @@ class ExtremPruneTestCase(HomeserverTestCase): | |||
) | |||
) | |||
self.get_success(self.persistence.persist_event(remote_event_2, context)) | |||
self.get_success(self._persistence.persist_event(remote_event_2, context)) | |||
# Check that we haven't dropped the old extremity. | |||
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) | |||
@@ -353,7 +353,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): | |||
def prepare(self, reactor, clock, homeserver): | |||
self.state = self.hs.get_state_handler() | |||
self.persistence = self.hs.get_storage().persistence | |||
self._persistence = self.hs.get_storage_controllers().persistence | |||
self.store = self.hs.get_datastores().main | |||
def test_remote_user_rooms_cache_invalidated(self): | |||
@@ -390,7 +390,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): | |||
) | |||
context = self.get_success(self.state.compute_event_context(remote_event_1)) | |||
self.get_success(self.persistence.persist_event(remote_event_1, context)) | |||
self.get_success(self._persistence.persist_event(remote_event_1, context)) | |||
# Call `get_rooms_for_user` to add the remote user to the cache | |||
rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) | |||
@@ -437,7 +437,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): | |||
) | |||
context = self.get_success(self.state.compute_event_context(remote_event_1)) | |||
self.get_success(self.persistence.persist_event(remote_event_1, context)) | |||
self.get_success(self._persistence.persist_event(remote_event_1, context)) | |||
# Call `get_users_in_room` to add the remote user to the cache | |||
users = self.get_success(self.store.get_users_in_room(room_id)) | |||
@@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase): | |||
self.room_id = self.helper.create_room_as(self.user_id) | |||
self.store = hs.get_datastores().main | |||
self.storage = self.hs.get_storage() | |||
self._storage_controllers = self.hs.get_storage_controllers() | |||
def test_purge_history(self): | |||
""" | |||
@@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase): | |||
# Purge everything before this topological token | |||
self.get_success( | |||
self.storage.purge_events.purge_history(self.room_id, token_str, True) | |||
self._storage_controllers.purge_events.purge_history( | |||
self.room_id, token_str, True | |||
) | |||
) | |||
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted | |||
@@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase): | |||
# Purge everything before this topological token | |||
f = self.get_failure( | |||
self.storage.purge_events.purge_history(self.room_id, event, True), | |||
self._storage_controllers.purge_events.purge_history( | |||
self.room_id, event, True | |||
), | |||
SynapseError, | |||
) | |||
self.assertIn("greater than forward", f.value.args[0]) | |||
@@ -105,7 +109,9 @@ class PurgeTests(HomeserverTestCase): | |||
self.assertIsNotNone(create_event) | |||
# Purge everything before this topological token | |||
self.get_success(self.storage.purge_events.purge_room(self.room_id)) | |||
self.get_success( | |||
self._storage_controllers.purge_events.purge_room(self.room_id) | |||
) | |||
# The events aren't found. | |||
self.store._invalidate_get_event_cache(create_event.event_id) | |||
@@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
def prepare(self, reactor, clock, hs): | |||
self.store = hs.get_datastores().main | |||
self.storage = hs.get_storage() | |||
self._storage = hs.get_storage_controllers() | |||
self.event_builder_factory = hs.get_event_builder_factory() | |||
self.event_creation_handler = hs.get_event_creation_handler() | |||
@@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
self.get_success(self.storage.persistence.persist_event(event, context)) | |||
self.get_success(self._storage.persistence.persist_event(event, context)) | |||
return event | |||
@@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
self.get_success(self.storage.persistence.persist_event(event, context)) | |||
self.get_success(self._storage.persistence.persist_event(event, context)) | |||
return event | |||
@@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
self.get_success(self.storage.persistence.persist_event(event, context)) | |||
self.get_success(self._storage.persistence.persist_event(event, context)) | |||
return event | |||
@@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
) | |||
) | |||
self.get_success(self.storage.persistence.persist_event(event_1, context_1)) | |||
self.get_success(self._storage.persistence.persist_event(event_1, context_1)) | |||
event_2, context_2 = self.get_success( | |||
self.event_creation_handler.create_new_client_event( | |||
@@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
) | |||
) | |||
) | |||
self.get_success(self.storage.persistence.persist_event(event_2, context_2)) | |||
self.get_success(self._storage.persistence.persist_event(event_2, context_2)) | |||
# fetch one of the redactions | |||
fetched = self.get_success(self.store.get_event(redaction_event_id1)) | |||
@@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
) | |||
self.get_success( | |||
self.storage.persistence.persist_event(redaction_event, context) | |||
self._storage.persistence.persist_event(redaction_event, context) | |||
) | |||
# Now lets jump to the future where we have censored the redaction event | |||
@@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): | |||
# Room events need the full datastore, for persist_event() and | |||
# get_room_state() | |||
self.store = hs.get_datastores().main | |||
self.storage = hs.get_storage() | |||
self._storage = hs.get_storage_controllers() | |||
self.event_factory = hs.get_event_factory() | |||
self.room = RoomID.from_string("!abcde:test") | |||
@@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): | |||
def inject_room_event(self, **kwargs): | |||
self.get_success( | |||
self.storage.persistence.persist_event( | |||
self._storage.persistence.persist_event( | |||
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) | |||
) | |||
) | |||
@@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase): | |||
prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) | |||
prev_event = self.get_success(store.get_event(prev_event_ids[0])) | |||
prev_state_map = self.get_success( | |||
self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0]) | |||
self.hs.get_storage_controllers().state.get_state_ids_for_event( | |||
prev_event_ids[0] | |||
) | |||
) | |||
event_dict = { | |||
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) | |||
class StateStoreTestCase(HomeserverTestCase): | |||
def prepare(self, reactor, clock, hs): | |||
self.store = hs.get_datastores().main | |||
self.storage = hs.get_storage() | |||
self.storage = hs.get_storage_controllers() | |||
self.state_datastore = self.storage.state.stores.state | |||
self.event_builder_factory = hs.get_event_builder_factory() | |||
self.event_creation_handler = hs.get_event_creation_handler() | |||
@@ -179,12 +179,12 @@ class Graph: | |||
class StateTestCase(unittest.TestCase): | |||
def setUp(self): | |||
self.dummy_store = _DummyStore() | |||
storage = Mock(main=self.dummy_store, state=self.dummy_store) | |||
storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store) | |||
hs = Mock( | |||
spec_set=[ | |||
"config", | |||
"get_datastores", | |||
"get_storage", | |||
"get_storage_controllers", | |||
"get_auth", | |||
"get_state_handler", | |||
"get_clock", | |||
@@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase): | |||
hs.get_clock.return_value = MockClock() | |||
hs.get_auth.return_value = Auth(hs) | |||
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) | |||
hs.get_storage.return_value = storage | |||
hs.get_storage_controllers.return_value = storage_controllers | |||
self.state = StateHandler(hs) | |||
self.event_id = 0 | |||
@@ -70,7 +70,7 @@ async def inject_event( | |||
""" | |||
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) | |||
persistence = hs.get_storage().persistence | |||
persistence = hs.get_storage_controllers().persistence | |||
assert persistence is not None | |||
await persistence.persist_event(event, context) | |||
@@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||
super(FilterEventsForServerTestCase, self).setUp() | |||
self.event_creation_handler = self.hs.get_event_creation_handler() | |||
self.event_builder_factory = self.hs.get_event_builder_factory() | |||
self.storage = self.hs.get_storage() | |||
self._storage_controllers = self.hs.get_storage_controllers() | |||
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) | |||
@@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||
events_to_filter.append(evt) | |||
filtered = self.get_success( | |||
filter_events_for_server(self.storage, "test_server", events_to_filter) | |||
filter_events_for_server( | |||
self._storage_controllers, "test_server", events_to_filter | |||
) | |||
) | |||
# the result should be 5 redacted events, and 5 unredacted events. | |||
@@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||
outlier = self._inject_outlier() | |||
self.assertEqual( | |||
self.get_success( | |||
filter_events_for_server(self.storage, "remote_hs", [outlier]) | |||
filter_events_for_server( | |||
self._storage_controllers, "remote_hs", [outlier] | |||
) | |||
), | |||
[outlier], | |||
) | |||
@@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||
evt = self._inject_message("@unerased:local_hs") | |||
filtered = self.get_success( | |||
filter_events_for_server(self.storage, "remote_hs", [outlier, evt]) | |||
filter_events_for_server( | |||
self._storage_controllers, "remote_hs", [outlier, evt] | |||
) | |||
) | |||
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}") | |||
self.assertEqual(filtered[0], outlier) | |||
@@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||
# ... but other servers should only be able to see the outlier (the other should | |||
# be redacted) | |||
filtered = self.get_success( | |||
filter_events_for_server(self.storage, "other_server", [outlier, evt]) | |||
filter_events_for_server( | |||
self._storage_controllers, "other_server", [outlier, evt] | |||
) | |||
) | |||
self.assertEqual(filtered[0], outlier) | |||
self.assertEqual(filtered[1].event_id, evt.event_id) | |||
@@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||
# ... and the filtering happens. | |||
filtered = self.get_success( | |||
filter_events_for_server(self.storage, "test_server", events_to_filter) | |||
filter_events_for_server( | |||
self._storage_controllers, "test_server", events_to_filter | |||
) | |||
) | |||
for i in range(0, len(events_to_filter)): | |||
@@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||
event, context = self.get_success( | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
self.get_success(self.storage.persistence.persist_event(event, context)) | |||
self.get_success( | |||
self._storage_controllers.persistence.persist_event(event, context) | |||
) | |||
return event | |||
def _inject_room_member( | |||
@@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
self.get_success(self.storage.persistence.persist_event(event, context)) | |||
self.get_success( | |||
self._storage_controllers.persistence.persist_event(event, context) | |||
) | |||
return event | |||
def _inject_message( | |||
@@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
self.get_success(self.storage.persistence.persist_event(event, context)) | |||
self.get_success( | |||
self._storage_controllers.persistence.persist_event(event, context) | |||
) | |||
return event | |||
def _inject_outlier(self) -> EventBase: | |||
@@ -234,8 +250,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) | |||
event.internal_metadata.outlier = True | |||
self.get_success( | |||
self.storage.persistence.persist_event( | |||
event, EventContext.for_outlier(self.storage) | |||
self._storage_controllers.persistence.persist_event( | |||
event, EventContext.for_outlier(self._storage_controllers) | |||
) | |||
) | |||
return event | |||
@@ -293,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): | |||
self.assertEqual( | |||
self.get_success( | |||
filter_events_for_client( | |||
self.hs.get_storage(), "@user:test", [invite_event, reject_event] | |||
self.hs.get_storage_controllers(), | |||
"@user:test", | |||
[invite_event, reject_event], | |||
) | |||
), | |||
[invite_event, reject_event], | |||
@@ -303,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): | |||
self.assertEqual( | |||
self.get_success( | |||
filter_events_for_client( | |||
self.hs.get_storage(), "@other:test", [invite_event, reject_event] | |||
self.hs.get_storage_controllers(), | |||
"@other:test", | |||
[invite_event, reject_event], | |||
) | |||
), | |||
[], | |||
@@ -264,7 +264,7 @@ class MockClock: | |||
async def create_room(hs, room_id: str, creator_id: str): | |||
"""Creates and persist a creation event for the given room""" | |||
persistence_store = hs.get_storage().persistence | |||
persistence_store = hs.get_storage_controllers().persistence | |||
store = hs.get_datastores().main | |||
event_builder_factory = hs.get_event_builder_factory() | |||
event_creation_handler = hs.get_event_creation_handler() | |||