Преглед изворни кода

Convert storage layer to async/await. (#7963)

tags/v1.19.0rc1
Patrick Cloke пре 3 година
committed by GitHub
родитељ
комит
3345c166a4
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
10 измењених фајлова са 210 додато и 185 уклоњено
  1. +1
    -0
      changelog.d/7963.misc
  2. +18
    -22
      synapse/storage/persist_events.py
  3. +18
    -20
      synapse/storage/purge_events.py
  4. +109
    -98
      synapse/storage/state.py
  5. +6
    -2
      tests/storage/test_purge.py
  6. +4
    -2
      tests/storage/test_room.py
  7. +39
    -25
      tests/storage/test_state.py
  8. +10
    -4
      tests/test_visibility.py
  9. +4
    -12
      tests/utils.py
  10. +1
    -0
      tox.ini

+ 1
- 0
changelog.d/7963.misc Прегледај датотеку

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

+ 18
- 22
synapse/storage/persist_events.py Прегледај датотеку

@@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram
from twisted.internet import defer from twisted.internet import defer


from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events import FrozenEvent
from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@@ -192,12 +192,11 @@ class EventsPersistenceStorage(object):
self._event_persist_queue = _EventPeristenceQueue() self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()


@defer.inlineCallbacks
def persist_events(
async def persist_events(
self, self,
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False, backfilled: bool = False,
):
) -> int:
""" """
Write events to the database Write events to the database
Args: Args:
@@ -207,7 +206,7 @@ class EventsPersistenceStorage(object):
which might update the current state etc. which might update the current state etc.


Returns: Returns:
Deferred[int]: the stream ordering of the latest persisted event
the stream ordering of the latest persisted event
""" """
partitioned = {} partitioned = {}
for event, ctx in events_and_contexts: for event, ctx in events_and_contexts:
@@ -223,22 +222,19 @@ class EventsPersistenceStorage(object):
for room_id in partitioned: for room_id in partitioned:
self._maybe_start_persisting(room_id) self._maybe_start_persisting(room_id)


yield make_deferred_yieldable(
await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True) defer.gatherResults(deferreds, consumeErrors=True)
) )


max_persisted_id = yield self.main_store.get_current_events_token()

return max_persisted_id
return self.main_store.get_current_events_token()


@defer.inlineCallbacks
def persist_event(
self, event: FrozenEvent, context: EventContext, backfilled: bool = False
):
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> Tuple[int, int]:
""" """
Returns: Returns:
Deferred[Tuple[int, int]]: the stream ordering of ``event``,
and the stream ordering of the latest persisted event
The stream ordering of `event`, and the stream ordering of the
latest persisted event
""" """
deferred = self._event_persist_queue.add_to_queue( deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], backfilled=backfilled event.room_id, [(event, context)], backfilled=backfilled
@@ -246,9 +242,9 @@ class EventsPersistenceStorage(object):


self._maybe_start_persisting(event.room_id) self._maybe_start_persisting(event.room_id)


yield make_deferred_yieldable(deferred)
await make_deferred_yieldable(deferred)


max_persisted_id = yield self.main_store.get_current_events_token()
max_persisted_id = self.main_store.get_current_events_token()
return (event.internal_metadata.stream_ordering, max_persisted_id) return (event.internal_metadata.stream_ordering, max_persisted_id)


def _maybe_start_persisting(self, room_id: str): def _maybe_start_persisting(self, room_id: str):
@@ -262,7 +258,7 @@ class EventsPersistenceStorage(object):


