@@ -0,0 +1 @@ | |||||
Convert various parts of the codebase to async/await. |
@@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram | |||||
from twisted.internet import defer | from twisted.internet import defer | ||||
from synapse.api.constants import EventTypes, Membership | from synapse.api.constants import EventTypes, Membership | ||||
from synapse.events import FrozenEvent | |||||
from synapse.events import EventBase | |||||
from synapse.events.snapshot import EventContext | from synapse.events.snapshot import EventContext | ||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable | from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable | ||||
from synapse.metrics.background_process_metrics import run_as_background_process | from synapse.metrics.background_process_metrics import run_as_background_process | ||||
@@ -192,12 +192,11 @@ class EventsPersistenceStorage(object): | |||||
self._event_persist_queue = _EventPeristenceQueue() | self._event_persist_queue = _EventPeristenceQueue() | ||||
self._state_resolution_handler = hs.get_state_resolution_handler() | self._state_resolution_handler = hs.get_state_resolution_handler() | ||||
@defer.inlineCallbacks | |||||
def persist_events( | |||||
async def persist_events( | |||||
self, | self, | ||||
events_and_contexts: List[Tuple[FrozenEvent, EventContext]], | |||||
events_and_contexts: List[Tuple[EventBase, EventContext]], | |||||
backfilled: bool = False, | backfilled: bool = False, | ||||
): | |||||
) -> int: | |||||
""" | """ | ||||
Write events to the database | Write events to the database | ||||
Args: | Args: | ||||
@@ -207,7 +206,7 @@ class EventsPersistenceStorage(object): | |||||
which might update the current state etc. | which might update the current state etc. | ||||
Returns: | Returns: | ||||
Deferred[int]: the stream ordering of the latest persisted event | |||||
the stream ordering of the latest persisted event | |||||
""" | """ | ||||
partitioned = {} | partitioned = {} | ||||
for event, ctx in events_and_contexts: | for event, ctx in events_and_contexts: | ||||
@@ -223,22 +222,19 @@ class EventsPersistenceStorage(object): | |||||
for room_id in partitioned: | for room_id in partitioned: | ||||
self._maybe_start_persisting(room_id) | self._maybe_start_persisting(room_id) | ||||
yield make_deferred_yieldable( | |||||
await make_deferred_yieldable( | |||||
defer.gatherResults(deferreds, consumeErrors=True) | defer.gatherResults(deferreds, consumeErrors=True) | ||||
) | ) | ||||
max_persisted_id = yield self.main_store.get_current_events_token() | |||||
return max_persisted_id | |||||
return self.main_store.get_current_events_token() | |||||
@defer.inlineCallbacks | |||||
def persist_event( | |||||
self, event: FrozenEvent, context: EventContext, backfilled: bool = False | |||||
): | |||||
async def persist_event( | |||||
self, event: EventBase, context: EventContext, backfilled: bool = False | |||||
) -> Tuple[int, int]: | |||||
""" | """ | ||||
Returns: | Returns: | ||||
Deferred[Tuple[int, int]]: the stream ordering of ``event``, | |||||
and the stream ordering of the latest persisted event | |||||
The stream ordering of `event`, and the stream ordering of the | |||||
latest persisted event | |||||
""" | """ | ||||
deferred = self._event_persist_queue.add_to_queue( | deferred = self._event_persist_queue.add_to_queue( | ||||
event.room_id, [(event, context)], backfilled=backfilled | event.room_id, [(event, context)], backfilled=backfilled | ||||
@@ -246,9 +242,9 @@ class EventsPersistenceStorage(object): | |||||
self._maybe_start_persisting(event.room_id) | self._maybe_start_persisting(event.room_id) | ||||
yield make_deferred_yieldable(deferred) | |||||
await make_deferred_yieldable(deferred) | |||||
max_persisted_id = yield self.main_store.get_current_events_token() | |||||
max_persisted_id = self.main_store.get_current_events_token() | |||||
return (event.internal_metadata.stream_ordering, max_persisted_id) | return (event.internal_metadata.stream_ordering, max_persisted_id) | ||||
def _maybe_start_persisting(self, room_id: str): | def _maybe_start_persisting(self, room_id: str): | ||||
@@ -262,7 +258,7 @@ class EventsPersistenceStorage(object): | |||||
async def _persist_events( | async def _persist_events( | ||||
self, | self, | ||||
events_and_contexts: List[Tuple[FrozenEvent, EventContext]], | |||||
events_and_contexts: List[Tuple[EventBase, EventContext]], | |||||
backfilled: bool = False, | backfilled: bool = False, | ||||
): | ): | ||||
"""Calculates the change to current state and forward extremities, and | """Calculates the change to current state and forward extremities, and | ||||
@@ -439,7 +435,7 @@ class EventsPersistenceStorage(object): | |||||
async def _calculate_new_extremities( | async def _calculate_new_extremities( | ||||
self, | self, | ||||
room_id: str, | room_id: str, | ||||
event_contexts: List[Tuple[FrozenEvent, EventContext]], | |||||
event_contexts: List[Tuple[EventBase, EventContext]], | |||||
latest_event_ids: List[str], | latest_event_ids: List[str], | ||||
): | ): | ||||
"""Calculates the new forward extremities for a room given events to | """Calculates the new forward extremities for a room given events to | ||||
@@ -497,7 +493,7 @@ class EventsPersistenceStorage(object): | |||||
async def _get_new_state_after_events( | async def _get_new_state_after_events( | ||||
self, | self, | ||||
room_id: str, | room_id: str, | ||||
events_context: List[Tuple[FrozenEvent, EventContext]], | |||||
events_context: List[Tuple[EventBase, EventContext]], | |||||
old_latest_event_ids: Iterable[str], | old_latest_event_ids: Iterable[str], | ||||
new_latest_event_ids: Iterable[str], | new_latest_event_ids: Iterable[str], | ||||
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: | ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: | ||||
@@ -683,7 +679,7 @@ class EventsPersistenceStorage(object): | |||||
async def _is_server_still_joined( | async def _is_server_still_joined( | ||||
self, | self, | ||||
room_id: str, | room_id: str, | ||||
ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]], | |||||
ev_ctx_rm: List[Tuple[EventBase, EventContext]], | |||||
delta: DeltaState, | delta: DeltaState, | ||||
current_state: Optional[StateMap[str]], | current_state: Optional[StateMap[str]], | ||||
potentially_left_users: Set[str], | potentially_left_users: Set[str], | ||||
@@ -15,8 +15,7 @@ | |||||
import itertools | import itertools | ||||
import logging | import logging | ||||
from twisted.internet import defer | |||||
from typing import Set | |||||
logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
@@ -28,49 +27,48 @@ class PurgeEventsStorage(object): | |||||
def __init__(self, hs, stores): | def __init__(self, hs, stores): | ||||
self.stores = stores | self.stores = stores | ||||
@defer.inlineCallbacks | |||||
def purge_room(self, room_id: str): | |||||
async def purge_room(self, room_id: str): | |||||
"""Deletes all record of a room | """Deletes all record of a room | ||||
""" | """ | ||||
state_groups_to_delete = yield self.stores.main.purge_room(room_id) | |||||
yield self.stores.state.purge_room_state(room_id, state_groups_to_delete) | |||||
state_groups_to_delete = await self.stores.main.purge_room(room_id) | |||||
await self.stores.state.purge_room_state(room_id, state_groups_to_delete) | |||||
@defer.inlineCallbacks | |||||
def purge_history(self, room_id, token, delete_local_events): | |||||
async def purge_history( | |||||
self, room_id: str, token: str, delete_local_events: bool | |||||
) -> None: | |||||
"""Deletes room history before a certain point | """Deletes room history before a certain point | ||||
Args: | Args: | ||||
room_id (str): | |||||
room_id: The room ID | |||||
token (str): A topological token to delete events before | |||||
token: A topological token to delete events before | |||||
delete_local_events (bool): | |||||
delete_local_events: | |||||
if True, we will delete local events as well as remote ones | if True, we will delete local events as well as remote ones | ||||
(instead of just marking them as outliers and deleting their | (instead of just marking them as outliers and deleting their | ||||
state groups). | state groups). | ||||
""" | """ | ||||
state_groups = yield self.stores.main.purge_history( | |||||
state_groups = await self.stores.main.purge_history( | |||||
room_id, token, delete_local_events | room_id, token, delete_local_events | ||||
) | ) | ||||
logger.info("[purge] finding state groups that can be deleted") | logger.info("[purge] finding state groups that can be deleted") | ||||
sg_to_delete = yield self._find_unreferenced_groups(state_groups) | |||||
sg_to_delete = await self._find_unreferenced_groups(state_groups) | |||||
yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) | |||||
await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) | |||||
@defer.inlineCallbacks | |||||
def _find_unreferenced_groups(self, state_groups): | |||||
async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]: | |||||
"""Used when purging history to figure out which state groups can be | """Used when purging history to figure out which state groups can be | ||||
deleted. | deleted. | ||||
Args: | Args: | ||||
state_groups (set[int]): Set of state groups referenced by events | |||||
state_groups: Set of state groups referenced by events | |||||
that are going to be deleted. | that are going to be deleted. | ||||
Returns: | Returns: | ||||
Deferred[set[int]] The set of state groups that can be deleted. | |||||
The set of state groups that can be deleted. | |||||
""" | """ | ||||
# Graph of state group -> previous group | # Graph of state group -> previous group | ||||
graph = {} | graph = {} | ||||
@@ -93,7 +91,7 @@ class PurgeEventsStorage(object): | |||||
current_search = set(itertools.islice(next_to_search, 100)) | current_search = set(itertools.islice(next_to_search, 100)) | ||||
next_to_search -= current_search | next_to_search -= current_search | ||||
referenced = yield self.stores.main.get_referenced_state_groups( | |||||
referenced = await self.stores.main.get_referenced_state_groups( | |||||
current_search | current_search | ||||
) | ) | ||||
referenced_groups |= referenced | referenced_groups |= referenced | ||||
@@ -102,7 +100,7 @@ class PurgeEventsStorage(object): | |||||
# groups that are referenced. | # groups that are referenced. | ||||
current_search -= referenced | current_search -= referenced | ||||
edges = yield self.stores.state.get_previous_state_groups(current_search) | |||||
edges = await self.stores.state.get_previous_state_groups(current_search) | |||||
prevs = set(edges.values()) | prevs = set(edges.values()) | ||||
# We don't bother re-handling groups we've already seen | # We don't bother re-handling groups we've already seen | ||||
@@ -14,13 +14,12 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
import logging | import logging | ||||
from typing import Iterable, List, TypeVar | |||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar | |||||
import attr | import attr | ||||
from twisted.internet import defer | |||||
from synapse.api.constants import EventTypes | from synapse.api.constants import EventTypes | ||||
from synapse.events import EventBase | |||||
from synapse.types import StateMap | from synapse.types import StateMap | ||||
logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
@@ -34,16 +33,16 @@ class StateFilter(object): | |||||
"""A filter used when querying for state. | """A filter used when querying for state. | ||||
Attributes: | Attributes: | ||||
types (dict[str, set[str]|None]): Map from type to set of state keys (or | |||||
None). This specifies which state_keys for the given type to fetch | |||||
from the DB. If None then all events with that type are fetched. If | |||||
the set is empty then no events with that type are fetched. | |||||
include_others (bool): Whether to fetch events with types that do not | |||||
types: Map from type to set of state keys (or None). This specifies | |||||
which state_keys for the given type to fetch from the DB. If None | |||||
then all events with that type are fetched. If the set is empty | |||||
then no events with that type are fetched. | |||||
include_others: Whether to fetch events with types that do not | |||||
appear in `types`. | appear in `types`. | ||||
""" | """ | ||||
types = attr.ib() | |||||
include_others = attr.ib(default=False) | |||||
types = attr.ib(type=Dict[str, Optional[Set[str]]]) | |||||
include_others = attr.ib(default=False, type=bool) | |||||
def __attrs_post_init__(self): | def __attrs_post_init__(self): | ||||
# If `include_others` is set we canonicalise the filter by removing | # If `include_others` is set we canonicalise the filter by removing | ||||
@@ -52,36 +51,35 @@ class StateFilter(object): | |||||
self.types = {k: v for k, v in self.types.items() if v is not None} | self.types = {k: v for k, v in self.types.items() if v is not None} | ||||
@staticmethod | @staticmethod | ||||
def all(): | |||||
def all() -> "StateFilter": | |||||
"""Creates a filter that fetches everything. | """Creates a filter that fetches everything. | ||||
Returns: | Returns: | ||||
StateFilter | |||||
The new state filter. | |||||
""" | """ | ||||
return StateFilter(types={}, include_others=True) | return StateFilter(types={}, include_others=True) | ||||
@staticmethod | @staticmethod | ||||
def none(): | |||||
def none() -> "StateFilter": | |||||
"""Creates a filter that fetches nothing. | """Creates a filter that fetches nothing. | ||||
Returns: | Returns: | ||||
StateFilter | |||||
The new state filter. | |||||
""" | """ | ||||
return StateFilter(types={}, include_others=False) | return StateFilter(types={}, include_others=False) | ||||
@staticmethod | @staticmethod | ||||
def from_types(types): | |||||
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": | |||||
"""Creates a filter that only fetches the given types | """Creates a filter that only fetches the given types | ||||
Args: | Args: | ||||
types (Iterable[tuple[str, str|None]]): A list of type and state | |||||
keys to fetch. A state_key of None fetches everything for | |||||
that type | |||||
types: A list of type and state keys to fetch. A state_key of None | |||||
fetches everything for that type | |||||
Returns: | Returns: | ||||
StateFilter | |||||
The new state filter. | |||||
""" | """ | ||||
type_dict = {} | |||||
type_dict = {} # type: Dict[str, Optional[Set[str]]] | |||||
for typ, s in types: | for typ, s in types: | ||||
if typ in type_dict: | if typ in type_dict: | ||||
if type_dict[typ] is None: | if type_dict[typ] is None: | ||||
@@ -91,24 +89,24 @@ class StateFilter(object): | |||||
type_dict[typ] = None | type_dict[typ] = None | ||||
continue | continue | ||||
type_dict.setdefault(typ, set()).add(s) | |||||
type_dict.setdefault(typ, set()).add(s) # type: ignore | |||||
return StateFilter(types=type_dict) | return StateFilter(types=type_dict) | ||||
@staticmethod | @staticmethod | ||||
def from_lazy_load_member_list(members): | |||||
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": | |||||
"""Creates a filter that returns all non-member events, plus the member | """Creates a filter that returns all non-member events, plus the member | ||||
events for the given users | events for the given users | ||||
Args: | Args: | ||||
members (iterable[str]): Set of user IDs | |||||
members: Set of user IDs | |||||
Returns: | Returns: | ||||
StateFilter | |||||
The new state filter | |||||
""" | """ | ||||
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) | return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) | ||||
def return_expanded(self): | |||||
def return_expanded(self) -> "StateFilter": | |||||
"""Creates a new StateFilter where type wild cards have been removed | """Creates a new StateFilter where type wild cards have been removed | ||||
(except for memberships). The returned filter is a superset of the | (except for memberships). The returned filter is a superset of the | ||||
current one, i.e. anything that passes the current filter will pass | current one, i.e. anything that passes the current filter will pass | ||||
@@ -130,7 +128,7 @@ class StateFilter(object): | |||||
return all non-member events | return all non-member events | ||||
Returns: | Returns: | ||||
StateFilter | |||||
The new state filter. | |||||
""" | """ | ||||
if self.is_full(): | if self.is_full(): | ||||
@@ -167,7 +165,7 @@ class StateFilter(object): | |||||
include_others=True, | include_others=True, | ||||
) | ) | ||||
def make_sql_filter_clause(self): | |||||
def make_sql_filter_clause(self) -> Tuple[str, List[str]]: | |||||
"""Converts the filter to an SQL clause. | """Converts the filter to an SQL clause. | ||||
For example: | For example: | ||||
@@ -179,13 +177,12 @@ class StateFilter(object): | |||||
Returns: | Returns: | ||||
tuple[str, list]: The SQL string (may be empty) and arguments. An | |||||
empty SQL string is returned when the filter matches everything | |||||
(i.e. is "full"). | |||||
The SQL string (may be empty) and arguments. An empty SQL string is | |||||
returned when the filter matches everything (i.e. is "full"). | |||||
""" | """ | ||||
where_clause = "" | where_clause = "" | ||||
where_args = [] | |||||
where_args = [] # type: List[str] | |||||
if self.is_full(): | if self.is_full(): | ||||
return where_clause, where_args | return where_clause, where_args | ||||
@@ -221,7 +218,7 @@ class StateFilter(object): | |||||
return where_clause, where_args | return where_clause, where_args | ||||
def max_entries_returned(self): | |||||
def max_entries_returned(self) -> Optional[int]: | |||||
"""Returns the maximum number of entries this filter will return if | """Returns the maximum number of entries this filter will return if | ||||
known, otherwise returns None. | known, otherwise returns None. | ||||
@@ -260,33 +257,33 @@ class StateFilter(object): | |||||
return filtered_state | return filtered_state | ||||
def is_full(self): | |||||
def is_full(self) -> bool: | |||||
"""Whether this filter fetches everything or not | """Whether this filter fetches everything or not | ||||
Returns: | Returns: | ||||
bool | |||||
True if the filter fetches everything. | |||||
""" | """ | ||||
return self.include_others and not self.types | return self.include_others and not self.types | ||||
def has_wildcards(self): | |||||
def has_wildcards(self) -> bool: | |||||
"""Whether the filter includes wildcards or is attempting to fetch | """Whether the filter includes wildcards or is attempting to fetch | ||||
specific state. | specific state. | ||||
Returns: | Returns: | ||||
bool | |||||
True if the filter includes wildcards. | |||||
""" | """ | ||||
return self.include_others or any( | return self.include_others or any( | ||||
state_keys is None for state_keys in self.types.values() | state_keys is None for state_keys in self.types.values() | ||||
) | ) | ||||
def concrete_types(self): | |||||
def concrete_types(self) -> List[Tuple[str, str]]: | |||||
"""Returns a list of concrete type/state_keys (i.e. not None) that | """Returns a list of concrete type/state_keys (i.e. not None) that | ||||
will be fetched. This will be a complete list if `has_wildcards` | will be fetched. This will be a complete list if `has_wildcards` | ||||
returns False, but otherwise will be a subset (or even empty). | returns False, but otherwise will be a subset (or even empty). | ||||
Returns: | Returns: | ||||
list[tuple[str,str]] | |||||
A list of type/state_keys tuples. | |||||
""" | """ | ||||
return [ | return [ | ||||
(t, s) | (t, s) | ||||
@@ -295,7 +292,7 @@ class StateFilter(object): | |||||
for s in state_keys | for s in state_keys | ||||
] | ] | ||||
def get_member_split(self): | |||||
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: | |||||
"""Return the filter split into two: one which assumes it's exclusively | """Return the filter split into two: one which assumes it's exclusively | ||||
matching against member state, and one which assumes it's matching | matching against member state, and one which assumes it's matching | ||||
against non member state. | against non member state. | ||||
@@ -307,7 +304,7 @@ class StateFilter(object): | |||||
state caches). | state caches). | ||||
Returns: | Returns: | ||||
tuple[StateFilter, StateFilter]: The member and non member filters | |||||
The member and non member filters | |||||
""" | """ | ||||
if EventTypes.Member in self.types: | if EventTypes.Member in self.types: | ||||
@@ -340,6 +337,9 @@ class StateGroupStorage(object): | |||||
"""Given a state group try to return a previous group and a delta between | """Given a state group try to return a previous group and a delta between | ||||
the old and the new. | the old and the new. | ||||
Args: | |||||
state_group: The state group used to retrieve state deltas. | |||||
Returns: | Returns: | ||||
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: | Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: | ||||
(prev_group, delta_ids) | (prev_group, delta_ids) | ||||
@@ -347,55 +347,59 @@ class StateGroupStorage(object): | |||||
return self.stores.state.get_state_group_delta(state_group) | return self.stores.state.get_state_group_delta(state_group) | ||||
@defer.inlineCallbacks | |||||
def get_state_groups_ids(self, _room_id, event_ids): | |||||
async def get_state_groups_ids( | |||||
self, _room_id: str, event_ids: Iterable[str] | |||||
) -> Dict[int, StateMap[str]]: | |||||
"""Get the event IDs of all the state for the state groups for the given events | """Get the event IDs of all the state for the state groups for the given events | ||||
Args: | Args: | ||||
_room_id (str): id of the room for these events | |||||
event_ids (iterable[str]): ids of the events | |||||
_room_id: id of the room for these events | |||||
event_ids: ids of the events | |||||
Returns: | Returns: | ||||
Deferred[dict[int, StateMap[str]]]: | |||||
dict of state_group_id -> (dict of (type, state_key) -> event id) | |||||
dict of state_group_id -> (dict of (type, state_key) -> event id) | |||||
""" | """ | ||||
if not event_ids: | if not event_ids: | ||||
return {} | return {} | ||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) | |||||
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) | |||||
groups = set(event_to_groups.values()) | groups = set(event_to_groups.values()) | ||||
group_to_state = yield self.stores.state._get_state_for_groups(groups) | |||||
group_to_state = await self.stores.state._get_state_for_groups(groups) | |||||
return group_to_state | return group_to_state | ||||
@defer.inlineCallbacks | |||||
def get_state_ids_for_group(self, state_group): | |||||
async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]: | |||||
"""Get the event IDs of all the state in the given state group | """Get the event IDs of all the state in the given state group | ||||
Args: | Args: | ||||
state_group (int) | |||||
state_group: A state group for which we want to get the state IDs. | |||||
Returns: | Returns: | ||||
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id | |||||
Resolves to a map of (type, state_key) -> event_id | |||||
""" | """ | ||||
group_to_state = yield self._get_state_for_groups((state_group,)) | |||||
group_to_state = await self._get_state_for_groups((state_group,)) | |||||
return group_to_state[state_group] | return group_to_state[state_group] | ||||
@defer.inlineCallbacks | |||||
def get_state_groups(self, room_id, event_ids): | |||||
async def get_state_groups( | |||||
self, room_id: str, event_ids: Iterable[str] | |||||
) -> Dict[int, List[EventBase]]: | |||||
""" Get the state groups for the given list of event_ids | """ Get the state groups for the given list of event_ids | ||||
Args: | |||||
room_id: ID of the room for these events. | |||||
event_ids: The event IDs to retrieve state for. | |||||
Returns: | Returns: | ||||
Deferred[dict[int, list[EventBase]]]: | |||||
dict of state_group_id -> list of state events. | |||||
dict of state_group_id -> list of state events. | |||||
""" | """ | ||||
if not event_ids: | if not event_ids: | ||||
return {} | return {} | ||||
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) | |||||
group_to_ids = await self.get_state_groups_ids(room_id, event_ids) | |||||
state_event_map = yield self.stores.main.get_events( | |||||
state_event_map = await self.stores.main.get_events( | |||||
[ | [ | ||||
ev_id | ev_id | ||||
for group_ids in group_to_ids.values() | for group_ids in group_to_ids.values() | ||||
@@ -423,31 +427,34 @@ class StateGroupStorage(object): | |||||
groups: list of state group IDs to query | groups: list of state group IDs to query | ||||
state_filter: The state filter used to fetch state | state_filter: The state filter used to fetch state | ||||
from the database. | from the database. | ||||
Returns: | Returns: | ||||
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. | Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. | ||||
""" | """ | ||||
return self.stores.state._get_state_groups_from_groups(groups, state_filter) | return self.stores.state._get_state_groups_from_groups(groups, state_filter) | ||||
@defer.inlineCallbacks | |||||
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): | |||||
async def get_state_for_events( | |||||
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() | |||||
): | |||||
"""Given a list of event_ids and type tuples, return a list of state | """Given a list of event_ids and type tuples, return a list of state | ||||
dicts for each event. | dicts for each event. | ||||
Args: | Args: | ||||
event_ids (list[string]) | |||||
state_filter (StateFilter): The state filter used to fetch state | |||||
from the database. | |||||
event_ids: The events to fetch the state of. | |||||
state_filter: The state filter used to fetch state. | |||||
Returns: | Returns: | ||||
deferred: A dict of (event_id) -> (type, state_key) -> [state_events] | |||||
A dict of (event_id) -> (type, state_key) -> [state_events] | |||||
""" | """ | ||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) | |||||
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) | |||||
groups = set(event_to_groups.values()) | groups = set(event_to_groups.values()) | ||||
group_to_state = yield self.stores.state._get_state_for_groups( | |||||
group_to_state = await self.stores.state._get_state_for_groups( | |||||
groups, state_filter | groups, state_filter | ||||
) | ) | ||||
state_event_map = yield self.stores.main.get_events( | |||||
state_event_map = await self.stores.main.get_events( | |||||
[ev_id for sd in group_to_state.values() for ev_id in sd.values()], | [ev_id for sd in group_to_state.values() for ev_id in sd.values()], | ||||
get_prev_content=False, | get_prev_content=False, | ||||
) | ) | ||||
@@ -463,24 +470,24 @@ class StateGroupStorage(object): | |||||
return {event: event_to_state[event] for event in event_ids} | return {event: event_to_state[event] for event in event_ids} | ||||
@defer.inlineCallbacks | |||||
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): | |||||
async def get_state_ids_for_events( | |||||
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() | |||||
): | |||||
""" | """ | ||||
Get the state dicts corresponding to a list of events, containing the event_ids | Get the state dicts corresponding to a list of events, containing the event_ids | ||||
of the state events (as opposed to the events themselves) | of the state events (as opposed to the events themselves) | ||||
Args: | Args: | ||||
event_ids(list(str)): events whose state should be returned | |||||
state_filter (StateFilter): The state filter used to fetch state | |||||
from the database. | |||||
event_ids: events whose state should be returned | |||||
state_filter: The state filter used to fetch state from the database. | |||||
Returns: | Returns: | ||||
A deferred dict from event_id -> (type, state_key) -> event_id | |||||
A dict from event_id -> (type, state_key) -> event_id | |||||
""" | """ | ||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) | |||||
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) | |||||
groups = set(event_to_groups.values()) | groups = set(event_to_groups.values()) | ||||
group_to_state = yield self.stores.state._get_state_for_groups( | |||||
group_to_state = await self.stores.state._get_state_for_groups( | |||||
groups, state_filter | groups, state_filter | ||||
) | ) | ||||
@@ -491,36 +498,36 @@ class StateGroupStorage(object): | |||||
return {event: event_to_state[event] for event in event_ids} | return {event: event_to_state[event] for event in event_ids} | ||||
@defer.inlineCallbacks | |||||
def get_state_for_event(self, event_id, state_filter=StateFilter.all()): | |||||
async def get_state_for_event( | |||||
self, event_id: str, state_filter: StateFilter = StateFilter.all() | |||||
): | |||||
""" | """ | ||||
Get the state dict corresponding to a particular event | Get the state dict corresponding to a particular event | ||||
Args: | Args: | ||||
event_id(str): event whose state should be returned | |||||
state_filter (StateFilter): The state filter used to fetch state | |||||
from the database. | |||||
event_id: event whose state should be returned | |||||
state_filter: The state filter used to fetch state from the database. | |||||
Returns: | Returns: | ||||
A deferred dict from (type, state_key) -> state_event | |||||
A dict from (type, state_key) -> state_event | |||||
""" | """ | ||||
state_map = yield self.get_state_for_events([event_id], state_filter) | |||||
state_map = await self.get_state_for_events([event_id], state_filter) | |||||
return state_map[event_id] | return state_map[event_id] | ||||
@defer.inlineCallbacks | |||||
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): | |||||
async def get_state_ids_for_event( | |||||
self, event_id: str, state_filter: StateFilter = StateFilter.all() | |||||
): | |||||
""" | """ | ||||
Get the state dict corresponding to a particular event | Get the state dict corresponding to a particular event | ||||
Args: | Args: | ||||
event_id(str): event whose state should be returned | |||||
state_filter (StateFilter): The state filter used to fetch state | |||||
from the database. | |||||
event_id: event whose state should be returned | |||||
state_filter: The state filter used to fetch state from the database. | |||||
Returns: | Returns: | ||||
A deferred dict from (type, state_key) -> state_event | A deferred dict from (type, state_key) -> state_event | ||||
""" | """ | ||||
state_map = yield self.get_state_ids_for_events([event_id], state_filter) | |||||
state_map = await self.get_state_ids_for_events([event_id], state_filter) | |||||
return state_map[event_id] | return state_map[event_id] | ||||
def _get_state_for_groups( | def _get_state_for_groups( | ||||
@@ -530,9 +537,8 @@ class StateGroupStorage(object): | |||||
filtering by type/state_key | filtering by type/state_key | ||||
Args: | Args: | ||||
groups (iterable[int]): list of state groups for which we want | |||||
to get the state. | |||||
state_filter (StateFilter): The state filter used to fetch state | |||||
groups: list of state groups for which we want to get the state. | |||||
state_filter: The state filter used to fetch state. | |||||
from the database. | from the database. | ||||
Returns: | Returns: | ||||
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. | Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. | ||||
@@ -540,18 +546,23 @@ class StateGroupStorage(object): | |||||
return self.stores.state._get_state_for_groups(groups, state_filter) | return self.stores.state._get_state_for_groups(groups, state_filter) | ||||
def store_state_group( | def store_state_group( | ||||
self, event_id, room_id, prev_group, delta_ids, current_state_ids | |||||
self, | |||||
event_id: str, | |||||
room_id: str, | |||||
prev_group: Optional[int], | |||||
delta_ids: Optional[dict], | |||||
current_state_ids: dict, | |||||
): | ): | ||||
"""Store a new set of state, returning a newly assigned state group. | """Store a new set of state, returning a newly assigned state group. | ||||
Args: | Args: | ||||
event_id (str): The event ID for which the state was calculated | |||||
room_id (str) | |||||
prev_group (int|None): A previous state group for the room, optional. | |||||
delta_ids (dict|None): The delta between state at `prev_group` and | |||||
event_id: The event ID for which the state was calculated. | |||||
room_id: ID of the room for which the state was calculated. | |||||
prev_group: A previous state group for the room, optional. | |||||
delta_ids: The delta between state at `prev_group` and | |||||
`current_state_ids`, if `prev_group` was given. Same format as | `current_state_ids`, if `prev_group` was given. Same format as | ||||
`current_state_ids`. | `current_state_ids`. | ||||
current_state_ids (dict): The state to store. Map of (type, state_key) | |||||
current_state_ids: The state to store. Map of (type, state_key) | |||||
to event_id. | to event_id. | ||||
Returns: | Returns: | ||||
@@ -13,6 +13,8 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from twisted.internet import defer | |||||
from synapse.rest.client.v1 import room | from synapse.rest.client.v1 import room | ||||
from tests.unittest import HomeserverTestCase | from tests.unittest import HomeserverTestCase | ||||
@@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase): | |||||
event = self.successResultOf(event) | event = self.successResultOf(event) | ||||
# Purge everything before this topological token | # Purge everything before this topological token | ||||
purge = storage.purge_events.purge_history(self.room_id, event, True) | |||||
purge = defer.ensureDeferred( | |||||
storage.purge_events.purge_history(self.room_id, event, True) | |||||
) | |||||
self.pump() | self.pump() | ||||
self.assertEqual(self.successResultOf(purge), None) | self.assertEqual(self.successResultOf(purge), None) | ||||
@@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase): | |||||
) | ) | ||||
# Purge everything before this topological token | # Purge everything before this topological token | ||||
purge = storage.purge_history(self.room_id, event, True) | |||||
purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True)) | |||||
self.pump() | self.pump() | ||||
f = self.failureResultOf(purge) | f = self.failureResultOf(purge) | ||||
self.assertIn("greater than forward", f.value.args[0]) | self.assertIn("greater than forward", f.value.args[0]) | ||||
@@ -97,8 +97,10 @@ class RoomEventsStoreTestCase(unittest.TestCase): | |||||
@defer.inlineCallbacks | @defer.inlineCallbacks | ||||
def inject_room_event(self, **kwargs): | def inject_room_event(self, **kwargs): | ||||
yield self.storage.persistence.persist_event( | |||||
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) | |||||
yield defer.ensureDeferred( | |||||
self.storage.persistence.persist_event( | |||||
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) | |||||
) | |||||
) | ) | ||||
@defer.inlineCallbacks | @defer.inlineCallbacks | ||||
@@ -68,7 +68,9 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||||
self.event_creation_handler.create_new_client_event(builder) | self.event_creation_handler.create_new_client_event(builder) | ||||
) | ) | ||||
yield self.storage.persistence.persist_event(event, context) | |||||
yield defer.ensureDeferred( | |||||
self.storage.persistence.persist_event(event, context) | |||||
) | |||||
return event | return event | ||||
@@ -87,8 +89,8 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||||
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} | self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} | ||||
) | ) | ||||
state_group_map = yield self.storage.state.get_state_groups_ids( | |||||
self.room, [e2.event_id] | |||||
state_group_map = yield defer.ensureDeferred( | |||||
self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) | |||||
) | ) | ||||
self.assertEqual(len(state_group_map), 1) | self.assertEqual(len(state_group_map), 1) | ||||
state_map = list(state_group_map.values())[0] | state_map = list(state_group_map.values())[0] | ||||
@@ -106,8 +108,8 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||||
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} | self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} | ||||
) | ) | ||||
state_group_map = yield self.storage.state.get_state_groups( | |||||
self.room, [e2.event_id] | |||||
state_group_map = yield defer.ensureDeferred( | |||||
self.storage.state.get_state_groups(self.room, [e2.event_id]) | |||||
) | ) | ||||
self.assertEqual(len(state_group_map), 1) | self.assertEqual(len(state_group_map), 1) | ||||
state_list = list(state_group_map.values())[0] | state_list = list(state_group_map.values())[0] | ||||
@@ -148,7 +150,9 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||||
) | ) | ||||
# check we get the full state as of the final event | # check we get the full state as of the final event | ||||
state = yield self.storage.state.get_state_for_event(e5.event_id) | |||||
state = yield defer.ensureDeferred( | |||||
self.storage.state.get_state_for_event(e5.event_id) | |||||
) | |||||
self.assertIsNotNone(e4) | self.assertIsNotNone(e4) | ||||
@@ -164,22 +168,28 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||||
) | ) | ||||
# check we can filter to the m.room.name event (with a '' state key) | # check we can filter to the m.room.name event (with a '' state key) | ||||
state = yield self.storage.state.get_state_for_event( | |||||
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) | |||||
state = yield defer.ensureDeferred( | |||||
self.storage.state.get_state_for_event( | |||||
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) | |||||
) | |||||
) | ) | ||||
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) | self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) | ||||
# check we can filter to the m.room.name event (with a wildcard None state key) | # check we can filter to the m.room.name event (with a wildcard None state key) | ||||
state = yield self.storage.state.get_state_for_event( | |||||
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) | |||||
state = yield defer.ensureDeferred( | |||||
self.storage.state.get_state_for_event( | |||||
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) | |||||
) | |||||
) | ) | ||||
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) | self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) | ||||
# check we can grab the m.room.member events (with a wildcard None state key) | # check we can grab the m.room.member events (with a wildcard None state key) | ||||
state = yield self.storage.state.get_state_for_event( | |||||
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) | |||||
state = yield defer.ensureDeferred( | |||||
self.storage.state.get_state_for_event( | |||||
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) | |||||
) | |||||
) | ) | ||||
self.assertStateMapEqual( | self.assertStateMapEqual( | ||||
@@ -188,12 +198,14 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||||
# check we can grab a specific room member without filtering out the | # check we can grab a specific room member without filtering out the | ||||
# other event types | # other event types | ||||
state = yield self.storage.state.get_state_for_event( | |||||
e5.event_id, | |||||
state_filter=StateFilter( | |||||
types={EventTypes.Member: {self.u_alice.to_string()}}, | |||||
include_others=True, | |||||
), | |||||
state = yield defer.ensureDeferred( | |||||
self.storage.state.get_state_for_event( | |||||
e5.event_id, | |||||
state_filter=StateFilter( | |||||
types={EventTypes.Member: {self.u_alice.to_string()}}, | |||||
include_others=True, | |||||
), | |||||
) | |||||
) | ) | ||||
self.assertStateMapEqual( | self.assertStateMapEqual( | ||||
@@ -206,11 +218,13 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||||
) | ) | ||||
# check that we can grab everything except members | # check that we can grab everything except members | ||||
state = yield self.storage.state.get_state_for_event( | |||||
e5.event_id, | |||||
state_filter=StateFilter( | |||||
types={EventTypes.Member: set()}, include_others=True | |||||
), | |||||
state = yield defer.ensureDeferred( | |||||
self.storage.state.get_state_for_event( | |||||
e5.event_id, | |||||
state_filter=StateFilter( | |||||
types={EventTypes.Member: set()}, include_others=True | |||||
), | |||||
) | |||||
) | ) | ||||
self.assertStateMapEqual( | self.assertStateMapEqual( | ||||
@@ -222,8 +236,8 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||||
####################################################### | ####################################################### | ||||
room_id = self.room.to_string() | room_id = self.room.to_string() | ||||
group_ids = yield self.storage.state.get_state_groups_ids( | |||||
room_id, [e5.event_id] | |||||
group_ids = yield defer.ensureDeferred( | |||||
self.storage.state.get_state_groups_ids(room_id, [e5.event_id]) | |||||
) | ) | ||||
group = list(group_ids.keys())[0] | group = list(group_ids.keys())[0] | ||||
@@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): | |||||
self.store = self.hs.get_datastore() | self.store = self.hs.get_datastore() | ||||
self.storage = self.hs.get_storage() | self.storage = self.hs.get_storage() | ||||
yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") | |||||
yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) | |||||
@defer.inlineCallbacks | @defer.inlineCallbacks | ||||
def test_filtering(self): | def test_filtering(self): | ||||
@@ -140,7 +140,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): | |||||
event, context = yield defer.ensureDeferred( | event, context = yield defer.ensureDeferred( | ||||
self.event_creation_handler.create_new_client_event(builder) | self.event_creation_handler.create_new_client_event(builder) | ||||
) | ) | ||||
yield self.storage.persistence.persist_event(event, context) | |||||
yield defer.ensureDeferred( | |||||
self.storage.persistence.persist_event(event, context) | |||||
) | |||||
return event | return event | ||||
@defer.inlineCallbacks | @defer.inlineCallbacks | ||||
@@ -162,7 +164,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): | |||||
self.event_creation_handler.create_new_client_event(builder) | self.event_creation_handler.create_new_client_event(builder) | ||||
) | ) | ||||
yield self.storage.persistence.persist_event(event, context) | |||||
yield defer.ensureDeferred( | |||||
self.storage.persistence.persist_event(event, context) | |||||
) | |||||
return event | return event | ||||
@defer.inlineCallbacks | @defer.inlineCallbacks | ||||
@@ -183,7 +187,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): | |||||
self.event_creation_handler.create_new_client_event(builder) | self.event_creation_handler.create_new_client_event(builder) | ||||
) | ) | ||||
yield self.storage.persistence.persist_event(event, context) | |||||
yield defer.ensureDeferred( | |||||
self.storage.persistence.persist_event(event, context) | |||||
) | |||||
return event | return event | ||||
@defer.inlineCallbacks | @defer.inlineCallbacks | ||||
@@ -638,14 +638,8 @@ class DeferredMockCallable(object): | |||||
) | ) | ||||
@defer.inlineCallbacks | |||||
def create_room(hs, room_id, creator_id): | |||||
async def create_room(hs, room_id: str, creator_id: str): | |||||
"""Creates and persist a creation event for the given room | """Creates and persist a creation event for the given room | ||||
Args: | |||||
hs | |||||
room_id (str) | |||||
creator_id (str) | |||||
""" | """ | ||||
persistence_store = hs.get_storage().persistence | persistence_store = hs.get_storage().persistence | ||||
@@ -653,7 +647,7 @@ def create_room(hs, room_id, creator_id): | |||||
event_builder_factory = hs.get_event_builder_factory() | event_builder_factory = hs.get_event_builder_factory() | ||||
event_creation_handler = hs.get_event_creation_handler() | event_creation_handler = hs.get_event_creation_handler() | ||||
yield store.store_room( | |||||
await store.store_room( | |||||
room_id=room_id, | room_id=room_id, | ||||
room_creator_user_id=creator_id, | room_creator_user_id=creator_id, | ||||
is_public=False, | is_public=False, | ||||
@@ -671,8 +665,6 @@ def create_room(hs, room_id, creator_id): | |||||
}, | }, | ||||
) | ) | ||||
event, context = yield defer.ensureDeferred( | |||||
event_creation_handler.create_new_client_event(builder) | |||||
) | |||||
event, context = await event_creation_handler.create_new_client_event(builder) | |||||
yield persistence_store.persist_event(event, context) | |||||
await persistence_store.persist_event(event, context) |
@@ -206,6 +206,7 @@ commands = mypy \ | |||||
synapse/storage/data_stores/main/ui_auth.py \ | synapse/storage/data_stores/main/ui_auth.py \ | ||||
synapse/storage/database.py \ | synapse/storage/database.py \ | ||||
synapse/storage/engines \ | synapse/storage/engines \ | ||||
synapse/storage/state.py \ | |||||
synapse/storage/util \ | synapse/storage/util \ | ||||
synapse/streams \ | synapse/streams \ | ||||
synapse/util/caches/stream_change_cache.py \ | synapse/util/caches/stream_change_cache.py \ | ||||