@@ -0,0 +1 @@ | |||
Convert various parts of the codebase to async/await. |
@@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram | |||
from twisted.internet import defer | |||
from synapse.api.constants import EventTypes, Membership | |||
from synapse.events import FrozenEvent | |||
from synapse.events import EventBase | |||
from synapse.events.snapshot import EventContext | |||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
@@ -192,12 +192,11 @@ class EventsPersistenceStorage(object): | |||
self._event_persist_queue = _EventPeristenceQueue() | |||
self._state_resolution_handler = hs.get_state_resolution_handler() | |||
@defer.inlineCallbacks | |||
def persist_events( | |||
async def persist_events( | |||
self, | |||
events_and_contexts: List[Tuple[FrozenEvent, EventContext]], | |||
events_and_contexts: List[Tuple[EventBase, EventContext]], | |||
backfilled: bool = False, | |||
): | |||
) -> int: | |||
""" | |||
Write events to the database | |||
Args: | |||
@@ -207,7 +206,7 @@ class EventsPersistenceStorage(object): | |||
which might update the current state etc. | |||
Returns: | |||
Deferred[int]: the stream ordering of the latest persisted event | |||
the stream ordering of the latest persisted event | |||
""" | |||
partitioned = {} | |||
for event, ctx in events_and_contexts: | |||
@@ -223,22 +222,19 @@ class EventsPersistenceStorage(object): | |||
for room_id in partitioned: | |||
self._maybe_start_persisting(room_id) | |||
yield make_deferred_yieldable( | |||
await make_deferred_yieldable( | |||
defer.gatherResults(deferreds, consumeErrors=True) | |||
) | |||
max_persisted_id = yield self.main_store.get_current_events_token() | |||
return max_persisted_id | |||
return self.main_store.get_current_events_token() | |||
@defer.inlineCallbacks | |||
def persist_event( | |||
self, event: FrozenEvent, context: EventContext, backfilled: bool = False | |||
): | |||
async def persist_event( | |||
self, event: EventBase, context: EventContext, backfilled: bool = False | |||
) -> Tuple[int, int]: | |||
""" | |||
Returns: | |||
Deferred[Tuple[int, int]]: the stream ordering of ``event``, | |||
and the stream ordering of the latest persisted event | |||
The stream ordering of `event`, and the stream ordering of the | |||
latest persisted event | |||
""" | |||
deferred = self._event_persist_queue.add_to_queue( | |||
event.room_id, [(event, context)], backfilled=backfilled | |||
@@ -246,9 +242,9 @@ class EventsPersistenceStorage(object): | |||
self._maybe_start_persisting(event.room_id) | |||
yield make_deferred_yieldable(deferred) | |||
await make_deferred_yieldable(deferred) | |||
max_persisted_id = yield self.main_store.get_current_events_token() | |||
max_persisted_id = self.main_store.get_current_events_token() | |||
return (event.internal_metadata.stream_ordering, max_persisted_id) | |||
def _maybe_start_persisting(self, room_id: str): | |||
@@ -262,7 +258,7 @@ class EventsPersistenceStorage(object): | |||
async def _persist_events( | |||
self, | |||
events_and_contexts: List[Tuple[FrozenEvent, EventContext]], | |||
events_and_contexts: List[Tuple[EventBase, EventContext]], | |||
backfilled: bool = False, | |||
): | |||
"""Calculates the change to current state and forward extremities, and | |||
@@ -439,7 +435,7 @@ class EventsPersistenceStorage(object): | |||
async def _calculate_new_extremities( | |||
self, | |||
room_id: str, | |||
event_contexts: List[Tuple[FrozenEvent, EventContext]], | |||
event_contexts: List[Tuple[EventBase, EventContext]], | |||
latest_event_ids: List[str], | |||
): | |||
"""Calculates the new forward extremities for a room given events to | |||
@@ -497,7 +493,7 @@ class EventsPersistenceStorage(object): | |||
async def _get_new_state_after_events( | |||
self, | |||
room_id: str, | |||
events_context: List[Tuple[FrozenEvent, EventContext]], | |||
events_context: List[Tuple[EventBase, EventContext]], | |||
old_latest_event_ids: Iterable[str], | |||
new_latest_event_ids: Iterable[str], | |||
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: | |||
@@ -683,7 +679,7 @@ class EventsPersistenceStorage(object): | |||
async def _is_server_still_joined( | |||
self, | |||
room_id: str, | |||
ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]], | |||
ev_ctx_rm: List[Tuple[EventBase, EventContext]], | |||
delta: DeltaState, | |||
current_state: Optional[StateMap[str]], | |||
potentially_left_users: Set[str], | |||
@@ -15,8 +15,7 @@ | |||
import itertools | |||
import logging | |||
from twisted.internet import defer | |||
from typing import Set | |||
logger = logging.getLogger(__name__) | |||
@@ -28,49 +27,48 @@ class PurgeEventsStorage(object): | |||
def __init__(self, hs, stores): | |||
self.stores = stores | |||
@defer.inlineCallbacks | |||
def purge_room(self, room_id: str): | |||
async def purge_room(self, room_id: str): | |||
"""Deletes all record of a room | |||
""" | |||
state_groups_to_delete = yield self.stores.main.purge_room(room_id) | |||
yield self.stores.state.purge_room_state(room_id, state_groups_to_delete) | |||
state_groups_to_delete = await self.stores.main.purge_room(room_id) | |||
await self.stores.state.purge_room_state(room_id, state_groups_to_delete) | |||
@defer.inlineCallbacks | |||
def purge_history(self, room_id, token, delete_local_events): | |||
async def purge_history( | |||
self, room_id: str, token: str, delete_local_events: bool | |||
) -> None: | |||
"""Deletes room history before a certain point | |||
Args: | |||
room_id (str): | |||
room_id: The room ID | |||
token (str): A topological token to delete events before | |||
token: A topological token to delete events before | |||
delete_local_events (bool): | |||
delete_local_events: | |||
if True, we will delete local events as well as remote ones | |||
(instead of just marking them as outliers and deleting their | |||
state groups). | |||
""" | |||
state_groups = yield self.stores.main.purge_history( | |||
state_groups = await self.stores.main.purge_history( | |||
room_id, token, delete_local_events | |||
) | |||
logger.info("[purge] finding state groups that can be deleted") | |||
sg_to_delete = yield self._find_unreferenced_groups(state_groups) | |||
sg_to_delete = await self._find_unreferenced_groups(state_groups) | |||
yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) | |||
await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) | |||
@defer.inlineCallbacks | |||
def _find_unreferenced_groups(self, state_groups): | |||
async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]: | |||
"""Used when purging history to figure out which state groups can be | |||
deleted. | |||
Args: | |||
state_groups (set[int]): Set of state groups referenced by events | |||
state_groups: Set of state groups referenced by events | |||
that are going to be deleted. | |||
Returns: | |||
Deferred[set[int]] The set of state groups that can be deleted. | |||
The set of state groups that can be deleted. | |||
""" | |||
# Graph of state group -> previous group | |||
graph = {} | |||
@@ -93,7 +91,7 @@ class PurgeEventsStorage(object): | |||
current_search = set(itertools.islice(next_to_search, 100)) | |||
next_to_search -= current_search | |||
referenced = yield self.stores.main.get_referenced_state_groups( | |||
referenced = await self.stores.main.get_referenced_state_groups( | |||
current_search | |||
) | |||
referenced_groups |= referenced | |||
@@ -102,7 +100,7 @@ class PurgeEventsStorage(object): | |||
# groups that are referenced. | |||
current_search -= referenced | |||
edges = yield self.stores.state.get_previous_state_groups(current_search) | |||
edges = await self.stores.state.get_previous_state_groups(current_search) | |||
prevs = set(edges.values()) | |||
# We don't bother re-handling groups we've already seen | |||
@@ -14,13 +14,12 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Iterable, List, TypeVar | |||
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar | |||
import attr | |||
from twisted.internet import defer | |||
from synapse.api.constants import EventTypes | |||
from synapse.events import EventBase | |||
from synapse.types import StateMap | |||
logger = logging.getLogger(__name__) | |||
@@ -34,16 +33,16 @@ class StateFilter(object): | |||
"""A filter used when querying for state. | |||
Attributes: | |||
types (dict[str, set[str]|None]): Map from type to set of state keys (or | |||
None). This specifies which state_keys for the given type to fetch | |||
from the DB. If None then all events with that type are fetched. If | |||
the set is empty then no events with that type are fetched. | |||
include_others (bool): Whether to fetch events with types that do not | |||
types: Map from type to set of state keys (or None). This specifies | |||
which state_keys for the given type to fetch from the DB. If None | |||
then all events with that type are fetched. If the set is empty | |||
then no events with that type are fetched. | |||
include_others: Whether to fetch events with types that do not | |||
appear in `types`. | |||
""" | |||
types = attr.ib() | |||
include_others = attr.ib(default=False) | |||
types = attr.ib(type=Dict[str, Optional[Set[str]]]) | |||
include_others = attr.ib(default=False, type=bool) | |||
def __attrs_post_init__(self): | |||
# If `include_others` is set we canonicalise the filter by removing | |||
@@ -52,36 +51,35 @@ class StateFilter(object): | |||
self.types = {k: v for k, v in self.types.items() if v is not None} | |||
@staticmethod | |||
def all(): | |||
def all() -> "StateFilter": | |||
"""Creates a filter that fetches everything. | |||
Returns: | |||
StateFilter | |||
The new state filter. | |||
""" | |||
return StateFilter(types={}, include_others=True) | |||
@staticmethod | |||
def none(): | |||
def none() -> "StateFilter": | |||
"""Creates a filter that fetches nothing. | |||
Returns: | |||
StateFilter | |||
The new state filter. | |||
""" | |||
return StateFilter(types={}, include_others=False) | |||
@staticmethod | |||
def from_types(types): | |||
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": | |||
"""Creates a filter that only fetches the given types | |||
Args: | |||
types (Iterable[tuple[str, str|None]]): A list of type and state | |||
keys to fetch. A state_key of None fetches everything for | |||
that type | |||
types: A list of type and state keys to fetch. A state_key of None | |||
fetches everything for that type | |||
Returns: | |||
StateFilter | |||
The new state filter. | |||
""" | |||
type_dict = {} | |||
type_dict = {} # type: Dict[str, Optional[Set[str]]] | |||
for typ, s in types: | |||
if typ in type_dict: | |||
if type_dict[typ] is None: | |||
@@ -91,24 +89,24 @@ class StateFilter(object): | |||
type_dict[typ] = None | |||
continue | |||
type_dict.setdefault(typ, set()).add(s) | |||
type_dict.setdefault(typ, set()).add(s) # type: ignore | |||
return StateFilter(types=type_dict) | |||
@staticmethod | |||
def from_lazy_load_member_list(members): | |||
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": | |||
"""Creates a filter that returns all non-member events, plus the member | |||
events for the given users | |||
Args: | |||
members (iterable[str]): Set of user IDs | |||
members: Set of user IDs | |||
Returns: | |||
StateFilter | |||
The new state filter | |||
""" | |||
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) | |||
def return_expanded(self): | |||
def return_expanded(self) -> "StateFilter": | |||
"""Creates a new StateFilter where type wild cards have been removed | |||
(except for memberships). The returned filter is a superset of the | |||
current one, i.e. anything that passes the current filter will pass | |||
@@ -130,7 +128,7 @@ class StateFilter(object): | |||
return all non-member events | |||
Returns: | |||
StateFilter | |||
The new state filter. | |||
""" | |||
if self.is_full(): | |||
@@ -167,7 +165,7 @@ class StateFilter(object): | |||
include_others=True, | |||
) | |||
def make_sql_filter_clause(self): | |||
def make_sql_filter_clause(self) -> Tuple[str, List[str]]: | |||
"""Converts the filter to an SQL clause. | |||
For example: | |||
@@ -179,13 +177,12 @@ class StateFilter(object): | |||
Returns: | |||
tuple[str, list]: The SQL string (may be empty) and arguments. An | |||
empty SQL string is returned when the filter matches everything | |||
(i.e. is "full"). | |||
The SQL string (may be empty) and arguments. An empty SQL string is | |||
returned when the filter matches everything (i.e. is "full"). | |||
""" | |||
where_clause = "" | |||
where_args = [] | |||
where_args = [] # type: List[str] | |||
if self.is_full(): | |||
return where_clause, where_args | |||
@@ -221,7 +218,7 @@ class StateFilter(object): | |||
return where_clause, where_args | |||
def max_entries_returned(self): | |||
def max_entries_returned(self) -> Optional[int]: | |||
"""Returns the maximum number of entries this filter will return if | |||
known, otherwise returns None. | |||
@@ -260,33 +257,33 @@ class StateFilter(object): | |||
return filtered_state | |||
def is_full(self): | |||
def is_full(self) -> bool: | |||
"""Whether this filter fetches everything or not | |||
Returns: | |||
bool | |||
True if the filter fetches everything. | |||
""" | |||
return self.include_others and not self.types | |||
def has_wildcards(self): | |||
def has_wildcards(self) -> bool: | |||
"""Whether the filter includes wildcards or is attempting to fetch | |||
specific state. | |||
Returns: | |||
bool | |||
True if the filter includes wildcards. | |||
""" | |||
return self.include_others or any( | |||
state_keys is None for state_keys in self.types.values() | |||
) | |||
def concrete_types(self): | |||
def concrete_types(self) -> List[Tuple[str, str]]: | |||
"""Returns a list of concrete type/state_keys (i.e. not None) that | |||
will be fetched. This will be a complete list if `has_wildcards` | |||
returns False, but otherwise will be a subset (or even empty). | |||
Returns: | |||
list[tuple[str,str]] | |||
A list of type/state_keys tuples. | |||
""" | |||
return [ | |||
(t, s) | |||
@@ -295,7 +292,7 @@ class StateFilter(object): | |||
for s in state_keys | |||
] | |||
def get_member_split(self): | |||
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: | |||
"""Return the filter split into two: one which assumes it's exclusively | |||
matching against member state, and one which assumes it's matching | |||
against non member state. | |||
@@ -307,7 +304,7 @@ class StateFilter(object): | |||
state caches). | |||
Returns: | |||
tuple[StateFilter, StateFilter]: The member and non member filters | |||
The member and non member filters | |||
""" | |||
if EventTypes.Member in self.types: | |||
@@ -340,6 +337,9 @@ class StateGroupStorage(object): | |||
"""Given a state group try to return a previous group and a delta between | |||
the old and the new. | |||
Args: | |||
state_group: The state group used to retrieve state deltas. | |||
Returns: | |||
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: | |||
(prev_group, delta_ids) | |||
@@ -347,55 +347,59 @@ class StateGroupStorage(object): | |||
return self.stores.state.get_state_group_delta(state_group) | |||
@defer.inlineCallbacks | |||
def get_state_groups_ids(self, _room_id, event_ids): | |||
async def get_state_groups_ids( | |||
self, _room_id: str, event_ids: Iterable[str] | |||
) -> Dict[int, StateMap[str]]: | |||
"""Get the event IDs of all the state for the state groups for the given events | |||
Args: | |||
_room_id (str): id of the room for these events | |||
event_ids (iterable[str]): ids of the events | |||
_room_id: id of the room for these events | |||
event_ids: ids of the events | |||
Returns: | |||
Deferred[dict[int, StateMap[str]]]: | |||
dict of state_group_id -> (dict of (type, state_key) -> event id) | |||
dict of state_group_id -> (dict of (type, state_key) -> event id) | |||
""" | |||
if not event_ids: | |||
return {} | |||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) | |||
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) | |||
groups = set(event_to_groups.values()) | |||
group_to_state = yield self.stores.state._get_state_for_groups(groups) | |||
group_to_state = await self.stores.state._get_state_for_groups(groups) | |||
return group_to_state | |||
@defer.inlineCallbacks | |||
def get_state_ids_for_group(self, state_group): | |||
async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]: | |||
"""Get the event IDs of all the state in the given state group | |||
Args: | |||
state_group (int) | |||
state_group: A state group for which we want to get the state IDs. | |||
Returns: | |||
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id | |||
Resolves to a map of (type, state_key) -> event_id | |||
""" | |||
group_to_state = yield self._get_state_for_groups((state_group,)) | |||
group_to_state = await self._get_state_for_groups((state_group,)) | |||
return group_to_state[state_group] | |||
@defer.inlineCallbacks | |||
def get_state_groups(self, room_id, event_ids): | |||
async def get_state_groups( | |||
self, room_id: str, event_ids: Iterable[str] | |||
) -> Dict[int, List[EventBase]]: | |||
""" Get the state groups for the given list of event_ids | |||
Args: | |||
room_id: ID of the room for these events. | |||
event_ids: The event IDs to retrieve state for. | |||
Returns: | |||
Deferred[dict[int, list[EventBase]]]: | |||
dict of state_group_id -> list of state events. | |||
dict of state_group_id -> list of state events. | |||
""" | |||
if not event_ids: | |||
return {} | |||
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) | |||
group_to_ids = await self.get_state_groups_ids(room_id, event_ids) | |||
state_event_map = yield self.stores.main.get_events( | |||
state_event_map = await self.stores.main.get_events( | |||
[ | |||
ev_id | |||
for group_ids in group_to_ids.values() | |||
@@ -423,31 +427,34 @@ class StateGroupStorage(object): | |||
groups: list of state group IDs to query | |||
state_filter: The state filter used to fetch state | |||
from the database. | |||
Returns: | |||
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. | |||
""" | |||
return self.stores.state._get_state_groups_from_groups(groups, state_filter) | |||
@defer.inlineCallbacks | |||
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): | |||
async def get_state_for_events( | |||
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() | |||
): | |||
"""Given a list of event_ids and type tuples, return a list of state | |||
dicts for each event. | |||
Args: | |||
event_ids (list[string]) | |||
state_filter (StateFilter): The state filter used to fetch state | |||
from the database. | |||
event_ids: The events to fetch the state of. | |||
state_filter: The state filter used to fetch state. | |||
Returns: | |||
deferred: A dict of (event_id) -> (type, state_key) -> [state_events] | |||
A dict of (event_id) -> (type, state_key) -> [state_events] | |||
""" | |||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) | |||
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) | |||
groups = set(event_to_groups.values()) | |||
group_to_state = yield self.stores.state._get_state_for_groups( | |||
group_to_state = await self.stores.state._get_state_for_groups( | |||
groups, state_filter | |||
) | |||
state_event_map = yield self.stores.main.get_events( | |||
state_event_map = await self.stores.main.get_events( | |||
[ev_id for sd in group_to_state.values() for ev_id in sd.values()], | |||
get_prev_content=False, | |||
) | |||
@@ -463,24 +470,24 @@ class StateGroupStorage(object): | |||
return {event: event_to_state[event] for event in event_ids} | |||
@defer.inlineCallbacks | |||
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): | |||
async def get_state_ids_for_events( | |||
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() | |||
): | |||
""" | |||
Get the state dicts corresponding to a list of events, containing the event_ids | |||
of the state events (as opposed to the events themselves) | |||
Args: | |||
event_ids(list(str)): events whose state should be returned | |||
state_filter (StateFilter): The state filter used to fetch state | |||
from the database. | |||
event_ids: events whose state should be returned | |||
state_filter: The state filter used to fetch state from the database. | |||
Returns: | |||
A deferred dict from event_id -> (type, state_key) -> event_id | |||
A dict from event_id -> (type, state_key) -> event_id | |||
""" | |||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) | |||
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) | |||
groups = set(event_to_groups.values()) | |||
group_to_state = yield self.stores.state._get_state_for_groups( | |||
group_to_state = await self.stores.state._get_state_for_groups( | |||
groups, state_filter | |||
) | |||
@@ -491,36 +498,36 @@ class StateGroupStorage(object): | |||
return {event: event_to_state[event] for event in event_ids} | |||
@defer.inlineCallbacks | |||
def get_state_for_event(self, event_id, state_filter=StateFilter.all()): | |||
async def get_state_for_event( | |||
self, event_id: str, state_filter: StateFilter = StateFilter.all() | |||
): | |||
""" | |||
Get the state dict corresponding to a particular event | |||
Args: | |||
event_id(str): event whose state should be returned | |||
state_filter (StateFilter): The state filter used to fetch state | |||
from the database. | |||
event_id: event whose state should be returned | |||
state_filter: The state filter used to fetch state from the database. | |||
Returns: | |||
A deferred dict from (type, state_key) -> state_event | |||
A dict from (type, state_key) -> state_event | |||
""" | |||
state_map = yield self.get_state_for_events([event_id], state_filter) | |||
state_map = await self.get_state_for_events([event_id], state_filter) | |||
return state_map[event_id] | |||
@defer.inlineCallbacks | |||
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): | |||
async def get_state_ids_for_event( | |||
self, event_id: str, state_filter: StateFilter = StateFilter.all() | |||
): | |||
""" | |||
Get the state dict corresponding to a particular event | |||
Args: | |||
event_id(str): event whose state should be returned | |||
state_filter (StateFilter): The state filter used to fetch state | |||
from the database. | |||
event_id: event whose state should be returned | |||
state_filter: The state filter used to fetch state from the database. | |||
Returns: | |||
A deferred dict from (type, state_key) -> state_event | |||
""" | |||
state_map = yield self.get_state_ids_for_events([event_id], state_filter) | |||
state_map = await self.get_state_ids_for_events([event_id], state_filter) | |||
return state_map[event_id] | |||
def _get_state_for_groups( | |||
@@ -530,9 +537,8 @@ class StateGroupStorage(object): | |||
filtering by type/state_key | |||
Args: | |||
groups (iterable[int]): list of state groups for which we want | |||
to get the state. | |||
state_filter (StateFilter): The state filter used to fetch state | |||
groups: list of state groups for which we want to get the state. | |||
state_filter: The state filter used to fetch state. | |||
from the database. | |||
Returns: | |||
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. | |||
@@ -540,18 +546,23 @@ class StateGroupStorage(object): | |||
return self.stores.state._get_state_for_groups(groups, state_filter) | |||
def store_state_group( | |||
self, event_id, room_id, prev_group, delta_ids, current_state_ids | |||
self, | |||
event_id: str, | |||
room_id: str, | |||
prev_group: Optional[int], | |||
delta_ids: Optional[dict], | |||
current_state_ids: dict, | |||
): | |||
"""Store a new set of state, returning a newly assigned state group. | |||
Args: | |||
event_id (str): The event ID for which the state was calculated | |||
room_id (str) | |||
prev_group (int|None): A previous state group for the room, optional. | |||
delta_ids (dict|None): The delta between state at `prev_group` and | |||
event_id: The event ID for which the state was calculated. | |||
room_id: ID of the room for which the state was calculated. | |||
prev_group: A previous state group for the room, optional. | |||
delta_ids: The delta between state at `prev_group` and | |||
`current_state_ids`, if `prev_group` was given. Same format as | |||
`current_state_ids`. | |||
current_state_ids (dict): The state to store. Map of (type, state_key) | |||
current_state_ids: The state to store. Map of (type, state_key) | |||
to event_id. | |||
Returns: | |||
@@ -13,6 +13,8 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.internet import defer | |||
from synapse.rest.client.v1 import room | |||
from tests.unittest import HomeserverTestCase | |||
@@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase): | |||
event = self.successResultOf(event) | |||
# Purge everything before this topological token | |||
purge = storage.purge_events.purge_history(self.room_id, event, True) | |||
purge = defer.ensureDeferred( | |||
storage.purge_events.purge_history(self.room_id, event, True) | |||
) | |||
self.pump() | |||
self.assertEqual(self.successResultOf(purge), None) | |||
@@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase): | |||
) | |||
# Purge everything before this topological token | |||
purge = storage.purge_history(self.room_id, event, True) | |||
purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True)) | |||
self.pump() | |||
f = self.failureResultOf(purge) | |||
self.assertIn("greater than forward", f.value.args[0]) | |||
@@ -97,8 +97,10 @@ class RoomEventsStoreTestCase(unittest.TestCase): | |||
@defer.inlineCallbacks | |||
def inject_room_event(self, **kwargs): | |||
yield self.storage.persistence.persist_event( | |||
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) | |||
yield defer.ensureDeferred( | |||
self.storage.persistence.persist_event( | |||
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) | |||
) | |||
) | |||
@defer.inlineCallbacks | |||
@@ -68,7 +68,9 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
yield self.storage.persistence.persist_event(event, context) | |||
yield defer.ensureDeferred( | |||
self.storage.persistence.persist_event(event, context) | |||
) | |||
return event | |||
@@ -87,8 +89,8 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} | |||
) | |||
state_group_map = yield self.storage.state.get_state_groups_ids( | |||
self.room, [e2.event_id] | |||
state_group_map = yield defer.ensureDeferred( | |||
self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) | |||
) | |||
self.assertEqual(len(state_group_map), 1) | |||
state_map = list(state_group_map.values())[0] | |||
@@ -106,8 +108,8 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} | |||
) | |||
state_group_map = yield self.storage.state.get_state_groups( | |||
self.room, [e2.event_id] | |||
state_group_map = yield defer.ensureDeferred( | |||
self.storage.state.get_state_groups(self.room, [e2.event_id]) | |||
) | |||
self.assertEqual(len(state_group_map), 1) | |||
state_list = list(state_group_map.values())[0] | |||
@@ -148,7 +150,9 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||
) | |||
# check we get the full state as of the final event | |||
state = yield self.storage.state.get_state_for_event(e5.event_id) | |||
state = yield defer.ensureDeferred( | |||
self.storage.state.get_state_for_event(e5.event_id) | |||
) | |||
self.assertIsNotNone(e4) | |||
@@ -164,22 +168,28 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||
) | |||
# check we can filter to the m.room.name event (with a '' state key) | |||
state = yield self.storage.state.get_state_for_event( | |||
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) | |||
state = yield defer.ensureDeferred( | |||
self.storage.state.get_state_for_event( | |||
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) | |||
) | |||
) | |||
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) | |||
# check we can filter to the m.room.name event (with a wildcard None state key) | |||
state = yield self.storage.state.get_state_for_event( | |||
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) | |||
state = yield defer.ensureDeferred( | |||
self.storage.state.get_state_for_event( | |||
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) | |||
) | |||
) | |||
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) | |||
# check we can grab the m.room.member events (with a wildcard None state key) | |||
state = yield self.storage.state.get_state_for_event( | |||
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) | |||
state = yield defer.ensureDeferred( | |||
self.storage.state.get_state_for_event( | |||
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) | |||
) | |||
) | |||
self.assertStateMapEqual( | |||
@@ -188,12 +198,14 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||
# check we can grab a specific room member without filtering out the | |||
# other event types | |||
state = yield self.storage.state.get_state_for_event( | |||
e5.event_id, | |||
state_filter=StateFilter( | |||
types={EventTypes.Member: {self.u_alice.to_string()}}, | |||
include_others=True, | |||
), | |||
state = yield defer.ensureDeferred( | |||
self.storage.state.get_state_for_event( | |||
e5.event_id, | |||
state_filter=StateFilter( | |||
types={EventTypes.Member: {self.u_alice.to_string()}}, | |||
include_others=True, | |||
), | |||
) | |||
) | |||
self.assertStateMapEqual( | |||
@@ -206,11 +218,13 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||
) | |||
# check that we can grab everything except members | |||
state = yield self.storage.state.get_state_for_event( | |||
e5.event_id, | |||
state_filter=StateFilter( | |||
types={EventTypes.Member: set()}, include_others=True | |||
), | |||
state = yield defer.ensureDeferred( | |||
self.storage.state.get_state_for_event( | |||
e5.event_id, | |||
state_filter=StateFilter( | |||
types={EventTypes.Member: set()}, include_others=True | |||
), | |||
) | |||
) | |||
self.assertStateMapEqual( | |||
@@ -222,8 +236,8 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||
####################################################### | |||
room_id = self.room.to_string() | |||
group_ids = yield self.storage.state.get_state_groups_ids( | |||
room_id, [e5.event_id] | |||
group_ids = yield defer.ensureDeferred( | |||
self.storage.state.get_state_groups_ids(room_id, [e5.event_id]) | |||
) | |||
group = list(group_ids.keys())[0] | |||
@@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): | |||
self.store = self.hs.get_datastore() | |||
self.storage = self.hs.get_storage() | |||
yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") | |||
yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) | |||
@defer.inlineCallbacks | |||
def test_filtering(self): | |||
@@ -140,7 +140,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): | |||
event, context = yield defer.ensureDeferred( | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
yield self.storage.persistence.persist_event(event, context) | |||
yield defer.ensureDeferred( | |||
self.storage.persistence.persist_event(event, context) | |||
) | |||
return event | |||
@defer.inlineCallbacks | |||
@@ -162,7 +164,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
yield self.storage.persistence.persist_event(event, context) | |||
yield defer.ensureDeferred( | |||
self.storage.persistence.persist_event(event, context) | |||
) | |||
return event | |||
@defer.inlineCallbacks | |||
@@ -183,7 +187,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
yield self.storage.persistence.persist_event(event, context) | |||
yield defer.ensureDeferred( | |||
self.storage.persistence.persist_event(event, context) | |||
) | |||
return event | |||
@defer.inlineCallbacks | |||
@@ -638,14 +638,8 @@ class DeferredMockCallable(object): | |||
) | |||
@defer.inlineCallbacks | |||
def create_room(hs, room_id, creator_id): | |||
async def create_room(hs, room_id: str, creator_id: str): | |||
"""Creates and persist a creation event for the given room | |||
Args: | |||
hs | |||
room_id (str) | |||
creator_id (str) | |||
""" | |||
persistence_store = hs.get_storage().persistence | |||
@@ -653,7 +647,7 @@ def create_room(hs, room_id, creator_id): | |||
event_builder_factory = hs.get_event_builder_factory() | |||
event_creation_handler = hs.get_event_creation_handler() | |||
yield store.store_room( | |||
await store.store_room( | |||
room_id=room_id, | |||
room_creator_user_id=creator_id, | |||
is_public=False, | |||
@@ -671,8 +665,6 @@ def create_room(hs, room_id, creator_id): | |||
}, | |||
) | |||
event, context = yield defer.ensureDeferred( | |||
event_creation_handler.create_new_client_event(builder) | |||
) | |||
event, context = await event_creation_handler.create_new_client_event(builder) | |||
yield persistence_store.persist_event(event, context) | |||
await persistence_store.persist_event(event, context) |
@@ -206,6 +206,7 @@ commands = mypy \ | |||
synapse/storage/data_stores/main/ui_auth.py \ | |||
synapse/storage/database.py \ | |||
synapse/storage/engines \ | |||
synapse/storage/state.py \ | |||
synapse/storage/util \ | |||
synapse/streams \ | |||
synapse/util/caches/stream_change_cache.py \ | |||