async def _persist_events( async def _persist_events(
self, self,
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False, backfilled: bool = False,
): ):
"""Calculates the change to current state and forward extremities, and """Calculates the change to current state and forward extremities, and
@@ -439,7 +435,7 @@ class EventsPersistenceStorage(object):
async def _calculate_new_extremities( async def _calculate_new_extremities(
self, self,
room_id: str, room_id: str,
event_contexts: List[Tuple[FrozenEvent, EventContext]],
event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: List[str], latest_event_ids: List[str],
): ):
"""Calculates the new forward extremities for a room given events to """Calculates the new forward extremities for a room given events to
@@ -497,7 +493,7 @@ class EventsPersistenceStorage(object):
async def _get_new_state_after_events( async def _get_new_state_after_events(
self, self,
room_id: str, room_id: str,
events_context: List[Tuple[FrozenEvent, EventContext]],
events_context: List[Tuple[EventBase, EventContext]],
old_latest_event_ids: Iterable[str], old_latest_event_ids: Iterable[str],
new_latest_event_ids: Iterable[str], new_latest_event_ids: Iterable[str],
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
@@ -683,7 +679,7 @@ class EventsPersistenceStorage(object):
async def _is_server_still_joined( async def _is_server_still_joined(
self, self,
room_id: str, room_id: str,
ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
ev_ctx_rm: List[Tuple[EventBase, EventContext]],
delta: DeltaState, delta: DeltaState,
current_state: Optional[StateMap[str]], current_state: Optional[StateMap[str]],
potentially_left_users: Set[str], potentially_left_users: Set[str],


+ 18
- 20
synapse/storage/purge_events.py Прегледај датотеку

@@ -15,8 +15,7 @@


import itertools import itertools
import logging import logging

from twisted.internet import defer
from typing import Set


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)


@@ -28,49 +27,48 @@ class PurgeEventsStorage(object):
def __init__(self, hs, stores): def __init__(self, hs, stores):
self.stores = stores self.stores = stores


@defer.inlineCallbacks
def purge_room(self, room_id: str):
async def purge_room(self, room_id: str):
"""Deletes all record of a room """Deletes all record of a room
""" """


state_groups_to_delete = yield self.stores.main.purge_room(room_id)
yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
state_groups_to_delete = await self.stores.main.purge_room(room_id)
await self.stores.state.purge_room_state(room_id, state_groups_to_delete)


@defer.inlineCallbacks
def purge_history(self, room_id, token, delete_local_events):
async def purge_history(
self, room_id: str, token: str, delete_local_events: bool
) -> None:
"""Deletes room history before a certain point """Deletes room history before a certain point


Args: Args:
room_id (str):
room_id: The room ID


token (str): A topological token to delete events before
token: A topological token to delete events before


delete_local_events (bool):
delete_local_events:
if True, we will delete local events as well as remote ones if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their (instead of just marking them as outliers and deleting their
state groups). state groups).
""" """
state_groups = yield self.stores.main.purge_history(
state_groups = await self.stores.main.purge_history(
room_id, token, delete_local_events room_id, token, delete_local_events
) )


logger.info("[purge] finding state groups that can be deleted") logger.info("[purge] finding state groups that can be deleted")


sg_to_delete = yield self._find_unreferenced_groups(state_groups)
sg_to_delete = await self._find_unreferenced_groups(state_groups)


yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)


@defer.inlineCallbacks
def _find_unreferenced_groups(self, state_groups):
async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
"""Used when purging history to figure out which state groups can be """Used when purging history to figure out which state groups can be
deleted. deleted.


Args: Args:
state_groups (set[int]): Set of state groups referenced by events
state_groups: Set of state groups referenced by events
that are going to be deleted. that are going to be deleted.


Returns: Returns:
Deferred[set[int]] The set of state groups that can be deleted.
The set of state groups that can be deleted.
""" """
# Graph of state group -> previous group # Graph of state group -> previous group
graph = {} graph = {}
@@ -93,7 +91,7 @@ class PurgeEventsStorage(object):
current_search = set(itertools.islice(next_to_search, 100)) current_search = set(itertools.islice(next_to_search, 100))
next_to_search -= current_search next_to_search -= current_search


referenced = yield self.stores.main.get_referenced_state_groups(
referenced = await self.stores.main.get_referenced_state_groups(
current_search current_search
) )
referenced_groups |= referenced referenced_groups |= referenced
@@ -102,7 +100,7 @@ class PurgeEventsStorage(object):
# groups that are referenced. # groups that are referenced.
current_search -= referenced current_search -= referenced


edges = yield self.stores.state.get_previous_state_groups(current_search)
edges = await self.stores.state.get_previous_state_groups(current_search)


prevs = set(edges.values()) prevs = set(edges.values())
# We don't bother re-handling groups we've already seen # We don't bother re-handling groups we've already seen


+ 109
- 98
synapse/storage/state.py Прегледај датотеку

@@ -14,13 +14,12 @@
# limitations under the License. # limitations under the License.


import logging import logging
from typing import Iterable, List, TypeVar
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar


import attr import attr


from twisted.internet import defer

from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import StateMap from synapse.types import StateMap


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -34,16 +33,16 @@ class StateFilter(object):
"""A filter used when querying for state. """A filter used when querying for state.


Attributes: Attributes:
types (dict[str, set[str]|None]): Map from type to set of state keys (or
None). This specifies which state_keys for the given type to fetch
from the DB. If None then all events with that type are fetched. If
the set is empty then no events with that type are fetched.
include_others (bool): Whether to fetch events with types that do not
types: Map from type to set of state keys (or None). This specifies
which state_keys for the given type to fetch from the DB. If None
then all events with that type are fetched. If the set is empty
then no events with that type are fetched.
include_others: Whether to fetch events with types that do not
appear in `types`. appear in `types`.
""" """


types = attr.ib()
include_others = attr.ib(default=False)
types = attr.ib(type=Dict[str, Optional[Set[str]]])
include_others = attr.ib(default=False, type=bool)


def __attrs_post_init__(self): def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing # If `include_others` is set we canonicalise the filter by removing
@@ -52,36 +51,35 @@ class StateFilter(object):
self.types = {k: v for k, v in self.types.items() if v is not None} self.types = {k: v for k, v in self.types.items() if v is not None}


@staticmethod @staticmethod
def all():
def all() -> "StateFilter":
"""Creates a filter that fetches everything. """Creates a filter that fetches everything.


Returns: Returns:
StateFilter
The new state filter.
""" """
return StateFilter(types={}, include_others=True) return StateFilter(types={}, include_others=True)


@staticmethod @staticmethod
def none():
def none() -> "StateFilter":
"""Creates a filter that fetches nothing. """Creates a filter that fetches nothing.


Returns: Returns:
StateFilter
The new state filter.
""" """
return StateFilter(types={}, include_others=False) return StateFilter(types={}, include_others=False)


@staticmethod @staticmethod
def from_types(types):
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
"""Creates a filter that only fetches the given types """Creates a filter that only fetches the given types


Args: Args:
types (Iterable[tuple[str, str|None]]): A list of type and state
keys to fetch. A state_key of None fetches everything for
that type
types: A list of type and state keys to fetch. A state_key of None
fetches everything for that type


Returns: Returns:
StateFilter
The new state filter.
""" """
type_dict = {}
type_dict = {} # type: Dict[str, Optional[Set[str]]]
for typ, s in types: for typ, s in types:
if typ in type_dict: if typ in type_dict:
if type_dict[typ] is None: if type_dict[typ] is None:
@@ -91,24 +89,24 @@ class StateFilter(object):
type_dict[typ] = None type_dict[typ] = None
continue continue


type_dict.setdefault(typ, set()).add(s)
type_dict.setdefault(typ, set()).add(s) # type: ignore


return StateFilter(types=type_dict) return StateFilter(types=type_dict)


@staticmethod @staticmethod
def from_lazy_load_member_list(members):
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
"""Creates a filter that returns all non-member events, plus the member """Creates a filter that returns all non-member events, plus the member
events for the given users events for the given users


Args: Args:
members (iterable[str]): Set of user IDs
members: Set of user IDs


Returns: Returns:
StateFilter
The new state filter
""" """
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)


def return_expanded(self):
def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed """Creates a new StateFilter where type wild cards have been removed
(except for memberships). The returned filter is a superset of the (except for memberships). The returned filter is a superset of the
current one, i.e. anything that passes the current filter will pass current one, i.e. anything that passes the current filter will pass
@@ -130,7 +128,7 @@ class StateFilter(object):
return all non-member events return all non-member events


Returns: Returns:
StateFilter
The new state filter.
""" """


if self.is_full(): if self.is_full():
@@ -167,7 +165,7 @@ class StateFilter(object):
include_others=True, include_others=True,
) )


def make_sql_filter_clause(self):
def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
"""Converts the filter to an SQL clause. """Converts the filter to an SQL clause.


For example: For example:
@@ -179,13 +177,12 @@ class StateFilter(object):




Returns: Returns:
tuple[str, list]: The SQL string (may be empty) and arguments. An
empty SQL string is returned when the filter matches everything
(i.e. is "full").
The SQL string (may be empty) and arguments. An empty SQL string is
returned when the filter matches everything (i.e. is "full").
""" """


where_clause = "" where_clause = ""
where_args = []
where_args = [] # type: List[str]


if self.is_full(): if self.is_full():
return where_clause, where_args return where_clause, where_args
@@ -221,7 +218,7 @@ class StateFilter(object):


return where_clause, where_args return where_clause, where_args


def max_entries_returned(self):
def max_entries_returned(self) -> Optional[int]:
"""Returns the maximum number of entries this filter will return if """Returns the maximum number of entries this filter will return if
known, otherwise returns None. known, otherwise returns None.


@@ -260,33 +257,33 @@ class StateFilter(object):


return filtered_state return filtered_state


def is_full(self):
def is_full(self) -> bool:
"""Whether this filter fetches everything or not """Whether this filter fetches everything or not


Returns: Returns:
bool
True if the filter fetches everything.
""" """
return self.include_others and not self.types return self.include_others and not self.types


def has_wildcards(self):
def has_wildcards(self) -> bool:
"""Whether the filter includes wildcards or is attempting to fetch """Whether the filter includes wildcards or is attempting to fetch
specific state. specific state.


Returns: Returns:
bool
True if the filter includes wildcards.
""" """


return self.include_others or any( return self.include_others or any(
state_keys is None for state_keys in self.types.values() state_keys is None for state_keys in self.types.values()
) )


def concrete_types(self):
def concrete_types(self) -> List[Tuple[str, str]]:
"""Returns a list of concrete type/state_keys (i.e. not None) that """Returns a list of concrete type/state_keys (i.e. not None) that
will be fetched. This will be a complete list if `has_wildcards` will be fetched. This will be a complete list if `has_wildcards`
returns False, but otherwise will be a subset (or even empty). returns False, but otherwise will be a subset (or even empty).


Returns: Returns:
list[tuple[str,str]]
A list of type/state_keys tuples.
""" """
return [ return [
(t, s) (t, s)
@@ -295,7 +292,7 @@ class StateFilter(object):
for s in state_keys for s in state_keys
] ]


def get_member_split(self):
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
"""Return the filter split into two: one which assumes it's exclusively """Return the filter split into two: one which assumes it's exclusively
matching against member state, and one which assumes it's matching matching against member state, and one which assumes it's matching
against non member state. against non member state.
@@ -307,7 +304,7 @@ class StateFilter(object):
state caches). state caches).


Returns: Returns:
tuple[StateFilter, StateFilter]: The member and non member filters
The member and non member filters
""" """


if EventTypes.Member in self.types: if EventTypes.Member in self.types:
@@ -340,6 +337,9 @@ class StateGroupStorage(object):
"""Given a state group try to return a previous group and a delta between """Given a state group try to return a previous group and a delta between
the old and the new. the old and the new.


Args:
state_group: The state group used to retrieve state deltas.

Returns: Returns:
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
(prev_group, delta_ids) (prev_group, delta_ids)
@@ -347,55 +347,59 @@ class StateGroupStorage(object):


return self.stores.state.get_state_group_delta(state_group) return self.stores.state.get_state_group_delta(state_group)


@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str]
) -> Dict[int, StateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events """Get the event IDs of all the state for the state groups for the given events


Args: Args:
_room_id (str): id of the room for these events
event_ids (iterable[str]): ids of the events
_room_id: id of the room for these events
event_ids: ids of the events


Returns: Returns:
Deferred[dict[int, StateMap[str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
dict of state_group_id -> (dict of (type, state_key) -> event id)
""" """
if not event_ids: if not event_ids:
return {} return {}


event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)


groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(groups)
group_to_state = await self.stores.state._get_state_for_groups(groups)


return group_to_state return group_to_state


@defer.inlineCallbacks
def get_state_ids_for_group(self, state_group):
async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
"""Get the event IDs of all the state in the given state group """Get the event IDs of all the state in the given state group


Args: Args:
state_group (int)
state_group: A state group for which we want to get the state IDs.


Returns: Returns:
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
Resolves to a map of (type, state_key) -> event_id
""" """
group_to_state = yield self._get_state_for_groups((state_group,))
group_to_state = await self._get_state_for_groups((state_group,))


return group_to_state[state_group] return group_to_state[state_group]


@defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
async def get_state_groups(
self, room_id: str, event_ids: Iterable[str]
) -> Dict[int, List[EventBase]]:
""" Get the state groups for the given list of event_ids """ Get the state groups for the given list of event_ids

Args:
room_id: ID of the room for these events.
event_ids: The event IDs to retrieve state for.

Returns: Returns:
Deferred[dict[int, list[EventBase]]]:
dict of state_group_id -> list of state events.
dict of state_group_id -> list of state events.
""" """
if not event_ids: if not event_ids:
return {} return {}


group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)


state_event_map = yield self.stores.main.get_events(
state_event_map = await self.stores.main.get_events(
[ [
ev_id ev_id
for group_ids in group_to_ids.values() for group_ids in group_to_ids.values()
@@ -423,31 +427,34 @@ class StateGroupStorage(object):
groups: list of state group IDs to query groups: list of state group IDs to query
state_filter: The state filter used to fetch state state_filter: The state filter used to fetch state
from the database. from the database.

Returns: Returns:
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
""" """


return self.stores.state._get_state_groups_from_groups(groups, state_filter) return self.stores.state._get_state_groups_from_groups(groups, state_filter)


@defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
async def get_state_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
"""Given a list of event_ids and type tuples, return a list of state """Given a list of event_ids and type tuples, return a list of state
dicts for each event. dicts for each event.

Args: Args:
event_ids (list[string])
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_ids: The events to fetch the state of.
state_filter: The state filter used to fetch state.
Returns: Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
A dict of (event_id) -> (type, state_key) -> [state_events]
""" """
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)


groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter groups, state_filter
) )


state_event_map = yield self.stores.main.get_events(
state_event_map = await self.stores.main.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()], [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False, get_prev_content=False,
) )
@@ -463,24 +470,24 @@ class StateGroupStorage(object):


return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}


@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
async def get_state_ids_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
""" """
Get the state dicts corresponding to a list of events, containing the event_ids Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves) of the state events (as opposed to the events themselves)


Args: Args:
event_ids(list(str)): events whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database.


Returns: Returns:
A deferred dict from event_id -> (type, state_key) -> event_id
A dict from event_id -> (type, state_key) -> event_id
""" """
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)


groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter groups, state_filter
) )


@@ -491,36 +498,36 @@ class StateGroupStorage(object):


return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}


@defer.inlineCallbacks
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
async def get_state_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event


Args: Args:
event_id(str): event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.


Returns: Returns:
A deferred dict from (type, state_key) -> state_event
A dict from (type, state_key) -> state_event
""" """
state_map = yield self.get_state_for_events([event_id], state_filter)
state_map = await self.get_state_for_events([event_id], state_filter)
return state_map[event_id] return state_map[event_id]


@defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
async def get_state_ids_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event


Args: Args:
event_id(str): event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.


Returns: Returns:
A deferred dict from (type, state_key) -> state_event A deferred dict from (type, state_key) -> state_event
""" """
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
state_map = await self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id] return state_map[event_id]


def _get_state_for_groups( def _get_state_for_groups(
@@ -530,9 +537,8 @@ class StateGroupStorage(object):
filtering by type/state_key filtering by type/state_key


Args: Args:
groups (iterable[int]): list of state groups for which we want
to get the state.
state_filter (StateFilter): The state filter used to fetch state
groups: list of state groups for which we want to get the state.
state_filter: The state filter used to fetch state.
from the database. from the database.
Returns: Returns:
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
@@ -540,18 +546,23 @@ class StateGroupStorage(object):
return self.stores.state._get_state_for_groups(groups, state_filter) return self.stores.state._get_state_for_groups(groups, state_filter)


def store_state_group( def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
self,
event_id: str,
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[dict],
current_state_ids: dict,
): ):
"""Store a new set of state, returning a newly assigned state group. """Store a new set of state, returning a newly assigned state group.


Args: Args:
event_id (str): The event ID for which the state was calculated
room_id (str)
prev_group (int|None): A previous state group for the room, optional.
delta_ids (dict|None): The delta between state at `prev_group` and
event_id: The event ID for which the state was calculated.
room_id: ID of the room for which the state was calculated.
prev_group: A previous state group for the room, optional.
delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as `current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`. `current_state_ids`.
current_state_ids (dict): The state to store. Map of (type, state_key)
current_state_ids: The state to store. Map of (type, state_key)
to event_id. to event_id.


Returns: Returns:


+ 6
- 2
tests/storage/test_purge.py Прегледај датотеку

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.


from twisted.internet import defer

from synapse.rest.client.v1 import room from synapse.rest.client.v1 import room


from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
event = self.successResultOf(event) event = self.successResultOf(event)


# Purge everything before this topological token # Purge everything before this topological token
purge = storage.purge_events.purge_history(self.room_id, event, True)
purge = defer.ensureDeferred(
storage.purge_events.purge_history(self.room_id, event, True)
)
self.pump() self.pump()
self.assertEqual(self.successResultOf(purge), None) self.assertEqual(self.successResultOf(purge), None)


@@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase):
) )


# Purge everything before this topological token # Purge everything before this topological token
purge = storage.purge_history(self.room_id, event, True)
purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
self.pump() self.pump()
f = self.failureResultOf(purge) f = self.failureResultOf(purge)
self.assertIn("greater than forward", f.value.args[0]) self.assertIn("greater than forward", f.value.args[0])


+ 4
- 2
tests/storage/test_room.py Прегледај датотеку

@@ -97,8 +97,10 @@ class RoomEventsStoreTestCase(unittest.TestCase):


@defer.inlineCallbacks @defer.inlineCallbacks
def inject_room_event(self, **kwargs): def inject_room_event(self, **kwargs):
yield self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
) )


@defer.inlineCallbacks @defer.inlineCallbacks


+ 39
- 25
tests/storage/test_state.py Прегледај датотеку

@@ -68,7 +68,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )


yield self.storage.persistence.persist_event(event, context)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)


return event return event


@@ -87,8 +89,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )


state_group_map = yield self.storage.state.get_state_groups_ids(
self.room, [e2.event_id]
state_group_map = yield defer.ensureDeferred(
self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
) )
self.assertEqual(len(state_group_map), 1) self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0] state_map = list(state_group_map.values())[0]
@@ -106,8 +108,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )


state_group_map = yield self.storage.state.get_state_groups(
self.room, [e2.event_id]
state_group_map = yield defer.ensureDeferred(
self.storage.state.get_state_groups(self.room, [e2.event_id])
) )
self.assertEqual(len(state_group_map), 1) self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0] state_list = list(state_group_map.values())[0]
@@ -148,7 +150,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )


# check we get the full state as of the final event # check we get the full state as of the final event
state = yield self.storage.state.get_state_for_event(e5.event_id)
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(e5.event_id)
)


self.assertIsNotNone(e4) self.assertIsNotNone(e4)


@@ -164,22 +168,28 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )


# check we can filter to the m.room.name event (with a '' state key) # check we can filter to the m.room.name event (with a '' state key)
state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
)
) )


self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)


# check we can filter to the m.room.name event (with a wildcard None state key) # check we can filter to the m.room.name event (with a wildcard None state key)
state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
) )


self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)


# check we can grab the m.room.member events (with a wildcard None state key) # check we can grab the m.room.member events (with a wildcard None state key)
state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
) )


self.assertStateMapEqual( self.assertStateMapEqual(
@@ -188,12 +198,14 @@ class StateStoreTestCase(tests.unittest.TestCase):


# check we can grab a specific room member without filtering out the # check we can grab a specific room member without filtering out the
# other event types # other event types
state = yield self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}},
include_others=True,
),
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}},
include_others=True,
),
)
) )


self.assertStateMapEqual( self.assertStateMapEqual(
@@ -206,11 +218,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )


# check that we can grab everything except members # check that we can grab everything except members
state = yield self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
),
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
),
)
) )


self.assertStateMapEqual( self.assertStateMapEqual(
@@ -222,8 +236,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
####################################################### #######################################################


room_id = self.room.to_string() room_id = self.room.to_string()
group_ids = yield self.storage.state.get_state_groups_ids(
room_id, [e5.event_id]
group_ids = yield defer.ensureDeferred(
self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
) )
group = list(group_ids.keys())[0] group = list(group_ids.keys())[0]




+ 10
- 4
tests/test_visibility.py Прегледај датотеку

@@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.storage = self.hs.get_storage() self.storage = self.hs.get_storage()


yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))


@defer.inlineCallbacks @defer.inlineCallbacks
def test_filtering(self): def test_filtering(self):
@@ -140,7 +140,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
event, context = yield defer.ensureDeferred( event, context = yield defer.ensureDeferred(
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
yield self.storage.persistence.persist_event(event, context)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
return event return event


@defer.inlineCallbacks @defer.inlineCallbacks
@@ -162,7 +164,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )


yield self.storage.persistence.persist_event(event, context)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
return event return event


@defer.inlineCallbacks @defer.inlineCallbacks
@@ -183,7 +187,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )


yield self.storage.persistence.persist_event(event, context)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
return event return event


@defer.inlineCallbacks @defer.inlineCallbacks


+ 4
- 12
tests/utils.py Прегледај датотеку

@@ -638,14 +638,8 @@ class DeferredMockCallable(object):
) )




@defer.inlineCallbacks
def create_room(hs, room_id, creator_id):
async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room """Creates and persist a creation event for the given room

Args:
hs
room_id (str)
creator_id (str)
""" """


persistence_store = hs.get_storage().persistence persistence_store = hs.get_storage().persistence
@@ -653,7 +647,7 @@ def create_room(hs, room_id, creator_id):
event_builder_factory = hs.get_event_builder_factory() event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler() event_creation_handler = hs.get_event_creation_handler()


yield store.store_room(
await store.store_room(
room_id=room_id, room_id=room_id,
room_creator_user_id=creator_id, room_creator_user_id=creator_id,
is_public=False, is_public=False,
@@ -671,8 +665,6 @@ def create_room(hs, room_id, creator_id):
}, },
) )


event, context = yield defer.ensureDeferred(
event_creation_handler.create_new_client_event(builder)
)
event, context = await event_creation_handler.create_new_client_event(builder)


yield persistence_store.persist_event(event, context)
await persistence_store.persist_event(event, context)

+ 1
- 0
tox.ini Прегледај датотеку

@@ -206,6 +206,7 @@ commands = mypy \
synapse/storage/data_stores/main/ui_auth.py \ synapse/storage/data_stores/main/ui_auth.py \
synapse/storage/database.py \ synapse/storage/database.py \
synapse/storage/engines \ synapse/storage/engines \
synapse/storage/state.py \
synapse/storage/util \ synapse/storage/util \
synapse/streams \ synapse/streams \
synapse/util/caches/stream_change_cache.py \ synapse/util/caches/stream_change_cache.py \


Loading…
Откажи
Сачувај